save does implicit upsert if not new
This commit is contained in:
parent
8c2ca53c7b
commit
aa2d5d1d02
63
mongo.py
63
mongo.py
@ -22,6 +22,9 @@ class InvalidCollection(ValueError):
|
||||
"""Raised when an invalid collection name is used.
|
||||
"""
|
||||
|
||||
_ZERO = "\x00\x00\x00\x00"
|
||||
_ONE = "\x01\x00\x00\x00"
|
||||
|
||||
class Mongo(object):
|
||||
"""A connection to a Mongo database.
|
||||
"""
|
||||
@ -62,7 +65,8 @@ class Mongo(object):
|
||||
self.__connection.connect((self.__host, self.__port))
|
||||
except socket.error:
|
||||
raise ConnectionException("could not connect to %s:%s, got: %s" %
|
||||
(self.__host, self.__port, traceback.format_exc()))
|
||||
(self.__host, self.__port,
|
||||
traceback.format_exc()))
|
||||
|
||||
def _send_message(self, operation, data):
|
||||
"""Say something to the database.
|
||||
@ -185,7 +189,8 @@ class Collection(object):
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, Collection):
|
||||
return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name))
|
||||
return cmp((self.__database, self.__collection_name),
|
||||
(other.__database, other.__collection_name))
|
||||
return NotImplemented
|
||||
|
||||
def _full_name(self):
|
||||
@ -195,7 +200,9 @@ class Collection(object):
|
||||
"""Wrap up a message and send it.
|
||||
"""
|
||||
# reserved int, full collection name, message data
|
||||
message = "\x00\x00\x00\x00%s%s" % (bson._make_c_string(self._full_name()), data)
|
||||
message = _ZERO
|
||||
message += bson._make_c_string(self._full_name())
|
||||
message += data
|
||||
return self.__database._send_message(operation, message)
|
||||
|
||||
def database(self):
|
||||
@ -212,17 +219,47 @@ class Collection(object):
|
||||
if not isinstance(to_save, (types.DictType, SON)):
|
||||
raise TypeError("cannot save object of type %s" % type(to_save))
|
||||
|
||||
if hasattr(to_save, "_id"):
|
||||
if "_id" in to_save:
|
||||
assert isinstance(to_save["_id"], ObjectId), "'_id' must be an ObjectId"
|
||||
else:
|
||||
to_save["_id"] = ObjectId()
|
||||
|
||||
# TODO possibly add _ns?
|
||||
|
||||
self._send_message(2002, bson.BSON.from_dict(to_save))
|
||||
if to_save["_id"].is_new():
|
||||
to_save["_id"]._use()
|
||||
self._send_message(2002, bson.BSON.from_dict(to_save))
|
||||
else:
|
||||
self._update({"_id": to_save["_id"]}, to_save, True)
|
||||
|
||||
return to_save["_id"]
|
||||
|
||||
def _update(self, spec, document, upsert=False):
|
||||
"""Update an object(s) in this collection.
|
||||
|
||||
Raises TypeError if either spec or document isn't an instance of
|
||||
(dict, SON) or upsert isn't an instance of bool.
|
||||
|
||||
- `spec`: a SON object specifying elements which must be present for a
|
||||
document to be updated
|
||||
- `document`: a SON object specifying the fields to be changed in the
|
||||
selected document(s), or (in the case of an upsert) the document to
|
||||
be inserted.
|
||||
- `upsert` (optional): perform an upsert operation
|
||||
"""
|
||||
if not isinstance(spec, (types.DictType, SON)):
|
||||
raise TypeError("spec must be an instance of (dict, SON)")
|
||||
if not isinstance(document, (types.DictType, SON)):
|
||||
raise TypeError("document must be an instance of (dict, SON)")
|
||||
if not isinstance(upsert, types.BooleanType):
|
||||
raise TypeError("upsert must be an instance of bool")
|
||||
|
||||
message = upsert and _ONE or _ZERO
|
||||
message += bson.BSON.from_dict(spec)
|
||||
message += bson.BSON.from_dict(document)
|
||||
|
||||
self._send_message(2001, message)
|
||||
|
||||
def remove(self, spec_or_object_id):
|
||||
"""Remove an object(s) from this collection.
|
||||
|
||||
@ -241,7 +278,7 @@ class Collection(object):
|
||||
if not isinstance(spec, (types.DictType, SON)):
|
||||
raise TypeError("spec must be an instance of (dict, SON)")
|
||||
|
||||
self._send_message(2006, "\x00\x00\x00\x00" + bson.BSON.from_dict(spec))
|
||||
self._send_message(2006, _ZERO + bson.BSON.from_dict(spec))
|
||||
|
||||
def find_one(self, spec_or_object_id=SON()):
|
||||
"""Get a single object from the database.
|
||||
@ -444,6 +481,20 @@ class TestMongo(unittest.TestCase):
|
||||
self.assertEqual(a_doc, db.test.find_one({"hello": u"world"}))
|
||||
self.assertEqual(None, db.test.find_one({"hello": u"test"}))
|
||||
|
||||
b = db.test.find_one()
|
||||
self.assertFalse(b["_id"].is_new())
|
||||
b["hello"] = u"mike"
|
||||
db.test.save(b)
|
||||
|
||||
self.assertNotEqual(a_doc, db.test.find_one(a_key))
|
||||
self.assertEqual(b, db.test.find_one(a_key))
|
||||
self.assertEqual(b, db.test.find_one())
|
||||
|
||||
count = 0
|
||||
for _ in db.test.find():
|
||||
count += 1
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_remove(self):
|
||||
db = Mongo("test", self.host, self.port)
|
||||
db.test.remove({})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user