Async client uses tasks instead of threads

PYTHON-4725 - Async client should use tasks for SDAM instead of threads
PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition
PYTHON-4941 - Synchronous unified test runner being used in asynchronous tests
PYTHON-4843 - Async test suite should use a single event loop
PYTHON-4945 - Fix test cleanups for mongoses

Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
This commit is contained in:
Noah Stapp 2024-11-26 16:55:27 -05:00 committed by GitHub
parent 9b5c0981d9
commit 0e8d70457f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 1737 additions and 1981 deletions

View File

@ -281,7 +281,7 @@ functions:
"run tests":
- command: subprocess.exec
params:
include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
binary: bash
working_dir: "src"
args:

View File

@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
2) License Notice for _asyncio_lock.py
-----------------------------------------
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF hereby
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
analyze, test, perform and/or display publicly, prepare derivative works,
distribute, and otherwise use Python alone or in any derivative version,
provided, however, that PSF's License Agreement and PSF's notice of copyright,
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
are retained in Python alone or in any derivative version prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

309
pymongo/_asyncio_lock.py Normal file
View File

@ -0,0 +1,309 @@
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
to port 3.13 fixes to older versions of Python.
Can be removed once we drop Python 3.12 support."""
from __future__ import annotations
import collections
import threading
from asyncio import events, exceptions
from typing import Any, Coroutine, Optional
_global_lock = threading.Lock()
class _LoopBoundMixin:
_loop = None
def _get_loop(self) -> Any:
loop = events._get_running_loop()
if self._loop is None:
with _global_lock:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f"{self!r} is bound to a different event loop")
return loop
class _ContextManagerMixin:
async def __aenter__(self) -> None:
await self.acquire() # type: ignore[attr-defined]
# We have no use for the "as ..." clause in the with
# statement for locks.
return
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release() # type: ignore[attr-defined]
class Lock(_ContextManagerMixin, _LoopBoundMixin):
"""Primitive lock objects.
A primitive lock is a synchronization primitive that is not owned
by a particular task when locked. A primitive lock is in one
of two states, 'locked' or 'unlocked'.
It is created in the unlocked state. It has two basic methods,
acquire() and release(). When the state is unlocked, acquire()
changes the state to locked and returns immediately. When the
state is locked, acquire() blocks until a call to release() in
another task changes it to unlocked, then the acquire() call
resets it to locked and returns. The release() method should only
be called in the locked state; it changes the state to unlocked
and returns immediately. If an attempt is made to release an
unlocked lock, a RuntimeError will be raised.
When more than one task is blocked in acquire() waiting for
the state to turn to unlocked, only one task proceeds when a
release() call resets the state to unlocked; successive release()
calls will unblock tasks in FIFO order.
Locks also support the asynchronous context management protocol.
'async with lock' statement should be used.
Usage:
lock = Lock()
...
await lock.acquire()
try:
...
finally:
lock.release()
Context manager usage:
lock = Lock()
...
async with lock:
...
Lock objects can be tested for locking state:
if not lock.locked():
await lock.acquire()
else:
# lock is acquired
...
"""
def __init__(self) -> None:
self._waiters: Optional[collections.deque] = None
self._locked = False
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
def locked(self) -> bool:
"""Return True if lock is acquired."""
return self._locked
async def acquire(self) -> bool:
"""Acquire a lock.
This method blocks until the lock is unlocked, then sets it to
locked and returns True.
"""
# Implement fair scheduling, where thread always waits
# its turn. Jumping the queue if all are cancelled is an optimization.
if not self._locked and (
self._waiters is None or all(w.cancelled() for w in self._waiters)
):
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
# Currently the only exception designed be able to occur here.
# Ensure the lock invariant: If lock is not claimed (or about
# to be claimed by us) and there is a Task in waiters,
# ensure that the Task at the head will run.
if not self._locked:
self._wake_up_first()
raise
# assert self._locked is False
self._locked = True
return True
def release(self) -> None:
"""Release a lock.
When the lock is locked, reset it to unlocked, and return.
If any other tasks are blocked waiting for the lock to become
unlocked, allow exactly one of them to proceed.
When invoked on an unlocked lock, a RuntimeError is raised.
There is no return value.
"""
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self) -> None:
"""Ensure that the first waiter will wake up."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return
# .done() means that the waiter is already set to wake up.
if not fut.done():
fut.set_result(True)
class Condition(_ContextManagerMixin, _LoopBoundMixin):
"""Asynchronous equivalent to threading.Condition.
This class implements condition variable objects. A condition variable
allows one or more tasks to wait until they are notified by another
task.
A new Lock object is created and used as the underlying lock.
"""
def __init__(self, lock: Optional[Lock] = None) -> None:
if lock is None:
lock = Lock()
self._lock = lock
# Export the lock's locked(), acquire() and release() methods.
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self._waiters: collections.deque = collections.deque()
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self.locked() else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
async def wait(self) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
if not self.locked():
raise RuntimeError("cannot wait on un-acquired lock")
fut = self._get_loop().create_future()
self.release()
try:
try:
self._waiters.append(fut)
try:
await fut
return True
finally:
self._waiters.remove(fut)
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except exceptions.CancelledError as e:
err = e
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self._notify(1)
raise
async def wait_for(self, predicate: Any) -> Coroutine:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one task waiting on this condition, if any.
If the calling task has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up n of the tasks waiting for the condition
variable; if fewer than n are waiting, they are all awoken.
Note: an awakened task does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
if not self.locked():
raise RuntimeError("cannot notify on un-acquired lock")
self._notify(n)
def _notify(self, n: int) -> None:
idx = 0
for fut in self._waiters:
if idx >= n:
break
if not fut.done():
idx += 1
fut.set_result(False)
def notify_all(self) -> None:
"""Wake up all tasks waiting on this condition. This method acts
like notify(), but wakes up all waiting tasks instead of one. If the
calling task has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))

49
pymongo/_asyncio_task.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
from __future__ import annotations
import asyncio
import sys
from typing import Any, Coroutine, Optional
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
asyncio._register_task(self)
def cancel(self, msg: Optional[str] = None) -> bool:
self._cancel_requests += 1
return super().cancel(msg=msg)
def uncancel(self) -> int:
if self._cancel_requests > 0:
self._cancel_requests -= 1
return self._cancel_requests
def cancelling(self) -> int:
return self._cancel_requests
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)

View File

@ -476,7 +476,6 @@ class _AsyncClientBulk:
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res
except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:

View File

@ -45,7 +45,7 @@ from pymongo.common import (
)
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock
from pymongo.lock import _async_create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._alock = _ALock(_create_lock())
self._lock = _async_create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import asyncio
import contextlib
import enum
import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
await database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
"""
from __future__ import annotations
import asyncio
import contextlib
import os
import warnings
@ -59,8 +60,8 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser
from pymongo.asynchronous import client_session, database, periodic_executor
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.asynchronous import client_session, database
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
@ -82,7 +83,11 @@ from pymongo.errors import (
WaitQueueTimeoutError,
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_async_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._default_database_name = dbase
self._lock = _ALock(_create_lock())
self._lock = _async_create_lock()
self._kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
await AsyncMongoClient._process_periodic_tasks(client)
return True
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.KILL_CURSOR_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
)
async with operation.conn_mgr._alock:
async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation(
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try:
if conn_mgr:
async with conn_mgr._alock:
async with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction.
assert address is not None
assert conn_mgr.conn is not None
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
await self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try:
await self._process_kill_cursors()
await self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -16,20 +16,20 @@
from __future__ import annotations
import asyncio
import atexit
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _async_create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
@ -76,7 +76,7 @@ class MonitorBase:
await monitor._run() # type:ignore[attr-defined]
return True
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=interval, min_interval=min_interval, target=target, name=name
)
@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
async def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
await self._executor.join(timeout)
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
"""
super().__init__(
topology,
"pymongo_server_monitor_thread",
"pymongo_server_monitor_task",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
await self.close()
finally:
if self._executor._stopped:
await self._rtt_monitor.close()
async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
@ -252,7 +255,9 @@ class Monitor(MonitorBase):
except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails.
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
await self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
await self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
self._rtt_monitor.reset()
await self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
self._conn_id = conn.id
response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
await self._rtt_monitor.add_sample(round_trip_time)
avg_rtt, min_rtt = self._rtt_monitor.get()
avg_rtt, min_rtt = await self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
"""
super().__init__(
topology,
"pymongo_server_rtt_thread",
"pymongo_server_rtt_task",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL,
)
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
self._pool = pool
self._moving_average = MovingAverage()
self._moving_min = MovingMinimum()
self._lock = _create_lock()
self._lock = _async_create_lock()
async def close(self) -> None:
self.gc_safe_close()
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()
def add_sample(self, sample: float) -> None:
async def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
async with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)
def get(self) -> tuple[Optional[float], float]:
async def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
async with self._lock:
return self._moving_average.get(), self._moving_min.get()
def reset(self) -> None:
async def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
async with self._lock:
self._moving_average.reset()
self._moving_min.reset()
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = await self._ping()
self.add_sample(rtt)
await self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except asyncio.CancelledError:
raise
except Exception:
await self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown()
atexit.register(_shutdown_resources)
if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -1,219 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
from __future__ import annotations
import asyncio
import sys
import threading
import time
import weakref
from typing import Any, Optional
from pymongo.lock import _ALock, _create_lock
_IS_SYNC = False
class PeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = _ALock(_create_lock())
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
# Mitigation to RuntimeError firing when thread starts on shutdown
# https://github.com/python/cpython/issues/114570
try:
thread.start()
except RuntimeError as e:
if "interpreter shutdown" in str(e) or sys.is_finalizing():
self._thread = None
return
raise
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _should_stop(self) -> bool:
async with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
async def _run(self) -> None:
while not await self._should_stop():
try:
if not await self._target():
self._stopped = True
break
except BaseException:
async with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor: PeriodicExecutor) -> None:
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
_EXECUTORS.remove(ref)
def _shutdown_executors() -> None:
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None

View File

@ -23,7 +23,6 @@ import os
import socket
import ssl
import sys
import threading
import time
import weakref
from typing import (
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return await condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
@ -706,6 +704,8 @@ class AsyncConnection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -992,8 +992,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock()
self.lock = _ALock(_lock)
self.lock = _async_create_lock()
self._max_connecting_cond = _async_create_condition(self.lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1019,7 +1019,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(_lock))
self.size_cond = _async_create_condition(self.lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1027,7 +1027,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
@ -1466,7 +1466,8 @@ class Pool:
async with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size):
if not await _cond_wait(self.size_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.requests < self.max_pool_size:
@ -1489,7 +1490,8 @@ class Pool:
async with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting):
if not await _cond_wait(self._max_connecting_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting:

View File

@ -27,8 +27,7 @@ import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared
from pymongo.asynchronous import periodic_executor
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.pool import Pool
@ -44,7 +43,11 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._lock = _async_create_lock()
self._condition = _async_create_condition(
self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
@ -185,7 +189,7 @@ class Topology:
async def target() -> bool:
return process_events_queue(weak)
executor = periodic_executor.PeriodicExecutor(
executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.EVENTS_QUEUE_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
@ -654,7 +658,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server."""
async with self._lock:
self._request_check_all()
await self._condition.wait(wait_time)
await _async_cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.
@ -742,7 +746,7 @@ class Topology:
if self._publish_server or self._publish_tp:
# Make sure the events executor thread is fully closed before publishing the remaining events
self.__events_executor.close()
self.__events_executor.join(1)
await self.__events_executor.join(1)
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
@property

View File

@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal helpers for lock and condition coordination primitives."""
from __future__ import annotations
import asyncio
import collections
import os
import sys
import threading
import time
import weakref
from typing import Any, Callable, Optional, TypeVar
from asyncio import wait_for
from typing import Any, Optional, TypeVar
import pymongo._asyncio_lock
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
_T = TypeVar("_T")
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
# in older versions of Python
if sys.version_info >= (3, 13):
Lock = asyncio.Lock
Condition = asyncio.Condition
else:
Lock = pymongo._asyncio_lock.Lock
Condition = pymongo._asyncio_lock.Condition
def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
return lock
def _async_create_lock() -> Lock:
"""Represents an asyncio.Lock."""
return Lock()
def _create_condition(
lock: threading.Lock, condition_class: Optional[Any] = None
) -> threading.Condition:
"""Represents a threading.Condition."""
if condition_class:
return condition_class(lock)
return threading.Condition(lock)
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
"""Represents an asyncio.Condition."""
if condition_class:
return condition_class(lock)
return Condition(lock)
def _release_locks() -> None:
# Completed the fork, reset all the locks in the child.
for lock in _forkable_locks:
@ -46,202 +81,12 @@ def _release_locks() -> None:
lock.release()
# Needed only for synchro.py compat.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock
class _ALock:
__slots__ = ("_lock",)
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
return self._lock.acquire(blocking=blocking, timeout=timeout)
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._lock.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
def release(self) -> None:
self._lock.release()
async def __aenter__(self) -> _ALock:
await self.a_acquire()
return self
def __enter__(self) -> _ALock:
self._lock.acquire()
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)
class _ACondition:
__slots__ = ("_condition", "_waiters")
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._condition.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
return await wait_for(condition.wait(), timeout)
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e
return False
self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break
if fut.done():
continue
try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue
idx += 1
for waiter in to_remove:
self._waiters.remove(waiter)
def notify_all(self) -> None:
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))
def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]
def release(self) -> None:
self._condition.release()
async def __aenter__(self) -> _ACondition:
await self.acquire()
return self
def __enter__(self) -> _ACondition:
self._condition.acquire()
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
return condition.wait(timeout)

View File

@ -29,6 +29,7 @@ from typing import (
)
from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception
@ -259,18 +260,19 @@ async def async_receive_data(
sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
cancellation_task = create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")

View File

@ -23,9 +23,102 @@ import time
import weakref
from typing import Any, Optional
from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock
_IS_SYNC = True
_IS_SYNC = False
class AsyncPeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background task.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying task.
"""
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._task: Optional[asyncio.Task] = None
self._name = name
self._skip_sleep = False
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def open(self) -> None:
"""Start. Multiple calls have no effect."""
self._stopped = False
if self._task is None or (
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
):
self._task = create_task(self._run(), name=self._name)
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
try:
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
except asyncio.TimeoutError:
# Task timed out
pass
except asyncio.exceptions.CancelledError:
# Task was already finished, or not yet started.
raise
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _run(self) -> None:
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError
try:
if not await self._target():
self._stopped = True
break
except BaseException:
self._stopped = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
class PeriodicExecutor:
@ -64,19 +157,6 @@ class PeriodicExecutor:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
@ -104,10 +184,7 @@ class PeriodicExecutor:
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)

View File

@ -474,7 +474,6 @@ class _ClientBulk:
if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res
except Exception as exc:
# Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive:

View File

@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: Connection, more_to_come: bool):
self.conn: Optional[Connection] = conn
self.more_to_come = more_to_come
self._alock = _create_lock()
self._lock = _create_lock()
def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import asyncio
import contextlib
import enum
import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
"""
from __future__ import annotations
import asyncio
import contextlib
import os
import warnings
@ -58,7 +59,7 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.client_options import ClientOptions
from pymongo.errors import (
AutoReconnect,
@ -74,7 +75,11 @@ from pymongo.errors import (
WaitQueueTimeoutError,
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks
from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, periodic_executor
from pymongo.synchronous import client_session, database
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address,
)
with operation.conn_mgr._alock:
with operation.conn_mgr._lock:
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return server.run_operation(
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try:
if conn_mgr:
with conn_mgr._alock:
with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction.
assert address is not None
assert conn_mgr.conn is not None
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try:
self._process_kill_cursors()
self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -16,24 +16,24 @@
from __future__ import annotations
import asyncio
import atexit
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.periodic_executor import _shutdown_executors
if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
self.close()
finally:
if self._executor._stopped:
self._rtt_monitor.close()
def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response.
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
self.close()
except asyncio.CancelledError:
raise
except Exception:
self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown()
atexit.register(_shutdown_resources)
if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -23,7 +23,6 @@ import os
import socket
import ssl
import sys
import threading
import time
import weakref
from typing import (
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock, _Lock
from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
@ -704,6 +702,8 @@ class Connection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -988,8 +988,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.lock = _create_lock()
self._max_connecting_cond = _create_condition(self.lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1015,7 +1015,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(_lock)
self.size_cond = _create_condition(self.lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1023,7 +1023,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
@ -1460,7 +1460,8 @@ class Pool:
with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size):
if not _cond_wait(self.size_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.requests < self.max_pool_size:
@ -1483,7 +1484,8 @@ class Pool:
with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting):
if not _cond_wait(self._max_connecting_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting:

View File

@ -27,7 +27,7 @@ import weakref
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.errors import (
ConnectionFailure,
InvalidOperation,
@ -39,7 +39,11 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock, _Lock
from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
secondary_server_selector,
writable_server_selector,
)
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.pool import Pool
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._lock = _create_lock()
self._condition = _create_condition(
self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
@ -652,7 +656,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server."""
with self._lock:
self._request_check_all()
self._condition.wait(wait_time)
_cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import logging
import multiprocessing
import os
import signal
@ -25,6 +26,7 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class ClientContext:
client.close()
def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
@ -860,6 +864,16 @@ class ClientContext:
client_context = ClientContext()
def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
class PyMongoTestCase(unittest.TestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
def _setup_class(cls):
def setUp(self) -> None:
pass
@classmethod
def _tearDown_class(cls):
def tearDown(self) -> None:
pass
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
db: Database
credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_connection
def _setup_class(cls):
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless")
cls.client = client_context.client
cls.db = cls.client.pymongo_test
self.client = client_context.client
self.db = self.client.pymongo_test
if client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd}
self.credentials = {"username": db_user, "password": db_pwd}
else:
cls.credentials = {}
@classmethod
def _tearDown_class(cls):
pass
self.credentials = {}
def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_no_load_balancer
def _setup_class(cls):
pass
@classmethod
def _tearDown_class(cls):
pass
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable()
def tearDown(self):
def tearDown(self) -> None:
self.client_knobs.disable()
super().tearDown()
@ -1253,7 +1211,6 @@ def teardown():
c.drop_database("pymongo_test_mike")
c.drop_database("pymongo_test_bernie")
c.close()
print_running_clients()

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import logging
import multiprocessing
import os
import signal
@ -25,6 +26,7 @@ import subprocess
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class AsyncClientContext:
await client.close()
async def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = await self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
@ -862,6 +866,16 @@ class AsyncClientContext:
async_client_context = AsyncClientContext()
async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
class AsyncUnitTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
async def _setup_class(cls):
async def asyncSetUp(self) -> None:
pass
@classmethod
async def _tearDown_class(cls):
async def asyncTearDown(self) -> None:
pass
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
db: AsyncDatabase
credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless")
cls.client = async_client_context.client
cls.db = cls.client.pymongo_test
self.client = async_client_context.client
self.db = self.client.pymongo_test
if async_client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd}
self.credentials = {"username": db_user, "password": db_pwd}
else:
cls.credentials = {}
@classmethod
async def _tearDown_class(cls):
pass
self.credentials = {}
async def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_no_load_balancer
async def _setup_class(cls):
pass
@classmethod
async def _tearDown_class(cls):
pass
def setUp(self):
super().setUp()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable()
def tearDown(self):
async def asyncTearDown(self) -> None:
self.client_knobs.disable()
super().tearDown()
await super().asyncTearDown()
async def async_setup():
@ -1271,7 +1229,6 @@ async def async_teardown():
await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie")
await c.close()
print_running_clients()

View File

@ -22,7 +22,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy()
@pytest_asyncio.fixture(scope="session", autouse=True)
@pytest_asyncio.fixture(scope="package", autouse=True)
async def test_setup_and_teardown():
await async_setup()
yield

View File

@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
coll: AsyncCollection
coll_w0: AsyncCollection
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
async def asyncSetUp(self):
super().setUp()
await super().asyncSetUp()
self.coll = self.db.test
await self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response."""
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
@classmethod
@async_client_context.require_auth
@async_client_context.require_no_api_version
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
super().setUp()
await super().asyncSetUp()
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
await self.db.command(
"createRole",
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
w: Optional[int]
secondary: AsyncMongoClient
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.w = async_client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:
async def asyncSetUp(self):
await super().asyncSetUp()
self.w = async_client_context.w
self.secondary = None
if self.w is not None and self.w > 1:
for member in (await async_client_context.hello)["hosts"]:
if member != (await async_client_context.hello)["primary"]:
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
self.secondary = await self.async_single_client(*partition_node(member))
break
@classmethod
async def async_tearDownClass(cls):
if cls.secondary:
await cls.secondary.close()
async def asyncTearDown(self):
if self.secondary:
await self.secondary.close()
async def cause_wtimeout(self, requests, ordered):
if not async_client_context.test_commands_enabled:

View File

@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
dbs: list
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod
async def _tearDown_class(cls):
for db in cls.dbs:
await cls.client.drop_database(db)
await super()._tearDown_class()
async def asyncTearDown(self):
for db in self.dbs:
await self.client.drop_database(db)
await super().asyncTearDown()
async def change_stream_with_client(self, client, *args, **kwargs):
return await client.watch(*args, **kwargs)
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
async def change_stream_with_client(self, client, *args, **kwargs):
return await client[self.db.name].watch(*args, **kwargs)
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
class TestAsyncCollectionAsyncChangeStream(
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
):
@classmethod
@async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
# Use a new collection for each test.
await self.watched_collection().drop()
await self.watched_collection().insert_one({})
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
await super()._tearDown_class()
def asyncSetUp(self):
super().asyncSetUp()
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = AllowListEventListener("aggregate", "getMore")
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
self.listener.reset()
async def asyncSetUpCluster(self, scenario_dict):

View File

@ -73,7 +73,6 @@ from test.utils import (
is_greenthread_patched,
lazy_client_trial,
one,
wait_until,
)
import bson
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.client = await cls.unmanaged_async_rs_or_single_client(
async def asyncSetUp(self) -> None:
self.client = await self.async_rs_or_single_client(
connect=False, serverSelectionTimeoutMS=100
)
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket, two
# connections could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.conns), 1)
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two connections from being created.
self.assertEqual(1, len(server._pool.conns))
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
async def test_max_idle_time_reaper_removes_stale(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
async with server._pool.checkout() as conn_two:
pass
self.assertIs(conn_one, conn_two)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 0,
"stale socket reaped and new one NOT added to the pool",
)
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
server = await (await client._get_topology()).select_server(
readable_server_selector, _Op.TEST
)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
# Assert that if a socket is closed, a new one takes its place
async with server._pool.checkout() as conn:
conn.close_conn(None)
wait_until(
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"a closed socket gets replaced from the pool",
)
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
async with eval(the_repr) as client_two:
self.assertEqual(client_two, client)
def test_getters(self):
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes")
async def test_getters(self):
await async_wait_until(
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
)
async def test_list_databases(self):
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started.
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread.
await client.admin.command("ping")
self.assertTrue(client._topology._opened)
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
async def test_close_does_not_open_servers(self):
client = await self.async_rs_client(connect=False)
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
async def test_server_selection_timeout(self):
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
self.assertRaises(
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
)
await client.close()
client = AsyncMongoClient(
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
await client.close()
# Test invalid timeout in URI ignored and set to default.
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
await async_client_context.port,
)
await self.async_single_client(uri, event_listeners=[listener])
wait_until(
await async_wait_until(
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
)
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
pool = await async_get_pool(client)
original_connect = pool.connect
def stall_connect(*args, **kwargs):
time.sleep(2)
return original_connect(*args, **kwargs)
async def stall_connect(*args, **kwargs):
await asyncio.sleep(2)
return await original_connect(*args, **kwargs)
pool.connect = stall_connect
# Un-patch Pool.connect to break the cyclic reference.
self.addCleanup(delattr, pool, "connect")
# Wait for the background thread to start creating connections
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
# Assert that application operations do not block.
for _ in range(10):
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
await client.close()
# Add cursor to kill cursors queue
del cursor
wait_until(
await async_wait_until(
lambda: client._kill_cursors_queue,
"waited for cursor to be added to queue",
)
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
await cursor.to_list()
self.assertTrue(conn.closed)
wait_until(
await async_wait_until(
lambda: len(client._kill_cursors_queue) == 0,
"waited for all killCursor requests to complete",
)
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1))
# Fail over.
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
# Total failure.
c.kill_host("a:1")
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
wait_until(lambda: len(c.nodes) == 2, "connect")
await async_wait_until(lambda: len(c.nodes) == 2, "connect")
c.kill_host("a:1")
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1))
self.assertEqual(await c.arbiters, {("c", 3)})
# Assert that we create 2 and only 2 pooled connections.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
# Assert that we do not create connections to arbiters.
arbiter = c._topology.get_server_by_address(("c", 3))
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 1, "connect")
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1)

View File

@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
warn_context: Any
collation: Collation
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
self.collation = Collation("en_US")
self.warn_context = warnings.catch_warnings()
self.warn_context.__enter__()
@classmethod
async def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
await cls.client.close()
await super()._tearDown_class()
def tearDown(self):
async def asyncTearDown(self) -> None:
self.warn_context.__exit__()
self.warn_context = None
self.listener.reset()
super().tearDown()
await super().asyncTearDown()
def last_command_started(self):
return self.listener.started_events[-1].command

View File

@ -40,7 +40,6 @@ from test.utils import (
async_get_pool,
async_is_mongos,
async_wait_until,
wait_until,
)
from bson import encode
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
db: AsyncDatabase
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.client = AsyncMongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.client = self.simple_client(connect=False)
self.db = self.client.pymongo_test
def test_collection(self):
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
class AsyncTestCollection(AsyncIntegrationTest):
w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = async_client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
async def async_tearDownClass(cls):
await cls.db.drop_collection("test_large_limit")
async def asyncSetUp(self):
await self.db.test.drop()
await super().asyncSetUp()
self.w = async_client_context.w # type: ignore
async def asyncTearDown(self):
await self.db.test.drop()
await self.db.drop_collection("test_large_limit")
await super().asyncTearDown()
@contextlib.contextmanager
def write_concern_collection(self):
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
async def predicate():
return await db_w0.test.find_one({"x": 1})
await async_wait_until(predicate, "find w:0 replaced document")
async def test_update_bypass_document_validation(self):
db = self.db
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
await cur.close()
cur = None
# Wait until the background thread returns the socket.
wait_until(lambda: pool.active_sockets == 0, "return socket")
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# The socket should be discarded.
self.assertEqual(0, len(pool.conns))

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
reset_client_context,
unittest,
)
from test.asynchronous.helpers import async_repl_set_step_down
from test.utils import (
CMAPListener,
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
listener: CMAPListener
coll: AsyncCollection
@classmethod
@async_client_context.require_replica_set
async def _setup_class(cls):
await super()._setup_class()
cls.listener = CMAPListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
async def asyncSetUp(self):
self.listener = CMAPListener()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that
# survive a replica set election.
await async_ensure_all_connected(cls.client)
cls.listener.reset()
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
async def asyncSetUp(self):
await async_ensure_all_connected(self.client)
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
# Note that all ops use same write-concern as self.db (majority).
await self.db.drop_collection("step-down")
await self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
async def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
if __name__ == "__main__":

View File

@ -34,9 +34,9 @@ from test.utils import (
AllowListEventListener,
EventListener,
OvertCommandListener,
async_wait_until,
delay,
ignore_deprecations,
wait_until,
)
from bson import decode_all
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout):
await cursor.next()
def assertCursorKilled():
wait_until(
async def assertCursorKilled():
await async_wait_until(
lambda: len(listener.succeeded_events),
"find successful killCursors command",
)
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
self.assertEqual(1, len(listener.succeeded_events))
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
assertCursorKilled()
await assertCursorKilled()
listener.reset()
cursor = await coll.aggregate([], batchSize=1)
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout):
await cursor.next()
assertCursorKilled()
await assertCursorKilled()
def test_delete_not_initialized(self):
# Creating a cursor with invalid arguments will not run __init__
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
class TestRawBatchCommandCursor(AsyncIntegrationTest):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def test_aggregate_raw(self):
c = self.db.test
await c.drop()

View File

@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
class TestDatabaseAggregation(AsyncIntegrationTest):
def setUp(self):
async def asyncSetUp(self):
await super().asyncSetUp()
self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}},
{"$limit": 1},

View File

@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
"""Base class for encryption integration tests."""
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@async_client_context.require_version_min(4, 2, -1)
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
def assertEncrypted(self, val):
self.assertIsInstance(val, Binary)
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
class TestClientMaxWireVersion(AsyncIntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
@async_client_context.require_version_max(4, 0, 99)
async def test_raise_max_wire_version_error(self):
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
"local": None,
}
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
await cls.client.db.coll.drop()
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
await self.client.db.coll.drop()
self.vault = await create_key_vault(self.client.keyvault.datakeys)
# Configure the encrypted field via the local schema_map option.
schemas = {
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
}
}
opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
)
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
cls.client_encryption = cls.unmanaged_create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
self.client_encryption = self.create_client_encryption(
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
)
@classmethod
async def _tearDown_class(cls):
await cls.vault.drop()
await cls.client.close()
await cls.client_encrypted.close()
await cls.client_encryption.close()
def setUp(self):
self.listener.reset()
async def asyncTearDown(self) -> None:
await self.vault.drop()
async def run_test(self, provider_name):
# Create data key.
master_key: Any = self.MASTER_KEYS[provider_name]
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
class TestCorpus(AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
@staticmethod
def kms_providers():
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
client_encrypted: AsyncMongoClient
listener: OvertCommandListener
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
db = async_client_context.client.db
cls.coll = db.coll
await cls.coll.drop()
self.coll = db.coll
await self.coll.drop()
# Configure the encrypted 'db.coll' collection via jsonSchema.
json_schema = json_data("limits", "limits-schema.json")
await db.create_collection(
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
await coll.insert_one(json_data("limits", "limits-key.json"))
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener()
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener]
self.listener = OvertCommandListener()
self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[self.listener]
)
cls.coll_encrypted = cls.client_encrypted.db.coll
self.coll_encrypted = self.client_encrypted.db.coll
@classmethod
async def _tearDown_class(cls):
await cls.coll_encrypted.drop()
await cls.client_encrypted.close()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
await self.coll_encrypted.drop()
async def test_01_insert_succeeds_under_2MiB(self):
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json")
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2.update(limits_doc)
self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
"""Prose tests for creating data keys with a custom endpoint."""
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
async def _setup_class(cls):
await super()._setup_class()
def setUp(self):
async def asyncSetUp(self):
await super().asyncSetUp()
kms_providers = {
"aws": AWS_CREDS,
"azure": AZURE_CREDS,
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
self._kmip_host_error = None
self._invalid_host_error = None
async def asyncTearDown(self):
await self.client_encryption.close()
await self.client_encryption_invalid.close()
async def run_test_expected_success(self, provider_name, master_key):
data_key_id = await self.client_encryption.create_data_key(
provider_name, master_key=master_key
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys"
client: AsyncMongoClient
async def asyncSetUp(self):
async def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
await create_key_vault(keyvault, self.DEK)
async def _test_explicit(self, expectation):
await self._setup()
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
async_client_context.client,
OPTS,
)
self.addAsyncCleanup(client_encryption.close)
ciphertext = await client_encryption.encrypt(
"string0",
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
async def _test_automatic(self, expectation_extjson, payload):
await self._setup()
encrypted_db = "db"
encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
client = await self.async_rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addAsyncCleanup(client.aclose)
coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
async def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class()
async def asyncSetUp(self):
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
self.DEK = json_data(BASE, "custom", "azure-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super().asyncSetUp()
async def test_explicit(self):
return await self._test_explicit(
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
async def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class()
async def asyncSetUp(self):
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super().asyncSetUp()
async def test_explicit(self):
return await self._test_explicit(
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client_test = await self.async_rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
self.ciphertext = await client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
)
await client_encryption.close()
self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener()
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await self.client.db.drop_collection("decryption_events")
await create_key_vault(self.client.keyvault.datakeys)
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client
await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary)
await client_encryption.close()
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
mongocryptd_client: AsyncMongoClient
MONGOCRYPTD_PORT = 27020
@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
async def _setup_class(cls):
await super()._setup_class()
start_mongocryptd(cls.MONGOCRYPTD_PORT)
@classmethod
async def _tearDown_class(cls):
await super()._tearDown_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
start_mongocryptd(self.MONGOCRYPTD_PORT)
self.listener = OvertCommandListener()
self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]

View File

@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
class AsyncTestGridFile(AsyncIntegrationTest):
async def asyncSetUp(self):
await super().asyncSetUp()
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
async def test_basic(self):

View File

@ -16,21 +16,20 @@ from __future__ import annotations
import asyncio
import sys
import threading
import unittest
from pymongo.lock import _async_create_condition, _async_create_lock
sys.path[0:0] = [""]
from pymongo.lock import _ACondition
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
if sys.version_info < (3, 13):
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
@ -93,7 +92,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
wait = asyncio.create_task(cond.wait())
@ -104,7 +103,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
self.assertTrue(cond.locked())
@ -132,7 +131,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
# See bpo-32841
waited = False
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
async def wait_on_cond():
nonlocal waited
@ -155,12 +154,12 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(waited)
async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
with self.assertRaises(RuntimeError):
await cond.wait()
async def test_wait_for(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
presult = False
def predicate():
@ -197,7 +196,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
@ -207,7 +206,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
@ -254,7 +253,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
result = []
@ -287,14 +286,14 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _ACondition(threading.Condition(threading.Lock()))
condition = _async_create_condition(_async_create_lock())
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@ -307,7 +306,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
@ -336,7 +335,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
@ -369,7 +368,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _ACondition(threading.Condition(threading.Lock()))
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
@ -415,7 +414,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _ACondition(threading.Condition(threading.Lock()))
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
@ -459,55 +458,5 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
condition.notify_all()
await c[1]
class TestCondition(unittest.IsolatedAsyncioTestCase):
async def test_multiple_loops_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
def tmain(cond):
async def atmain(cond):
await asyncio.sleep(1)
async with cond:
cond.notify(1)
asyncio.run(atmain(cond))
t = threading.Thread(target=tmain, args=(cond,))
t.start()
async with cond:
self.assertTrue(await cond.wait(30))
t.join()
async def test_multiple_loops_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
results = []
def tmain(cond, results):
async def atmain(cond, results):
await asyncio.sleep(1)
async with cond:
res = await cond.wait(30)
results.append(res)
asyncio.run(atmain(cond, results))
nthreads = 5
threads = []
for _ in range(nthreads):
threads.append(threading.Thread(target=tmain, args=(cond, results)))
for t in threads:
t.start()
await asyncio.sleep(2)
async with cond:
cond.notify_all()
for t in threads:
t.join()
self.assertEqual(results, [True] * nthreads)
if __name__ == "__main__":
if __name__ == "__main__":
unittest.main()

View File

@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
listener: EventListener
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
await super()._tearDown_class()
async def asyncTearDown(self):
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.listener.reset()
await super().asyncTearDown()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], retryWrites=False
)
async def test_started_simple(self):
await self.client.pymongo_test.command("ping")
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
saved_listeners: Any
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
# We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener)
cls.client = await cls.unmanaged_async_single_client()
# Get one (authenticated) socket in the pool.
await cls.client.pymongo_test.command("ping")
@classmethod
async def _tearDown_class(cls):
monitoring._LISTENERS = cls.saved_listeners
await cls.client.close()
await super()._tearDown_class()
@async_client_context.require_connection
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener.reset()
self.client = await self.async_single_client()
# Get one (authenticated) socket in the pool.
await self.client.pymongo_test.command("ping")
@classmethod
def tearDownClass(cls):
monitoring._LISTENERS = cls.saved_listeners
async def test_simple(self):
await self.client.pymongo_test.command("ping")

View File

@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.deprecation_filter = DeprecationFilter()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.deprecation_filter = DeprecationFilter()
@classmethod
async def _tearDown_class(cls):
cls.deprecation_filter.stop()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
self.deprecation_filter.stop()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.client = await self.async_rs_or_single_client(retryWrites=True)
self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncTearDown(self) -> None:
self.knobs.disable()
@async_client_context.require_no_standalone
async def test_actionable_error_message(self):
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener
knobs: client_knobs
@classmethod
@async_client_context.require_no_mmap
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener]
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(
retryWrites=True, event_listeners=[self.listener]
)
cls.db = cls.client.pymongo_test
self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
if async_client_context.is_rs and async_client_context.test_commands_enabled:
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True
fail_insert: dict
@classmethod
@async_client_context.require_replica_set
@async_client_context.require_no_mmap
@async_client_context.require_failCommand_fail_point
async def _setup_class(cls):
await super()._setup_class()
cls.fail_insert = {
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.fail_insert = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {

View File

@ -38,7 +38,6 @@ from test.utils import (
ExceptionCatchingThread,
OvertCommandListener,
async_wait_until,
wait_until,
)
from bson import DBRef
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
client2: AsyncMongoClient
sensitive_commands: Set[str]
@classmethod
@async_client_context.require_sessions
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self):
await super().asyncSetUp()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
self.client2 = await self.async_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear()
@classmethod
async def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
await cls.client2.close()
await super()._tearDown_class()
async def asyncSetUp(self):
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addAsyncCleanup(self.client.close)
self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
async def asyncTearDown(self):
"""All sessions used in the test must be returned to the pool."""
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
await self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events:
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids)
await super().asyncTearDown()
async def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0]
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener
client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@async_client_context.require_sessions
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
@async_client_context.require_no_standalone
async def test_core(self):

View File

@ -26,7 +26,7 @@ sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import (
OvertCommandListener,
wait_until,
async_wait_until,
)
from typing import List
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000
)
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test
# Create the collection.
await coll.insert_one({})
@ -403,21 +403,12 @@ class PatchSessionTimeout:
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
if async_client_context.supports_transactions():
for address in async_client_context.mongoses:
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
@classmethod
async def _tearDown_class(cls):
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
async def _set_fail_point(self, client, command_args):
cmd = {"configureFailPoint": "failCommand"}

View File

@ -50,6 +50,7 @@ from test.unified_format_shared import (
)
from test.utils import (
async_get_pool,
async_wait_until,
camel_to_snake,
camel_to_snake_args,
parse_spec_options,
@ -304,7 +305,6 @@ class EntityMapUtil:
kwargs["h"] = uri
client = await self.test.async_rs_or_single_client(**kwargs)
self[spec["id"]] = client
self.test.addAsyncCleanup(client.close)
return
elif entity_type == "database":
client = self[spec["client"]]
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
await db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod
async def _setup_class(cls):
# super call creates internal client cls.client
await super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not await cls.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
cls.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
cls.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
def setUpClass(cls) -> None:
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(
heartbeat_frequency=0.1,
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
cls.knobs.enable()
@classmethod
async def _tearDown_class(cls):
def tearDownClass(cls) -> None:
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
# super call creates internal client cls.client
await super().asyncSetUp()
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not await self.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
self.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
self.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
# process schemaVersion
# note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
self.assertLessEqual(
version,
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
)
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
self.addAsyncCleanup(client.close)
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
async def _testOperation_createEntities(self, spec):
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
client, event, count = spec["client"], spec["event"], spec["count"]
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
def _testOperation_waitForEvent(self, spec):
async def _testOperation_waitForEvent(self, spec):
"""Run the waitForEvent test operation.
Wait for a number of events to be published, or fail.
"""
client, event, count = spec["client"], spec["event"], spec["count"]
wait_until(
await async_wait_until(
lambda: self._event_count(client, event) >= count,
f"find {count} {event} event(s)",
)

View File

@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
knobs: client_knobs
listener: EventListener
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.mongos_clients = []
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
def setUp(self):
super().setUp()
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
async def asyncTearDown(self) -> None:
self.knobs.disable()
async def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.close)
# Create session0 and session1.
sessions = {}

View File

@ -20,7 +20,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy()
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="package", autouse=True)
def test_setup_and_teardown():
setup()
yield

View File

@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
coll: Collection
coll_w0: Collection
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
def setUp(self):
super().setUp()
self.coll = self.db.test
self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response."""
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
class BulkAuthorizationTestBase(BulkTestBase):
@classmethod
@client_context.require_auth
@client_context.require_no_api_version
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
w: Optional[int]
secondary: MongoClient
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.w = client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:
def setUp(self):
super().setUp()
self.w = client_context.w
self.secondary = None
if self.w is not None and self.w > 1:
for member in (client_context.hello)["hosts"]:
if member != (client_context.hello)["primary"]:
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
self.secondary = self.single_client(*partition_node(member))
break
@classmethod
def async_tearDownClass(cls):
if cls.secondary:
cls.secondary.close()
def tearDown(self):
if self.secondary:
self.secondary.close()
def cause_wtimeout(self, requests, ordered):
if not client_context.test_commands_enabled:

View File

@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
dbs: list
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
cls.dbs = [cls.db, cls.client.pymongo_test_2]
def setUp(self) -> None:
super().setUp()
self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod
def _tearDown_class(cls):
for db in cls.dbs:
cls.client.drop_database(db)
super()._tearDown_class()
def tearDown(self):
for db in self.dbs:
self.client.drop_database(db)
super().tearDown()
def change_stream_with_client(self, client, *args, **kwargs):
return client.watch(*args, **kwargs)
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
@classmethod
@client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
def setUp(self) -> None:
super().setUp()
def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].watch(*args, **kwargs)
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
@classmethod
@client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
# Use a new collection for each test.
self.watched_collection().drop()
self.watched_collection().insert_one({})
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
cls.client.close()
super()._tearDown_class()
def setUp(self):
super().setUp()
self.listener = AllowListEventListener("aggregate", "getMore")
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.listener.reset()
def setUpCluster(self, scenario_dict):

View File

@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
client: MongoClient
@classmethod
def _setup_class(cls):
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self) -> None:
self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started.
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread.
client.admin.command("ping")
self.assertTrue(client._topology._opened)
if _IS_SYNC:
kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
def test_close_does_not_open_servers(self):
client = self.rs_client(connect=False)
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
def test_server_selection_timeout(self):
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
self.assertRaises(
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout)
client.close()
# Test invalid timeout in URI ignored and set to default.
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout)

View File

@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
warn_context: Any
collation: Collation
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
def setUp(self) -> None:
super().setUp()
self.listener = OvertCommandListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
self.collation = Collation("en_US")
self.warn_context = warnings.catch_warnings()
self.warn_context.__enter__()
@classmethod
def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
cls.client.close()
super()._tearDown_class()
def tearDown(self):
def tearDown(self) -> None:
self.warn_context.__exit__()
self.warn_context = None
self.listener.reset()
super().tearDown()

View File

@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
db: Database
client: MongoClient
@classmethod
def _setup_class(cls):
cls.client = MongoClient(connect=False)
cls.db = cls.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self) -> None:
super().setUp()
self.client = self.simple_client(connect=False)
self.db = self.client.pymongo_test
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
class TestCollection(IntegrationTest):
w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
def async_tearDownClass(cls):
cls.db.drop_collection("test_large_limit")
def setUp(self):
self.db.test.drop()
super().setUp()
self.w = client_context.w # type: ignore
def tearDown(self):
self.db.test.drop()
self.db.drop_collection("test_large_limit")
super().tearDown()
@contextlib.contextmanager
def write_concern_collection(self):
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
db.test.insert_one({"y": 1}, bypass_document_validation=True)
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
def predicate():
return db_w0.test.find_one({"x": 1})
wait_until(predicate, "find w:0 replaced document")
def test_update_bypass_document_validation(self):
db = self.db

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test import (
IntegrationTest,
client_context,
reset_client_context,
unittest,
)
from test.helpers import repl_set_step_down
from test.utils import (
CMAPListener,
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
listener: CMAPListener
coll: Collection
@classmethod
@client_context.require_replica_set
def _setup_class(cls):
super()._setup_class()
cls.listener = CMAPListener()
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
def setUp(self):
self.listener = CMAPListener()
self.client = self.rs_or_single_client(
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that
# survive a replica set election.
ensure_all_connected(cls.client)
cls.listener.reset()
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self):
ensure_all_connected(self.client)
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
# Note that all ops use same write-concern as self.db (majority).
self.db.drop_collection("step-down")
self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
if __name__ == "__main__":

View File

@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
class TestRawBatchCommandCursor(IntegrationTest):
@classmethod
def _setup_class(cls):
super()._setup_class()
def test_aggregate_raw(self):
c = self.db.test
c.drop()

View File

@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
class TestCollectionWCustomType(IntegrationTest):
def setUp(self):
super().setUp()
self.db.test.drop()
def tearDown(self):
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
class TestGridFileCustomType(IntegrationTest):
def setUp(self):
super().setUp()
self.db.drop_collection("fs.files")
self.db.drop_collection("fs.chunks")
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams
def setUpClass(cls):
super().setUpClass()
cls.db.test.delete_many({})
def setUp(self):
super().setUp()
self.db.test.delete_many({})
def tearDown(self):
self.input_target.drop()

View File

@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
class TestDatabaseAggregation(IntegrationTest):
def setUp(self):
super().setUp()
self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}},
{"$limit": 1},

View File

@ -211,11 +211,10 @@ class TestClientOptions(PyMongoTestCase):
class EncryptionIntegrationTest(IntegrationTest):
"""Base class for encryption integration tests."""
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@client_context.require_version_min(4, 2, -1)
def _setup_class(cls):
super()._setup_class()
def setUp(self) -> None:
super().setUp()
def assertEncrypted(self, val):
self.assertIsInstance(val, Binary)
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
class TestClientMaxWireVersion(IntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
@client_context.require_version_max(4, 0, 99)
def test_raise_max_wire_version_error(self):
@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
"local": None,
}
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.client.db.coll.drop()
cls.vault = create_key_vault(cls.client.keyvault.datakeys)
def setUp(self):
super().setUp()
self.listener = OvertCommandListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.client.db.coll.drop()
self.vault = create_key_vault(self.client.keyvault.datakeys)
# Configure the encrypted field via the local schema_map option.
schemas = {
@ -844,25 +841,22 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
}
}
opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
)
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
self.client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
cls.client_encryption = cls.unmanaged_create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
self.client_encryption = self.create_client_encryption(
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
)
@classmethod
def _tearDown_class(cls):
cls.vault.drop()
cls.client.close()
cls.client_encrypted.close()
cls.client_encryption.close()
def setUp(self):
self.listener.reset()
def tearDown(self) -> None:
self.vault.drop()
def run_test(self, provider_name):
# Create data key.
master_key: Any = self.MASTER_KEYS[provider_name]
@ -1007,10 +1001,9 @@ class TestViews(EncryptionIntegrationTest):
class TestCorpus(EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
@staticmethod
def kms_providers():
@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
client_encrypted: MongoClient
listener: OvertCommandListener
@classmethod
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
db = client_context.client.db
cls.coll = db.coll
cls.coll.drop()
self.coll = db.coll
self.coll.drop()
# Configure the encrypted 'db.coll' collection via jsonSchema.
json_schema = json_data("limits", "limits-schema.json")
db.create_collection(
@ -1207,17 +1199,14 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
coll.insert_one(json_data("limits", "limits-key.json"))
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener()
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener]
self.listener = OvertCommandListener()
self.client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[self.listener]
)
cls.coll_encrypted = cls.client_encrypted.db.coll
self.coll_encrypted = self.client_encrypted.db.coll
@classmethod
def _tearDown_class(cls):
cls.coll_encrypted.drop()
cls.client_encrypted.close()
super()._tearDown_class()
def tearDown(self) -> None:
self.coll_encrypted.drop()
def test_01_insert_succeeds_under_2MiB(self):
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
@ -1241,7 +1230,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json")
@ -1251,7 +1242,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
doc2.update(limits_doc)
self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
@ -1281,15 +1274,12 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
class TestCustomEndpoint(EncryptionIntegrationTest):
"""Prose tests for creating data keys with a custom endpoint."""
@classmethod
@unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set",
)
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
kms_providers = {
"aws": AWS_CREDS,
"azure": AZURE_CREDS,
@ -1318,10 +1308,6 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
self._kmip_host_error = None
self._invalid_host_error = None
def tearDown(self):
self.client_encryption.close()
self.client_encryption_invalid.close()
def run_test_expected_success(self, provider_name, master_key):
data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key)
encrypted = self.client_encryption.encrypt(
@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys"
client: MongoClient
def setUp(self):
def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
create_key_vault(keyvault, self.DEK)
def _test_explicit(self, expectation):
self._setup()
client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
client_context.client,
OPTS,
)
self.addCleanup(client_encryption.close)
ciphertext = client_encryption.encrypt(
"string0",
@ -1517,6 +1503,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
self.assertEqual(client_encryption.decrypt(ciphertext), "string0")
def _test_automatic(self, expectation_extjson, payload):
self._setup()
encrypted_db = "db"
encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
@ -1531,7 +1518,6 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
client = self.rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
)
self.addCleanup(client.close)
coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
@ -1553,13 +1539,12 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super()._setup_class()
def setUp(self):
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
self.DEK = json_data(BASE, "custom", "azure-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super().setUp()
def test_explicit(self):
return self._test_explicit(
@ -1579,13 +1564,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
def _setup_class(cls):
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super()._setup_class()
def setUp(self):
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super().setUp()
def test_explicit(self):
return self._test_explicit(
@ -1607,6 +1591,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client_test = self.rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
)
@ -1637,7 +1622,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
self.ciphertext = client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
)
client_encryption.close()
self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener()
@ -1832,6 +1816,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client = client_context.client
self.client.db.drop_collection("decryption_events")
create_key_vault(self.client.keyvault.datakeys)
@ -2267,6 +2252,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
def setUp(self):
super().setUp()
self.client = client_context.client
create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
@ -2608,8 +2594,6 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary)
client_encryption.close()
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(EncryptionIntegrationTest):
@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
mongocryptd_client: MongoClient
MONGOCRYPTD_PORT = 27020
@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
def _setup_class(cls):
super()._setup_class()
start_mongocryptd(cls.MONGOCRYPTD_PORT)
@classmethod
def _tearDown_class(cls):
super()._tearDown_class()
def setUp(self) -> None:
super().setUp()
start_mongocryptd(self.MONGOCRYPTD_PORT)
self.listener = OvertCommandListener()
self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]

View File

@ -33,19 +33,14 @@ from pymongo.write_concern import WriteConcern
class TestSampleShellCommands(IntegrationTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Run once before any tests run.
cls.db.inventory.drop()
@classmethod
def tearDownClass(cls):
cls.client.drop_database("pymongo_test")
def setUp(self):
super().setUp()
self.db.inventory.drop()
def tearDown(self):
# Run after every test.
self.db.inventory.drop()
self.client.drop_database("pymongo_test")
def test_first_three_examples(self):
db = self.db

View File

@ -97,6 +97,7 @@ class TestGridFileNoConnect(UnitTest):
class TestGridFile(IntegrationTest):
def setUp(self):
super().setUp()
self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
def test_basic(self):

View File

@ -75,9 +75,9 @@ class JustRead(threading.Thread):
class TestGridfsNoConnect(unittest.TestCase):
db: Database
@classmethod
def setUpClass(cls):
cls.db = MongoClient(connect=False).pymongo_test
def setUp(self):
super().setUp()
self.db = MongoClient(connect=False).pymongo_test
def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo")
@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest):
fs: gridfs.GridFS
alt: gridfs.GridFS
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.fs = gridfs.GridFS(cls.db)
cls.alt = gridfs.GridFS(cls.db, "alt")
def setUp(self):
super().setUp()
self.fs = gridfs.GridFS(self.db)
self.alt = gridfs.GridFS(self.db, "alt")
self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
)
@ -509,10 +506,9 @@ class TestGridfs(IntegrationTest):
class TestGridfsReplicaSet(IntegrationTest):
@classmethod
@client_context.require_secondaries_count(1)
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
@classmethod
def tearDownClass(cls):

View File

@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest):
fs: gridfs.GridFSBucket
alt: gridfs.GridFSBucket
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.fs = gridfs.GridFSBucket(cls.db)
cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt")
def setUp(self):
super().setUp()
self.fs = gridfs.GridFSBucket(self.db)
self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt")
self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
)
@ -479,10 +476,9 @@ class TestGridfs(IntegrationTest):
class TestGridfsBucketReplicaSet(IntegrationTest):
@classmethod
@client_context.require_secondaries_count(1)
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
@classmethod
def tearDownClass(cls):

View File

@ -29,7 +29,7 @@ from test.utils import (
wait_until,
)
from pymongo.synchronous.periodic_executor import _EXECUTORS
from pymongo.periodic_executor import _EXECUTORS
def unregistered(ref):

View File

@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest):
listener: EventListener
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod
def _tearDown_class(cls):
cls.client.close()
super()._tearDown_class()
def tearDown(self):
@client_context.require_connection
def setUp(self) -> None:
super().setUp()
self.listener.reset()
super().tearDown()
self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False)
def test_started_simple(self):
self.client.pymongo_test.command("ping")
@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest):
saved_listeners: Any
@classmethod
@client_context.require_connection
def _setup_class(cls):
super()._setup_class()
def setUpClass(cls) -> None:
cls.listener = OvertCommandListener()
# We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener)
cls.client = cls.unmanaged_single_client()
# Get one (authenticated) socket in the pool.
cls.client.pymongo_test.command("ping")
@classmethod
def _tearDown_class(cls):
monitoring._LISTENERS = cls.saved_listeners
cls.client.close()
super()._tearDown_class()
@client_context.require_connection
def setUp(self):
super().setUp()
self.listener.reset()
self.client = self.single_client()
# Get one (authenticated) socket in the pool.
self.client.pymongo_test.command("ping")
@classmethod
def tearDownClass(cls):
monitoring._LISTENERS = cls.saved_listeners
def test_simple(self):
self.client.pymongo_test.command("ping")

View File

@ -31,24 +31,16 @@ from pymongo.read_concern import ReadConcern
class TestReadConcern(IntegrationTest):
listener: OvertCommandListener
@classmethod
@client_context.require_connection
def setUpClass(cls):
super().setUpClass()
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
def setUp(self):
super().setUp()
self.listener = OvertCommandListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
client_context.client.pymongo_test.create_collection("coll")
@classmethod
def tearDownClass(cls):
cls.client.close()
client_context.client.pymongo_test.drop_collection("coll")
super().tearDownClass()
def tearDown(self):
self.listener.reset()
super().tearDown()
client_context.client.pymongo_test.drop_collection("coll")
def test_read_concern(self):
rc = ReadConcern()

View File

@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest):
RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.deprecation_filter = DeprecationFilter()
def setUp(self) -> None:
super().setUp()
self.deprecation_filter = DeprecationFilter()
@classmethod
def _tearDown_class(cls):
cls.deprecation_filter.stop()
super()._tearDown_class()
def tearDown(self) -> None:
self.deprecation_filter.stop()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs
@classmethod
def _setup_class(cls):
super()._setup_class()
def setUp(self) -> None:
super().setUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.client = self.rs_or_single_client(retryWrites=True)
self.db = self.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
cls.client.close()
super()._tearDown_class()
def tearDown(self) -> None:
self.knobs.disable()
@client_context.require_no_standalone
def test_actionable_error_message(self):
@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener
knobs: client_knobs
@classmethod
@client_context.require_no_mmap
def _setup_class(cls):
super()._setup_class()
def setUp(self) -> None:
super().setUp()
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener]
)
cls.db = cls.client.pymongo_test
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.listener = OvertCommandListener()
self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener])
self.db = self.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
cls.client.close()
super()._tearDown_class()
def setUp(self):
if client_context.is_rs and client_context.test_commands_enabled:
self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
@ -210,6 +193,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()
@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest):
RUN_ON_SERVERLESS = True
fail_insert: dict
@classmethod
@client_context.require_replica_set
@client_context.require_no_mmap
@client_context.require_failCommand_fail_point
def _setup_class(cls):
super()._setup_class()
cls.fail_insert = {
def setUp(self) -> None:
super().setUp()
self.fail_insert = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {

View File

@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest):
@classmethod
@client_context.require_failCommand_fail_point
def setUpClass(cls):
super().setUpClass()
super().setUp(cls)
# Speed up the tests by decreasing the event publish frequency.
cls.knobs = client_knobs(
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1

View File

@ -82,36 +82,27 @@ class TestSession(IntegrationTest):
client2: MongoClient
sensitive_commands: Set[str]
@classmethod
@client_context.require_sessions
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
# Create a second client so we can make sure clients cannot share
# sessions.
cls.client2 = cls.unmanaged_rs_or_single_client()
self.client2 = self.rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear()
@classmethod
def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
cls.client2.close()
super()._tearDown_class()
def setUp(self):
self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener()
self.client = self.rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addCleanup(self.client.close)
self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
def tearDown(self):
"""All sessions used in the test must be returned to the pool."""
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events:
@ -121,6 +112,8 @@ class TestSession(IntegrationTest):
current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids)
super().tearDown()
def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0]
@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest):
listener: SessionTestListener
client: MongoClient
@classmethod
def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
cls.client.close()
@client_context.require_sessions
def setUp(self):
super().setUp()
self.listener = SessionTestListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
@client_context.require_no_standalone
def test_core(self):

View File

@ -105,6 +105,7 @@ class Update(threading.Thread):
class TestThreads(IntegrationTest):
def setUp(self):
super().setUp()
self.db = self.client.pymongo_test
def test_threading(self):

View File

@ -395,19 +395,12 @@ class PatchSessionTimeout:
class TestTransactionsConvenientAPI(TransactionsBase):
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.mongos_clients = []
def setUp(self) -> None:
super().setUp()
self.mongos_clients = []
if client_context.supports_transactions():
for address in client_context.mongoses:
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
@classmethod
def _tearDown_class(cls):
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
def _set_fail_point(self, client, command_args):
cmd = {"configureFailPoint": "failCommand"}

View File

@ -114,10 +114,9 @@ class TestMypyFails(unittest.TestCase):
class TestPymongo(IntegrationTest):
coll: Collection
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.coll = cls.client.test.test
def setUp(self):
super().setUp()
self.coll = self.client.test.test
def test_insert_find(self) -> None:
doc = {"my": "doc"}

View File

@ -304,7 +304,6 @@ class EntityMapUtil:
kwargs["h"] = uri
client = self.test.rs_or_single_client(**kwargs)
self[spec["id"]] = client
self.test.addCleanup(client.close)
return
elif entity_type == "database":
client = self[spec["client"]]
@ -479,31 +478,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod
def _setup_class(cls):
# super call creates internal client cls.client
super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not cls.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if client_context.storage_engine == "mmapv1":
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
cls.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
cls.mongos_clients = []
if (
client_context.supports_transactions()
and not client_context.load_balancer
and not client_context.serverless
):
for address in client_context.mongoses:
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
def setUpClass(cls) -> None:
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(
heartbeat_frequency=0.1,
@ -514,17 +489,36 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
cls.knobs.enable()
@classmethod
def _tearDown_class(cls):
def tearDownClass(cls) -> None:
cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def setUp(self):
# super call creates internal client cls.client
super().setUp()
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not self.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if client_context.storage_engine == "mmapv1":
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
self.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
self.mongos_clients = []
if (
client_context.supports_transactions()
and not client_context.load_balancer
and not client_context.serverless
):
for address in client_context.mongoses:
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
# process schemaVersion
# note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
self.assertLessEqual(
version,
@ -1026,7 +1020,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
)
client = self.single_client("{}:{}".format(*session._pinned_address))
self.addCleanup(client.close)
self.__set_fail_point(client=client, command_args=spec["failPoint"])
def _testOperation_createEntities(self, spec):

View File

@ -99,6 +99,12 @@ class BaseListener:
"""Wait for a number of events to be published, or fail."""
wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)")
async def async_wait_for_event(self, event, count):
"""Wait for a number of events to be published, or fail."""
await async_wait_until(
lambda: self.event_count(event) >= count, f"find {count} {event} event(s)"
)
class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
def connection_created(self, event):
@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10):
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
if iscoroutinefunction(predicate):
retval = await predicate()
else:
retval = predicate()
if retval:
return retval

View File

@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest):
knobs: client_knobs
listener: EventListener
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.mongos_clients = []
def setUp(self) -> None:
super().setUp()
self.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable()
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def setUp(self):
super().setUp()
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
self.knobs.enable()
self.targets = {}
self.listener = None # type: ignore
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
def tearDown(self) -> None:
self.knobs.disable()
def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
@ -697,8 +689,6 @@ class SpecRunner(IntegrationTest):
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addCleanup(client.close)
# Create session0 and session1.
sessions = {}

View File

@ -110,6 +110,13 @@ replacements = {
"async_set_fail_point": "set_fail_point",
"async_ensure_all_connected": "ensure_all_connected",
"async_repl_set_step_down": "repl_set_step_down",
"AsyncPeriodicExecutor": "PeriodicExecutor",
"async_wait_for_event": "wait_for_event",
"pymongo_server_monitor_task": "pymongo_server_monitor_thread",
"pymongo_server_rtt_task": "pymongo_server_rtt_thread",
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -130,8 +137,6 @@ docstring_removals: set[str] = {
".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release."
}
type_replacements = {"_Condition": "threading.Condition"}
import_replacements = {"test.synchronous": "test"}
_pymongo_base = "./pymongo/asynchronous/"
@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None:
lines = translate_async_sleeps(lines)
if file in docstring_translate_files:
lines = translate_docstrings(lines)
translate_locks(lines)
translate_types(lines)
if file in sync_test_files:
translate_imports(lines)
f.seek(0)
@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]:
return lines
def translate_locks(lines: list[str]) -> list[str]:
lock_lines = [line for line in lines if "_Lock(" in line]
cond_lines = [line for line in lines if "_Condition(" in line]
for line in lock_lines:
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)
lines[index] = line.replace(old, res[1])
for line in cond_lines:
res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)
lines[index] = line.replace(old, res[1])
return lines
def translate_types(lines: list[str]) -> list[str]:
for k, v in type_replacements.items():
matches = [line for line in lines if k in line and "import" not in line]
for line in matches:
index = lines.index(line)
lines[index] = line.replace(k, v)
return lines
def translate_imports(lines: list[str]) -> list[str]:
for k, v in import_replacements.items():
matches = [line for line in lines if k in line and "import" in line]