PYTHON-821 - Implement replace_one, update_one, and update_many.

This commit is contained in:
Bernie Hackett 2015-02-11 20:01:16 -08:00
parent b9cd7b627b
commit e7866dbd19
7 changed files with 301 additions and 74 deletions

View File

@ -33,6 +33,9 @@
.. autoattribute:: write_concern
.. automethod:: with_options
.. automethod:: insert_one
.. automethod:: replace_one
.. automethod:: update_one
.. automethod:: update_many
.. automethod:: insert(doc_or_docs[, manipulate=True[, check_keys=True[, continue_on_error=False[, **kwargs]]]])
.. automethod:: save(to_save[, manipulate=True[, check_keys=True[, **kwargs]]])
.. automethod:: update(spec, document[, upsert=False[, manipulate=False[, multi=False[, check_keys=True[, **kwargs]]]]])

View File

@ -4,3 +4,4 @@
.. automodule:: pymongo.results
:synopsis: Result class definitions
:members:
:inherited-members:

View File

@ -23,6 +23,7 @@ import collections
from bson.objectid import ObjectId
from bson.son import SON
from pymongo import helpers
from pymongo.errors import (BulkWriteError,
DocumentTooLarge,
InvalidOperation,
@ -208,14 +209,7 @@ class _Bulk(object):
def add_update(self, selector, update, multi=False, upsert=False):
"""Create an update document and add it to the list of ops.
"""
if not isinstance(update, collections.Mapping):
raise TypeError('update must be a mapping type.')
# Update can not be {}
if not update:
raise ValueError('update only works with $ operators')
first = next(iter(update))
if not first.startswith('$'):
raise ValueError('update only works with $ operators')
helpers._check_ok_for_update(update)
cmd = SON([('q', selector), ('u', update),
('multi', multi), ('upsert', upsert)])
self.ops.append((_UPDATE, cmd))
@ -223,13 +217,7 @@ class _Bulk(object):
def add_replace(self, selector, replacement, upsert=False):
"""Create a replace document and add it to the list of ops.
"""
if not isinstance(replacement, collections.Mapping):
raise TypeError('replacement must be a mapping type.')
# Replacement can be {}
if replacement:
first = next(iter(replacement))
if first.startswith('$'):
raise ValueError('replacement can not include $ operators')
helpers._check_ok_for_replace(replacement)
cmd = SON([('q', selector), ('u', replacement),
('multi', False), ('upsert', upsert)])
self.ops.append((_UPDATE, cmd))

View File

@ -37,7 +37,7 @@ from pymongo.helpers import _check_write_command_response, _command
from pymongo.message import _INSERT, _UPDATE, _DELETE
from pymongo.options import ReturnDocument, _WriteOp
from pymongo.read_preferences import ReadPreference
from pymongo.results import InsertOneResult, BulkWriteResult
from pymongo.results import BulkWriteResult, InsertOneResult, UpdateResult
from pymongo.write_concern import WriteConcern
@ -647,15 +647,6 @@ class Collection(common.BaseObject):
raise TypeError("spec must be a mapping type")
if not isinstance(document, collections.Mapping):
raise TypeError("document must be a mapping type")
if not isinstance(upsert, bool):
raise TypeError("upsert must be an instance of bool")
if manipulate:
document = self.__database._fix_incoming(document, self)
concern = kwargs or self.write_concern.document
safe = concern.get("w") != 0
if document:
# If a top level key begins with '$' this is a modify operation
# and we should skip key validation. It doesn't matter which key
@ -666,6 +657,73 @@ class Collection(common.BaseObject):
if first.startswith('$'):
check_keys = False
write_concern = None
if kwargs:
write_concern = WriteConcern(**kwargs)
return self.__update(spec, document, upsert,
check_keys, multi, manipulate, write_concern)
def replace_one(self, filter, replacement, upsert=False):
"""Replace a single document matching the filter.
:Parameters:
- `filter`: A query that matches the document to replace.
- `replacement`: The new document.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
:Returns:
- An instance of :class:`~pymongo.results.UpdateResult`.
"""
helpers._check_ok_for_replace(replacement)
result = self.__update(filter, replacement, upsert)
return UpdateResult(result, self.write_concern.acknowledged)
def update_one(self, filter, update, upsert=False):
"""Update a single document matching the filter.
:Parameters:
- `filter`: A query that matches the document to update.
- `update`: The modifications to apply.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
:Returns:
- An instance of :class:`~pymongo.results.UpdateResult`.
"""
helpers._check_ok_for_update(update)
result = self.__update(filter, update, upsert, False)
return UpdateResult(result, self.write_concern.acknowledged)
def update_many(self, filter, update, upsert=False):
"""Update one or more documents that match the filter.
:Parameters:
- `filter`: A query that matches the documents to update.
- `update`: The modifications to apply.
- `upsert` (optional): If ``True``, perform an insert if no documents
match the filter.
:Returns:
- An instance of :class:`~pymongo.results.UpdateResult`.
"""
helpers._check_ok_for_update(update)
result = self.__update(filter, update, upsert, False, True)
return UpdateResult(result, self.write_concern.acknowledged)
def __update(self, filter, document, upsert=False, check_keys=True,
multi=False, manipulate=False, write_concern=None):
"""Internal update / replace helper."""
if not isinstance(filter, collections.Mapping):
raise TypeError("filter must be a mapping type")
if not isinstance(upsert, bool):
raise TypeError("upsert must be an instance of bool")
if manipulate:
document = self.__database._fix_incoming(document, self)
concern = (write_concern or self.write_concern).document
safe = concern.get("w") != 0
client = self.database.connection
if client._writable_max_wire_version() > 1 and safe:
# Update command
@ -673,7 +731,7 @@ class Collection(common.BaseObject):
if concern:
command['writeConcern'] = concern
docs = [SON([('q', spec), ('u', document),
docs = [SON([('q', filter), ('u', document),
('multi', multi), ('upsert', upsert)])]
results = message._do_batched_write_command(
@ -698,7 +756,7 @@ class Collection(common.BaseObject):
# Legacy OP_UPDATE
return client._send_message(
message.update(self.__full_name, upsert, multi,
spec, document, safe, concern,
filter, document, safe, concern,
check_keys, self.codec_options), safe)
def drop(self):
@ -1886,13 +1944,7 @@ class Collection(common.BaseObject):
as keyword arguments (for example maxTimeMS can be used with
recent server versions).
"""
if not isinstance(replacement, collections.Mapping):
raise TypeError('replacement must be a mapping type.')
# Replacement can be {}
if replacement:
first = next(iter(replacement))
if first.startswith('$'):
raise ValueError('replacement can not include $ operators')
helpers._check_ok_for_replace(replacement)
kwargs['update'] = replacement
return self.__find_and_modify(filter, projection,
sort, upsert, return_document, **kwargs)
@ -1926,14 +1978,7 @@ class Collection(common.BaseObject):
as keyword arguments (for example maxTimeMS can be used with
recent server versions).
"""
if not isinstance(update, collections.Mapping):
raise TypeError('update must be a mapping type.')
# Update can not be {}
if not update:
raise ValueError('update only works with $ operators')
first = next(iter(update))
if not first.startswith('$'):
raise ValueError('update only works with $ operators')
helpers._check_ok_for_update(update)
kwargs['update'] = update
return self.__find_and_modify(filter, projection,
sort, upsert, return_document, **kwargs)

View File

@ -31,6 +31,29 @@ from pymongo.errors import (CursorNotFound,
from pymongo.message import query
def _check_ok_for_replace(replacement):
"""Validate a replacement document."""
if not isinstance(replacement, collections.Mapping):
raise TypeError('replacement must be a mapping type.')
# Replacement can be {}
if replacement:
first = next(iter(replacement))
if first.startswith('$'):
raise ValueError('replacement can not include $ operators')
def _check_ok_for_update(update):
"""Validate an update document."""
if not isinstance(update, collections.Mapping):
raise TypeError('update must be a mapping type.')
# Update can not be {}
if not update:
raise ValueError('update only works with $ operators')
first = next(iter(update))
if not first.startswith('$'):
raise ValueError('update only works with $ operators')
def _index_list(key_or_list, direction=None):
"""Helper to generate a list of (key, direction) pairs.

View File

@ -17,19 +17,19 @@
from pymongo.errors import InvalidOperation
class InsertOneResult(object):
"""The return type for :meth:`Collection.insert_one`."""
class _WriteResult(object):
"""Base class for write result classes."""
__slots__ = ("__inserted_id", "__acknowledged")
def __init__(self, inserted_id, acknowledged):
self.__inserted_id = inserted_id
def __init__(self, acknowledged):
self.__acknowledged = acknowledged
@property
def inserted_id(self):
"""The inserted document's _id."""
return self.__inserted_id
def _raise_if_unacknowledged(self, property_name):
"""Raise an exception on property access if unacknowledged."""
if not self.__acknowledged:
raise InvalidOperation("A value for %s is not available when "
"the write is unacknowledged. Check the "
"acknowledged attribute to avoid this "
"error." % (property_name,))
@property
def acknowledged(self):
@ -37,7 +37,66 @@ class InsertOneResult(object):
return self.__acknowledged
class BulkWriteResult(object):
class InsertOneResult(_WriteResult):
"""The return type for :meth:`Collection.insert_one`."""
__slots__ = ("__inserted_id", "__acknowledged")
def __init__(self, inserted_id, acknowledged):
self.__inserted_id = inserted_id
super(InsertOneResult, self).__init__(acknowledged)
@property
def inserted_id(self):
"""The inserted document's _id."""
return self.__inserted_id
class UpdateResult(_WriteResult):
"""The return type for :meth:`~pymongo.collection.Collection.update_one`
and :meth:`~pymongo.collection.Collection.update_many`"""
__slots__ = ("__raw_result", "__acknowledged")
def __init__(self, raw_result, acknowledged):
self.__raw_result = raw_result
super(UpdateResult, self).__init__(acknowledged)
@property
def raw_result(self):
"""The raw result document returned by the server."""
return self.__raw_result
@property
def matched_count(self):
"""The number of documents matched for this update."""
self._raise_if_unacknowledged("matched_count")
if self.upserted_id is not None:
return 0
return self.__raw_result.get("n", 0)
@property
def modified_count(self):
"""The number of documents modified.
.. note:: modified_count is only reported by MongoDB 2.6 and later.
When connected to an earlier server version, or in certain mixed
version sharding configurations, this attribute will be set to
``None``.
"""
self._raise_if_unacknowledged("modified_count")
return self.__raw_result.get("nModified")
@property
def upserted_id(self):
"""The _id of the inserted document if an upsert took place. Otherwise
``None``.
"""
self._raise_if_unacknowledged("upserted_id")
return self.__raw_result.get("upserted")
class BulkWriteResult(_WriteResult):
"""An object wrapper for bulk API write results."""
__slots__ = ("__bulk_api_result", "__acknowledged")
@ -52,36 +111,23 @@ class BulkWriteResult(object):
:exc:`~pymongo.errors.InvalidOperation`.
"""
self.__bulk_api_result = bulk_api_result
self.__acknowledged = acknowledged
def __raise_if_unacknowledged(self, property_name):
"""Raise an exception on property access if unacknowledged."""
if not self.__acknowledged:
raise InvalidOperation("A value for %s is not available when "
"the write is unacknowledged. Check the "
"acknowledged attribute to avoid this "
"error." % (property_name,))
super(BulkWriteResult, self).__init__(acknowledged)
@property
def bulk_api_result(self):
"""The raw bulk API result."""
return self.__bulk_api_result
@property
def acknowledged(self):
"""Is this the result of an acknowledged bulk write operation?"""
return self.__acknowledged
@property
def inserted_count(self):
"""The number of documents inserted."""
self.__raise_if_unacknowledged("inserted_count")
self._raise_if_unacknowledged("inserted_count")
return self.__bulk_api_result.get("nInserted")
@property
def matched_count(self):
"""The number of documents matched for an update."""
self.__raise_if_unacknowledged("matched_count")
self._raise_if_unacknowledged("matched_count")
return self.__bulk_api_result.get("nMatched")
@property
@ -93,25 +139,25 @@ class BulkWriteResult(object):
version sharding configurations, this attribute will be set to
``None``.
"""
self.__raise_if_unacknowledged("modified_count")
self._raise_if_unacknowledged("modified_count")
return self.__bulk_api_result.get("nModified")
@property
def deleted_count(self):
"""The number of documents deleted."""
self.__raise_if_unacknowledged("deleted_count")
self._raise_if_unacknowledged("deleted_count")
return self.__bulk_api_result.get("nRemoved")
@property
def upserted_count(self):
"""The number of documents upserted."""
self.__raise_if_unacknowledged("upserted_count")
self._raise_if_unacknowledged("upserted_count")
return self.__bulk_api_result.get("nUpserted")
@property
def upserted_ids(self):
"""A map of operation index to the _id of the upserted document."""
self.__raise_if_unacknowledged("upserted_ids")
self._raise_if_unacknowledged("upserted_ids")
if self.__bulk_api_result:
return dict((upsert["index"], upsert["_id"])
for upsert in self.bulk_api_result["upserted"])

View File

@ -48,7 +48,7 @@ from pymongo.errors import (ConfigurationError,
WTimeoutError)
from pymongo.options import ReturnDocument
from pymongo.read_preferences import ReadPreference
from pymongo.results import InsertOneResult
from pymongo.results import InsertOneResult, UpdateResult
from pymongo.son_manipulator import SONManipulator
from pymongo.write_concern import WriteConcern
from test.test_client import IntegrationTest
@ -1072,6 +1072,127 @@ class TestCollection(IntegrationTest):
self.assertEqual(db.test.find_one(id1)["x"], 7)
self.assertEqual(db.test.find_one(id2)["x"], 1)
def test_replace_one(self):
db = self.db
db.drop_collection("test")
self.assertRaises(ValueError,
lambda: db.test.replace_one({}, {"$set": {"x": 1}}))
id1 = db.test.insert_one({"x": 1}).inserted_id
result = db.test.replace_one({"x": 1}, {"y": 1})
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(1, result.matched_count)
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(1, db.test.count({"y": 1}))
self.assertEqual(0, db.test.count({"x": 1}))
self.assertEqual(db.test.find_one(id1)["y"], 1)
result = db.test.replace_one({"x": 2}, {"y": 2}, True)
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(0, result.matched_count)
self.assertTrue(result.modified_count in (None, 0))
self.assertTrue(isinstance(result.upserted_id, ObjectId))
self.assertTrue(result.acknowledged)
self.assertEqual(1, db.test.count({"y": 2}))
db = db.connection.get_database(db.name,
write_concern=WriteConcern(w=0))
result = db.test.replace_one({"x": 0}, {"y": 0})
self.assertTrue(isinstance(result, UpdateResult))
self.assertRaises(InvalidOperation, lambda: result.matched_count)
self.assertRaises(InvalidOperation, lambda: result.modified_count)
self.assertRaises(InvalidOperation, lambda: result.upserted_id)
self.assertFalse(result.acknowledged)
def test_update_one(self):
db = self.db
db.drop_collection("test")
self.assertRaises(ValueError,
lambda: db.test.update_one({}, {"x": 1}))
id1 = db.test.insert_one({"x": 5}).inserted_id
result = db.test.update_one({}, {"$inc": {"x": 1}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(1, result.matched_count)
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(db.test.find_one(id1)["x"], 6)
id2 = db.test.insert_one({"x": 1}).inserted_id
result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(1, result.matched_count)
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(db.test.find_one(id1)["x"], 7)
self.assertEqual(db.test.find_one(id2)["x"], 1)
result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True)
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(0, result.matched_count)
self.assertTrue(result.modified_count in (None, 0))
self.assertTrue(isinstance(result.upserted_id, ObjectId))
self.assertTrue(result.acknowledged)
db = db.connection.get_database(db.name,
write_concern=WriteConcern(w=0))
result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertRaises(InvalidOperation, lambda: result.matched_count)
self.assertRaises(InvalidOperation, lambda: result.modified_count)
self.assertRaises(InvalidOperation, lambda: result.upserted_id)
self.assertFalse(result.acknowledged)
def test_update_many(self):
db = self.db
db.drop_collection("test")
self.assertRaises(ValueError,
lambda: db.test.update_many({}, {"x": 1}))
db.test.insert_one({"x": 4, "y": 3})
db.test.insert_one({"x": 5, "y": 5})
db.test.insert_one({"x": 4, "y": 4})
result = db.test.update_many({"x": 4}, {"$set": {"y": 5}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(2, result.matched_count)
self.assertTrue(result.modified_count in (None, 2))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(3, db.test.count({"y": 5}))
result = db.test.update_many({"x": 5}, {"$set": {"y": 6}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(1, result.matched_count)
self.assertTrue(result.modified_count in (None, 1))
self.assertIsNone(result.upserted_id)
self.assertTrue(result.acknowledged)
self.assertEqual(1, db.test.count({"y": 6}))
result = db.test.update_many({"x": 2}, {"$set": {"y": 1}}, True)
self.assertTrue(isinstance(result, UpdateResult))
self.assertEqual(0, result.matched_count)
self.assertTrue(result.modified_count in (None, 0))
self.assertTrue(isinstance(result.upserted_id, ObjectId))
self.assertTrue(result.acknowledged)
db = db.connection.get_database(db.name,
write_concern=WriteConcern(w=0))
result = db.test.update_many({"x": 0}, {"$inc": {"x": 1}})
self.assertTrue(isinstance(result, UpdateResult))
self.assertRaises(InvalidOperation, lambda: result.matched_count)
self.assertRaises(InvalidOperation, lambda: result.modified_count)
self.assertRaises(InvalidOperation, lambda: result.upserted_id)
self.assertFalse(result.acknowledged)
def test_update_manipulate(self):
db = self.db
db.drop_collection("test")