diff --git a/mongo.py b/mongo.py index 35b9694ae..95b17ab8f 100644 --- a/mongo.py +++ b/mongo.py @@ -22,6 +22,9 @@ class InvalidCollection(ValueError): """Raised when an invalid collection name is used. """ +_ZERO = "\x00\x00\x00\x00" +_ONE = "\x01\x00\x00\x00" + class Mongo(object): """A connection to a Mongo database. """ @@ -62,7 +65,8 @@ class Mongo(object): self.__connection.connect((self.__host, self.__port)) except socket.error: raise ConnectionException("could not connect to %s:%s, got: %s" % - (self.__host, self.__port, traceback.format_exc())) + (self.__host, self.__port, + traceback.format_exc())) def _send_message(self, operation, data): """Say something to the database. @@ -185,7 +189,8 @@ class Collection(object): def __cmp__(self, other): if isinstance(other, Collection): - return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name)) + return cmp((self.__database, self.__collection_name), + (other.__database, other.__collection_name)) return NotImplemented def _full_name(self): @@ -195,7 +200,9 @@ class Collection(object): """Wrap up a message and send it. """ # reserved int, full collection name, message data - message = "\x00\x00\x00\x00%s%s" % (bson._make_c_string(self._full_name()), data) + message = _ZERO + message += bson._make_c_string(self._full_name()) + message += data return self.__database._send_message(operation, message) def database(self): @@ -212,17 +219,47 @@ class Collection(object): if not isinstance(to_save, (types.DictType, SON)): raise TypeError("cannot save object of type %s" % type(to_save)) - if hasattr(to_save, "_id"): + if "_id" in to_save: assert isinstance(to_save["_id"], ObjectId), "'_id' must be an ObjectId" else: to_save["_id"] = ObjectId() # TODO possibly add _ns? - self._send_message(2002, bson.BSON.from_dict(to_save)) + if to_save["_id"].is_new(): + to_save["_id"]._use() + self._send_message(2002, bson.BSON.from_dict(to_save)) + else: + self._update({"_id": to_save["_id"]}, to_save, True) return to_save["_id"] + def _update(self, spec, document, upsert=False): + """Update an object(s) in this collection. + + Raises TypeError if either spec or document isn't an instance of + (dict, SON) or upsert isn't an instance of bool. + + - `spec`: a SON object specifying elements which must be present for a + document to be updated + - `document`: a SON object specifying the fields to be changed in the + selected document(s), or (in the case of an upsert) the document to + be inserted. + - `upsert` (optional): perform an upsert operation + """ + if not isinstance(spec, (types.DictType, SON)): + raise TypeError("spec must be an instance of (dict, SON)") + if not isinstance(document, (types.DictType, SON)): + raise TypeError("document must be an instance of (dict, SON)") + if not isinstance(upsert, types.BooleanType): + raise TypeError("upsert must be an instance of bool") + + message = upsert and _ONE or _ZERO + message += bson.BSON.from_dict(spec) + message += bson.BSON.from_dict(document) + + self._send_message(2001, message) + def remove(self, spec_or_object_id): """Remove an object(s) from this collection. @@ -241,7 +278,7 @@ class Collection(object): if not isinstance(spec, (types.DictType, SON)): raise TypeError("spec must be an instance of (dict, SON)") - self._send_message(2006, "\x00\x00\x00\x00" + bson.BSON.from_dict(spec)) + self._send_message(2006, _ZERO + bson.BSON.from_dict(spec)) def find_one(self, spec_or_object_id=SON()): """Get a single object from the database. @@ -444,6 +481,20 @@ class TestMongo(unittest.TestCase): self.assertEqual(a_doc, db.test.find_one({"hello": u"world"})) self.assertEqual(None, db.test.find_one({"hello": u"test"})) + b = db.test.find_one() + self.assertFalse(b["_id"].is_new()) + b["hello"] = u"mike" + db.test.save(b) + + self.assertNotEqual(a_doc, db.test.find_one(a_key)) + self.assertEqual(b, db.test.find_one(a_key)) + self.assertEqual(b, db.test.find_one()) + + count = 0 + for _ in db.test.find(): + count += 1 + self.assertEqual(count, 1) + def test_remove(self): db = Mongo("test", self.host, self.port) db.test.remove({})