PYTHON-2389 Add session support to find_raw_batches and aggregate_raw_batches (#658)

This commit is contained in:
Prashant Mital 2021-06-30 19:14:22 -07:00 committed by GitHub
parent b823b95de1
commit 0e0c4fd944
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 33 deletions

View File

@ -1484,21 +1484,16 @@ class Collection(common.BaseObject):
>>> for batch in cursor:
... print(bson.decode_all(batch))
.. note:: find_raw_batches does not support sessions or auto
encryption.
.. note:: find_raw_batches does not support auto encryption.
.. versionchanged:: 3.12
Instead of ignoring the user-specified read concern, this method
now sends it to the server when connected to MongoDB 3.6+.
Added session support.
.. versionadded:: 3.6
"""
# OP_MSG with document stream returns is required to support
# sessions.
if "session" in kwargs:
raise ConfigurationError(
"find_raw_batches does not support sessions")
# OP_MSG is required to support encryption.
if self.__database.client._encrypter:
raise InvalidOperation(
@ -2256,7 +2251,7 @@ class Collection(common.BaseObject):
explicit_session=session is not None,
**kwargs)
def aggregate_raw_batches(self, pipeline, **kwargs):
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
"""Perform an aggregation and retrieve batches of raw BSON.
Similar to the :meth:`aggregate` method but returns a
@ -2273,28 +2268,25 @@ class Collection(common.BaseObject):
>>> for batch in cursor:
... print(bson.decode_all(batch))
.. note:: aggregate_raw_batches does not support sessions or auto
encryption.
.. note:: aggregate_raw_batches does not support auto encryption.
.. versionchanged:: 3.12
Added session support.
.. versionadded:: 3.6
"""
# OP_MSG with document stream returns is required to support
# sessions.
if "session" in kwargs:
raise ConfigurationError(
"aggregate_raw_batches does not support sessions")
# OP_MSG is required to support encryption.
if self.__database.client._encrypter:
raise InvalidOperation(
"aggregate_raw_batches does not support auto encryption")
return self._aggregate(_CollectionRawAggregationCommand,
pipeline,
RawBatchCommandCursor,
session=None,
explicit_session=False,
**kwargs)
with self.__database.client._tmp_session(session, close=False) as s:
return self._aggregate(_CollectionRawAggregationCommand,
pipeline,
RawBatchCommandCursor,
session=s,
explicit_session=session is not None,
**kwargs)
def watch(self, pipeline=None, full_document=None, resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None,

View File

@ -468,6 +468,8 @@ class _RawBatchQuery(_Query):
class _RawBatchGetMore(_GetMore):
def use_command(self, sock_info):
# Compatibility checks.
super(_RawBatchGetMore, self).use_command(sock_info)
if sock_info.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True

View File

@ -27,6 +27,7 @@ sys.path[0:0] = [""]
from bson import decode_all
from bson.code import Code
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo import (ASCENDING,
DESCENDING,
@ -44,6 +45,7 @@ from test import (client_context,
unittest,
IntegrationTest)
from test.utils import (EventListener,
OvertCommandListener,
ignore_deprecations,
rs_or_single_client,
WhiteListEventListener)
@ -1466,6 +1468,76 @@ class TestRawBatchCursor(IntegrationTest):
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@client_context.require_transactions
def test_find_raw_transaction(self):
c = self.db.test
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
with client.start_session() as session:
with session.start_transaction():
batches = list(client[self.db.name].test.find_raw_batches(
session=session).sort('_id'))
cmd = listener.results['started'][0]
self.assertEqual(cmd.command_name, 'find')
self.assertEqual(cmd.command['$clusterTime'],
decode_all(session.cluster_time.raw)[0])
self.assertEqual(cmd.command['startTransaction'], True)
self.assertEqual(cmd.command['txnNumber'], 1)
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@client_context.require_sessions
@client_context.require_failCommand_fail_point
def test_find_raw_retryable_reads(self):
c = self.db.test
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener],
retryReads=True)
with self.fail_point({
'mode': {'times': 1}, 'data': {'failCommands': ['find'],
'closeConnection': True}}):
batches = list(
client[self.db.name].test.find_raw_batches().sort('_id'))
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
self.assertEqual(len(listener.results['started']), 2)
for cmd in listener.results['started']:
self.assertEqual(cmd.command_name, 'find')
@client_context.require_version_min(5, 0, 0)
@client_context.require_no_standalone
def test_find_raw_snapshot_reads(self):
c = self.db.get_collection(
"test", write_concern=WriteConcern(w="majority"))
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener],
retryReads=True)
db = client[self.db.name]
with client.start_session(snapshot=True) as session:
db.test.distinct('x', {}, session=session)
batches = list(db.test.find_raw_batches(
session=session).sort('_id'))
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
find_cmd = listener.results['started'][1].command
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
def test_explain(self):
c = self.db.test
c.insert_one({})
@ -1590,6 +1662,75 @@ class TestRawBatchCommandCursor(IntegrationTest):
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@client_context.require_transactions
def test_aggregate_raw_transaction(self):
c = self.db.test
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener])
with client.start_session() as session:
with session.start_transaction():
batches = list(client[self.db.name].test.aggregate_raw_batches(
[{'$sort': {'_id': 1}}], session=session))
cmd = listener.results['started'][0]
self.assertEqual(cmd.command_name, 'aggregate')
self.assertEqual(cmd.command['$clusterTime'], session.cluster_time)
self.assertEqual(cmd.command['startTransaction'], True)
self.assertEqual(cmd.command['txnNumber'], 1)
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@client_context.require_sessions
@client_context.require_failCommand_fail_point
def test_aggregate_raw_retryable_reads(self):
c = self.db.test
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener],
retryReads=True)
with self.fail_point({
'mode': {'times': 1}, 'data': {'failCommands': ['aggregate'],
'closeConnection': True}}):
batches = list(client[self.db.name].test.aggregate_raw_batches(
[{'$sort': {'_id': 1}}]))
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
self.assertEqual(len(listener.results['started']), 3)
cmds = listener.results['started']
self.assertEqual(cmds[0].command_name, 'aggregate')
self.assertEqual(cmds[1].command_name, 'aggregate')
@client_context.require_version_min(5, 0, -1)
@client_context.require_no_standalone
def test_aggregate_raw_snapshot_reads(self):
c = self.db.get_collection(
"test", write_concern=WriteConcern(w="majority"))
c.drop()
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
c.insert_many(docs)
listener = OvertCommandListener()
client = rs_or_single_client(event_listeners=[listener],
retryReads=True)
db = client[self.db.name]
with client.start_session(snapshot=True) as session:
db.test.distinct('x', {}, session=session)
batches = list(db.test.aggregate_raw_batches(
[{'$sort': {'_id': 1}}], session=session))
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
find_cmd = listener.results['started'][1].command
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
def test_server_error(self):
c = self.db.test
c.drop()

View File

@ -824,6 +824,12 @@ class TestCausalConsistency(unittest.TestCase):
lambda coll, session: coll.count_documents({}, session=session))
self._test_reads(
lambda coll, session: coll.distinct('foo', session=session))
self._test_reads(
lambda coll, session: list(coll.aggregate_raw_batches(
[], session=session)))
self._test_reads(
lambda coll, session: list(coll.find_raw_batches(
{}, session=session)))
# SERVER-40938 removed support for casually consistent mapReduce.
map_reduce_exc = None
@ -841,16 +847,6 @@ class TestCausalConsistency(unittest.TestCase):
'function() {}', 'function() {}', session=session),
exception=map_reduce_exc)
self.assertRaises(
ConfigurationError,
self._test_reads,
lambda coll, session: list(
coll.aggregate_raw_batches([], session=session)))
self.assertRaises(
ConfigurationError,
self._test_reads,
lambda coll, session: list(
coll.find_raw_batches({}, session=session)))
self.assertRaises(
ConfigurationError,
self._test_reads,