PYTHON-2389 Add session support to find_raw_batches and aggregate_raw_batches (#658)
This commit is contained in:
parent
b823b95de1
commit
0e0c4fd944
@ -1484,21 +1484,16 @@ class Collection(common.BaseObject):
|
||||
>>> for batch in cursor:
|
||||
... print(bson.decode_all(batch))
|
||||
|
||||
.. note:: find_raw_batches does not support sessions or auto
|
||||
encryption.
|
||||
.. note:: find_raw_batches does not support auto encryption.
|
||||
|
||||
.. versionchanged:: 3.12
|
||||
Instead of ignoring the user-specified read concern, this method
|
||||
now sends it to the server when connected to MongoDB 3.6+.
|
||||
|
||||
Added session support.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
# OP_MSG with document stream returns is required to support
|
||||
# sessions.
|
||||
if "session" in kwargs:
|
||||
raise ConfigurationError(
|
||||
"find_raw_batches does not support sessions")
|
||||
|
||||
# OP_MSG is required to support encryption.
|
||||
if self.__database.client._encrypter:
|
||||
raise InvalidOperation(
|
||||
@ -2256,7 +2251,7 @@ class Collection(common.BaseObject):
|
||||
explicit_session=session is not None,
|
||||
**kwargs)
|
||||
|
||||
def aggregate_raw_batches(self, pipeline, **kwargs):
|
||||
def aggregate_raw_batches(self, pipeline, session=None, **kwargs):
|
||||
"""Perform an aggregation and retrieve batches of raw BSON.
|
||||
|
||||
Similar to the :meth:`aggregate` method but returns a
|
||||
@ -2273,28 +2268,25 @@ class Collection(common.BaseObject):
|
||||
>>> for batch in cursor:
|
||||
... print(bson.decode_all(batch))
|
||||
|
||||
.. note:: aggregate_raw_batches does not support sessions or auto
|
||||
encryption.
|
||||
.. note:: aggregate_raw_batches does not support auto encryption.
|
||||
|
||||
.. versionchanged:: 3.12
|
||||
Added session support.
|
||||
|
||||
.. versionadded:: 3.6
|
||||
"""
|
||||
# OP_MSG with document stream returns is required to support
|
||||
# sessions.
|
||||
if "session" in kwargs:
|
||||
raise ConfigurationError(
|
||||
"aggregate_raw_batches does not support sessions")
|
||||
|
||||
# OP_MSG is required to support encryption.
|
||||
if self.__database.client._encrypter:
|
||||
raise InvalidOperation(
|
||||
"aggregate_raw_batches does not support auto encryption")
|
||||
|
||||
return self._aggregate(_CollectionRawAggregationCommand,
|
||||
pipeline,
|
||||
RawBatchCommandCursor,
|
||||
session=None,
|
||||
explicit_session=False,
|
||||
**kwargs)
|
||||
with self.__database.client._tmp_session(session, close=False) as s:
|
||||
return self._aggregate(_CollectionRawAggregationCommand,
|
||||
pipeline,
|
||||
RawBatchCommandCursor,
|
||||
session=s,
|
||||
explicit_session=session is not None,
|
||||
**kwargs)
|
||||
|
||||
def watch(self, pipeline=None, full_document=None, resume_after=None,
|
||||
max_await_time_ms=None, batch_size=None, collation=None,
|
||||
|
||||
@ -468,6 +468,8 @@ class _RawBatchQuery(_Query):
|
||||
|
||||
class _RawBatchGetMore(_GetMore):
|
||||
def use_command(self, sock_info):
|
||||
# Compatibility checks.
|
||||
super(_RawBatchGetMore, self).use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
|
||||
@ -27,6 +27,7 @@ sys.path[0:0] = [""]
|
||||
|
||||
from bson import decode_all
|
||||
from bson.code import Code
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from pymongo import (ASCENDING,
|
||||
DESCENDING,
|
||||
@ -44,6 +45,7 @@ from test import (client_context,
|
||||
unittest,
|
||||
IntegrationTest)
|
||||
from test.utils import (EventListener,
|
||||
OvertCommandListener,
|
||||
ignore_deprecations,
|
||||
rs_or_single_client,
|
||||
WhiteListEventListener)
|
||||
@ -1466,6 +1468,76 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_find_raw_transaction(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
batches = list(client[self.db.name].test.find_raw_batches(
|
||||
session=session).sort('_id'))
|
||||
cmd = listener.results['started'][0]
|
||||
self.assertEqual(cmd.command_name, 'find')
|
||||
self.assertEqual(cmd.command['$clusterTime'],
|
||||
decode_all(session.cluster_time.raw)[0])
|
||||
self.assertEqual(cmd.command['startTransaction'], True)
|
||||
self.assertEqual(cmd.command['txnNumber'], 1)
|
||||
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@client_context.require_sessions
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_find_raw_retryable_reads(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener],
|
||||
retryReads=True)
|
||||
with self.fail_point({
|
||||
'mode': {'times': 1}, 'data': {'failCommands': ['find'],
|
||||
'closeConnection': True}}):
|
||||
batches = list(
|
||||
client[self.db.name].test.find_raw_batches().sort('_id'))
|
||||
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
self.assertEqual(len(listener.results['started']), 2)
|
||||
for cmd in listener.results['started']:
|
||||
self.assertEqual(cmd.command_name, 'find')
|
||||
|
||||
@client_context.require_version_min(5, 0, 0)
|
||||
@client_context.require_no_standalone
|
||||
def test_find_raw_snapshot_reads(self):
|
||||
c = self.db.get_collection(
|
||||
"test", write_concern=WriteConcern(w="majority"))
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener],
|
||||
retryReads=True)
|
||||
db = client[self.db.name]
|
||||
with client.start_session(snapshot=True) as session:
|
||||
db.test.distinct('x', {}, session=session)
|
||||
batches = list(db.test.find_raw_batches(
|
||||
session=session).sort('_id'))
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
find_cmd = listener.results['started'][1].command
|
||||
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
|
||||
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
|
||||
|
||||
def test_explain(self):
|
||||
c = self.db.test
|
||||
c.insert_one({})
|
||||
@ -1590,6 +1662,75 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_aggregate_raw_transaction(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener])
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
batches = list(client[self.db.name].test.aggregate_raw_batches(
|
||||
[{'$sort': {'_id': 1}}], session=session))
|
||||
cmd = listener.results['started'][0]
|
||||
self.assertEqual(cmd.command_name, 'aggregate')
|
||||
self.assertEqual(cmd.command['$clusterTime'], session.cluster_time)
|
||||
self.assertEqual(cmd.command['startTransaction'], True)
|
||||
self.assertEqual(cmd.command['txnNumber'], 1)
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@client_context.require_sessions
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_aggregate_raw_retryable_reads(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener],
|
||||
retryReads=True)
|
||||
with self.fail_point({
|
||||
'mode': {'times': 1}, 'data': {'failCommands': ['aggregate'],
|
||||
'closeConnection': True}}):
|
||||
batches = list(client[self.db.name].test.aggregate_raw_batches(
|
||||
[{'$sort': {'_id': 1}}]))
|
||||
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
self.assertEqual(len(listener.results['started']), 3)
|
||||
cmds = listener.results['started']
|
||||
self.assertEqual(cmds[0].command_name, 'aggregate')
|
||||
self.assertEqual(cmds[1].command_name, 'aggregate')
|
||||
|
||||
@client_context.require_version_min(5, 0, -1)
|
||||
@client_context.require_no_standalone
|
||||
def test_aggregate_raw_snapshot_reads(self):
|
||||
c = self.db.get_collection(
|
||||
"test", write_concern=WriteConcern(w="majority"))
|
||||
c.drop()
|
||||
docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(event_listeners=[listener],
|
||||
retryReads=True)
|
||||
db = client[self.db.name]
|
||||
with client.start_session(snapshot=True) as session:
|
||||
db.test.distinct('x', {}, session=session)
|
||||
batches = list(db.test.aggregate_raw_batches(
|
||||
[{'$sort': {'_id': 1}}], session=session))
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
find_cmd = listener.results['started'][1].command
|
||||
self.assertEqual(find_cmd['readConcern']['level'], 'snapshot')
|
||||
self.assertIsNotNone(find_cmd['readConcern']['atClusterTime'])
|
||||
|
||||
def test_server_error(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
|
||||
@ -824,6 +824,12 @@ class TestCausalConsistency(unittest.TestCase):
|
||||
lambda coll, session: coll.count_documents({}, session=session))
|
||||
self._test_reads(
|
||||
lambda coll, session: coll.distinct('foo', session=session))
|
||||
self._test_reads(
|
||||
lambda coll, session: list(coll.aggregate_raw_batches(
|
||||
[], session=session)))
|
||||
self._test_reads(
|
||||
lambda coll, session: list(coll.find_raw_batches(
|
||||
{}, session=session)))
|
||||
|
||||
# SERVER-40938 removed support for casually consistent mapReduce.
|
||||
map_reduce_exc = None
|
||||
@ -841,16 +847,6 @@ class TestCausalConsistency(unittest.TestCase):
|
||||
'function() {}', 'function() {}', session=session),
|
||||
exception=map_reduce_exc)
|
||||
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self._test_reads,
|
||||
lambda coll, session: list(
|
||||
coll.aggregate_raw_batches([], session=session)))
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self._test_reads,
|
||||
lambda coll, session: list(
|
||||
coll.find_raw_batches({}, session=session)))
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self._test_reads,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user