mongo-python-driver/test/asynchronous/test_session.py

1396 lines
55 KiB
Python

# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the client_session module."""
from __future__ import annotations
import copy
import sys
import time
from inspect import iscoroutinefunction
from io import BytesIO
from test.asynchronous.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
AsyncUnitTest,
SkipTest,
async_client_context,
unittest,
)
from test.asynchronous.helpers import client_knobs
from test.utils_shared import (
EventListener,
HeartbeatEventListener,
OvertCommandListener,
async_wait_until,
)
from bson import DBRef
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo import ASCENDING, AsyncMongoClient, monitoring
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern
_IS_SYNC = False
# Ignore auth commands like saslStart, so we can assert lsid is in all commands.
class SessionTestListener(EventListener):
def started(self, event):
if not event.command_name.startswith("sasl"):
super().started(event)
def succeeded(self, event):
if not event.command_name.startswith("sasl"):
super().succeeded(event)
def failed(self, event):
if not event.command_name.startswith("sasl"):
super().failed(event)
def first_command_started(self):
assert len(self.started_events) >= 1, "No command-started events"
return self.started_events[0]
def session_ids(client):
return [s.session_id for s in copy.copy(client._topology._session_pool)]
class TestSession(AsyncIntegrationTest):
client2: AsyncMongoClient
sensitive_commands: Set[str]
@async_client_context.require_sessions
async def asyncSetUp(self):
await super().asyncSetUp()
# Create a second client so we can make sure clients cannot share
# sessions.
self.client2 = await self.async_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear()
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
async def asyncTearDown(self):
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
await self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events:
if "lsid" in event.command:
used_lsids.add(event.command["lsid"]["id"])
current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids)
await super().asyncTearDown()
async def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0]
for f, args, kw in ops:
async with client.start_session() as s:
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
kw["session"] = s
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
await f(*args, **kw)
# Test a session cannot be used on another client.
async with self.client2.start_session() as s:
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
kw["session"] = s
with self.assertRaisesRegex(
InvalidOperation,
"Can only use session with the AsyncMongoClient that started it",
):
await f(*args, **kw)
# No explicit session.
for f, args, kw in ops:
listener.reset()
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
lsids = []
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
lsids.append(event.command["lsid"])
if not (sys.platform.startswith("java") or "PyPy" in sys.version):
# Server session was returned to pool. Ignore interpreters with
# non-deterministic GC.
for lsid in lsids:
self.assertIn(
lsid,
session_ids(client),
f"{f.__name__} did not return implicit session to pool",
)
# Explicit bound session
for f, args, kw in ops:
async with client.start_session() as s:
async with s.bind():
listener.reset()
s._materialize()
last_use = s._server_session.last_use
start = time.monotonic()
self.assertLessEqual(last_use, start)
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{f.__name__} sent wrong lsid with {event.command_name}",
)
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
async with s.bind():
await f(*args, **kw)
# Test a session cannot be used on another client.
async with self.client2.start_session() as s:
async with s.bind():
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
with self.assertRaisesRegex(
InvalidOperation,
"Only the client that created the bound session can perform operations within its context block",
):
await f(*args, **kw)
async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
succeeded = False
lsid_set = set()
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
# Retry up to 10 times because there is a known race condition that can cause multiple
# sessions to be used: connection check in happens before session check in
for _ in range(10):
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
(client.db.test.delete_one, [{}]),
(client.db.test.update_one, [{}, {"$set": {"x": 2}}]),
(client.db.test.bulk_write, [[UpdateOne({}, {"$set": {"x": 2}})]]),
(client.db.test.find_one_and_delete, [{}]),
(client.db.test.find_one_and_update, [{}, {"$set": {"x": 1}}]),
(client.db.test.find_one_and_replace, [{}, {}]),
(client.db.test.aggregate, [[{"$limit": 1}]]),
(client.db.test.find, []),
(client.server_info, []),
(client.db.aggregate, [[{"$listLocalSessions": {}}, {"$limit": 1}]]),
(cursor.distinct, ["_id"]),
(client.db.list_collections, []),
]
tasks = []
listener.reset()
async def target(op, *args):
if iscoroutinefunction(op):
res = await op(*args)
else:
res = op(*args)
if isinstance(res, (AsyncCursor, AsyncCommandCursor)):
await res.to_list()
for op, args in ops:
tasks.append(
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
)
await tasks[-1].start()
self.assertEqual(len(tasks), len(ops))
for t in tasks:
await t.join()
self.assertIsNone(t.exc)
lsid_set.clear()
for i in listener.started_events:
if i.command.get("lsid"):
lsid_set.add(i.command.get("lsid")["id"])
if len(lsid_set) == 1:
# Break on first success.
succeeded = True
break
self.assertTrue(succeeded, lsid_set)
async def test_pool_lifo(self):
# "Pool is LIFO" test from Driver Sessions Spec.
a = self.client.start_session()
b = self.client.start_session()
a_id = a.session_id
b_id = b.session_id
await a.end_session()
await b.end_session()
s = self.client.start_session()
self.assertEqual(b_id, s.session_id)
self.assertNotEqual(a_id, s.session_id)
s2 = self.client.start_session()
self.assertEqual(a_id, s2.session_id)
self.assertNotEqual(b_id, s2.session_id)
await s.end_session()
await s2.end_session()
async def test_end_session(self):
# We test elsewhere that using an ended session throws InvalidOperation.
client = self.client
s = client.start_session()
self.assertFalse(s.has_ended)
self.assertIsNotNone(s.session_id)
await s.end_session()
self.assertTrue(s.has_ended)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
s.session_id
async def test_end_sessions(self):
# Use a new client so that the tearDown hook does not error.
listener = SessionTestListener()
client = await self.async_rs_or_single_client(event_listeners=[listener])
# Start many sessions.
sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)]
for s in sessions:
s._materialize()
for s in sessions:
await s.end_session()
# Closing the client should end all sessions and clear the pool.
self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1)
await client.close()
self.assertEqual(len(client._topology._session_pool), 0)
end_sessions = [e for e in listener.started_events if e.command_name == "endSessions"]
self.assertEqual(len(end_sessions), 2)
# Closing again should not send any commands.
listener.reset()
await client.close()
self.assertEqual(len(listener.started_events), 0)
async def test_client(self):
client = self.client
ops: list = [
(client.server_info, [], {}),
(client.list_database_names, [], {}),
(client.drop_database, ["pymongo_test"], {}),
]
await self._test_ops(client, *ops)
async def test_database(self):
client = self.client
db = client.pymongo_test
ops: list = [
(db.command, ["ping"], {}),
(db.create_collection, ["collection"], {}),
(db.list_collection_names, [], {}),
(db.validate_collection, ["collection"], {}),
(db.drop_collection, ["collection"], {}),
(db.dereference, [DBRef("collection", 1)], {}),
]
await self._test_ops(client, *ops)
@staticmethod
def collection_write_ops(coll):
"""Generate database write ops for tests."""
return [
(coll.drop, [], {}),
(coll.bulk_write, [[InsertOne({})]], {}),
(coll.insert_one, [{}], {}),
(coll.insert_many, [[{}, {}]], {}),
(coll.replace_one, [{}, {}], {}),
(coll.update_one, [{}, {"$set": {"a": 1}}], {}),
(coll.update_many, [{}, {"$set": {"a": 1}}], {}),
(coll.delete_one, [{}], {}),
(coll.delete_many, [{}], {}),
(coll.find_one_and_replace, [{}, {}], {}),
(coll.find_one_and_update, [{}, {"$set": {"a": 1}}], {}),
(coll.find_one_and_delete, [{}, {}], {}),
(coll.rename, ["collection2"], {}),
# Drop collection2 between tests of "rename", above.
(coll.database.drop_collection, ["collection2"], {}),
(coll.create_indexes, [[IndexModel("a")]], {}),
(coll.create_index, ["a"], {}),
(coll.drop_index, ["a_1"], {}),
(coll.drop_indexes, [], {}),
(coll.aggregate, [[{"$out": "aggout"}]], {}),
]
async def test_collection(self):
client = self.client
coll = client.pymongo_test.collection
# Test some collection methods - the rest are in test_cursor.
ops = self.collection_write_ops(coll)
ops.extend(
[
(coll.distinct, ["a"], {}),
(coll.find_one, [], {}),
(coll.count_documents, [{}], {}),
(coll.list_indexes, [], {}),
(coll.index_information, [], {}),
(coll.options, [], {}),
(coll.aggregate, [[]], {}),
]
)
await self._test_ops(client, *ops)
async def test_cursor_clone(self):
coll = self.client.pymongo_test.collection
# Ensure some batches.
await coll.insert_many({} for _ in range(10))
self.addAsyncCleanup(coll.drop)
async with self.client.start_session() as s:
cursor = coll.find(session=s)
self.assertIs(cursor.session, s)
clone = cursor.clone()
self.assertIs(clone.session, s)
# No explicit session.
cursor = coll.find(batch_size=2)
await anext(cursor)
# Session is "owned" by cursor.
self.assertIsNone(cursor.session)
self.assertIsNotNone(cursor._session)
clone = cursor.clone()
await anext(clone)
self.assertIsNone(clone.session)
self.assertIsNotNone(clone._session)
self.assertIsNot(cursor._session, clone._session)
await cursor.close()
await clone.close()
async def test_cursor(self):
listener = self.listener
client = self.client
coll = client.pymongo_test.collection
await coll.insert_many([{} for _ in range(1000)])
# Test all cursor methods.
if _IS_SYNC:
# getitem is only supported in the synchronous API
ops = [
("find", lambda session: coll.find(session=session).to_list()),
("getitem", lambda session: coll.find(session=session)[0]),
("distinct", lambda session: coll.find(session=session).distinct("a")),
("explain", lambda session: coll.find(session=session).explain()),
]
else:
ops = [
("find", lambda session: coll.find(session=session).to_list()),
("distinct", lambda session: coll.find(session=session).distinct("a")),
("explain", lambda session: coll.find(session=session).explain()),
]
for name, f in ops:
async with client.start_session() as s:
listener.reset()
await f(session=s)
self.assertGreaterEqual(len(listener.started_events), 1)
for event in listener.started_events:
self.assertIn(
"lsid",
event.command,
f"{name} sent no lsid with {event.command_name}",
)
self.assertEqual(
s.session_id,
event.command["lsid"],
f"{name} sent wrong lsid with {event.command_name}",
)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
await f(session=s)
# No explicit session.
for name, f in ops:
listener.reset()
await f(session=None)
event0 = listener.first_command_started()
self.assertIn("lsid", event0.command, f"{name} sent no lsid with {event0.command_name}")
lsid = event0.command["lsid"]
for event in listener.started_events[1:]:
self.assertIn(
"lsid", event.command, f"{name} sent no lsid with {event.command_name}"
)
self.assertEqual(
lsid,
event.command["lsid"],
f"{name} sent wrong lsid with {event.command_name}",
)
async def test_gridfs(self):
client = self.client
fs = AsyncGridFS(client.pymongo_test)
async def new_file(session=None):
grid_file = fs.new_file(_id=1, filename="f", session=session)
# 1 MB, 5 chunks, to test that each chunk is fetched with same lsid.
await grid_file.write(b"a" * 1048576)
await grid_file.close()
async def find(session=None):
files = await fs.find({"_id": 1}, session=session).to_list()
for f in files:
await f.read()
async def get(session=None):
await (await fs.get(1, session=session)).read()
async def get_version(session=None):
await (await fs.get_version("f", session=session)).read()
async def get_last_version(session=None):
await (await fs.get_last_version("f", session=session)).read()
async def find_list(session=None):
await fs.find(session=session).to_list()
await self._test_ops(
client,
(new_file, [], {}),
(fs.put, [b"data"], {}),
(get, [], {}),
(get_version, [], {}),
(get_last_version, [], {}),
(fs.list, [], {}),
(fs.find_one, [1], {}),
(find_list, [], {}),
(fs.exists, [1], {}),
(find, [], {}),
(fs.delete, [1], {}),
)
async def test_gridfs_bucket(self):
client = self.client
bucket = AsyncGridFSBucket(client.pymongo_test)
async def upload(session=None):
stream = bucket.open_upload_stream("f", session=session)
await stream.write(b"a" * 1048576)
await stream.close()
async def upload_with_id(session=None):
stream = bucket.open_upload_stream_with_id(1, "f1", session=session)
await stream.write(b"a" * 1048576)
await stream.close()
async def open_download_stream(session=None):
stream = await bucket.open_download_stream(1, session=session)
await stream.read()
async def open_download_stream_by_name(session=None):
stream = await bucket.open_download_stream_by_name("f", session=session)
await stream.read()
async def find(session=None):
files = await bucket.find({"_id": 1}, session=session).to_list()
for f in files:
await f.read()
sio = BytesIO()
await self._test_ops(
client,
(upload, [], {}),
(upload_with_id, [], {}),
(bucket.upload_from_stream, ["f", b"data"], {}),
(bucket.upload_from_stream_with_id, [2, "f", b"data"], {}),
(open_download_stream, [], {}),
(open_download_stream_by_name, [], {}),
(bucket.download_to_stream, [1, sio], {}),
(bucket.download_to_stream_by_name, ["f", sio], {}),
(find, [], {}),
(bucket.rename, [1, "f2"], {}),
(bucket.rename_by_name, ["f2", "f3"], {}),
# Delete both files so _test_ops can run these operations twice.
(bucket.delete, [1], {}),
(bucket.delete_by_name, ["f"], {}),
)
async def test_gridfsbucket_cursor(self):
client = self.client
bucket = AsyncGridFSBucket(client.pymongo_test)
for file_id in 1, 2:
stream = bucket.open_upload_stream_with_id(file_id, str(file_id))
await stream.write(b"a" * 1048576)
await stream.close()
async with client.start_session() as s:
cursor = bucket.find(session=s)
async for f in cursor:
await f.read()
self.assertFalse(s.has_ended)
self.assertTrue(s.has_ended)
# No explicit session.
cursor = bucket.find(batch_size=1)
files = [await cursor.next()]
s = cursor._session
self.assertFalse(s.has_ended)
cursor.__del__()
self.assertTrue(s.has_ended)
self.assertIsNone(cursor._session)
# Files are still valid, they use their own sessions.
for f in files:
await f.read()
# Explicit session.
async with client.start_session() as s:
cursor = bucket.find(session=s)
assert cursor.session is not None
s = cursor.session
files = await cursor.to_list()
cursor.__del__()
self.assertFalse(s.has_ended)
for f in files:
await f.read()
for f in files:
# Attempt to read the file again.
await f.seek(0)
with self.assertRaisesRegex(InvalidOperation, "ended session"):
await f.read()
async def test_aggregate(self):
client = self.client
coll = client.pymongo_test.collection
async def agg(session=None):
await (await coll.aggregate([], batchSize=2, session=session)).to_list()
# With empty collection.
await self._test_ops(client, (agg, [], {}))
# Now with documents.
await coll.insert_many([{} for _ in range(10)])
self.addAsyncCleanup(coll.drop)
await self._test_ops(client, (agg, [], {}))
async def test_killcursors(self):
client = self.client
coll = client.pymongo_test.collection
await coll.insert_many([{} for _ in range(10)])
async def explicit_close(session=None):
cursor = coll.find(batch_size=2, session=session)
await anext(cursor)
await cursor.close()
await self._test_ops(client, (explicit_close, [], {}))
async def test_aggregate_error(self):
listener = self.listener
client = self.client
coll = client.pymongo_test.collection
# 3.6.0 mongos only validates the aggregate pipeline when the
# database exists.
await coll.insert_one({})
listener.reset()
with self.assertRaises(OperationFailure):
await coll.aggregate([{"$badOperation": {"bar": 1}}])
event = listener.first_command_started()
self.assertEqual(event.command_name, "aggregate")
lsid = event.command["lsid"]
# Session was returned to pool despite error.
self.assertIn(lsid, session_ids(client))
async def _test_cursor_helper(self, create_cursor, close_cursor):
coll = self.client.pymongo_test.collection
await coll.insert_many([{} for _ in range(1000)])
cursor = await create_cursor(coll, None)
await anext(cursor)
# Session is "owned" by cursor.
session = cursor._session
self.assertIsNotNone(session)
lsid = session.session_id
await anext(cursor)
# Cursor owns its session unto death.
self.assertNotIn(lsid, session_ids(self.client))
await close_cursor(cursor)
self.assertIn(lsid, session_ids(self.client))
# An explicit session is not ended by cursor.close() or list(cursor).
async with self.client.start_session() as s:
cursor = await create_cursor(coll, s)
await anext(cursor)
await close_cursor(cursor)
self.assertFalse(s.has_ended)
lsid = s.session_id
self.assertTrue(s.has_ended)
self.assertIn(lsid, session_ids(self.client))
async def test_cursor_close(self):
async def find(coll, session):
return coll.find(session=session)
await self._test_cursor_helper(find, lambda cursor: cursor.close())
async def test_command_cursor_close(self):
async def aggregate(coll, session):
return await coll.aggregate([], session=session)
await self._test_cursor_helper(aggregate, lambda cursor: cursor.close())
async def test_cursor_del(self):
async def find(coll, session):
return coll.find(session=session)
async def delete(cursor):
return cursor.__del__()
await self._test_cursor_helper(find, delete)
async def test_command_cursor_del(self):
async def aggregate(coll, session):
return await coll.aggregate([], session=session)
async def delete(cursor):
return cursor.__del__()
await self._test_cursor_helper(aggregate, delete)
async def test_cursor_exhaust(self):
async def find(coll, session):
return coll.find(session=session)
await self._test_cursor_helper(find, lambda cursor: cursor.to_list())
async def test_command_cursor_exhaust(self):
async def aggregate(coll, session):
return await coll.aggregate([], session=session)
await self._test_cursor_helper(aggregate, lambda cursor: cursor.to_list())
async def test_cursor_limit_reached(self):
async def find(coll, session):
return coll.find(limit=4, batch_size=2, session=session)
await self._test_cursor_helper(
find,
lambda cursor: cursor.to_list(),
)
async def test_command_cursor_limit_reached(self):
async def aggregate(coll, session):
return await coll.aggregate([], batchSize=900, session=session)
await self._test_cursor_helper(
aggregate,
lambda cursor: cursor.to_list(),
)
async def _test_unacknowledged_ops(self, client, *ops):
listener = client.options.event_listeners[0]
for f, args, kw in ops:
async with client.start_session() as s:
listener.reset()
# In case "f" modifies its inputs.
args = copy.copy(args)
kw = copy.copy(kw)
kw["session"] = s
with self.assertRaises(
ConfigurationError, msg=f"{f.__name__} did not raise ConfigurationError"
):
await f(*args, **kw)
if f.__name__ == "create_collection":
# create_collection runs listCollections first.
event = listener.started_events.pop(0)
self.assertEqual("listCollections", event.command_name)
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
# Should not run any command before raising an error.
self.assertFalse(listener.started_events, f"{f.__name__} sent command")
self.assertTrue(s.has_ended)
# Unacknowledged write without a session does not send an lsid.
for f, args, kw in ops:
listener.reset()
await f(*args, **kw)
self.assertGreaterEqual(len(listener.started_events), 1)
if f.__name__ == "create_collection":
# create_collection runs listCollections first.
event = listener.started_events.pop(0)
self.assertEqual("listCollections", event.command_name)
self.assertIn(
"lsid",
event.command,
f"{f.__name__} sent no lsid with {event.command_name}",
)
for event in listener.started_events:
self.assertNotIn(
"lsid", event.command, f"{f.__name__} sent lsid with {event.command_name}"
)
async def test_unacknowledged_writes(self):
# Ensure the collection exists.
await self.client.pymongo_test.test_unacked_writes.insert_one({})
client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener])
db = client.pymongo_test
coll = db.test_unacked_writes
ops: list = [
(client.drop_database, [db.name], {}),
(db.create_collection, ["collection"], {}),
(db.drop_collection, ["collection"], {}),
]
ops.extend(self.collection_write_ops(coll))
await self._test_unacknowledged_ops(client, *ops)
async def drop_db():
try:
await self.client.drop_database(db.name)
return True
except OperationFailure as exc:
# Try again on BackgroundOperationInProgressForDatabase and
# BackgroundOperationInProgressForNamespace.
if exc.code in (12586, 12587):
return False
raise
await async_wait_until(drop_db, "dropped database after w=0 writes")
async def test_snapshot_incompatible_with_causal_consistency(self):
async with self.client.start_session(causal_consistency=False, snapshot=False):
pass
async with self.client.start_session(causal_consistency=False, snapshot=True):
pass
async with self.client.start_session(causal_consistency=True, snapshot=False):
pass
with self.assertRaises(ConfigurationError):
async with self.client.start_session(causal_consistency=True, snapshot=True):
pass
async def test_session_not_copyable(self):
client = self.client
async with client.start_session() as s:
self.assertRaises(TypeError, lambda: copy.copy(s))
async def test_nested_session_binding(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
session1 = self.client.start_session()
session2 = self.client.start_session()
session1._materialize()
session2._materialize()
try:
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
async with session1.bind(end_session=False):
self.listener.reset()
# Uses bound session1
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
async with session2.bind(end_session=False):
self.listener.reset()
# Uses bound session2
await coll.find_one()
session2_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session2_lsid, session2.session_id)
self.assertNotEqual(session2_lsid, session1.session_id)
self.listener.reset()
# Use bound session1 again
await coll.find_one()
session1_lsid = self.listener.started_events[0].command.get("lsid")
self.assertEqual(session1_lsid, session1.session_id)
self.assertNotEqual(session1_lsid, session2.session_id)
self.listener.reset()
# Uses implicit session
await coll.find_one()
implicit_lsid = self.listener.started_events[0].command.get("lsid")
self.assertIsNotNone(implicit_lsid)
self.assertNotEqual(implicit_lsid, session1.session_id)
self.assertNotEqual(implicit_lsid, session2.session_id)
finally:
await session1.end_session()
await session2.end_session()
async def test_session_binding_end_session(self):
coll = self.client.pymongo_test.test
await coll.insert_one({"x": 1})
async with self.client.start_session().bind() as s1:
await coll.find_one()
self.assertTrue(s1.has_ended)
async with self.client.start_session().bind(end_session=False) as s2:
await coll.find_one()
self.assertFalse(s2.has_ended)
await s2.end_session()
async def test_getmore_preserves_lsid_after_session_support_lost(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
coll = client.pymongo_test.test
await coll.drop()
await coll.insert_many([{"x": i} for i in range(10)])
self.addAsyncCleanup(coll.drop)
async with client.start_session() as s:
cursor = coll.find({}, batch_size=2, session=s)
await anext(cursor)
find_event = next(e for e in listener.started_events if e.command_name == "find")
lsid = find_event.command["lsid"]
# Simulate a node stepping down: mark idle connections as not supporting sessions.
for server in client._topology._servers.values():
for conn in server.pool.conns:
conn.supports_sessions = False
listener.reset()
await cursor.to_list()
getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
for event in getmore_events:
self.assertIn(
"lsid", event.command, "getMore must include lsid when session is materialized"
)
self.assertEqual(
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
)
class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener
client: AsyncMongoClient
@async_client_context.require_sessions
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
@async_client_context.require_no_standalone
async def test_core(self):
async with self.client.start_session() as sess:
self.assertIsNone(sess.cluster_time)
self.assertIsNone(sess.operation_time)
self.listener.reset()
await self.client.pymongo_test.test.find_one(session=sess)
started = self.listener.started_events[0]
cmd = started.command
self.assertIsNone(cmd.get("readConcern"))
op_time = sess.operation_time
self.assertIsNotNone(op_time)
succeeded = self.listener.succeeded_events[0]
reply = succeeded.reply
self.assertEqual(op_time, reply.get("operationTime"))
# No explicit session
await self.client.pymongo_test.test.insert_one({})
self.assertEqual(sess.operation_time, op_time)
self.listener.reset()
try:
await self.client.pymongo_test.command("doesntexist", session=sess)
except:
pass
failed = self.listener.failed_events[0]
failed_op_time = failed.failure.get("operationTime")
# Some older builds of MongoDB 3.5 / 3.6 return None for
# operationTime when a command fails. Make sure we don't
# change operation_time to None.
if failed_op_time is None:
self.assertIsNotNone(sess.operation_time)
else:
self.assertEqual(sess.operation_time, failed_op_time)
async with self.client.start_session() as sess2:
self.assertIsNone(sess2.cluster_time)
self.assertIsNone(sess2.operation_time)
self.assertRaises(TypeError, sess2.advance_cluster_time, 1)
self.assertRaises(ValueError, sess2.advance_cluster_time, {})
self.assertRaises(TypeError, sess2.advance_operation_time, 1)
# No error
assert sess.cluster_time is not None
assert sess.operation_time is not None
sess2.advance_cluster_time(sess.cluster_time)
sess2.advance_operation_time(sess.operation_time)
self.assertEqual(sess.cluster_time, sess2.cluster_time)
self.assertEqual(sess.operation_time, sess2.operation_time)
async def _test_reads(self, op, exception=None):
coll = self.client.pymongo_test.test
async with self.client.start_session() as sess:
await coll.find_one({}, session=sess)
operation_time = sess.operation_time
self.assertIsNotNone(operation_time)
self.listener.reset()
if exception:
with self.assertRaises(exception):
await op(coll, sess)
else:
await op(coll, sess)
act = (
self.listener.started_events[0]
.command.get("readConcern", {})
.get("afterClusterTime")
)
self.assertEqual(operation_time, act)
@async_client_context.require_no_standalone
async def test_reads(self):
# Make sure the collection exists.
await self.client.pymongo_test.test.insert_one({})
async def aggregate(coll, session):
return await (await coll.aggregate([], session=session)).to_list()
async def aggregate_raw(coll, session):
return await (await coll.aggregate_raw_batches([], session=session)).to_list()
async def find_raw(coll, session):
return await coll.find_raw_batches({}, session=session).to_list()
await self._test_reads(aggregate)
await self._test_reads(lambda coll, session: coll.find({}, session=session).to_list())
await self._test_reads(lambda coll, session: coll.find_one({}, session=session))
await self._test_reads(lambda coll, session: coll.count_documents({}, session=session))
await self._test_reads(lambda coll, session: coll.distinct("foo", session=session))
await self._test_reads(aggregate_raw)
await self._test_reads(find_raw)
with self.assertRaises(ConfigurationError):
await self._test_reads(
lambda coll, session: coll.estimated_document_count(session=session)
)
async def _test_writes(self, op):
coll = self.client.pymongo_test.test
async with self.client.start_session() as sess:
await op(coll, sess)
operation_time = sess.operation_time
self.assertIsNotNone(operation_time)
self.listener.reset()
await coll.find_one({}, session=sess)
act = (
self.listener.started_events[0]
.command.get("readConcern", {})
.get("afterClusterTime")
)
self.assertEqual(operation_time, act)
@async_client_context.require_no_standalone
async def test_writes(self):
await self._test_writes(
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
await self._test_writes(lambda coll, session: coll.insert_one({}, session=session))
await self._test_writes(lambda coll, session: coll.insert_many([{}], session=session))
await self._test_writes(
lambda coll, session: coll.replace_one({"_id": 1}, {"x": 1}, session=session)
)
await self._test_writes(
lambda coll, session: coll.update_one({}, {"$set": {"X": 1}}, session=session)
)
await self._test_writes(
lambda coll, session: coll.update_many({}, {"$set": {"x": 1}}, session=session)
)
await self._test_writes(lambda coll, session: coll.delete_one({}, session=session))
await self._test_writes(lambda coll, session: coll.delete_many({}, session=session))
await self._test_writes(
lambda coll, session: coll.find_one_and_replace({"x": 1}, {"y": 1}, session=session)
)
await self._test_writes(
lambda coll, session: coll.find_one_and_update(
{"y": 1}, {"$set": {"x": 1}}, session=session
)
)
await self._test_writes(
lambda coll, session: coll.find_one_and_delete({"x": 1}, session=session)
)
await self._test_writes(lambda coll, session: coll.create_index("foo", session=session))
await self._test_writes(
lambda coll, session: coll.create_indexes(
[IndexModel([("bar", ASCENDING)])], session=session
)
)
await self._test_writes(lambda coll, session: coll.drop_index("foo_1", session=session))
await self._test_writes(lambda coll, session: coll.drop_indexes(session=session))
async def _test_no_read_concern(self, op):
coll = self.client.pymongo_test.test
async with self.client.start_session() as sess:
await coll.find_one({}, session=sess)
operation_time = sess.operation_time
self.assertIsNotNone(operation_time)
self.listener.reset()
await op(coll, sess)
rc = self.listener.started_events[0].command.get("readConcern")
self.assertIsNone(rc)
@async_client_context.require_no_standalone
async def test_writes_do_not_include_read_concern(self):
await self._test_no_read_concern(
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
await self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session))
await self._test_no_read_concern(
lambda coll, session: coll.insert_many([{}], session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.replace_one({"_id": 1}, {"x": 1}, session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.update_one({}, {"$set": {"X": 1}}, session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.update_many({}, {"$set": {"x": 1}}, session=session)
)
await self._test_no_read_concern(lambda coll, session: coll.delete_one({}, session=session))
await self._test_no_read_concern(
lambda coll, session: coll.delete_many({}, session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.find_one_and_replace({"x": 1}, {"y": 1}, session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.find_one_and_update(
{"y": 1}, {"$set": {"x": 1}}, session=session
)
)
await self._test_no_read_concern(
lambda coll, session: coll.find_one_and_delete({"x": 1}, session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.create_index("foo", session=session)
)
await self._test_no_read_concern(
lambda coll, session: coll.create_indexes(
[IndexModel([("bar", ASCENDING)])], session=session
)
)
await self._test_no_read_concern(
lambda coll, session: coll.drop_index("foo_1", session=session)
)
await self._test_no_read_concern(lambda coll, session: coll.drop_indexes(session=session))
# Not a write, but explain also doesn't support readConcern.
await self._test_no_read_concern(
lambda coll, session: coll.find({}, session=session).explain()
)
@async_client_context.require_no_standalone
async def test_get_more_does_not_include_read_concern(self):
coll = self.client.pymongo_test.test
async with self.client.start_session() as sess:
await coll.find_one({}, session=sess)
operation_time = sess.operation_time
self.assertIsNotNone(operation_time)
await coll.insert_many([{}, {}])
cursor = coll.find({}).batch_size(1)
await anext(cursor)
self.listener.reset()
await cursor.to_list()
started = self.listener.started_events[0]
self.assertEqual(started.command_name, "getMore")
self.assertIsNone(started.command.get("readConcern"))
async def test_session_not_causal(self):
async with self.client.start_session(causal_consistency=False) as s:
await self.client.pymongo_test.test.insert_one({}, session=s)
self.listener.reset()
await self.client.pymongo_test.test.find_one({}, session=s)
act = (
self.listener.started_events[0]
.command.get("readConcern", {})
.get("afterClusterTime")
)
self.assertIsNone(act)
@async_client_context.require_standalone
async def test_server_not_causal(self):
async with self.client.start_session(causal_consistency=True) as s:
await self.client.pymongo_test.test.insert_one({}, session=s)
self.listener.reset()
await self.client.pymongo_test.test.find_one({}, session=s)
act = (
self.listener.started_events[0]
.command.get("readConcern", {})
.get("afterClusterTime")
)
self.assertIsNone(act)
@async_client_context.require_no_standalone
async def test_read_concern(self):
async with self.client.start_session(causal_consistency=True) as s:
coll = self.client.pymongo_test.test
await coll.insert_one({}, session=s)
self.listener.reset()
await coll.find_one({}, session=s)
read_concern = self.listener.started_events[0].command.get("readConcern")
self.assertIsNotNone(read_concern)
self.assertIsNone(read_concern.get("level"))
self.assertIsNotNone(read_concern.get("afterClusterTime"))
coll = coll.with_options(read_concern=ReadConcern("majority"))
self.listener.reset()
await coll.find_one({}, session=s)
read_concern = self.listener.started_events[0].command.get("readConcern")
self.assertIsNotNone(read_concern)
self.assertEqual(read_concern.get("level"), "majority")
self.assertIsNotNone(read_concern.get("afterClusterTime"))
@async_client_context.require_no_standalone
async def test_cluster_time_with_server_support(self):
await self.client.pymongo_test.test.insert_one({})
self.listener.reset()
await self.client.pymongo_test.test.find_one({})
after_cluster_time = self.listener.started_events[0].command.get("$clusterTime")
self.assertIsNotNone(after_cluster_time)
@async_client_context.require_standalone
async def test_cluster_time_no_server_support(self):
await self.client.pymongo_test.test.insert_one({})
self.listener.reset()
await self.client.pymongo_test.test.find_one({})
after_cluster_time = self.listener.started_events[0].command.get("$clusterTime")
self.assertIsNone(after_cluster_time)
class TestClusterTime(AsyncIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
if "$clusterTime" not in (await async_client_context.hello):
raise SkipTest("$clusterTime not supported")
# Sessions prose test: 3) $clusterTime in commands
async def test_cluster_time(self):
listener = SessionTestListener()
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
await collection.insert_many([{} for _ in range(10)])
self.addAsyncCleanup(collection.drop)
self.addAsyncCleanup(client.pymongo_test.collection2.drop)
async def rename_and_drop():
# Ensure collection exists.
await collection.insert_one({})
await collection.rename("collection2")
await client.pymongo_test.collection2.drop()
async def insert_and_find():
cursor = collection.find().batch_size(1)
for _ in range(10):
# Advance the cluster time.
await collection.insert_one({})
await anext(cursor)
await cursor.close()
async def insert_and_aggregate():
cursor = (await collection.aggregate([], batchSize=1)).batch_size(1)
for _ in range(5):
# Advance the cluster time.
await collection.insert_one({})
await anext(cursor)
await cursor.close()
async def aggregate():
await (await collection.aggregate([])).to_list()
ops = [
# Tests from Driver Sessions Spec.
("ping", lambda: client.admin.command("ping")),
("aggregate", lambda: aggregate()),
("find", lambda: collection.find().to_list()),
("insert_one", lambda: collection.insert_one({})),
# Additional PyMongo tests.
("insert_and_find", insert_and_find),
("insert_and_aggregate", insert_and_aggregate),
("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})),
("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})),
("delete_one", lambda: collection.delete_one({})),
("delete_many", lambda: collection.delete_many({})),
("bulk_write", lambda: collection.bulk_write([InsertOne({})])),
("rename_and_drop", rename_and_drop),
]
for _name, f in ops:
listener.reset()
# Call f() twice, insert to advance clusterTime, call f() again.
await f()
await f()
await collection.insert_one({})
await f()
self.assertGreaterEqual(len(listener.started_events), 1)
for i, event in enumerate(listener.started_events):
self.assertIn(
"$clusterTime",
event.command,
f"{f.__name__} sent no $clusterTime with {event.command_name}",
)
if i > 0:
succeeded = listener.succeeded_events[i - 1]
self.assertIn(
"$clusterTime",
succeeded.reply,
f"{f.__name__} received no $clusterTime with {succeeded.command_name}",
)
self.assertTrue(
event.command["$clusterTime"]["clusterTime"]
>= succeeded.reply["$clusterTime"]["clusterTime"],
f"{f.__name__} sent wrong $clusterTime with {event.command_name}",
)
# Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands
async def test_cluster_time_not_used_by_sdam(self):
heartbeat_listener = HeartbeatEventListener()
cmd_listener = OvertCommandListener()
with client_knobs(min_heartbeat_interval=0.01):
c1 = await self.async_single_client(
event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10
)
cluster_time = (await c1.admin.command({"ping": 1}))["$clusterTime"]
self.assertEqual(c1._topology.max_cluster_time(), cluster_time)
# Advance the server's $clusterTime by performing an insert via another client.
await self.db.test.insert_one({"advance": "$clusterTime"})
# Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events.
heartbeat_listener.reset()
async def next_heartbeat():
events = heartbeat_listener.events
for i in range(len(events) - 1):
if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent):
if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent):
return True
return False
await async_wait_until(
next_heartbeat, "never found pair of heartbeat started + succeeded events"
)
# Assert that C1's max $clusterTime is still the same and has not been updated by SDAM.
cmd_listener.reset()
await c1.admin.command({"ping": 1})
started = cmd_listener.started_events[0]
self.assertEqual(started.command_name, "ping")
self.assertEqual(started.command["$clusterTime"], cluster_time)
if __name__ == "__main__":
unittest.main()