From e7866dbd19ad6ca816d5017e90d5c8a747c892f9 Mon Sep 17 00:00:00 2001 From: Bernie Hackett Date: Wed, 11 Feb 2015 20:01:16 -0800 Subject: [PATCH] PYTHON-821 - Implement replace_one, update_one, and update_many. --- doc/api/pymongo/collection.rst | 3 + doc/api/pymongo/results.rst | 1 + pymongo/bulk.py | 18 +---- pymongo/collection.py | 99 ++++++++++++++++++-------- pymongo/helpers.py | 23 ++++++ pymongo/results.py | 108 ++++++++++++++++++++--------- test/test_collection.py | 123 ++++++++++++++++++++++++++++++++- 7 files changed, 301 insertions(+), 74 deletions(-) diff --git a/doc/api/pymongo/collection.rst b/doc/api/pymongo/collection.rst index 4c58e24ec..f88105992 100644 --- a/doc/api/pymongo/collection.rst +++ b/doc/api/pymongo/collection.rst @@ -33,6 +33,9 @@ .. autoattribute:: write_concern .. automethod:: with_options .. automethod:: insert_one + .. automethod:: replace_one + .. automethod:: update_one + .. automethod:: update_many .. automethod:: insert(doc_or_docs[, manipulate=True[, check_keys=True[, continue_on_error=False[, **kwargs]]]]) .. automethod:: save(to_save[, manipulate=True[, check_keys=True[, **kwargs]]]) .. automethod:: update(spec, document[, upsert=False[, manipulate=False[, multi=False[, check_keys=True[, **kwargs]]]]]) diff --git a/doc/api/pymongo/results.rst b/doc/api/pymongo/results.rst index 70243c285..765e9fb25 100644 --- a/doc/api/pymongo/results.rst +++ b/doc/api/pymongo/results.rst @@ -4,3 +4,4 @@ .. automodule:: pymongo.results :synopsis: Result class definitions :members: + :inherited-members: diff --git a/pymongo/bulk.py b/pymongo/bulk.py index e874f674a..22aefa2f4 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -23,6 +23,7 @@ import collections from bson.objectid import ObjectId from bson.son import SON +from pymongo import helpers from pymongo.errors import (BulkWriteError, DocumentTooLarge, InvalidOperation, @@ -208,14 +209,7 @@ class _Bulk(object): def add_update(self, selector, update, multi=False, upsert=False): """Create an update document and add it to the list of ops. """ - if not isinstance(update, collections.Mapping): - raise TypeError('update must be a mapping type.') - # Update can not be {} - if not update: - raise ValueError('update only works with $ operators') - first = next(iter(update)) - if not first.startswith('$'): - raise ValueError('update only works with $ operators') + helpers._check_ok_for_update(update) cmd = SON([('q', selector), ('u', update), ('multi', multi), ('upsert', upsert)]) self.ops.append((_UPDATE, cmd)) @@ -223,13 +217,7 @@ class _Bulk(object): def add_replace(self, selector, replacement, upsert=False): """Create a replace document and add it to the list of ops. """ - if not isinstance(replacement, collections.Mapping): - raise TypeError('replacement must be a mapping type.') - # Replacement can be {} - if replacement: - first = next(iter(replacement)) - if first.startswith('$'): - raise ValueError('replacement can not include $ operators') + helpers._check_ok_for_replace(replacement) cmd = SON([('q', selector), ('u', replacement), ('multi', False), ('upsert', upsert)]) self.ops.append((_UPDATE, cmd)) diff --git a/pymongo/collection.py b/pymongo/collection.py index 3b3b17c60..acefa12ea 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -37,7 +37,7 @@ from pymongo.helpers import _check_write_command_response, _command from pymongo.message import _INSERT, _UPDATE, _DELETE from pymongo.options import ReturnDocument, _WriteOp from pymongo.read_preferences import ReadPreference -from pymongo.results import InsertOneResult, BulkWriteResult +from pymongo.results import BulkWriteResult, InsertOneResult, UpdateResult from pymongo.write_concern import WriteConcern @@ -647,15 +647,6 @@ class Collection(common.BaseObject): raise TypeError("spec must be a mapping type") if not isinstance(document, collections.Mapping): raise TypeError("document must be a mapping type") - if not isinstance(upsert, bool): - raise TypeError("upsert must be an instance of bool") - - if manipulate: - document = self.__database._fix_incoming(document, self) - - concern = kwargs or self.write_concern.document - safe = concern.get("w") != 0 - if document: # If a top level key begins with '$' this is a modify operation # and we should skip key validation. It doesn't matter which key @@ -666,6 +657,73 @@ class Collection(common.BaseObject): if first.startswith('$'): check_keys = False + write_concern = None + if kwargs: + write_concern = WriteConcern(**kwargs) + return self.__update(spec, document, upsert, + check_keys, multi, manipulate, write_concern) + + def replace_one(self, filter, replacement, upsert=False): + """Replace a single document matching the filter. + + :Parameters: + - `filter`: A query that matches the document to replace. + - `replacement`: The new document. + - `upsert` (optional): If ``True``, perform an insert if no documents + match the filter. + + :Returns: + - An instance of :class:`~pymongo.results.UpdateResult`. + """ + helpers._check_ok_for_replace(replacement) + result = self.__update(filter, replacement, upsert) + return UpdateResult(result, self.write_concern.acknowledged) + + def update_one(self, filter, update, upsert=False): + """Update a single document matching the filter. + + :Parameters: + - `filter`: A query that matches the document to update. + - `update`: The modifications to apply. + - `upsert` (optional): If ``True``, perform an insert if no documents + match the filter. + + :Returns: + - An instance of :class:`~pymongo.results.UpdateResult`. + """ + helpers._check_ok_for_update(update) + result = self.__update(filter, update, upsert, False) + return UpdateResult(result, self.write_concern.acknowledged) + + def update_many(self, filter, update, upsert=False): + """Update one or more documents that match the filter. + + :Parameters: + - `filter`: A query that matches the documents to update. + - `update`: The modifications to apply. + - `upsert` (optional): If ``True``, perform an insert if no documents + match the filter. + + :Returns: + - An instance of :class:`~pymongo.results.UpdateResult`. + """ + helpers._check_ok_for_update(update) + result = self.__update(filter, update, upsert, False, True) + return UpdateResult(result, self.write_concern.acknowledged) + + def __update(self, filter, document, upsert=False, check_keys=True, + multi=False, manipulate=False, write_concern=None): + """Internal update / replace helper.""" + if not isinstance(filter, collections.Mapping): + raise TypeError("filter must be a mapping type") + if not isinstance(upsert, bool): + raise TypeError("upsert must be an instance of bool") + if manipulate: + document = self.__database._fix_incoming(document, self) + + concern = (write_concern or self.write_concern).document + safe = concern.get("w") != 0 + client = self.database.connection if client._writable_max_wire_version() > 1 and safe: # Update command @@ -673,7 +731,7 @@ class Collection(common.BaseObject): if concern: command['writeConcern'] = concern - docs = [SON([('q', spec), ('u', document), + docs = [SON([('q', filter), ('u', document), ('multi', multi), ('upsert', upsert)])] results = message._do_batched_write_command( @@ -698,7 +756,7 @@ class Collection(common.BaseObject): # Legacy OP_UPDATE return client._send_message( message.update(self.__full_name, upsert, multi, - spec, document, safe, concern, + filter, document, safe, concern, check_keys, self.codec_options), safe) def drop(self): @@ -1886,13 +1944,7 @@ class Collection(common.BaseObject): as keyword arguments (for example maxTimeMS can be used with recent server versions). """ - if not isinstance(replacement, collections.Mapping): - raise TypeError('replacement must be a mapping type.') - # Replacement can be {} - if replacement: - first = next(iter(replacement)) - if first.startswith('$'): - raise ValueError('replacement can not include $ operators') + helpers._check_ok_for_replace(replacement) kwargs['update'] = replacement return self.__find_and_modify(filter, projection, sort, upsert, return_document, **kwargs) @@ -1926,14 +1978,7 @@ class Collection(common.BaseObject): as keyword arguments (for example maxTimeMS can be used with recent server versions). """ - if not isinstance(update, collections.Mapping): - raise TypeError('update must be a mapping type.') - # Update can not be {} - if not update: - raise ValueError('update only works with $ operators') - first = next(iter(update)) - if not first.startswith('$'): - raise ValueError('update only works with $ operators') + helpers._check_ok_for_update(update) kwargs['update'] = update return self.__find_and_modify(filter, projection, sort, upsert, return_document, **kwargs) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index aa6c66ca6..3915bb66c 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -31,6 +31,29 @@ from pymongo.errors import (CursorNotFound, from pymongo.message import query +def _check_ok_for_replace(replacement): + """Validate a replacement document.""" + if not isinstance(replacement, collections.Mapping): + raise TypeError('replacement must be a mapping type.') + # Replacement can be {} + if replacement: + first = next(iter(replacement)) + if first.startswith('$'): + raise ValueError('replacement can not include $ operators') + + +def _check_ok_for_update(update): + """Validate an update document.""" + if not isinstance(update, collections.Mapping): + raise TypeError('update must be a mapping type.') + # Update can not be {} + if not update: + raise ValueError('update only works with $ operators') + first = next(iter(update)) + if not first.startswith('$'): + raise ValueError('update only works with $ operators') + + def _index_list(key_or_list, direction=None): """Helper to generate a list of (key, direction) pairs. diff --git a/pymongo/results.py b/pymongo/results.py index 35227b105..60c04e48c 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -17,19 +17,19 @@ from pymongo.errors import InvalidOperation -class InsertOneResult(object): - """The return type for :meth:`Collection.insert_one`.""" +class _WriteResult(object): + """Base class for write result classes.""" - __slots__ = ("__inserted_id", "__acknowledged") - - def __init__(self, inserted_id, acknowledged): - self.__inserted_id = inserted_id + def __init__(self, acknowledged): self.__acknowledged = acknowledged - @property - def inserted_id(self): - """The inserted document's _id.""" - return self.__inserted_id + def _raise_if_unacknowledged(self, property_name): + """Raise an exception on property access if unacknowledged.""" + if not self.__acknowledged: + raise InvalidOperation("A value for %s is not available when " + "the write is unacknowledged. Check the " + "acknowledged attribute to avoid this " + "error." % (property_name,)) @property def acknowledged(self): @@ -37,7 +37,66 @@ class InsertOneResult(object): return self.__acknowledged -class BulkWriteResult(object): +class InsertOneResult(_WriteResult): + """The return type for :meth:`Collection.insert_one`.""" + + __slots__ = ("__inserted_id", "__acknowledged") + + def __init__(self, inserted_id, acknowledged): + self.__inserted_id = inserted_id + super(InsertOneResult, self).__init__(acknowledged) + + @property + def inserted_id(self): + """The inserted document's _id.""" + return self.__inserted_id + + +class UpdateResult(_WriteResult): + """The return type for :meth:`~pymongo.collection.Collection.update_one` + and :meth:`~pymongo.collection.Collection.update_many`""" + + __slots__ = ("__raw_result", "__acknowledged") + + def __init__(self, raw_result, acknowledged): + self.__raw_result = raw_result + super(UpdateResult, self).__init__(acknowledged) + + @property + def raw_result(self): + """The raw result document returned by the server.""" + return self.__raw_result + + @property + def matched_count(self): + """The number of documents matched for this update.""" + self._raise_if_unacknowledged("matched_count") + if self.upserted_id is not None: + return 0 + return self.__raw_result.get("n", 0) + + @property + def modified_count(self): + """The number of documents modified. + + .. note:: modified_count is only reported by MongoDB 2.6 and later. + When connected to an earlier server version, or in certain mixed + version sharding configurations, this attribute will be set to + ``None``. + """ + self._raise_if_unacknowledged("modified_count") + return self.__raw_result.get("nModified") + + @property + def upserted_id(self): + """The _id of the inserted document if an upsert took place. Otherwise + ``None``. + """ + self._raise_if_unacknowledged("upserted_id") + return self.__raw_result.get("upserted") + + +class BulkWriteResult(_WriteResult): """An object wrapper for bulk API write results.""" __slots__ = ("__bulk_api_result", "__acknowledged") @@ -52,36 +111,23 @@ class BulkWriteResult(object): :exc:`~pymongo.errors.InvalidOperation`. """ self.__bulk_api_result = bulk_api_result - self.__acknowledged = acknowledged - - def __raise_if_unacknowledged(self, property_name): - """Raise an exception on property access if unacknowledged.""" - if not self.__acknowledged: - raise InvalidOperation("A value for %s is not available when " - "the write is unacknowledged. Check the " - "acknowledged attribute to avoid this " - "error." % (property_name,)) + super(BulkWriteResult, self).__init__(acknowledged) @property def bulk_api_result(self): """The raw bulk API result.""" return self.__bulk_api_result - @property - def acknowledged(self): - """Is this the result of an acknowledged bulk write operation?""" - return self.__acknowledged - @property def inserted_count(self): """The number of documents inserted.""" - self.__raise_if_unacknowledged("inserted_count") + self._raise_if_unacknowledged("inserted_count") return self.__bulk_api_result.get("nInserted") @property def matched_count(self): """The number of documents matched for an update.""" - self.__raise_if_unacknowledged("matched_count") + self._raise_if_unacknowledged("matched_count") return self.__bulk_api_result.get("nMatched") @property @@ -93,25 +139,25 @@ class BulkWriteResult(object): version sharding configurations, this attribute will be set to ``None``. """ - self.__raise_if_unacknowledged("modified_count") + self._raise_if_unacknowledged("modified_count") return self.__bulk_api_result.get("nModified") @property def deleted_count(self): """The number of documents deleted.""" - self.__raise_if_unacknowledged("deleted_count") + self._raise_if_unacknowledged("deleted_count") return self.__bulk_api_result.get("nRemoved") @property def upserted_count(self): """The number of documents upserted.""" - self.__raise_if_unacknowledged("upserted_count") + self._raise_if_unacknowledged("upserted_count") return self.__bulk_api_result.get("nUpserted") @property def upserted_ids(self): """A map of operation index to the _id of the upserted document.""" - self.__raise_if_unacknowledged("upserted_ids") + self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: return dict((upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"]) diff --git a/test/test_collection.py b/test/test_collection.py index 1e8aae6cd..d93f1c1ab 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -48,7 +48,7 @@ from pymongo.errors import (ConfigurationError, WTimeoutError) from pymongo.options import ReturnDocument from pymongo.read_preferences import ReadPreference -from pymongo.results import InsertOneResult +from pymongo.results import InsertOneResult, UpdateResult from pymongo.son_manipulator import SONManipulator from pymongo.write_concern import WriteConcern from test.test_client import IntegrationTest @@ -1072,6 +1072,127 @@ class TestCollection(IntegrationTest): self.assertEqual(db.test.find_one(id1)["x"], 7) self.assertEqual(db.test.find_one(id2)["x"], 1) + def test_replace_one(self): + db = self.db + db.drop_collection("test") + + self.assertRaises(ValueError, + lambda: db.test.replace_one({}, {"$set": {"x": 1}})) + + id1 = db.test.insert_one({"x": 1}).inserted_id + result = db.test.replace_one({"x": 1}, {"y": 1}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (None, 1)) + self.assertIsNone(result.upserted_id) + self.assertTrue(result.acknowledged) + self.assertEqual(1, db.test.count({"y": 1})) + self.assertEqual(0, db.test.count({"x": 1})) + self.assertEqual(db.test.find_one(id1)["y"], 1) + + result = db.test.replace_one({"x": 2}, {"y": 2}, True) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(0, result.matched_count) + self.assertTrue(result.modified_count in (None, 0)) + self.assertTrue(isinstance(result.upserted_id, ObjectId)) + self.assertTrue(result.acknowledged) + self.assertEqual(1, db.test.count({"y": 2})) + + db = db.connection.get_database(db.name, + write_concern=WriteConcern(w=0)) + result = db.test.replace_one({"x": 0}, {"y": 0}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertRaises(InvalidOperation, lambda: result.matched_count) + self.assertRaises(InvalidOperation, lambda: result.modified_count) + self.assertRaises(InvalidOperation, lambda: result.upserted_id) + self.assertFalse(result.acknowledged) + + def test_update_one(self): + db = self.db + db.drop_collection("test") + + self.assertRaises(ValueError, + lambda: db.test.update_one({}, {"x": 1})) + + id1 = db.test.insert_one({"x": 5}).inserted_id + result = db.test.update_one({}, {"$inc": {"x": 1}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (None, 1)) + self.assertIsNone(result.upserted_id) + self.assertTrue(result.acknowledged) + self.assertEqual(db.test.find_one(id1)["x"], 6) + + id2 = db.test.insert_one({"x": 1}).inserted_id + result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (None, 1)) + self.assertIsNone(result.upserted_id) + self.assertTrue(result.acknowledged) + self.assertEqual(db.test.find_one(id1)["x"], 7) + self.assertEqual(db.test.find_one(id2)["x"], 1) + + result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(0, result.matched_count) + self.assertTrue(result.modified_count in (None, 0)) + self.assertTrue(isinstance(result.upserted_id, ObjectId)) + self.assertTrue(result.acknowledged) + + db = db.connection.get_database(db.name, + write_concern=WriteConcern(w=0)) + result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertRaises(InvalidOperation, lambda: result.matched_count) + self.assertRaises(InvalidOperation, lambda: result.modified_count) + self.assertRaises(InvalidOperation, lambda: result.upserted_id) + self.assertFalse(result.acknowledged) + + def test_update_many(self): + db = self.db + db.drop_collection("test") + + self.assertRaises(ValueError, + lambda: db.test.update_many({}, {"x": 1})) + + db.test.insert_one({"x": 4, "y": 3}) + db.test.insert_one({"x": 5, "y": 5}) + db.test.insert_one({"x": 4, "y": 4}) + + result = db.test.update_many({"x": 4}, {"$set": {"y": 5}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(2, result.matched_count) + self.assertTrue(result.modified_count in (None, 2)) + self.assertIsNone(result.upserted_id) + self.assertTrue(result.acknowledged) + self.assertEqual(3, db.test.count({"y": 5})) + + result = db.test.update_many({"x": 5}, {"$set": {"y": 6}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (None, 1)) + self.assertIsNone(result.upserted_id) + self.assertTrue(result.acknowledged) + self.assertEqual(1, db.test.count({"y": 6})) + + result = db.test.update_many({"x": 2}, {"$set": {"y": 1}}, True) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertEqual(0, result.matched_count) + self.assertTrue(result.modified_count in (None, 0)) + self.assertTrue(isinstance(result.upserted_id, ObjectId)) + self.assertTrue(result.acknowledged) + + db = db.connection.get_database(db.name, + write_concern=WriteConcern(w=0)) + result = db.test.update_many({"x": 0}, {"$inc": {"x": 1}}) + self.assertTrue(isinstance(result, UpdateResult)) + self.assertRaises(InvalidOperation, lambda: result.matched_count) + self.assertRaises(InvalidOperation, lambda: result.modified_count) + self.assertRaises(InvalidOperation, lambda: result.upserted_id) + self.assertFalse(result.acknowledged) + + def test_update_manipulate(self): db = self.db db.drop_collection("test")