PYTHON-4533 - Convert test/test_transactions.py to async (#1732)

This commit is contained in:
Noah Stapp 2024-07-15 16:45:59 -07:00 committed by GitHub
parent 875688cecc
commit 1b3dea3f03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1311 additions and 25 deletions

View File

@ -154,7 +154,7 @@ class AsyncGridFS:
"""
async with AsyncGridIn(self._collection, **kwargs) as grid_file:
await grid_file.write(data)
return await grid_file._id
return grid_file._id
async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut:
"""Get a file from GridFS by ``"_id"``.

View File

@ -144,6 +144,7 @@ from typing import (
Any,
AsyncContextManager,
Callable,
Coroutine,
Mapping,
MutableMapping,
NoReturn,
@ -598,7 +599,7 @@ class AsyncClientSession:
async def with_transaction(
self,
callback: Callable[[AsyncClientSession], _T],
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
@ -693,7 +694,7 @@ class AsyncClientSession:
read_concern, write_concern, read_preference, max_commit_time_ms
)
try:
ret = callback(self)
ret = await callback(self)
except Exception as exc:
if self.in_transaction:
await self.abort_transaction()

View File

@ -0,0 +1,585 @@
# Copyright 2018-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Execute Transactions Spec tests."""
from __future__ import annotations
import os
import sys
from io import BytesIO
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
sys.path[0:0] = [""]
from test.asynchronous import async_client_context, unittest
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from test.utils import (
OvertCommandListener,
async_rs_client,
async_single_client,
wait_until,
)
from typing import List
from bson import encode
from bson.raw_bson import RawBSONDocument
from pymongo import WriteConcern
from pymongo.asynchronous import client_session
from pymongo.asynchronous.client_session import TransactionOptions
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
ConnectionFailure,
InvalidOperation,
OperationFailure,
)
from pymongo.operations import IndexModel, InsertOne
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
_IS_SYNC = False
_TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG")
# Max number of operations to perform after a transaction to prove unpinning
# occurs. Chosen so that there's a low false positive rate. With 2 mongoses,
# 50 attempts yields a one in a quadrillion chance of a false positive
# (1/(0.5^50)).
UNPIN_TEST_MAX_ATTEMPTS = 50
class AsyncTransactionsBase(AsyncSpecRunner):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
if async_client_context.supports_transactions():
for address in async_client_context.mongoses:
cls.mongos_clients.append(await async_single_client("{}:{}".format(*address)))
@classmethod
async def _tearDown_class(cls):
for client in cls.mongos_clients:
await client.aclose()
await super()._tearDown_class()
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
if (
"secondary" in self.id()
and not async_client_context.is_mongos
and not async_client_context.has_secondaries
):
raise unittest.SkipTest("No secondaries")
class TestTransactions(AsyncTransactionsBase):
RUN_ON_SERVERLESS = True
@async_client_context.require_transactions
def test_transaction_options_validation(self):
default_options = TransactionOptions()
self.assertIsNone(default_options.read_concern)
self.assertIsNone(default_options.write_concern)
self.assertIsNone(default_options.read_preference)
self.assertIsNone(default_options.max_commit_time_ms)
# No error when valid options are provided.
TransactionOptions(
read_concern=ReadConcern(),
write_concern=WriteConcern(),
read_preference=ReadPreference.PRIMARY,
max_commit_time_ms=10000,
)
with self.assertRaisesRegex(TypeError, "read_concern must be "):
TransactionOptions(read_concern={}) # type: ignore
with self.assertRaisesRegex(TypeError, "write_concern must be "):
TransactionOptions(write_concern={}) # type: ignore
with self.assertRaisesRegex(
ConfigurationError, "transactions do not support unacknowledged write concern"
):
TransactionOptions(write_concern=WriteConcern(w=0))
with self.assertRaisesRegex(TypeError, "is not valid for read_preference"):
TransactionOptions(read_preference={}) # type: ignore
with self.assertRaisesRegex(TypeError, "max_commit_time_ms must be an integer or None"):
TransactionOptions(max_commit_time_ms="10000") # type: ignore
@async_client_context.require_transactions
async def test_transaction_write_concern_override(self):
"""Test txn overrides Client/Database/Collection write_concern."""
client = await async_rs_client(w=0)
self.addAsyncCleanup(client.aclose)
db = client.test
coll = db.test
await coll.insert_one({})
async with client.start_session() as s:
async with await s.start_transaction(write_concern=WriteConcern(w=1)):
self.assertTrue((await coll.insert_one({}, session=s)).acknowledged)
self.assertTrue((await coll.insert_many([{}, {}], session=s)).acknowledged)
self.assertTrue((await coll.bulk_write([InsertOne({})], session=s)).acknowledged)
self.assertTrue((await coll.replace_one({}, {}, session=s)).acknowledged)
self.assertTrue(
(await coll.update_one({}, {"$set": {"a": 1}}, session=s)).acknowledged
)
self.assertTrue(
(await coll.update_many({}, {"$set": {"a": 1}}, session=s)).acknowledged
)
self.assertTrue((await coll.delete_one({}, session=s)).acknowledged)
self.assertTrue((await coll.delete_many({}, session=s)).acknowledged)
await coll.find_one_and_delete({}, session=s)
await coll.find_one_and_replace({}, {}, session=s)
await coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s)
unsupported_txn_writes: list = [
(client.drop_database, [db.name], {}),
(db.drop_collection, ["collection"], {}),
(coll.drop, [], {}),
(coll.rename, ["collection2"], {}),
# Drop collection2 between tests of "rename", above.
(coll.database.drop_collection, ["collection2"], {}),
(coll.create_indexes, [[IndexModel("a")]], {}),
(coll.create_index, ["a"], {}),
(coll.drop_index, ["a_1"], {}),
(coll.drop_indexes, [], {}),
(coll.aggregate, [[{"$out": "aggout"}]], {}),
]
# Creating a collection in a transaction requires MongoDB 4.4+.
if async_client_context.version < (4, 3, 4):
unsupported_txn_writes.extend(
[
(db.create_collection, ["collection"], {}),
]
)
for op in unsupported_txn_writes:
op, args, kwargs = op
async with client.start_session() as s:
kwargs["session"] = s
await s.start_transaction(write_concern=WriteConcern(w=1))
with self.assertRaises(OperationFailure):
await op(*args, **kwargs)
await s.abort_transaction()
@async_client_context.require_transactions
@async_client_context.require_multiple_mongoses
async def test_unpin_for_next_transaction(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.aclose)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
await coll.insert_one({}, session=s)
addresses = set()
for _ in range(UNPIN_TEST_MAX_ATTEMPTS):
async with await s.start_transaction():
cursor = await coll.find({}, session=s)
self.assertTrue(await anext(cursor))
addresses.add(cursor.address)
# Break early if we can.
if len(addresses) > 1:
break
self.assertGreater(len(addresses), 1)
@async_client_context.require_transactions
@async_client_context.require_multiple_mongoses
async def test_unpin_for_non_transaction_operation(self):
# Increase localThresholdMS and wait until both nodes are discovered
# to avoid false positives.
client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.aclose)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
await coll.insert_one({}, session=s)
addresses = set()
for _ in range(UNPIN_TEST_MAX_ATTEMPTS):
cursor = await coll.find({}, session=s)
self.assertTrue(await anext(cursor))
addresses.add(cursor.address)
# Break early if we can.
if len(addresses) > 1:
break
self.assertGreater(len(addresses), 1)
@async_client_context.require_transactions
@async_client_context.require_version_min(4, 3, 4)
async def test_create_collection(self):
client = async_client_context.client
db = client.pymongo_test
coll = db.test_create_collection
self.addAsyncCleanup(coll.drop)
# Use with_transaction to avoid StaleConfig errors on sharded clusters.
async def create_and_insert(session):
coll2 = await db.create_collection(coll.name, session=session)
self.assertEqual(coll, coll2)
await coll.insert_one({}, session=session)
async with client.start_session() as s:
await s.with_transaction(create_and_insert)
# Outside a transaction we raise CollectionInvalid on existing colls.
with self.assertRaises(CollectionInvalid):
await db.create_collection(coll.name)
# Inside a transaction we raise the OperationFailure from create.
async with client.start_session() as s:
await s.start_transaction()
with self.assertRaises(OperationFailure) as ctx:
await db.create_collection(coll.name, session=s)
self.assertEqual(ctx.exception.code, 48) # NamespaceExists
@async_client_context.require_transactions
async def test_gridfs_does_not_support_transactions(self):
client = async_client_context.client
db = client.pymongo_test
gfs = AsyncGridFS(db)
bucket = AsyncGridFSBucket(db)
async def gridfs_find(*args, **kwargs):
return await gfs.find(*args, **kwargs).next()
async def gridfs_open_upload_stream(*args, **kwargs):
await (await bucket.open_upload_stream(*args, **kwargs)).write(b"1")
gridfs_ops = [
(gfs.put, (b"123",)),
(gfs.get, (1,)),
(gfs.get_version, ("name",)),
(gfs.get_last_version, ("name",)),
(gfs.delete, (1,)),
(gfs.list, ()),
(gfs.find_one, ()),
(gridfs_find, ()),
(gfs.exists, ()),
(gridfs_open_upload_stream, ("name",)),
(
bucket.upload_from_stream,
(
"name",
b"data",
),
),
(
bucket.download_to_stream,
(
1,
BytesIO(),
),
),
(
bucket.download_to_stream_by_name,
(
"name",
BytesIO(),
),
),
(bucket.delete, (1,)),
(bucket.find, ()),
(bucket.open_download_stream, (1,)),
(bucket.open_download_stream_by_name, ("name",)),
(
bucket.rename,
(
1,
"new-name",
),
),
]
async with client.start_session() as s, await s.start_transaction():
for op, args in gridfs_ops:
with self.assertRaisesRegex(
InvalidOperation,
"GridFS does not support multi-document transactions",
):
await op(*args, session=s) # type: ignore
# Require 4.2+ for large (16MB+) transactions.
@async_client_context.require_version_min(4, 2)
@async_client_context.require_transactions
@unittest.skipIf(sys.platform == "win32", "Our Windows machines are too slow to pass this test")
async def test_transaction_starts_with_batched_write(self):
if "PyPy" in sys.version and async_client_context.tls:
self.skipTest(
"PYTHON-2937 PyPy is so slow sending large "
"messages over TLS that this test fails"
)
# Start a transaction with a batch of operations that needs to be
# split.
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
coll = client[self.db.name].test
await coll.delete_many({})
listener.reset()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(coll.drop)
large_str = "\0" * (1 * 1024 * 1024)
ops: List[InsertOne[RawBSONDocument]] = [
InsertOne(RawBSONDocument(encode({"a": large_str}))) for _ in range(48)
]
async with client.start_session() as session:
async with await session.start_transaction():
await coll.bulk_write(ops, session=session) # type: ignore[arg-type]
# Assert commands were constructed properly.
self.assertEqual(
["insert", "insert", "commitTransaction"], listener.started_command_names()
)
first_cmd = listener.started_events[0].command
self.assertTrue(first_cmd["startTransaction"])
lsid = first_cmd["lsid"]
txn_number = first_cmd["txnNumber"]
for event in listener.started_events[1:]:
self.assertNotIn("startTransaction", event.command)
self.assertEqual(lsid, event.command["lsid"])
self.assertEqual(txn_number, event.command["txnNumber"])
self.assertEqual(48, await coll.count_documents({}))
@async_client_context.require_transactions
async def test_transaction_direct_connection(self):
client = await async_single_client()
self.addAsyncCleanup(client.aclose)
coll = client.pymongo_test.test
# Make sure the collection exists.
await coll.insert_one({})
self.assertEqual(client.topology_description.topology_type_name, "Single")
ops = [
(coll.bulk_write, [[InsertOne[dict]({})]]),
(coll.insert_one, [{}]),
(coll.insert_many, [[{}, {}]]),
(coll.replace_one, [{}, {}]),
(coll.update_one, [{}, {"$set": {"a": 1}}]),
(coll.update_many, [{}, {"$set": {"a": 1}}]),
(coll.delete_one, [{}]),
(coll.delete_many, [{}]),
(coll.find_one_and_replace, [{}, {}]),
(coll.find_one_and_update, [{}, {"$set": {"a": 1}}]),
(coll.find_one_and_delete, [{}, {}]),
(coll.find_one, [{}]),
(coll.count_documents, [{}]),
(coll.distinct, ["foo"]),
(coll.aggregate, [[]]),
(coll.find, [{}]),
(coll.aggregate_raw_batches, [[]]),
(coll.find_raw_batches, [{}]),
(coll.database.command, ["find", coll.name]),
]
for f, args in ops:
async with client.start_session() as s, await s.start_transaction():
res = await f(*args, session=s) # type:ignore[operator]
if isinstance(res, (AsyncCommandCursor, AsyncCursor)):
await res.to_list()
class PatchSessionTimeout:
"""Patches the client_session's with_transaction timeout for testing."""
def __init__(self, mock_timeout):
self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT
self.mock_timeout = mock_timeout
def __enter__(self):
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout
return self
def __exit__(self, exc_type, exc_val, exc_tb):
client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
@async_client_context.require_transactions
async def test_callback_raises_custom_error(self):
class _MyException(Exception):
pass
async def raise_error(_):
raise _MyException
async with self.client.start_session() as s:
with self.assertRaises(_MyException):
await s.with_transaction(raise_error)
@async_client_context.require_transactions
async def test_callback_returns_value(self):
async def callback(_):
return "Foo"
async with self.client.start_session() as s:
self.assertEqual(await s.with_transaction(callback), "Foo")
await self.db.test.insert_one({})
async def callback2(session):
await self.db.test.insert_one({}, session=session)
return "Foo"
async with self.client.start_session() as s:
self.assertEqual(await s.with_transaction(callback2), "Foo")
@async_client_context.require_transactions
async def test_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
err: dict = {
"ok": 0,
"errmsg": "Transaction 7819 has been aborted.",
"code": 251,
"codeName": "NoSuchTransaction",
"errorLabels": ["TransientTransactionError"],
}
raise OperationFailure(err["errmsg"], err["code"], err)
# Create the collection.
await coll.insert_one({})
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
# Create the collection.
await coll.insert_one({})
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["commitTransaction"],
"errorCode": 251, # NoSuchTransaction
},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(OperationFailure):
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
@async_client_context.require_test_commands
@async_client_context.require_transactions
async def test_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
coll = client[self.db.name].test
async def callback(session):
await coll.insert_one({}, session=session)
# Create the collection.
await coll.insert_one({})
await self.set_fail_point(
{
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {"failCommands": ["commitTransaction"], "closeConnection": True},
}
)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}
)
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(ConnectionFailure):
await s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
# retry).
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Tested here because this supports Motor's convenient transactions API.
@async_client_context.require_transactions
async def test_in_transaction_property(self):
client = async_client_context.client
coll = client.test.testcollection
await coll.insert_one({})
self.addAsyncCleanup(coll.drop)
async with client.start_session() as s:
self.assertFalse(s.in_transaction)
await s.start_transaction()
self.assertTrue(s.in_transaction)
await coll.insert_one({}, session=s)
self.assertTrue(s.in_transaction)
await s.commit_transaction()
self.assertFalse(s.in_transaction)
async with client.start_session() as s:
await s.start_transaction()
# commit empty transaction
await s.commit_transaction()
self.assertFalse(s.in_transaction)
async with client.start_session() as s:
await s.start_transaction()
await s.abort_transaction()
self.assertFalse(s.in_transaction)
# Using a callback
async def callback(session):
self.assertTrue(session.in_transaction)
async with client.start_session() as s:
self.assertFalse(s.in_transaction)
await s.with_transaction(callback)
self.assertFalse(s.in_transaction)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,691 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for testing driver specs."""
from __future__ import annotations
import functools
import threading
from collections import abc
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.utils import (
CMAPListener,
CompareType,
EventListener,
OvertCommandListener,
ServerAndTopologyEventListener,
async_rs_client,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
prepare_spec_arguments,
)
from typing import List
from bson import ObjectId, decode, encode
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class SpecRunnerThread(threading.Thread):
def __init__(self, name):
super().__init__()
self.name = name
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.ops = []
self.stopped = False
def schedule(self, work):
self.ops.append(work)
with self.cond:
self.cond.notify()
def stop(self):
self.stopped = True
with self.cond:
self.cond.notify()
def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
if self.ops:
try:
work = self.ops.pop(0)
work()
except Exception as exc:
self.exc = exc
self.stop()
class AsyncSpecRunner(AsyncIntegrationTest):
mongos_clients: List
knobs: client_knobs
listener: EventListener
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await super()._tearDown_class()
def setUp(self):
super().setUp()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
async def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
await client.admin.command(cmd)
async def set_fail_point(self, command_args):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
await self._set_fail_point(client, command_args)
async def targeted_fail_point(self, session, fail_point):
"""Run the targetedFailPoint test operation.
Enable the fail point on the session's pinned mongos.
"""
clients = {c.address: c for c in self.mongos_clients}
client = clients[session._pinned_address]
await self._set_fail_point(client, fail_point)
self.addAsyncCleanup(self.set_fail_point, {"mode": "off"})
def assert_session_pinned(self, session):
"""Run the assertSessionPinned test operation.
Assert that the given session is pinned.
"""
self.assertIsNotNone(session._transaction.pinned_address)
def assert_session_unpinned(self, session):
"""Run the assertSessionUnpinned test operation.
Assert that the given session is not pinned.
"""
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
async def assert_collection_exists(self, database, collection):
"""Run the assertCollectionExists test operation."""
db = self.client[database]
self.assertIn(collection, await db.list_collection_names())
async def assert_collection_not_exists(self, database, collection):
"""Run the assertCollectionNotExists test operation."""
db = self.client[database]
self.assertNotIn(collection, await db.list_collection_names())
async def assert_index_exists(self, database, collection, index):
"""Run the assertIndexExists test operation."""
coll = self.client[database][collection]
self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()])
async def assert_index_not_exists(self, database, collection, index):
"""Run the assertIndexNotExists test operation."""
coll = self.client[database][collection]
self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()])
def assertErrorLabelsContain(self, exc, expected_labels):
labels = [l for l in expected_labels if exc.has_error_label(l)]
self.assertEqual(labels, expected_labels)
def assertErrorLabelsOmit(self, exc, omit_labels):
for label in omit_labels:
self.assertFalse(
exc.has_error_label(label), msg=f"error labels should not contain {label}"
)
async def kill_all_sessions(self):
clients = self.mongos_clients if self.mongos_clients else [self.client]
for client in clients:
try:
await client.admin.command("killAllSessions", [])
except OperationFailure:
# "operation was interrupted" by killing the command's
# own session.
pass
def check_command_result(self, expected_result, result):
# Only compare the keys in the expected result.
filtered_result = {}
for key in expected_result:
try:
filtered_result[key] = result[key]
except KeyError:
pass
self.assertEqual(filtered_result, expected_result)
# TODO: factor the following function with test_crud.py.
def check_result(self, expected_result, result):
if isinstance(result, _WriteResult):
for res in expected_result:
prop = camel_to_snake(res)
# SPEC-869: Only BulkWriteResult has upserted_count.
if prop == "upserted_count" and not isinstance(result, BulkWriteResult):
if result.upserted_id is not None:
upserted_count = 1
else:
upserted_count = 0
self.assertEqual(upserted_count, expected_result[res], prop)
elif prop == "inserted_ids":
# BulkWriteResult does not have inserted_ids.
if isinstance(result, BulkWriteResult):
self.assertEqual(len(expected_result[res]), result.inserted_count)
else:
# InsertManyResult may be compared to [id1] from the
# crud spec or {"0": id1} from the retryable write spec.
ids = expected_result[res]
if isinstance(ids, dict):
ids = [ids[str(i)] for i in range(len(ids))]
self.assertEqual(ids, result.inserted_ids, prop)
elif prop == "upserted_ids":
# Convert indexes from strings to integers.
ids = expected_result[res]
expected_ids = {}
for str_index in ids:
expected_ids[int(str_index)] = ids[str_index]
self.assertEqual(expected_ids, result.upserted_ids, prop)
else:
self.assertEqual(getattr(result, prop), expected_result[res], prop)
return True
else:
def _helper(expected_result, result):
if isinstance(expected_result, abc.Mapping):
for i in expected_result.keys():
self.assertEqual(expected_result[i], result[i])
elif isinstance(expected_result, list):
for i, k in zip(expected_result, result):
_helper(i, k)
else:
self.assertEqual(expected_result, result)
_helper(expected_result, result)
return None
def get_object_name(self, op):
"""Allow subclasses to override handling of 'object'
Transaction spec says 'object' is required.
"""
return op["object"]
@staticmethod
def parse_options(opts):
return parse_spec_options(opts)
async def run_operation(self, sessions, collection, operation):
original_collection = collection
name = camel_to_snake(operation["name"])
if name == "run_command":
name = "command"
elif name == "download_by_name":
name = "open_download_stream_by_name"
elif name == "download":
name = "open_download_stream"
elif name == "map_reduce":
self.skipTest("PyMongo does not support mapReduce")
elif name == "count":
self.skipTest("PyMongo does not support count")
database = collection.database
collection = database.get_collection(collection.name)
if "collectionOptions" in operation:
collection = collection.with_options(
**self.parse_options(operation["collectionOptions"])
)
object_name = self.get_object_name(operation)
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = GridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
"database": database,
"collection": collection,
"testRunner": self,
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get("arguments", {})
arguments.update(arguments.pop("options", {}))
self.parse_options(arguments)
cmd = getattr(obj, name)
with_txn_callback = functools.partial(
self.run_operations, sessions, original_collection, in_with_transaction=True
)
prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback)
if name == "run_on_thread":
args = {"sessions": sessions, "collection": collection}
args.update(arguments)
arguments = args
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addAsyncCleanup(result.close)
if name == "aggregate":
if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
# Read from the primary to ensure causal consistency.
out = collection.database.get_collection(
arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY
)
return out.find()
if "download" in name:
result = Binary(result.read())
if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor):
return await result.to_list()
return result
def allowable_errors(self, op):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
async def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get("result")
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
await self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = await self.run_operation(sessions, collection, op.copy())
if "result" in op:
if op["name"] == "runCommand":
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
async def run_operations(self, sessions, collection, ops, in_with_transaction=False):
for op in ops:
await self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
events = listener.started_events
if not len(test["expectations"]):
return
# Give a nicer message when there are missing or extra events
cmds = decode_raw([event.command for event in events])
self.assertEqual(len(events), len(test["expectations"]), cmds)
for i, expectation in enumerate(test["expectations"]):
event_type = next(iter(expectation))
event = events[i]
# The tests substitute 42 for any number other than 0.
if event.command_name == "getMore" and event.command["getMore"]:
event.command["getMore"] = Int64(42)
elif event.command_name == "killCursors":
event.command["cursors"] = [Int64(42)]
elif event.command_name == "update":
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
updates = expectation[event_type]["command"]["updates"]
for update in updates:
update.setdefault("upsert", False)
update.setdefault("multi", False)
# Replace afterClusterTime: 42 with actual afterClusterTime.
expected_cmd = expectation[event_type]["command"]
expected_read_concern = expected_cmd.get("readConcern")
if expected_read_concern is not None:
time = expected_read_concern.get("afterClusterTime")
if time == 42:
actual_time = event.command.get("readConcern", {}).get("afterClusterTime")
if actual_time is not None:
expected_read_concern["afterClusterTime"] = actual_time
recovery_token = expected_cmd.get("recoveryToken")
if recovery_token == 42:
expected_cmd["recoveryToken"] = CompareType(dict)
# Replace lsid with a name like "session0" to match test.
if "lsid" in event.command:
for name, lsid in session_ids.items():
if event.command["lsid"] == lsid:
event.command["lsid"] = name
break
for attr, expected in expectation[event_type].items():
actual = getattr(event, attr)
expected = wrap_types(expected)
if isinstance(expected, dict):
for key, val in expected.items():
if val is None:
if key in actual:
self.fail(f"Unexpected key [{key}] in {actual!r}")
elif key not in actual:
self.fail(f"Expected key [{key}] in {actual!r}")
else:
self.assertEqual(
val, decode_raw(actual[key]), f"Key [{key}] in {actual}"
)
else:
self.assertEqual(actual, expected)
def maybe_skip_scenario(self, test):
if test.get("skipReason"):
self.skipTest(test.get("skipReason"))
def get_scenario_db_name(self, scenario_def):
"""Allow subclasses to override a test's database name."""
return scenario_def["database_name"]
def get_scenario_coll_name(self, scenario_def):
"""Allow subclasses to override a test's collection name."""
return scenario_def["collection_name"]
def get_outcome_coll_name(self, outcome, collection):
"""Allow subclasses to override outcome collection."""
return collection.name
async def run_test_ops(self, sessions, collection, test):
"""Added to allow retryable writes spec to override a test's
operation.
"""
await self.run_operations(sessions, collection, test["operations"])
def parse_client_options(self, opts):
"""Allow encryption spec to override a clientOptions parsing."""
# Convert test['clientOptions'] to dict to avoid a Jython bug using
# "**" with ScenarioDict.
return dict(opts)
async def setup_scenario(self, scenario_def):
"""Allow specs to override a test's setup."""
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
documents = scenario_def["data"]
# Setup the collection with as few majority writes as possible.
db = async_client_context.client.get_database(db_name)
coll_exists = bool(await db.list_collection_names(filter={"name": coll_name}))
if coll_exists:
await db[coll_name].delete_many({})
# Only use majority wc only on the final write.
wc = WriteConcern(w="majority")
if documents:
db.get_collection(coll_name, write_concern=wc).insert_many(documents)
elif not coll_exists:
# Ensure collection exists.
await db.create_collection(coll_name, write_concern=wc)
async def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
await self.kill_all_sessions()
self.addAsyncCleanup(self.kill_all_sessions)
await self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
await c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(fp)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test["clientOptions"])
# MMAPv1 does not support retryable writes.
if (
client_options.get("retryWrites") is True
and async_client_context.storage_engine == "mmapv1"
):
self.skipTest("MMAPv1 does not support retryWrites=True")
use_multi_mongos = test["useMultipleMongoses"]
host = None
if use_multi_mongos:
if async_client_context.load_balancer or async_client_context.serverless:
host = async_client_context.MULTI_MONGOS_LB_URI
elif async_client_context.is_mongos:
host = async_client_context.mongos_seeds()
client = await async_rs_client(
h=host, event_listeners=[listener, pool_listener, server_listener], **client_options
)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.aclose)
# Create session0 and session1.
sessions = {}
session_ids = {}
for i in range(2):
# Don't attempt to create sessions if they are not supported by
# the running server version.
if not async_client_context.sessions_enabled:
break
session_name = "session%d" % i
opts = camel_to_snake_args(test["sessionOptions"][session_name])
if "default_transaction_options" in opts:
txn_opts = self.parse_options(opts["default_transaction_options"])
txn_opts = client_session.TransactionOptions(**txn_opts)
opts["default_transaction_options"] = txn_opts
s = client.start_session(**dict(opts))
sessions[session_name] = s
# Store lsid so we can access it after end_session, in check_events.
session_ids[session_name] = s.session_id
self.addAsyncCleanup(end_sessions, sessions)
collection = client[database_name][collection_name]
await self.run_test_ops(sessions, collection, test)
await end_sessions(sessions)
self.check_events(test, listener, session_ids)
# Disable fail points.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(
{"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
# Assert final state is expected.
outcome = test["outcome"]
expected_c = outcome.get("collection")
if expected_c is not None:
outcome_coll_name = self.get_outcome_coll_name(outcome, collection)
# Read from the primary with local read concern to ensure causal
# consistency.
outcome_coll = async_client_context.client[collection.database.name].get_collection(
outcome_coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c["data"]), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get("error")
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return isinstance(expected_result["errorContains"], str)
return False
def expect_error_code(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorCodeName"]
return False
def expect_error_labels_contain(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsContain"]
return False
def expect_error_labels_omit(expected_result):
if isinstance(expected_result, dict):
return expected_result["errorLabelsOmit"]
return False
def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]
return False
def expect_error(op):
expected_result = op.get("result")
return (
expect_any_error(op)
or expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)
async def end_sessions(sessions):
for s in sessions.values():
# Aborts the transaction if it's open.
await s.end_session()
def decode_raw(val):
"""Decode RawBSONDocuments in the given container."""
if isinstance(val, (list, abc.Mapping)):
return decode(encode({"v": val}))["v"]
return val
TYPES = {
"binData": Binary,
"long": Int64,
"int": int,
"string": str,
"objectId": ObjectId,
"object": dict,
"array": list,
}
def wrap_types(val):
"""Support $$type assertion in command results."""
if isinstance(val, list):
return [wrap_types(v) for v in val]
if isinstance(val, abc.Mapping):
typ = val.get("$$type")
if typ:
if isinstance(typ, str):
types = TYPES[typ]
else:
types = tuple(TYPES[t] for t in typ)
return CompareType(types)
d = {}
for key in val:
d[key] = wrap_types(val[key])
return d
return val

View File

@ -19,12 +19,13 @@ import os
import sys
from io import BytesIO
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
sys.path[0:0] = [""]
from test import client_context, unittest
from test.utils import (
OvertCommandListener,
SpecTestCreator,
rs_client,
single_client,
wait_until,
@ -34,7 +35,6 @@ from typing import List
from bson import encode
from bson.raw_bson import RawBSONDocument
from gridfs import GridFS, GridFSBucket
from pymongo import WriteConcern
from pymongo.errors import (
CollectionInvalid,
@ -50,6 +50,9 @@ from pymongo.synchronous import client_session
from pymongo.synchronous.client_session import TransactionOptions
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.helpers import next
_IS_SYNC = True
_TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG")
@ -62,17 +65,17 @@ UNPIN_TEST_MAX_ATTEMPTS = 50
class TransactionsBase(SpecRunner):
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
if client_context.supports_transactions():
for address in client_context.mongoses:
cls.mongos_clients.append(single_client("{}:{}".format(*address)))
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
for client in cls.mongos_clients:
client.close()
super().tearDownClass()
super()._tearDown_class()
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
@ -124,14 +127,14 @@ class TestTransactions(TransactionsBase):
coll.insert_one({})
with client.start_session() as s:
with s.start_transaction(write_concern=WriteConcern(w=1)):
self.assertTrue(coll.insert_one({}, session=s).acknowledged)
self.assertTrue(coll.insert_many([{}, {}], session=s).acknowledged)
self.assertTrue(coll.bulk_write([InsertOne({})], session=s).acknowledged)
self.assertTrue(coll.replace_one({}, {}, session=s).acknowledged)
self.assertTrue(coll.update_one({}, {"$set": {"a": 1}}, session=s).acknowledged)
self.assertTrue(coll.update_many({}, {"$set": {"a": 1}}, session=s).acknowledged)
self.assertTrue(coll.delete_one({}, session=s).acknowledged)
self.assertTrue(coll.delete_many({}, session=s).acknowledged)
self.assertTrue((coll.insert_one({}, session=s)).acknowledged)
self.assertTrue((coll.insert_many([{}, {}], session=s)).acknowledged)
self.assertTrue((coll.bulk_write([InsertOne({})], session=s)).acknowledged)
self.assertTrue((coll.replace_one({}, {}, session=s)).acknowledged)
self.assertTrue((coll.update_one({}, {"$set": {"a": 1}}, session=s)).acknowledged)
self.assertTrue((coll.update_many({}, {"$set": {"a": 1}}, session=s)).acknowledged)
self.assertTrue((coll.delete_one({}, session=s)).acknowledged)
self.assertTrue((coll.delete_many({}, session=s)).acknowledged)
coll.find_one_and_delete({}, session=s)
coll.find_one_and_replace({}, {}, session=s)
coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s)
@ -260,7 +263,7 @@ class TestTransactions(TransactionsBase):
return gfs.find(*args, **kwargs).next()
def gridfs_open_upload_stream(*args, **kwargs):
bucket.open_upload_stream(*args, **kwargs).write(b"1")
(bucket.open_upload_stream(*args, **kwargs)).write(b"1")
gridfs_ops = [
(gfs.put, (b"123",)),
@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase):
with client.start_session() as s, s.start_transaction():
res = f(*args, session=s) # type:ignore[operator]
if isinstance(res, (CommandCursor, Cursor)):
list(res)
res.to_list()
class PatchSessionTimeout:

View File

@ -47,6 +47,8 @@ from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class SpecRunnerThread(threading.Thread):
def __init__(self, name):
@ -88,8 +90,8 @@ class SpecRunner(IntegrationTest):
listener: EventListener
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
@ -97,9 +99,9 @@ class SpecRunner(IntegrationTest):
cls.knobs.enable()
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.knobs.disable()
super().tearDownClass()
super()._tearDown_class()
def setUp(self):
super().setUp()
@ -325,7 +327,7 @@ class SpecRunner(IntegrationTest):
result = Binary(result.read())
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
return list(result)
return result.to_list()
return result
@ -580,7 +582,7 @@ class SpecRunner(IntegrationTest):
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = list(outcome_coll.find(sort=[("_id", 1)]))
actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.

View File

@ -74,6 +74,8 @@ replacements = {
"IsolatedAsyncioTestCase": "TestCase",
"AsyncUnitTest": "UnitTest",
"AsyncMockClient": "MockClient",
"AsyncSpecRunner": "SpecRunner",
"AsyncTransactionsBase": "TransactionsBase",
"async_get_pool": "get_pool",
"async_is_mongos": "is_mongos",
"async_rs_or_single_client": "rs_or_single_client",
@ -141,9 +143,11 @@ converted_tests = [
"__init__.py",
"conftest.py",
"pymongo_mocks.py",
"utils_spec_runner.py",
"test_client.py",
"test_collection.py",
"test_database.py",
"test_transactions.py",
]
sync_test_files = [