From 0e0c4fd9443723e5c613a5964143528ee4ffe13d Mon Sep 17 00:00:00 2001 From: Prashant Mital <5883388+prashantmital@users.noreply.github.com> Date: Wed, 30 Jun 2021 19:14:22 -0700 Subject: [PATCH] PYTHON-2389 Add session support to find_raw_batches and aggregate_raw_batches (#658) --- pymongo/collection.py | 38 +++++------- pymongo/message.py | 2 + test/test_cursor.py | 141 ++++++++++++++++++++++++++++++++++++++++++ test/test_session.py | 16 ++--- 4 files changed, 164 insertions(+), 33 deletions(-) diff --git a/pymongo/collection.py b/pymongo/collection.py index 64ec74b75..2fa0d9288 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -1484,21 +1484,16 @@ class Collection(common.BaseObject): >>> for batch in cursor: ... print(bson.decode_all(batch)) - .. note:: find_raw_batches does not support sessions or auto - encryption. + .. note:: find_raw_batches does not support auto encryption. .. versionchanged:: 3.12 Instead of ignoring the user-specified read concern, this method now sends it to the server when connected to MongoDB 3.6+. + Added session support. + .. versionadded:: 3.6 """ - # OP_MSG with document stream returns is required to support - # sessions. - if "session" in kwargs: - raise ConfigurationError( - "find_raw_batches does not support sessions") - # OP_MSG is required to support encryption. if self.__database.client._encrypter: raise InvalidOperation( @@ -2256,7 +2251,7 @@ class Collection(common.BaseObject): explicit_session=session is not None, **kwargs) - def aggregate_raw_batches(self, pipeline, **kwargs): + def aggregate_raw_batches(self, pipeline, session=None, **kwargs): """Perform an aggregation and retrieve batches of raw BSON. Similar to the :meth:`aggregate` method but returns a @@ -2273,28 +2268,25 @@ class Collection(common.BaseObject): >>> for batch in cursor: ... print(bson.decode_all(batch)) - .. note:: aggregate_raw_batches does not support sessions or auto - encryption. + .. note:: aggregate_raw_batches does not support auto encryption. + + .. versionchanged:: 3.12 + Added session support. .. versionadded:: 3.6 """ - # OP_MSG with document stream returns is required to support - # sessions. - if "session" in kwargs: - raise ConfigurationError( - "aggregate_raw_batches does not support sessions") - # OP_MSG is required to support encryption. if self.__database.client._encrypter: raise InvalidOperation( "aggregate_raw_batches does not support auto encryption") - return self._aggregate(_CollectionRawAggregationCommand, - pipeline, - RawBatchCommandCursor, - session=None, - explicit_session=False, - **kwargs) + with self.__database.client._tmp_session(session, close=False) as s: + return self._aggregate(_CollectionRawAggregationCommand, + pipeline, + RawBatchCommandCursor, + session=s, + explicit_session=session is not None, + **kwargs) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, diff --git a/pymongo/message.py b/pymongo/message.py index 13cee23cd..79b56ec1a 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -468,6 +468,8 @@ class _RawBatchQuery(_Query): class _RawBatchGetMore(_GetMore): def use_command(self, sock_info): + # Compatibility checks. + super(_RawBatchGetMore, self).use_command(sock_info) if sock_info.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True diff --git a/test/test_cursor.py b/test/test_cursor.py index dc0be8f68..021e4d7cb 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -27,6 +27,7 @@ sys.path[0:0] = [""] from bson import decode_all from bson.code import Code +from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo import (ASCENDING, DESCENDING, @@ -44,6 +45,7 @@ from test import (client_context, unittest, IntegrationTest) from test.utils import (EventListener, + OvertCommandListener, ignore_deprecations, rs_or_single_client, WhiteListEventListener) @@ -1466,6 +1468,76 @@ class TestRawBatchCursor(IntegrationTest): self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) + @client_context.require_transactions + def test_find_raw_transaction(self): + c = self.db.test + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener]) + with client.start_session() as session: + with session.start_transaction(): + batches = list(client[self.db.name].test.find_raw_batches( + session=session).sort('_id')) + cmd = listener.results['started'][0] + self.assertEqual(cmd.command_name, 'find') + self.assertEqual(cmd.command['$clusterTime'], + decode_all(session.cluster_time.raw)[0]) + self.assertEqual(cmd.command['startTransaction'], True) + self.assertEqual(cmd.command['txnNumber'], 1) + + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + + @client_context.require_sessions + @client_context.require_failCommand_fail_point + def test_find_raw_retryable_reads(self): + c = self.db.test + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener], + retryReads=True) + with self.fail_point({ + 'mode': {'times': 1}, 'data': {'failCommands': ['find'], + 'closeConnection': True}}): + batches = list( + client[self.db.name].test.find_raw_batches().sort('_id')) + + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + self.assertEqual(len(listener.results['started']), 2) + for cmd in listener.results['started']: + self.assertEqual(cmd.command_name, 'find') + + @client_context.require_version_min(5, 0, 0) + @client_context.require_no_standalone + def test_find_raw_snapshot_reads(self): + c = self.db.get_collection( + "test", write_concern=WriteConcern(w="majority")) + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener], + retryReads=True) + db = client[self.db.name] + with client.start_session(snapshot=True) as session: + db.test.distinct('x', {}, session=session) + batches = list(db.test.find_raw_batches( + session=session).sort('_id')) + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + + find_cmd = listener.results['started'][1].command + self.assertEqual(find_cmd['readConcern']['level'], 'snapshot') + self.assertIsNotNone(find_cmd['readConcern']['atClusterTime']) + def test_explain(self): c = self.db.test c.insert_one({}) @@ -1590,6 +1662,75 @@ class TestRawBatchCommandCursor(IntegrationTest): self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) + @client_context.require_transactions + def test_aggregate_raw_transaction(self): + c = self.db.test + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener]) + with client.start_session() as session: + with session.start_transaction(): + batches = list(client[self.db.name].test.aggregate_raw_batches( + [{'$sort': {'_id': 1}}], session=session)) + cmd = listener.results['started'][0] + self.assertEqual(cmd.command_name, 'aggregate') + self.assertEqual(cmd.command['$clusterTime'], session.cluster_time) + self.assertEqual(cmd.command['startTransaction'], True) + self.assertEqual(cmd.command['txnNumber'], 1) + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + + @client_context.require_sessions + @client_context.require_failCommand_fail_point + def test_aggregate_raw_retryable_reads(self): + c = self.db.test + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener], + retryReads=True) + with self.fail_point({ + 'mode': {'times': 1}, 'data': {'failCommands': ['aggregate'], + 'closeConnection': True}}): + batches = list(client[self.db.name].test.aggregate_raw_batches( + [{'$sort': {'_id': 1}}])) + + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + self.assertEqual(len(listener.results['started']), 3) + cmds = listener.results['started'] + self.assertEqual(cmds[0].command_name, 'aggregate') + self.assertEqual(cmds[1].command_name, 'aggregate') + + @client_context.require_version_min(5, 0, -1) + @client_context.require_no_standalone + def test_aggregate_raw_snapshot_reads(self): + c = self.db.get_collection( + "test", write_concern=WriteConcern(w="majority")) + c.drop() + docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + c.insert_many(docs) + + listener = OvertCommandListener() + client = rs_or_single_client(event_listeners=[listener], + retryReads=True) + db = client[self.db.name] + with client.start_session(snapshot=True) as session: + db.test.distinct('x', {}, session=session) + batches = list(db.test.aggregate_raw_batches( + [{'$sort': {'_id': 1}}], session=session)) + self.assertEqual(1, len(batches)) + self.assertEqual(docs, decode_all(batches[0])) + + find_cmd = listener.results['started'][1].command + self.assertEqual(find_cmd['readConcern']['level'], 'snapshot') + self.assertIsNotNone(find_cmd['readConcern']['atClusterTime']) + def test_server_error(self): c = self.db.test c.drop() diff --git a/test/test_session.py b/test/test_session.py index ab78e30f8..41837ab21 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -824,6 +824,12 @@ class TestCausalConsistency(unittest.TestCase): lambda coll, session: coll.count_documents({}, session=session)) self._test_reads( lambda coll, session: coll.distinct('foo', session=session)) + self._test_reads( + lambda coll, session: list(coll.aggregate_raw_batches( + [], session=session))) + self._test_reads( + lambda coll, session: list(coll.find_raw_batches( + {}, session=session))) # SERVER-40938 removed support for casually consistent mapReduce. map_reduce_exc = None @@ -841,16 +847,6 @@ class TestCausalConsistency(unittest.TestCase): 'function() {}', 'function() {}', session=session), exception=map_reduce_exc) - self.assertRaises( - ConfigurationError, - self._test_reads, - lambda coll, session: list( - coll.aggregate_raw_batches([], session=session))) - self.assertRaises( - ConfigurationError, - self._test_reads, - lambda coll, session: list( - coll.find_raw_batches({}, session=session))) self.assertRaises( ConfigurationError, self._test_reads,