Async client uses tasks instead of threads

PYTHON-4725 - Async client should use tasks for SDAM instead of threads
PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition
PYTHON-4941 - Synchronous unified test runner being used in asynchronous tests
PYTHON-4843 - Async test suite should use a single event loop
PYTHON-4945 - Fix test cleanups for mongoses

Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
This commit is contained in:
Noah Stapp 2024-11-26 16:55:27 -05:00 committed by GitHub
parent 9b5c0981d9
commit 0e8d70457f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
72 changed files with 1737 additions and 1981 deletions

View File

@ -281,7 +281,7 @@ functions:
"run tests": "run tests":
- command: subprocess.exec - command: subprocess.exec
params: params:
include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
binary: bash binary: bash
working_dir: "src" working_dir: "src"
args: args:

View File

@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
2) License Notice for _asyncio_lock.py
-----------------------------------------
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF hereby
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
analyze, test, perform and/or display publicly, prepare derivative works,
distribute, and otherwise use Python alone or in any derivative version,
provided, however, that PSF's License Agreement and PSF's notice of copyright,
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
are retained in Python alone or in any derivative version prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

309
pymongo/_asyncio_lock.py Normal file
View File

@ -0,0 +1,309 @@
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
to port 3.13 fixes to older versions of Python.
Can be removed once we drop Python 3.12 support."""
from __future__ import annotations
import collections
import threading
from asyncio import events, exceptions
from typing import Any, Coroutine, Optional
_global_lock = threading.Lock()
class _LoopBoundMixin:
_loop = None
def _get_loop(self) -> Any:
loop = events._get_running_loop()
if self._loop is None:
with _global_lock:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f"{self!r} is bound to a different event loop")
return loop
class _ContextManagerMixin:
async def __aenter__(self) -> None:
await self.acquire() # type: ignore[attr-defined]
# We have no use for the "as ..." clause in the with
# statement for locks.
return
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release() # type: ignore[attr-defined]
class Lock(_ContextManagerMixin, _LoopBoundMixin):
"""Primitive lock objects.
A primitive lock is a synchronization primitive that is not owned
by a particular task when locked. A primitive lock is in one
of two states, 'locked' or 'unlocked'.
It is created in the unlocked state. It has two basic methods,
acquire() and release(). When the state is unlocked, acquire()
changes the state to locked and returns immediately. When the
state is locked, acquire() blocks until a call to release() in
another task changes it to unlocked, then the acquire() call
resets it to locked and returns. The release() method should only
be called in the locked state; it changes the state to unlocked
and returns immediately. If an attempt is made to release an
unlocked lock, a RuntimeError will be raised.
When more than one task is blocked in acquire() waiting for
the state to turn to unlocked, only one task proceeds when a
release() call resets the state to unlocked; successive release()
calls will unblock tasks in FIFO order.
Locks also support the asynchronous context management protocol.
'async with lock' statement should be used.
Usage:
lock = Lock()
...
await lock.acquire()
try:
...
finally:
lock.release()
Context manager usage:
lock = Lock()
...
async with lock:
...
Lock objects can be tested for locking state:
if not lock.locked():
await lock.acquire()
else:
# lock is acquired
...
"""
def __init__(self) -> None:
self._waiters: Optional[collections.deque] = None
self._locked = False
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self._locked else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
def locked(self) -> bool:
"""Return True if lock is acquired."""
return self._locked
async def acquire(self) -> bool:
"""Acquire a lock.
This method blocks until the lock is unlocked, then sets it to
locked and returns True.
"""
# Implement fair scheduling, where thread always waits
# its turn. Jumping the queue if all are cancelled is an optimization.
if not self._locked and (
self._waiters is None or all(w.cancelled() for w in self._waiters)
):
self._locked = True
return True
if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)
try:
try:
await fut
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
# Currently the only exception designed be able to occur here.
# Ensure the lock invariant: If lock is not claimed (or about
# to be claimed by us) and there is a Task in waiters,
# ensure that the Task at the head will run.
if not self._locked:
self._wake_up_first()
raise
# assert self._locked is False
self._locked = True
return True
def release(self) -> None:
"""Release a lock.
When the lock is locked, reset it to unlocked, and return.
If any other tasks are blocked waiting for the lock to become
unlocked, allow exactly one of them to proceed.
When invoked on an unlocked lock, a RuntimeError is raised.
There is no return value.
"""
if self._locked:
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
def _wake_up_first(self) -> None:
"""Ensure that the first waiter will wake up."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return
# .done() means that the waiter is already set to wake up.
if not fut.done():
fut.set_result(True)
class Condition(_ContextManagerMixin, _LoopBoundMixin):
"""Asynchronous equivalent to threading.Condition.
This class implements condition variable objects. A condition variable
allows one or more tasks to wait until they are notified by another
task.
A new Lock object is created and used as the underlying lock.
"""
def __init__(self, lock: Optional[Lock] = None) -> None:
if lock is None:
lock = Lock()
self._lock = lock
# Export the lock's locked(), acquire() and release() methods.
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release
self._waiters: collections.deque = collections.deque()
def __repr__(self) -> str:
res = super().__repr__()
extra = "locked" if self.locked() else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"
async def wait(self) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
if not self.locked():
raise RuntimeError("cannot wait on un-acquired lock")
fut = self._get_loop().create_future()
self.release()
try:
try:
self._waiters.append(fut)
try:
await fut
return True
finally:
self._waiters.remove(fut)
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except exceptions.CancelledError as e:
err = e
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self._notify(1)
raise
async def wait_for(self, predicate: Any) -> Coroutine:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one task waiting on this condition, if any.
If the calling task has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up n of the tasks waiting for the condition
variable; if fewer than n are waiting, they are all awoken.
Note: an awakened task does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
if not self.locked():
raise RuntimeError("cannot notify on un-acquired lock")
self._notify(n)
def _notify(self, n: int) -> None:
idx = 0
for fut in self._waiters:
if idx >= n:
break
if not fut.done():
idx += 1
fut.set_result(False)
def notify_all(self) -> None:
"""Wake up all tasks waiting on this condition. This method acts
like notify(), but wakes up all waiting tasks instead of one. If the
calling task has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))

49
pymongo/_asyncio_task.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
from __future__ import annotations
import asyncio
import sys
from typing import Any, Coroutine, Optional
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
asyncio._register_task(self)
def cancel(self, msg: Optional[str] = None) -> bool:
self._cancel_requests += 1
return super().cancel(msg=msg)
def uncancel(self) -> int:
if self._cancel_requests > 0:
self._cancel_requests -= 1
return self._cancel_requests
def cancelling(self) -> int:
return self._cancel_requests
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)

View File

@ -476,7 +476,6 @@ class _AsyncClientBulk:
if op_type == "delete": if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res full_result[f"{op_type}Results"][original_index] = res
except Exception as exc: except Exception as exc:
# Attempt to close the cursor, then raise top-level error. # Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive: if cmd_cursor.alive:

View File

@ -45,7 +45,7 @@ from pymongo.common import (
) )
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock from pymongo.lock import _async_create_lock
from pymongo.message import ( from pymongo.message import (
_CursorAddress, _CursorAddress,
_GetMore, _GetMore,
@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: AsyncConnection, more_to_come: bool): def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come self.more_to_come = more_to_come
self._alock = _ALock(_create_lock()) self._lock = _async_create_lock()
def update_exhaust(self, more_to_come: bool) -> None: def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come self.more_to_come = more_to_come

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption.""" """Support for explicit client-side field level encryption."""
from __future__ import annotations from __future__ import annotations
import asyncio
import contextlib import contextlib
import enum import enum
import socket import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so # BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged. # we should propagate them unchanged.
raise raise
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
raise EncryptionError(exc) from exc raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
conn.close() conn.close()
except (PyMongoError, MongoCryptError): except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly. raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error: except Exception as error:
# Wrap I/O errors in PyMongo exceptions. # Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error) _raise_connection_failure((host, port), error)
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
await database.create_collection(name=name, **kwargs), await database.create_collection(name=name, **kwargs),
encrypted_fields, encrypted_fields,
) )
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import contextlib import contextlib
import os import os
import warnings import warnings
@ -59,8 +60,8 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.asynchronous import client_session, database, periodic_executor from pymongo.asynchronous import client_session, database
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.client_session import _EmptyServerSession
@ -82,7 +83,11 @@ from pymongo.errors import (
WaitQueueTimeoutError, WaitQueueTimeoutError,
WriteConcernError, WriteConcernError,
) )
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_async_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason from pymongo.monitoring import ConnectionClosedReason
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._default_database_name = dbase self._default_database_name = dbase
self._lock = _ALock(_create_lock()) self._lock = _async_create_lock()
self._kill_cursors_queue: list = [] self._kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners self._event_listeners = options.pool_options._event_listeners
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
await AsyncMongoClient._process_periodic_tasks(client) await AsyncMongoClient._process_periodic_tasks(client)
return True return True
executor = periodic_executor.PeriodicExecutor( executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.KILL_CURSOR_FREQUENCY, interval=common.KILL_CURSOR_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL, min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target, target=target,
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
address=address, address=address,
) )
async with operation.conn_mgr._alock: async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn) err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation( return await server.run_operation(
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try: try:
if conn_mgr: if conn_mgr:
async with conn_mgr._alock: async with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction. # Cursor is pinned to LB outside of a transaction.
assert address is not None assert address is not None
assert conn_mgr.conn is not None assert conn_mgr.conn is not None
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors: for address, cursor_id, conn_mgr in pinned_cursors:
try: try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it # Raise the exception when client is closed so that it
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items(): for address, cursor_ids in address_to_cursor_ids.items():
try: try:
await self._kill_cursors(cursor_ids, address, topology, session=None) await self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
raise raise
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try: try:
await self._process_kill_cursors() await self._process_kill_cursors()
await self._topology.update_pool() await self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
return return

View File

@ -16,20 +16,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import atexit import atexit
import logging import logging
import time import time
import weakref import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum from pymongo._csot import MovingMinimum
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.periodic_executor import _shutdown_executors
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello from pymongo.hello import Hello
from pymongo.lock import _create_lock from pymongo.lock import _async_create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription from pymongo.server_description import ServerDescription
@ -76,7 +76,7 @@ class MonitorBase:
await monitor._run() # type:ignore[attr-defined] await monitor._run() # type:ignore[attr-defined]
return True return True
executor = periodic_executor.PeriodicExecutor( executor = periodic_executor.AsyncPeriodicExecutor(
interval=interval, min_interval=min_interval, target=target, name=name interval=interval, min_interval=min_interval, target=target, name=name
) )
@ -112,9 +112,9 @@ class MonitorBase:
""" """
self.gc_safe_close() self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None: async def join(self, timeout: Optional[int] = None) -> None:
"""Wait for the monitor to stop.""" """Wait for the monitor to stop."""
self._executor.join(timeout) await self._executor.join(timeout)
def request_check(self) -> None: def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon.""" """If the monitor is sleeping, wake it soon."""
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
""" """
super().__init__( super().__init__(
topology, topology,
"pymongo_server_monitor_thread", "pymongo_server_monitor_task",
topology_settings.heartbeat_frequency, topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL, common.MIN_HEARTBEAT_INTERVAL,
) )
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError: except ReferenceError:
# Topology was garbage-collected. # Topology was garbage-collected.
await self.close() await self.close()
finally:
if self._executor._stopped:
await self._rtt_monitor.close()
async def _check_server(self) -> ServerDescription: async def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response. """Call hello or read the next streaming response.
@ -252,8 +255,10 @@ class Monitor(MonitorBase):
except (OperationFailure, NotPrimaryError) as exc: except (OperationFailure, NotPrimaryError) as exc:
# Update max cluster time even when hello fails. # Update max cluster time even when hello fails.
details = cast(Mapping[str, Any], exc.details) details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime")) await self._topology.receive_cluster_time(details.get("$clusterTime"))
raise raise
except asyncio.CancelledError:
raise
except ReferenceError: except ReferenceError:
raise raise
except Exception as error: except Exception as error:
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
await self._reset_connection() await self._reset_connection()
if isinstance(error, _OperationCancelled): if isinstance(error, _OperationCancelled):
raise raise
self._rtt_monitor.reset() await self._rtt_monitor.reset()
# Server type defaults to Unknown. # Server type defaults to Unknown.
return ServerDescription(address, error=error) return ServerDescription(address, error=error)
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
self._conn_id = conn.id self._conn_id = conn.id
response, round_trip_time = await self._check_with_socket(conn) response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable: if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time) await self._rtt_monitor.add_sample(round_trip_time)
avg_rtt, min_rtt = self._rtt_monitor.get() avg_rtt, min_rtt = await self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish: if self._publish:
assert self._listeners is not None assert self._listeners is not None
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0: if len(seedlist) == 0:
# As per the spec: this should be treated as a failure. # As per the spec: this should be treated as a failure.
raise Exception raise Exception
except asyncio.CancelledError:
raise
except Exception: except Exception:
# As per the spec, upon encountering an error: # As per the spec, upon encountering an error:
# - An error must not be raised # - An error must not be raised
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
""" """
super().__init__( super().__init__(
topology, topology,
"pymongo_server_rtt_thread", "pymongo_server_rtt_task",
topology_settings.heartbeat_frequency, topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL, common.MIN_HEARTBEAT_INTERVAL,
) )
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
self._pool = pool self._pool = pool
self._moving_average = MovingAverage() self._moving_average = MovingAverage()
self._moving_min = MovingMinimum() self._moving_min = MovingMinimum()
self._lock = _create_lock() self._lock = _async_create_lock()
async def close(self) -> None: async def close(self) -> None:
self.gc_safe_close() self.gc_safe_close()
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
# thread has the socket checked out, it will be closed when checked in. # thread has the socket checked out, it will be closed when checked in.
await self._pool.reset() await self._pool.reset()
def add_sample(self, sample: float) -> None: async def add_sample(self, sample: float) -> None:
"""Add a RTT sample.""" """Add a RTT sample."""
with self._lock: async with self._lock:
self._moving_average.add_sample(sample) self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample) self._moving_min.add_sample(sample)
def get(self) -> tuple[Optional[float], float]: async def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min.""" """Get the calculated average, or None if no samples yet and the min."""
with self._lock: async with self._lock:
return self._moving_average.get(), self._moving_min.get() return self._moving_average.get(), self._moving_min.get()
def reset(self) -> None: async def reset(self) -> None:
"""Reset the average RTT.""" """Reset the average RTT."""
with self._lock: async with self._lock:
self._moving_average.reset() self._moving_average.reset()
self._moving_min.reset() self._moving_min.reset()
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
# heartbeat protocol (MongoDB 4.4+). # heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown? # XXX: Skip check if the server is unknown?
rtt = await self._ping() rtt = await self._ping()
self.add_sample(rtt) await self.add_sample(rtt)
except ReferenceError: except ReferenceError:
# Topology was garbage-collected. # Topology was garbage-collected.
await self.close() await self.close()
except asyncio.CancelledError:
raise
except Exception: except Exception:
await self._pool.reset() await self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown() shutdown()
atexit.register(_shutdown_resources) if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -1,219 +0,0 @@
# Copyright 2014-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run a target function on a background thread."""
from __future__ import annotations
import asyncio
import sys
import threading
import time
import weakref
from typing import Any, Optional
from pymongo.lock import _ALock, _create_lock
_IS_SYNC = False
class PeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background thread.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying thread.
"""
# threading.Event and its internal condition variable are expensive
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
# The executor's design is constrained by several Python issues, see
# "periodic_executor.rst" in this repository.
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._thread: Optional[threading.Thread] = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = _ALock(_create_lock())
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None:
"""Start. Multiple calls have no effect.
Not safe to call from multiple threads at once.
"""
with self._lock:
if self._thread_will_exit:
# If the background thread has read self._stopped as True
# there is a chance that it has not yet exited. The call to
# join should not block indefinitely because there is no
# other work done outside the while loop in self._run.
try:
assert self._thread is not None
self._thread.join()
except ReferenceError:
# Thread terminated.
pass
self._thread_will_exit = False
self._stopped = False
started: Any = False
try:
started = self._thread and self._thread.is_alive()
except ReferenceError:
# Thread terminated.
pass
if not started:
if _IS_SYNC:
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True
self._thread = weakref.proxy(thread)
_register_executor(self)
# Mitigation to RuntimeError firing when thread starts on shutdown
# https://github.com/python/cpython/issues/114570
try:
thread.start()
except RuntimeError as e:
if "interpreter shutdown" in str(e) or sys.is_finalizing():
self._thread = None
return
raise
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
def join(self, timeout: Optional[int] = None) -> None:
if self._thread is not None:
try:
self._thread.join(timeout)
except (ReferenceError, RuntimeError):
# Thread already terminated, or not yet started.
pass
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _should_stop(self) -> bool:
async with self._lock:
if self._stopped:
self._thread_will_exit = True
return True
return False
async def _run(self) -> None:
while not await self._should_stop():
try:
if not await self._target():
self._stopped = True
break
except BaseException:
async with self._lock:
self._stopped = True
self._thread_will_exit = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
# an executor is kept alive by a strong reference from its thread and perhaps
# from other objects. When the thread dies and all other referrers are freed,
# the executor is freed and removed from _EXECUTORS. If any threads are
# running when the interpreter begins to shut down, we try to halt and join
# them to avoid spurious errors.
_EXECUTORS = set()
def _register_executor(executor: PeriodicExecutor) -> None:
ref = weakref.ref(executor, _on_executor_deleted)
_EXECUTORS.add(ref)
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
_EXECUTORS.remove(ref)
def _shutdown_executors() -> None:
if _EXECUTORS is None:
return
# Copy the set. Stopping threads has the side effect of removing executors.
executors = list(_EXECUTORS)
# First signal all executors to close...
for ref in executors:
executor = ref()
if executor:
executor.close()
# ...then try to join them.
for ref in executors:
executor = ref()
if executor:
executor.join(1)
executor = None

View File

@ -23,7 +23,6 @@ import os
import socket import socket
import ssl import ssl
import sys import sys
import threading
import time import time
import weakref import weakref
from typing import ( from typing import (
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError, _CertificateError,
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _ACondition, _ALock, _create_lock from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import ( from pymongo.logger import (
_CONNECTION_LOGGER, _CONNECTION_LOGGER,
_ConnectionStatusMessage, _ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error raise AutoReconnect(msg) from error
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return await condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]: def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {} details = {}
timeout = _csot.get_timeout() timeout = _csot.get_timeout()
@ -706,6 +704,8 @@ class AsyncConnection:
# shutdown. # shutdown.
try: try:
self.conn.close() self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110 except Exception: # noqa: S110
pass pass
@ -992,8 +992,8 @@ class Pool:
# from the right side. # from the right side.
self.conns: collections.deque = collections.deque() self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set() self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock() self.lock = _async_create_lock()
self.lock = _ALock(_lock) self._max_connecting_cond = _async_create_condition(self.lock)
self.active_sockets = 0 self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events. # Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1 self.next_connection_id = 1
@ -1019,7 +1019,7 @@ class Pool:
# The first portion of the wait queue. # The first portion of the wait queue.
# Enforces: maxPoolSize # Enforces: maxPoolSize
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(_lock)) self.size_cond = _async_create_condition(self.lock)
self.requests = 0 self.requests = 0
self.max_pool_size = self.opts.max_pool_size self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size: if not self.max_pool_size:
@ -1027,7 +1027,7 @@ class Pool:
# The second portion of the wait queue. # The second portion of the wait queue.
# Enforces: maxConnecting # Enforces: maxConnecting
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting_cond = _async_create_condition(self.lock)
self._max_connecting = self.opts.max_connecting self._max_connecting = self.opts.max_connecting
self._pending = 0 self._pending = 0
self._client_id = client_id self._client_id = client_id
@ -1466,7 +1466,8 @@ class Pool:
async with self.size_cond: async with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True) self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size): while not (self.requests < self.max_pool_size):
if not await _cond_wait(self.size_cond, deadline): timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a # Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition. # timeout doesn't consume the condition.
if self.requests < self.max_pool_size: if self.requests < self.max_pool_size:
@ -1489,7 +1490,8 @@ class Pool:
async with self._max_connecting_cond: async with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False) self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting): while not (self.conns or self._pending < self._max_connecting):
if not await _cond_wait(self._max_connecting_cond, deadline): timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a # Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition. # timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting: if self.conns or self._pending < self._max_connecting:

View File

@ -27,8 +27,7 @@ import weakref
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous import periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.pool import Pool
@ -44,7 +43,11 @@ from pymongo.errors import (
WriteError, WriteError,
) )
from pymongo.hello import Hello from pymongo.hello import Hello
from pymongo.lock import _ACondition, _ALock, _create_lock from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import ( from pymongo.logger import (
_SDAM_LOGGER, _SDAM_LOGGER,
_SERVER_SELECTION_LOGGER, _SERVER_SELECTION_LOGGER,
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions()) self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False self._opened = False
self._closed = False self._closed = False
_lock = _create_lock() self._lock = _async_create_lock()
self._lock = _ALock(_lock) self._condition = _async_create_condition(
self._condition = _ACondition(self._settings.condition_class(_lock)) self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {} self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None self._max_cluster_time: Optional[ClusterTime] = None
@ -185,7 +189,7 @@ class Topology:
async def target() -> bool: async def target() -> bool:
return process_events_queue(weak) return process_events_queue(weak)
executor = periodic_executor.PeriodicExecutor( executor = periodic_executor.AsyncPeriodicExecutor(
interval=common.EVENTS_QUEUE_FREQUENCY, interval=common.EVENTS_QUEUE_FREQUENCY,
min_interval=common.MIN_HEARTBEAT_INTERVAL, min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target, target=target,
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that # change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've # came after our most recent apply_selector call, since we've
# held the lock until now. # held the lock until now.
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible() self._description.check_compatible()
now = time.monotonic() now = time.monotonic()
server_descriptions = self._description.apply_selector( server_descriptions = self._description.apply_selector(
@ -654,7 +658,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server.""" """Wake all monitors, wait for at least one to check its server."""
async with self._lock: async with self._lock:
self._request_check_all() self._request_check_all()
await self._condition.wait(wait_time) await _async_cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]: def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers. """Return a list of all data-bearing servers.
@ -742,7 +746,7 @@ class Topology:
if self._publish_server or self._publish_tp: if self._publish_server or self._publish_tp:
# Make sure the events executor thread is fully closed before publishing the remaining events # Make sure the events executor thread is fully closed before publishing the remaining events
self.__events_executor.close() self.__events_executor.close()
self.__events_executor.join(1) await self.__events_executor.join(1)
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type] process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
@property @property

View File

@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Internal helpers for lock and condition coordination primitives."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import collections
import os import os
import sys
import threading import threading
import time
import weakref import weakref
from typing import Any, Callable, Optional, TypeVar from asyncio import wait_for
from typing import Any, Optional, TypeVar
import pymongo._asyncio_lock
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
_T = TypeVar("_T") _T = TypeVar("_T")
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
# in older versions of Python
if sys.version_info >= (3, 13):
Lock = asyncio.Lock
Condition = asyncio.Condition
else:
Lock = pymongo._asyncio_lock.Lock
Condition = pymongo._asyncio_lock.Condition
def _create_lock() -> threading.Lock: def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and """Represents a lock that is tracked upon instantiation using a WeakSet and
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
return lock return lock
def _async_create_lock() -> Lock:
"""Represents an asyncio.Lock."""
return Lock()
def _create_condition(
lock: threading.Lock, condition_class: Optional[Any] = None
) -> threading.Condition:
"""Represents a threading.Condition."""
if condition_class:
return condition_class(lock)
return threading.Condition(lock)
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
"""Represents an asyncio.Condition."""
if condition_class:
return condition_class(lock)
return Condition(lock)
def _release_locks() -> None: def _release_locks() -> None:
# Completed the fork, reset all the locks in the child. # Completed the fork, reset all the locks in the child.
for lock in _forkable_locks: for lock in _forkable_locks:
@ -46,202 +81,12 @@ def _release_locks() -> None:
lock.release() lock.release()
# Needed only for synchro.py compat. async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
def _Lock(lock: threading.Lock) -> threading.Lock: try:
return lock return await wait_for(condition.wait(), timeout)
except asyncio.TimeoutError:
return False
class _ALock: def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
__slots__ = ("_lock",) return condition.wait(timeout)
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
return self._lock.acquire(blocking=blocking, timeout=timeout)
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._lock.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
def release(self) -> None:
self._lock.release()
async def __aenter__(self) -> _ALock:
await self.a_acquire()
return self
def __enter__(self) -> _ALock:
self._lock.acquire()
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)
class _ACondition:
__slots__ = ("_condition", "_waiters")
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
tstart = time.monotonic()
while True:
acquired = self._condition.acquire(blocking=False)
if acquired:
return True
if timeout > 0 and (time.monotonic() - tstart) > timeout:
return False
if not blocking:
return False
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
"""Wait until notified.
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e
self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break
if fut.done():
continue
try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue
idx += 1
for waiter in to_remove:
self._waiters.remove(waiter)
def notify_all(self) -> None:
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))
def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]
def release(self) -> None:
self._condition.release()
async def __aenter__(self) -> _ACondition:
await self.acquire()
return self
def __enter__(self) -> _ACondition:
self._condition.acquire()
return self
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()

View File

@ -29,6 +29,7 @@ from typing import (
) )
from pymongo import _csot, ssl_support from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception from pymongo.socket_checker import _errno_from_exception
@ -259,19 +260,20 @@ async def async_receive_data(
sock.settimeout(0.0) sock.settimeout(0.0)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn)) cancellation_task = create_task(_poll_cancellation(conn))
try: try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else: else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task] tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
) )
for task in pending: for task in pending:
task.cancel() task.cancel()
await asyncio.wait(pending) if pending:
await asyncio.wait(pending)
if len(done) == 0: if len(done) == 0:
raise socket.timeout("timed out") raise socket.timeout("timed out")
if read_task in done: if read_task in done:

View File

@ -23,9 +23,102 @@ import time
import weakref import weakref
from typing import Any, Optional from typing import Any, Optional
from pymongo._asyncio_task import create_task
from pymongo.lock import _create_lock from pymongo.lock import _create_lock
_IS_SYNC = True _IS_SYNC = False
class AsyncPeriodicExecutor:
def __init__(
self,
interval: float,
min_interval: float,
target: Any,
name: Optional[str] = None,
):
"""Run a target function periodically on a background task.
If the target's return value is false, the executor stops.
:param interval: Seconds between calls to `target`.
:param min_interval: Minimum seconds between calls if `wake` is
called very often.
:param target: A function.
:param name: A name to give the underlying task.
"""
self._event = False
self._interval = interval
self._min_interval = min_interval
self._target = target
self._stopped = False
self._task: Optional[asyncio.Task] = None
self._name = name
self._skip_sleep = False
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def open(self) -> None:
"""Start. Multiple calls have no effect."""
self._stopped = False
if self._task is None or (
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
):
self._task = create_task(self._run(), name=self._name)
def close(self, dummy: Any = None) -> None:
"""Stop. To restart, call open().
The dummy parameter allows an executor's close method to be a weakref
callback; see monitor.py.
"""
self._stopped = True
async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
try:
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
except asyncio.TimeoutError:
# Task timed out
pass
except asyncio.exceptions.CancelledError:
# Task was already finished, or not yet started.
raise
def wake(self) -> None:
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval: int) -> None:
self._interval = new_interval
def skip_sleep(self) -> None:
self._skip_sleep = True
async def _run(self) -> None:
while not self._stopped:
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
raise asyncio.CancelledError
try:
if not await self._target():
self._stopped = True
break
except BaseException:
self._stopped = True
raise
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = time.monotonic() + self._interval
while not self._stopped and time.monotonic() < deadline:
await asyncio.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
class PeriodicExecutor: class PeriodicExecutor:
@ -64,19 +157,6 @@ class PeriodicExecutor:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
def _run_async(self) -> None:
# The default asyncio loop implementation on Windows
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
# We explicitly use a different loop implementation here to prevent that issue
if sys.platform == "win32":
loop = asyncio.SelectorEventLoop()
try:
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
finally:
loop.close()
else:
asyncio.run(self._run()) # type: ignore[func-returns-value]
def open(self) -> None: def open(self) -> None:
"""Start. Multiple calls have no effect. """Start. Multiple calls have no effect.
@ -104,10 +184,7 @@ class PeriodicExecutor:
pass pass
if not started: if not started:
if _IS_SYNC: thread = threading.Thread(target=self._run, name=self._name)
thread = threading.Thread(target=self._run, name=self._name)
else:
thread = threading.Thread(target=self._run_async, name=self._name)
thread.daemon = True thread.daemon = True
self._thread = weakref.proxy(thread) self._thread = weakref.proxy(thread)
_register_executor(self) _register_executor(self)

View File

@ -474,7 +474,6 @@ class _ClientBulk:
if op_type == "delete": if op_type == "delete":
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
full_result[f"{op_type}Results"][original_index] = res full_result[f"{op_type}Results"][original_index] = res
except Exception as exc: except Exception as exc:
# Attempt to close the cursor, then raise top-level error. # Attempt to close the cursor, then raise top-level error.
if cmd_cursor.alive: if cmd_cursor.alive:

View File

@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: Connection, more_to_come: bool): def __init__(self, conn: Connection, more_to_come: bool):
self.conn: Optional[Connection] = conn self.conn: Optional[Connection] = conn
self.more_to_come = more_to_come self.more_to_come = more_to_come
self._alock = _create_lock() self._lock = _create_lock()
def update_exhaust(self, more_to_come: bool) -> None: def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come self.more_to_come = more_to_come

View File

@ -15,6 +15,7 @@
"""Support for explicit client-side field level encryption.""" """Support for explicit client-side field level encryption."""
from __future__ import annotations from __future__ import annotations
import asyncio
import contextlib import contextlib
import enum import enum
import socket import socket
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so # BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged. # we should propagate them unchanged.
raise raise
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
raise EncryptionError(exc) from exc raise EncryptionError(exc) from exc
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
conn.close() conn.close()
except (PyMongoError, MongoCryptError): except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly. raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error: except Exception as error:
# Wrap I/O errors in PyMongo exceptions. # Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error) _raise_connection_failure((host, port), error)
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
database.create_collection(name=name, **kwargs), database.create_collection(name=name, **kwargs),
encrypted_fields, encrypted_fields,
) )
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -32,6 +32,7 @@ access:
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import contextlib import contextlib
import os import os
import warnings import warnings
@ -58,7 +59,7 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, uri_parser from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.client_options import ClientOptions from pymongo.client_options import ClientOptions
from pymongo.errors import ( from pymongo.errors import (
AutoReconnect, AutoReconnect,
@ -74,7 +75,11 @@ from pymongo.errors import (
WaitQueueTimeoutError, WaitQueueTimeoutError,
WriteConcernError, WriteConcernError,
) )
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason from pymongo.monitoring import ConnectionClosedReason
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult from pymongo.results import ClientBulkWriteResult
from pymongo.server_selectors import writable_server_selector from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, periodic_executor from pymongo.synchronous import client_session, database
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.client_session import _EmptyServerSession
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
address=address, address=address,
) )
with operation.conn_mgr._alock: with operation.conn_mgr._lock:
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn) err_handler.contribute_socket(operation.conn_mgr.conn)
return server.run_operation( return server.run_operation(
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try: try:
if conn_mgr: if conn_mgr:
with conn_mgr._alock: with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction. # Cursor is pinned to LB outside of a transaction.
assert address is not None assert address is not None
assert conn_mgr.conn is not None assert conn_mgr.conn is not None
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors: for address, cursor_id, conn_mgr in pinned_cursors:
try: try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it # Raise the exception when client is closed so that it
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items(): for address, cursor_ids in address_to_cursor_ids.items():
try: try:
self._kill_cursors(cursor_ids, address, topology, session=None) self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
raise raise
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try: try:
self._process_kill_cursors() self._process_kill_cursors()
self._topology.update_pool() self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc: except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed: if isinstance(exc, InvalidOperation) and self._topology._closed:
return return

View File

@ -16,24 +16,24 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import atexit import atexit
import logging import logging
import time import time
import weakref import weakref
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
from pymongo import common from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum from pymongo._csot import MovingMinimum
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello from pymongo.hello import Hello
from pymongo.lock import _create_lock from pymongo.lock import _create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver from pymongo.srv_resolver import _SrvResolver
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.periodic_executor import _shutdown_executors
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
except ReferenceError: except ReferenceError:
# Topology was garbage-collected. # Topology was garbage-collected.
self.close() self.close()
finally:
if self._executor._stopped:
self._rtt_monitor.close()
def _check_server(self) -> ServerDescription: def _check_server(self) -> ServerDescription:
"""Call hello or read the next streaming response. """Call hello or read the next streaming response.
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
details = cast(Mapping[str, Any], exc.details) details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime")) self._topology.receive_cluster_time(details.get("$clusterTime"))
raise raise
except asyncio.CancelledError:
raise
except ReferenceError: except ReferenceError:
raise raise
except Exception as error: except Exception as error:
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0: if len(seedlist) == 0:
# As per the spec: this should be treated as a failure. # As per the spec: this should be treated as a failure.
raise Exception raise Exception
except asyncio.CancelledError:
raise
except Exception: except Exception:
# As per the spec, upon encountering an error: # As per the spec, upon encountering an error:
# - An error must not be raised # - An error must not be raised
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
except ReferenceError: except ReferenceError:
# Topology was garbage-collected. # Topology was garbage-collected.
self.close() self.close()
except asyncio.CancelledError:
raise
except Exception: except Exception:
self._pool.reset() self._pool.reset()
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
shutdown() shutdown()
atexit.register(_shutdown_resources) if _IS_SYNC:
atexit.register(_shutdown_resources)

View File

@ -23,7 +23,6 @@ import os
import socket import socket
import ssl import ssl
import sys import sys
import threading
import time import time
import weakref import weakref
from typing import ( from typing import (
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError, _CertificateError,
) )
from pymongo.hello import Hello, HelloCompat from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock, _Lock from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import ( from pymongo.logger import (
_CONNECTION_LOGGER, _CONNECTION_LOGGER,
_ConnectionStatusMessage, _ConnectionStatusMessage,
@ -208,11 +211,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error raise AutoReconnect(msg) from error
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return condition.wait(timeout)
def _get_timeout_details(options: PoolOptions) -> dict[str, float]: def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {} details = {}
timeout = _csot.get_timeout() timeout = _csot.get_timeout()
@ -704,6 +702,8 @@ class Connection:
# shutdown. # shutdown.
try: try:
self.conn.close() self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110 except Exception: # noqa: S110
pass pass
@ -988,8 +988,8 @@ class Pool:
# from the right side. # from the right side.
self.conns: collections.deque = collections.deque() self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set() self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock() self.lock = _create_lock()
self.lock = _Lock(_lock) self._max_connecting_cond = _create_condition(self.lock)
self.active_sockets = 0 self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events. # Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1 self.next_connection_id = 1
@ -1015,7 +1015,7 @@ class Pool:
# The first portion of the wait queue. # The first portion of the wait queue.
# Enforces: maxPoolSize # Enforces: maxPoolSize
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self.size_cond = threading.Condition(_lock) self.size_cond = _create_condition(self.lock)
self.requests = 0 self.requests = 0
self.max_pool_size = self.opts.max_pool_size self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size: if not self.max_pool_size:
@ -1023,7 +1023,7 @@ class Pool:
# The second portion of the wait queue. # The second portion of the wait queue.
# Enforces: maxConnecting # Enforces: maxConnecting
# Also used for: clearing the wait queue # Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(_lock) self._max_connecting_cond = _create_condition(self.lock)
self._max_connecting = self.opts.max_connecting self._max_connecting = self.opts.max_connecting
self._pending = 0 self._pending = 0
self._client_id = client_id self._client_id = client_id
@ -1460,7 +1460,8 @@ class Pool:
with self.size_cond: with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True) self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size): while not (self.requests < self.max_pool_size):
if not _cond_wait(self.size_cond, deadline): timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a # Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition. # timeout doesn't consume the condition.
if self.requests < self.max_pool_size: if self.requests < self.max_pool_size:
@ -1483,7 +1484,8 @@ class Pool:
with self._max_connecting_cond: with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False) self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting): while not (self.conns or self._pending < self._max_connecting):
if not _cond_wait(self._max_connecting_cond, deadline): timeout = deadline - time.monotonic() if deadline else None
if not _cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a # Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition. # timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting: if self.conns or self._pending < self._max_connecting:

View File

@ -27,7 +27,7 @@ import weakref
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.errors import ( from pymongo.errors import (
ConnectionFailure, ConnectionFailure,
InvalidOperation, InvalidOperation,
@ -39,7 +39,11 @@ from pymongo.errors import (
WriteError, WriteError,
) )
from pymongo.hello import Hello from pymongo.hello import Hello
from pymongo.lock import _create_lock, _Lock from pymongo.lock import (
_cond_wait,
_create_condition,
_create_lock,
)
from pymongo.logger import ( from pymongo.logger import (
_SDAM_LOGGER, _SDAM_LOGGER,
_SERVER_SELECTION_LOGGER, _SERVER_SELECTION_LOGGER,
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
secondary_server_selector, secondary_server_selector,
writable_server_selector, writable_server_selector,
) )
from pymongo.synchronous import periodic_executor
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.pool import Pool from pymongo.synchronous.pool import Pool
@ -170,9 +173,10 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions()) self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False self._opened = False
self._closed = False self._closed = False
_lock = _create_lock() self._lock = _create_lock()
self._lock = _Lock(_lock) self._condition = _create_condition(
self._condition = self._settings.condition_class(_lock) self._lock, self._settings.condition_class if _IS_SYNC else None
)
self._servers: dict[_Address, Server] = {} self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None self._max_cluster_time: Optional[ClusterTime] = None
@ -354,7 +358,7 @@ class Topology:
# change, or for a timeout. We won't miss any changes that # change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've # came after our most recent apply_selector call, since we've
# held the lock until now. # held the lock until now.
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) _cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible() self._description.check_compatible()
now = time.monotonic() now = time.monotonic()
server_descriptions = self._description.apply_selector( server_descriptions = self._description.apply_selector(
@ -652,7 +656,7 @@ class Topology:
"""Wake all monitors, wait for at least one to check its server.""" """Wake all monitors, wait for at least one to check its server."""
with self._lock: with self._lock:
self._request_check_all() self._request_check_all()
self._condition.wait(wait_time) _cond_wait(self._condition, wait_time)
def data_bearing_servers(self) -> list[ServerDescription]: def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers. """Return a list of all data-bearing servers.

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio import asyncio
import gc import gc
import logging
import multiprocessing import multiprocessing
import os import os
import signal import signal
@ -25,6 +26,7 @@ import subprocess
import sys import sys
import threading import threading
import time import time
import traceback
import unittest import unittest
import warnings import warnings
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class ClientContext:
client.close() client.close()
def _init_client(self): def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = self._connect(host, port) self.client = self._connect(host, port)
if self.client is not None: if self.client is not None:
# Return early when connected to dataLake as mongohoused does not # Return early when connected to dataLake as mongohoused does not
@ -860,6 +864,16 @@ class ClientContext:
client_context = ClientContext() client_context = ClientContext()
def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
class PyMongoTestCase(unittest.TestCase): class PyMongoTestCase(unittest.TestCase):
def assertEqualCommand(self, expected, actual, msg=None): def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
class UnitTest(PyMongoTestCase): class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB.""" """Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod def setUp(self) -> None:
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
def _setup_class(cls):
pass pass
@classmethod def tearDown(self) -> None:
def _tearDown_class(cls):
pass pass
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
db: Database db: Database
credentials: Dict[str, str] credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_connection @client_context.require_connection
def _setup_class(cls): def setUp(self) -> None:
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers") raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless") raise SkipTest("this test does not support serverless")
cls.client = client_context.client self.client = client_context.client
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
if client_context.auth_enabled: if client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd} self.credentials = {"username": db_user, "password": db_pwd}
else: else:
cls.credentials = {} self.credentials = {}
@classmethod
def _tearDown_class(cls):
pass
def cleanup_colls(self, *collections): def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection.""" """Cleanup collections faster than drop_collection."""
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass # MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible # multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True. # with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_no_load_balancer @client_context.require_no_load_balancer
def _setup_class(cls): def setUp(self) -> None:
pass
@classmethod
def _tearDown_class(cls):
pass
def setUp(self):
super().setUp() super().setUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable() self.client_knobs.enable()
def tearDown(self): def tearDown(self) -> None:
self.client_knobs.disable() self.client_knobs.disable()
super().tearDown() super().tearDown()
@ -1253,7 +1211,6 @@ def teardown():
c.drop_database("pymongo_test_mike") c.drop_database("pymongo_test_mike")
c.drop_database("pymongo_test_bernie") c.drop_database("pymongo_test_bernie")
c.close() c.close()
print_running_clients() print_running_clients()

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio import asyncio
import gc import gc
import logging
import multiprocessing import multiprocessing
import os import os
import signal import signal
@ -25,6 +26,7 @@ import subprocess
import sys import sys
import threading import threading
import time import time
import traceback
import unittest import unittest
import warnings import warnings
from asyncio import iscoroutinefunction from asyncio import iscoroutinefunction
@ -191,6 +193,8 @@ class AsyncClientContext:
await client.close() await client.close()
async def _init_client(self): async def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = await self._connect(host, port) self.client = await self._connect(host, port)
if self.client is not None: if self.client is not None:
# Return early when connected to dataLake as mongohoused does not # Return early when connected to dataLake as mongohoused does not
@ -862,6 +866,16 @@ class AsyncClientContext:
async_client_context = AsyncClientContext() async_client_context = AsyncClientContext()
async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def assertEqualCommand(self, expected, actual, msg=None): def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
class AsyncUnitTest(AsyncPyMongoTestCase): class AsyncUnitTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB.""" """Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod async def asyncSetUp(self) -> None:
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
async def _setup_class(cls):
pass pass
@classmethod async def asyncTearDown(self) -> None:
async def _tearDown_class(cls):
pass pass
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
db: AsyncDatabase db: AsyncDatabase
credentials: Dict[str, str] credentials: Dict[str, str]
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_connection @async_client_context.require_connection
async def _setup_class(cls): async def asyncSetUp(self) -> None:
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers") raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
raise SkipTest("this test does not support serverless") raise SkipTest("this test does not support serverless")
cls.client = async_client_context.client self.client = async_client_context.client
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
if async_client_context.auth_enabled: if async_client_context.auth_enabled:
cls.credentials = {"username": db_user, "password": db_pwd} self.credentials = {"username": db_user, "password": db_pwd}
else: else:
cls.credentials = {} self.credentials = {}
@classmethod
async def _tearDown_class(cls):
pass
async def cleanup_colls(self, *collections): async def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection.""" """Cleanup collections faster than drop_collection."""
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
# MockClients tests that use replicaSet, directConnection=True, pass # MockClients tests that use replicaSet, directConnection=True, pass
# multiple seed addresses, or wait for heartbeat events are incompatible # multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True. # with loadBalanced=True.
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_no_load_balancer @async_client_context.require_no_load_balancer
async def _setup_class(cls): async def asyncSetUp(self) -> None:
pass await super().asyncSetUp()
@classmethod
async def _tearDown_class(cls):
pass
def setUp(self):
super().setUp()
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
self.client_knobs.enable() self.client_knobs.enable()
def tearDown(self): async def asyncTearDown(self) -> None:
self.client_knobs.disable() self.client_knobs.disable()
super().tearDown() await super().asyncTearDown()
async def async_setup(): async def async_setup():
@ -1271,7 +1229,6 @@ async def async_teardown():
await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie") await c.drop_database("pymongo_test_bernie")
await c.close() await c.close()
print_running_clients() print_running_clients()

View File

@ -22,7 +22,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy() return asyncio.get_event_loop_policy()
@pytest_asyncio.fixture(scope="session", autouse=True) @pytest_asyncio.fixture(scope="package", autouse=True)
async def test_setup_and_teardown(): async def test_setup_and_teardown():
await async_setup() await async_setup()
yield yield

View File

@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
coll: AsyncCollection coll: AsyncCollection
coll_w0: AsyncCollection coll_w0: AsyncCollection
@classmethod
async def _setup_class(cls):
await super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
async def asyncSetUp(self): async def asyncSetUp(self):
super().setUp() await super().asyncSetUp()
self.coll = self.db.test
await self.coll.drop() await self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual): def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response.""" """Compare response from bulk.execute() to expected response."""
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
@classmethod
@async_client_context.require_auth @async_client_context.require_auth
@async_client_context.require_no_api_version @async_client_context.require_no_api_version
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self): async def asyncSetUp(self):
super().setUp() await super().asyncSetUp()
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
await self.db.command( await self.db.command(
"createRole", "createRole",
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
w: Optional[int] w: Optional[int]
secondary: AsyncMongoClient secondary: AsyncMongoClient
@classmethod async def asyncSetUp(self):
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class() self.w = async_client_context.w
cls.w = async_client_context.w self.secondary = None
cls.secondary = None if self.w is not None and self.w > 1:
if cls.w is not None and cls.w > 1:
for member in (await async_client_context.hello)["hosts"]: for member in (await async_client_context.hello)["hosts"]:
if member != (await async_client_context.hello)["primary"]: if member != (await async_client_context.hello)["primary"]:
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) self.secondary = await self.async_single_client(*partition_node(member))
break break
@classmethod async def asyncTearDown(self):
async def async_tearDownClass(cls): if self.secondary:
if cls.secondary: await self.secondary.close()
await cls.secondary.close()
async def cause_wtimeout(self, requests, ordered): async def cause_wtimeout(self, requests, ordered):
if not async_client_context.test_commands_enabled: if not async_client_context.test_commands_enabled:

View File

@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
dbs: list dbs: list
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams @async_client_context.require_change_streams
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
cls.dbs = [cls.db, cls.client.pymongo_test_2] self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod async def asyncTearDown(self):
async def _tearDown_class(cls): for db in self.dbs:
for db in cls.dbs: await self.client.drop_database(db)
await cls.client.drop_database(db) await super().asyncTearDown()
await super()._tearDown_class()
async def change_stream_with_client(self, client, *args, **kwargs): async def change_stream_with_client(self, client, *args, **kwargs):
return await client.watch(*args, **kwargs) return await client.watch(*args, **kwargs)
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
@classmethod
@async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_version_min(4, 0, 0, -1)
@async_client_context.require_change_streams @async_client_context.require_change_streams
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
async def change_stream_with_client(self, client, *args, **kwargs): async def change_stream_with_client(self, client, *args, **kwargs):
return await client[self.db.name].watch(*args, **kwargs) return await client[self.db.name].watch(*args, **kwargs)
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
class TestAsyncCollectionAsyncChangeStream( class TestAsyncCollectionAsyncChangeStream(
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
): ):
@classmethod
@async_client_context.require_change_streams @async_client_context.require_change_streams
async def _setup_class(cls):
await super()._setup_class()
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
# Use a new collection for each test. # Use a new collection for each test.
await self.watched_collection().drop() await self.watched_collection().drop()
await self.watched_collection().insert_one({}) await self.watched_collection().insert_one({})
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener listener: AllowListEventListener
@classmethod
@async_client_context.require_connection @async_client_context.require_connection
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
cls.listener = AllowListEventListener("aggregate", "getMore") self.listener = AllowListEventListener("aggregate", "getMore")
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
await super()._tearDown_class()
def asyncSetUp(self):
super().asyncSetUp()
self.listener.reset() self.listener.reset()
async def asyncSetUpCluster(self, scenario_dict): async def asyncSetUpCluster(self, scenario_dict):

View File

@ -73,7 +73,6 @@ from test.utils import (
is_greenthread_patched, is_greenthread_patched,
lazy_client_trial, lazy_client_trial,
one, one,
wait_until,
) )
import bson import bson
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
client: AsyncMongoClient client: AsyncMongoClient
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): self.client = await self.async_rs_or_single_client(
cls.client = await cls.unmanaged_async_rs_or_single_client(
connect=False, serverSelectionTimeoutMS=100 connect=False, serverSelectionTimeoutMS=100
) )
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_fixtures(self, caplog): def inject_fixtures(self, caplog):
self._caplog = caplog self._caplog = caplog
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket, two # When the reaper runs at the same time as the get_socket, two
# connections could be created and checked into the pool. # connections could be created and checked into the pool.
self.assertGreaterEqual(len(server._pool.conns), 1) self.assertGreaterEqual(len(server._pool.conns), 1)
wait_until(lambda: conn not in server._pool.conns, "remove stale socket") await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1): with client_knobs(kill_cursor_frequency=0.1):
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
# When the reaper runs at the same time as the get_socket, # When the reaper runs at the same time as the get_socket,
# maxPoolSize=1 should prevent two connections from being created. # maxPoolSize=1 should prevent two connections from being created.
self.assertEqual(1, len(server._pool.conns)) self.assertEqual(1, len(server._pool.conns))
wait_until(lambda: conn not in server._pool.conns, "remove stale socket") await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
async def test_max_idle_time_reaper_removes_stale(self): async def test_max_idle_time_reaper_removes_stale(self):
with client_knobs(kill_cursor_frequency=0.1): with client_knobs(kill_cursor_frequency=0.1):
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
async with server._pool.checkout() as conn_two: async with server._pool.checkout() as conn_two:
pass pass
self.assertIs(conn_one, conn_two) self.assertIs(conn_one, conn_two)
wait_until( await async_wait_until(
lambda: len(server._pool.conns) == 0, lambda: len(server._pool.conns) == 0,
"stale socket reaped and new one NOT added to the pool", "stale socket reaped and new one NOT added to the pool",
) )
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
server = await (await client._get_topology()).select_server( server = await (await client._get_topology()).select_server(
readable_server_selector, _Op.TEST readable_server_selector, _Op.TEST
) )
wait_until( await async_wait_until(
lambda: len(server._pool.conns) == 10, lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections", "pool initialized with 10 connections",
) )
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
# Assert that if a socket is closed, a new one takes its place # Assert that if a socket is closed, a new one takes its place
async with server._pool.checkout() as conn: async with server._pool.checkout() as conn:
conn.close_conn(None) conn.close_conn(None)
wait_until( await async_wait_until(
lambda: len(server._pool.conns) == 10, lambda: len(server._pool.conns) == 10,
"a closed socket gets replaced from the pool", "a closed socket gets replaced from the pool",
) )
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
async with eval(the_repr) as client_two: async with eval(the_repr) as client_two:
self.assertEqual(client_two, client) self.assertEqual(client_two, client)
def test_getters(self): async def test_getters(self):
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") await async_wait_until(
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
)
async def test_list_databases(self): async def test_list_databases(self):
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
self.assertFalse(client._topology._opened) self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started. # Ensure kill cursors thread has not been started.
kc_thread = client._kill_cursors_executor._thread if _IS_SYNC:
self.assertFalse(kc_thread and kc_thread.is_alive()) kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread. # Using the client should open topology and start the thread.
await client.admin.command("ping") await client.admin.command("ping")
self.assertTrue(client._topology._opened) self.assertTrue(client._topology._opened)
kc_thread = client._kill_cursors_executor._thread if _IS_SYNC:
self.assertTrue(kc_thread and kc_thread.is_alive()) kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
async def test_close_does_not_open_servers(self): async def test_close_does_not_open_servers(self):
client = await self.async_rs_client(connect=False) client = await self.async_rs_client(connect=False)
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
async def test_server_selection_timeout(self): async def test_server_selection_timeout(self):
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout) self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
self.assertRaises( self.assertRaises(
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
) )
await client.close()
client = AsyncMongoClient( client = AsyncMongoClient(
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
) )
self.assertAlmostEqual(0.1, client.options.server_selection_timeout) self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout) self.assertAlmostEqual(0, client.options.server_selection_timeout)
await client.close()
# Test invalid timeout in URI ignored and set to default. # Test invalid timeout in URI ignored and set to default.
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout) self.assertAlmostEqual(30, client.options.server_selection_timeout)
await client.close()
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout) self.assertAlmostEqual(30, client.options.server_selection_timeout)
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
await async_client_context.port, await async_client_context.port,
) )
await self.async_single_client(uri, event_listeners=[listener]) await self.async_single_client(uri, event_listeners=[listener])
wait_until( await async_wait_until(
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
) )
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
pool = await async_get_pool(client) pool = await async_get_pool(client)
original_connect = pool.connect original_connect = pool.connect
def stall_connect(*args, **kwargs): async def stall_connect(*args, **kwargs):
time.sleep(2) await asyncio.sleep(2)
return original_connect(*args, **kwargs) return await original_connect(*args, **kwargs)
pool.connect = stall_connect pool.connect = stall_connect
# Un-patch Pool.connect to break the cyclic reference. # Un-patch Pool.connect to break the cyclic reference.
self.addCleanup(delattr, pool, "connect") self.addCleanup(delattr, pool, "connect")
# Wait for the background thread to start creating connections # Wait for the background thread to start creating connections
wait_until(lambda: len(pool.conns) > 1, "start creating connections") await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
# Assert that application operations do not block. # Assert that application operations do not block.
for _ in range(10): for _ in range(10):
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
await client.close() await client.close()
# Add cursor to kill cursors queue # Add cursor to kill cursors queue
del cursor del cursor
wait_until( await async_wait_until(
lambda: client._kill_cursors_queue, lambda: client._kill_cursors_queue,
"waited for cursor to be added to queue", "waited for cursor to be added to queue",
) )
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
await cursor.to_list() await cursor.to_list()
self.assertTrue(conn.closed) self.assertTrue(conn.closed)
wait_until( await async_wait_until(
lambda: len(client._kill_cursors_queue) == 0, lambda: len(client._kill_cursors_queue) == 0,
"waited for all killCursor requests to complete", "waited for all killCursor requests to complete",
) )
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
) )
self.addAsyncCleanup(c.close) self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect") await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1)) self.assertEqual(await c.address, ("a", 1))
# Fail over. # Fail over.
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
) )
self.addAsyncCleanup(c.close) self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect") await async_wait_until(lambda: len(c.nodes) == 3, "connect")
# Total failure. # Total failure.
c.kill_host("a:1") c.kill_host("a:1")
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
wait_until(lambda: len(c.nodes) == 2, "connect") await async_wait_until(lambda: len(c.nodes) == 2, "connect")
c.kill_host("a:1") c.kill_host("a:1")
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
) )
self.addAsyncCleanup(c.close) self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect") await async_wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1)) self.assertEqual(await c.address, ("a", 1))
self.assertEqual(await c.arbiters, {("c", 3)}) self.assertEqual(await c.arbiters, {("c", 3)})
# Assert that we create 2 and only 2 pooled connections. # Assert that we create 2 and only 2 pooled connections.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
# Assert that we do not create connections to arbiters. # Assert that we do not create connections to arbiters.
arbiter = c._topology.get_server_by_address(("c", 3)) arbiter = c._topology.get_server_by_address(("c", 3))
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
) )
self.addAsyncCleanup(c.close) self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 1, "connect") await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3)) self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection. # Assert that we create 1 pooled connection.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3)) arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1) self.assertEqual(len(arbiter.pool.conns), 1)

View File

@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
warn_context: Any warn_context: Any
collation: Collation collation: Collation
@classmethod
@async_client_context.require_connection @async_client_context.require_connection
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
cls.collation = Collation("en_US") self.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings() self.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__() self.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
@classmethod async def asyncTearDown(self) -> None:
async def _tearDown_class(cls): self.warn_context.__exit__()
cls.warn_context.__exit__() self.warn_context = None
cls.warn_context = None
await cls.client.close()
await super()._tearDown_class()
def tearDown(self):
self.listener.reset() self.listener.reset()
super().tearDown() await super().asyncTearDown()
def last_command_started(self): def last_command_started(self):
return self.listener.started_events[-1].command return self.listener.started_events[-1].command

View File

@ -40,7 +40,6 @@ from test.utils import (
async_get_pool, async_get_pool,
async_is_mongos, async_is_mongos,
async_wait_until, async_wait_until,
wait_until,
) )
from bson import encode from bson import encode
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
db: AsyncDatabase db: AsyncDatabase
client: AsyncMongoClient client: AsyncMongoClient
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): await super().asyncSetUp()
cls.client = AsyncMongoClient(connect=False) self.client = self.simple_client(connect=False)
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
def test_collection(self): def test_collection(self):
self.assertRaises(TypeError, AsyncCollection, self.db, 5) self.assertRaises(TypeError, AsyncCollection, self.db, 5)
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
class AsyncTestCollection(AsyncIntegrationTest): class AsyncTestCollection(AsyncIntegrationTest):
w: int w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = async_client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
async def async_tearDownClass(cls):
await cls.db.drop_collection("test_large_limit")
async def asyncSetUp(self): async def asyncSetUp(self):
await self.db.test.drop() await super().asyncSetUp()
self.w = async_client_context.w # type: ignore
async def asyncTearDown(self): async def asyncTearDown(self):
await self.db.test.drop() await self.db.test.drop()
await self.db.drop_collection("test_large_limit")
await super().asyncTearDown()
@contextlib.contextmanager @contextlib.contextmanager
def write_concern_collection(self): def write_concern_collection(self):
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
await db.test.insert_one({"y": 1}, bypass_document_validation=True) await db.test.insert_one({"y": 1}, bypass_document_validation=True)
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") async def predicate():
return await db_w0.test.find_one({"x": 1})
await async_wait_until(predicate, "find w:0 replaced document")
async def test_update_bypass_document_validation(self): async def test_update_bypass_document_validation(self):
db = self.db db = self.db
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
await cur.close() await cur.close()
cur = None cur = None
# Wait until the background thread returns the socket. # Wait until the background thread returns the socket.
wait_until(lambda: pool.active_sockets == 0, "return socket") await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# The socket should be discarded. # The socket should be discarded.
self.assertEqual(0, len(pool.conns)) self.assertEqual(0, len(pool.conns))

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous import (
AsyncIntegrationTest,
async_client_context,
reset_client_context,
unittest,
)
from test.asynchronous.helpers import async_repl_set_step_down from test.asynchronous.helpers import async_repl_set_step_down
from test.utils import ( from test.utils import (
CMAPListener, CMAPListener,
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
listener: CMAPListener listener: CMAPListener
coll: AsyncCollection coll: AsyncCollection
@classmethod
@async_client_context.require_replica_set @async_client_context.require_replica_set
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() self.listener = CMAPListener()
cls.listener = CMAPListener() self.client = await self.async_rs_or_single_client(
cls.client = await cls.unmanaged_async_rs_or_single_client( event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
) )
# Ensure connections to all servers in replica set. This is to test # Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that # that the is_writable flag is properly updated for connections that
# survive a replica set election. # survive a replica set election.
await async_ensure_all_connected(cls.client) await async_ensure_all_connected(self.client)
cls.listener.reset() self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
async def asyncSetUp(self):
# Note that all ops use same write-concern as self.db (majority). # Note that all ops use same write-concern as self.db (majority).
await self.db.drop_collection("step-down") await self.db.drop_collection("step-down")
await self.db.create_collection("step-down") await self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0) self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]: for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"]) self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
async def test_store_all_others_as_entities(self): async def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1() self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5) self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -34,9 +34,9 @@ from test.utils import (
AllowListEventListener, AllowListEventListener,
EventListener, EventListener,
OvertCommandListener, OvertCommandListener,
async_wait_until,
delay, delay,
ignore_deprecations, ignore_deprecations,
wait_until,
) )
from bson import decode_all from bson import decode_all
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout): with self.assertRaises(ExecutionTimeout):
await cursor.next() await cursor.next()
def assertCursorKilled(): async def assertCursorKilled():
wait_until( await async_wait_until(
lambda: len(listener.succeeded_events), lambda: len(listener.succeeded_events),
"find successful killCursors command", "find successful killCursors command",
) )
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
self.assertEqual(1, len(listener.succeeded_events)) self.assertEqual(1, len(listener.succeeded_events))
self.assertEqual("killCursors", listener.succeeded_events[0].command_name) self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
assertCursorKilled() await assertCursorKilled()
listener.reset() listener.reset()
cursor = await coll.aggregate([], batchSize=1) cursor = await coll.aggregate([], batchSize=1)
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
with self.assertRaises(ExecutionTimeout): with self.assertRaises(ExecutionTimeout):
await cursor.next() await cursor.next()
assertCursorKilled() await assertCursorKilled()
def test_delete_not_initialized(self): def test_delete_not_initialized(self):
# Creating a cursor with invalid arguments will not run __init__ # Creating a cursor with invalid arguments will not run __init__
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
class TestRawBatchCommandCursor(AsyncIntegrationTest): class TestRawBatchCommandCursor(AsyncIntegrationTest):
@classmethod
async def _setup_class(cls):
await super()._setup_class()
async def test_aggregate_raw(self): async def test_aggregate_raw(self):
c = self.db.test c = self.db.test
await c.drop() await c.drop()

View File

@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
class TestDatabaseAggregation(AsyncIntegrationTest): class TestDatabaseAggregation(AsyncIntegrationTest):
def setUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
self.pipeline: List[Mapping[str, Any]] = [ self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}}, {"$listLocalSessions": {}},
{"$limit": 1}, {"$limit": 1},

View File

@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
"""Base class for encryption integration tests.""" """Base class for encryption integration tests."""
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@async_client_context.require_version_min(4, 2, -1) @async_client_context.require_version_min(4, 2, -1)
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
def assertEncrypted(self, val): def assertEncrypted(self, val):
self.assertIsInstance(val, Binary) self.assertIsInstance(val, Binary)
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
class TestClientMaxWireVersion(AsyncIntegrationTest): class TestClientMaxWireVersion(AsyncIntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
@async_client_context.require_version_max(4, 0, 99) @async_client_context.require_version_max(4, 0, 99)
async def test_raise_max_wire_version_error(self): async def test_raise_max_wire_version_error(self):
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
"local": None, "local": None,
} }
@classmethod
@unittest.skipUnless( @unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set", "No environment credentials are set",
) )
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
await cls.client.db.coll.drop() await self.client.db.coll.drop()
cls.vault = await create_key_vault(cls.client.keyvault.datakeys) self.vault = await create_key_vault(self.client.keyvault.datakeys)
# Configure the encrypted field via the local schema_map option. # Configure the encrypted field via the local schema_map option.
schemas = { schemas = {
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
} }
} }
opts = AutoEncryptionOpts( opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
) )
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard" auto_encryption_opts=opts, uuidRepresentation="standard"
) )
cls.client_encryption = cls.unmanaged_create_client_encryption( self.client_encryption = self.create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
) )
@classmethod
async def _tearDown_class(cls):
await cls.vault.drop()
await cls.client.close()
await cls.client_encrypted.close()
await cls.client_encryption.close()
def setUp(self):
self.listener.reset() self.listener.reset()
async def asyncTearDown(self) -> None:
await self.vault.drop()
async def run_test(self, provider_name): async def run_test(self, provider_name):
# Create data key. # Create data key.
master_key: Any = self.MASTER_KEYS[provider_name] master_key: Any = self.MASTER_KEYS[provider_name]
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
class TestCorpus(AsyncEncryptionIntegrationTest): class TestCorpus(AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
@staticmethod @staticmethod
def kms_providers(): def kms_providers():
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
client_encrypted: AsyncMongoClient client_encrypted: AsyncMongoClient
listener: OvertCommandListener listener: OvertCommandListener
@classmethod async def asyncSetUp(self):
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class()
db = async_client_context.client.db db = async_client_context.client.db
cls.coll = db.coll self.coll = db.coll
await cls.coll.drop() await self.coll.drop()
# Configure the encrypted 'db.coll' collection via jsonSchema. # Configure the encrypted 'db.coll' collection via jsonSchema.
json_schema = json_data("limits", "limits-schema.json") json_schema = json_data("limits", "limits-schema.json")
await db.create_collection( await db.create_collection(
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
await coll.insert_one(json_data("limits", "limits-key.json")) await coll.insert_one(json_data("limits", "limits-key.json"))
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener] auto_encryption_opts=opts, event_listeners=[self.listener]
) )
cls.coll_encrypted = cls.client_encrypted.db.coll self.coll_encrypted = self.client_encrypted.db.coll
@classmethod async def asyncTearDown(self) -> None:
async def _tearDown_class(cls): await self.coll_encrypted.drop()
await cls.coll_encrypted.drop()
await cls.client_encrypted.close()
await super()._tearDown_class()
async def test_01_insert_succeeds_under_2MiB(self): async def test_01_insert_succeeds_under_2MiB(self):
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset() self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_04_bulk_batch_split(self): async def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json") limits_doc = json_data("limits", "limits-doc.json")
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
doc2.update(limits_doc) doc2.update(limits_doc)
self.listener.reset() self.listener.reset()
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
async def test_05_insert_succeeds_just_under_16MiB(self): async def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
class TestCustomEndpoint(AsyncEncryptionIntegrationTest): class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
"""Prose tests for creating data keys with a custom endpoint.""" """Prose tests for creating data keys with a custom endpoint."""
@classmethod
@unittest.skipUnless( @unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set", "No environment credentials are set",
) )
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
def setUp(self):
kms_providers = { kms_providers = {
"aws": AWS_CREDS, "aws": AWS_CREDS,
"azure": AZURE_CREDS, "azure": AZURE_CREDS,
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
self._kmip_host_error = None self._kmip_host_error = None
self._invalid_host_error = None self._invalid_host_error = None
async def asyncTearDown(self):
await self.client_encryption.close()
await self.client_encryption_invalid.close()
async def run_test_expected_success(self, provider_name, master_key): async def run_test_expected_success(self, provider_name, master_key):
data_key_id = await self.client_encryption.create_data_key( data_key_id = await self.client_encryption.create_data_key(
provider_name, master_key=master_key provider_name, master_key=master_key
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys" KEYVAULT_COLL = "datakeys"
client: AsyncMongoClient client: AsyncMongoClient
async def asyncSetUp(self): async def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
await create_key_vault(keyvault, self.DEK) await create_key_vault(keyvault, self.DEK)
async def _test_explicit(self, expectation): async def _test_explicit(self, expectation):
await self._setup()
client_encryption = self.create_client_encryption( client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type] self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
async_client_context.client, async_client_context.client,
OPTS, OPTS,
) )
self.addAsyncCleanup(client_encryption.close)
ciphertext = await client_encryption.encrypt( ciphertext = await client_encryption.encrypt(
"string0", "string0",
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0") self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
async def _test_automatic(self, expectation_extjson, payload): async def _test_automatic(self, expectation_extjson, payload):
await self._setup()
encrypted_db = "db" encrypted_db = "db"
encrypted_coll = "coll" encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
client = await self.async_rs_or_single_client( client = await self.async_rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
) )
self.addAsyncCleanup(client.aclose)
coll = client.get_database(encrypted_db).get_collection( coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
async def _setup_class(cls): async def asyncSetUp(self):
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
cls.DEK = json_data(BASE, "custom", "azure-dek.json") self.DEK = json_data(BASE, "custom", "azure-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class() await super().asyncSetUp()
async def test_explicit(self): async def test_explicit(self):
return await self._test_explicit( return await self._test_explicit(
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
async def _setup_class(cls): async def asyncSetUp(self):
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
cls.DEK = json_data(BASE, "custom", "gcp-dek.json") self.DEK = json_data(BASE, "custom", "gcp-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
await super()._setup_class() await super().asyncSetUp()
async def test_explicit(self): async def test_explicit(self):
return await self._test_explicit( return await self._test_explicit(
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(AsyncEncryptionIntegrationTest): class TestDeadlockProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
self.client_test = await self.async_rs_or_single_client( self.client_test = await self.async_rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
) )
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
self.ciphertext = await client_encryption.encrypt( self.ciphertext = await client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
) )
await client_encryption.close()
self.client_listener = OvertCommandListener() self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener() self.topology_listener = TopologyEventListener()
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(AsyncEncryptionIntegrationTest): class TestDecryptProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client self.client = async_client_context.client
await self.client.db.drop_collection("decryption_events") await self.client.db.drop_collection("decryption_events")
await create_key_vault(self.client.keyvault.datakeys) await create_key_vault(self.client.keyvault.datakeys)
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest): class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
self.client = async_client_context.client self.client = async_client_context.client
await create_key_vault(self.client.keyvault.datakeys) await create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary)
await client_encryption.close()
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(AsyncEncryptionIntegrationTest): class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
mongocryptd_client: AsyncMongoClient mongocryptd_client: AsyncMongoClient
MONGOCRYPTD_PORT = 27020 MONGOCRYPTD_PORT = 27020
@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
async def _setup_class(cls):
await super()._setup_class()
start_mongocryptd(cls.MONGOCRYPTD_PORT)
@classmethod
async def _tearDown_class(cls):
await super()._tearDown_class()
async def asyncSetUp(self) -> None: async def asyncSetUp(self) -> None:
await super().asyncSetUp()
start_mongocryptd(self.MONGOCRYPTD_PORT)
self.listener = OvertCommandListener() self.listener = OvertCommandListener()
self.mongocryptd_client = self.simple_client( self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]

View File

@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
class AsyncTestGridFile(AsyncIntegrationTest): class AsyncTestGridFile(AsyncIntegrationTest):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp()
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
async def test_basic(self): async def test_basic(self):

View File

@ -16,498 +16,447 @@ from __future__ import annotations
import asyncio import asyncio
import sys import sys
import threading
import unittest import unittest
from pymongo.lock import _async_create_condition, _async_create_lock
sys.path[0:0] = [""] sys.path[0:0] = [""]
from pymongo.lock import _ACondition if sys.version_info < (3, 13):
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py await cond.acquire()
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
await cond.acquire()
if await cond.wait():
result.append(1)
return True
async def c2(result):
await cond.acquire()
if await cond.wait():
result.append(2)
return True
async def c3(result):
await cond.acquire()
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
self.assertTrue(await cond.acquire())
cond.notify()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
try:
await wait_task
except asyncio.CancelledError:
# Should not happen, since no cancellation points
pass
self.assertTrue(cond.locked())
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
cond = _ACondition(threading.Condition(threading.Lock()))
async def wait_on_cond():
nonlocal waited
async with cond:
waited = True # Make sure this area was reached
await cond.wait()
waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
await cond.acquire()
cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
self.assertTrue(waiter.cancelled())
self.assertTrue(waited)
async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
with self.assertRaises(RuntimeError):
await cond.wait()
async def test_wait_for(self):
cond = _ACondition(threading.Condition(threading.Lock()))
presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait(): if await cond.wait():
result.append(1) result.append(1)
return True return True
async def c2(result): async def c2(result):
async with cond: await cond.acquire()
if await cond.wait(): if await cond.wait():
result.append(2) result.append(2)
return True return True
async def c3(result): async def c3(result):
async with cond: await cond.acquire()
if await cond.wait(): if await cond.wait():
result.append(3) result.append(3)
return True return True
t1 = asyncio.create_task(c1(result)) t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result)) t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result)) t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _ACondition(threading.Condition(threading.Lock()))
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _ACondition(threading.Condition(threading.Lock()))
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0) await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
# Cancel it while awaiting the lock self.assertTrue(await cond.acquire())
# This cancel could come the outside. cond.notify()
c[0].cancel() await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _async_create_condition(_async_create_lock())
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try: try:
async with asyncio.timeout(1): await wait_task
await condition.wait_for(lambda: state == 0) except asyncio.CancelledError:
except TimeoutError: # Should not happen, since no cancellation points
pass pass
self.assertEqual(state, 0)
# clean up self.assertTrue(cond.locked())
state = -1
condition.notify_all()
await c[1]
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
class TestCondition(unittest.IsolatedAsyncioTestCase): cond = _async_create_condition(_async_create_lock())
async def test_multiple_loops_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
def tmain(cond): async def wait_on_cond():
async def atmain(cond): nonlocal waited
await asyncio.sleep(1)
async with cond: async with cond:
cond.notify(1) waited = True # Make sure this area was reached
await cond.wait()
asyncio.run(atmain(cond)) waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
t = threading.Thread(target=tmain, args=(cond,)) await cond.acquire()
t.start() cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
async with cond: self.assertTrue(waiter.cancelled())
self.assertTrue(await cond.wait(30)) self.assertTrue(waited)
t.join()
async def test_multiple_loops_notify_all(self): async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock())) cond = _async_create_condition(_async_create_lock())
results = [] with self.assertRaises(RuntimeError):
await cond.wait()
def tmain(cond, results): async def test_wait_for(self):
async def atmain(cond, results): cond = _async_create_condition(_async_create_lock())
await asyncio.sleep(1) presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _async_create_condition(_async_create_lock())
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
async with cond: async with cond:
res = await cond.wait(30) if await cond.wait():
results.append(res) result.append(1)
return True
asyncio.run(atmain(cond, results)) async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
nthreads = 5 async def c3(result):
threads = [] async with cond:
for _ in range(nthreads): if await cond.wait():
threads.append(threading.Thread(target=tmain, args=(cond, results))) result.append(3)
for t in threads: return True
t.start()
await asyncio.sleep(2) t1 = asyncio.create_task(c1(result))
async with cond: t2 = asyncio.create_task(c2(result))
cond.notify_all() t3 = asyncio.create_task(c3(result))
for t in threads: await asyncio.sleep(0)
t.join() self.assertEqual([], result)
self.assertEqual(results, [True] * nthreads) async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
if __name__ == "__main__": self.assertTrue(t1.done())
unittest.main() self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _async_create_condition(_async_create_lock())
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _async_create_condition(_async_create_lock())
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _async_create_condition(_async_create_lock())
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _async_create_condition(_async_create_lock())
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _async_create_condition(_async_create_lock())
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0)
# Cancel it while awaiting the lock
# This cancel could come the outside.
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
if __name__ == "__main__":
unittest.main()

View File

@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
listener: EventListener listener: EventListener
@classmethod @classmethod
@async_client_context.require_connection def setUpClass(cls) -> None:
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener() cls.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod @async_client_context.require_connection
async def _tearDown_class(cls): async def asyncSetUp(self) -> None:
await cls.client.close() await super().asyncSetUp()
await super()._tearDown_class()
async def asyncTearDown(self):
self.listener.reset() self.listener.reset()
await super().asyncTearDown() self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener], retryWrites=False
)
async def test_started_simple(self): async def test_started_simple(self):
await self.client.pymongo_test.command("ping") await self.client.pymongo_test.command("ping")
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
saved_listeners: Any saved_listeners: Any
@classmethod @classmethod
@async_client_context.require_connection def setUpClass(cls) -> None:
async def _setup_class(cls):
await super()._setup_class()
cls.listener = OvertCommandListener() cls.listener = OvertCommandListener()
# We plan to call register(), which internally modifies _LISTENERS. # We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener) monitoring.register(cls.listener)
cls.client = await cls.unmanaged_async_single_client()
# Get one (authenticated) socket in the pool.
await cls.client.pymongo_test.command("ping")
@classmethod
async def _tearDown_class(cls):
monitoring._LISTENERS = cls.saved_listeners
await cls.client.close()
await super()._tearDown_class()
@async_client_context.require_connection
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.listener.reset() self.listener.reset()
self.client = await self.async_single_client()
# Get one (authenticated) socket in the pool.
await self.client.pymongo_test.command("ping")
@classmethod
def tearDownClass(cls):
monitoring._LISTENERS = cls.saved_listeners
async def test_simple(self): async def test_simple(self):
await self.client.pymongo_test.command("ping") await self.client.pymongo_test.command("ping")

View File

@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter deprecation_filter: DeprecationFilter
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class() self.deprecation_filter = DeprecationFilter()
cls.deprecation_filter = DeprecationFilter()
@classmethod async def asyncTearDown(self) -> None:
async def _tearDown_class(cls): self.deprecation_filter.stop()
cls.deprecation_filter.stop()
await super()._tearDown_class()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs knobs: client_knobs
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) self.client = await self.async_rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
@classmethod async def asyncTearDown(self) -> None:
async def _tearDown_class(cls): self.knobs.disable()
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
@async_client_context.require_no_standalone @async_client_context.require_no_standalone
async def test_actionable_error_message(self): async def test_actionable_error_message(self):
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener listener: OvertCommandListener
knobs: client_knobs knobs: client_knobs
@classmethod
@async_client_context.require_no_mmap @async_client_context.require_no_mmap
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = await cls.unmanaged_async_rs_or_single_client( self.client = await self.async_rs_or_single_client(
retryWrites=True, event_listeners=[cls.listener] retryWrites=True, event_listeners=[self.listener]
) )
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
await cls.client.close()
await super()._tearDown_class()
async def asyncSetUp(self):
if async_client_context.is_rs and async_client_context.test_commands_enabled: if async_client_context.is_rs and async_client_context.test_commands_enabled:
await self.client.admin.command( await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
await self.client.admin.command( await self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
) )
self.knobs.disable()
async def test_supported_single_statement_no_retry(self): async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener() listener = OvertCommandListener()
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
RUN_ON_SERVERLESS = True RUN_ON_SERVERLESS = True
fail_insert: dict fail_insert: dict
@classmethod
@async_client_context.require_replica_set @async_client_context.require_replica_set
@async_client_context.require_no_mmap @async_client_context.require_no_mmap
@async_client_context.require_failCommand_fail_point @async_client_context.require_failCommand_fail_point
async def _setup_class(cls): async def asyncSetUp(self) -> None:
await super()._setup_class() await super().asyncSetUp()
cls.fail_insert = { self.fail_insert = {
"configureFailPoint": "failCommand", "configureFailPoint": "failCommand",
"mode": {"times": 2}, "mode": {"times": 2},
"data": { "data": {

View File

@ -38,7 +38,6 @@ from test.utils import (
ExceptionCatchingThread, ExceptionCatchingThread,
OvertCommandListener, OvertCommandListener,
async_wait_until, async_wait_until,
wait_until,
) )
from bson import DBRef from bson import DBRef
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
client2: AsyncMongoClient client2: AsyncMongoClient
sensitive_commands: Set[str] sensitive_commands: Set[str]
@classmethod
@async_client_context.require_sessions @async_client_context.require_sessions
async def _setup_class(cls): async def asyncSetUp(self):
await super()._setup_class() await super().asyncSetUp()
# Create a second client so we can make sure clients cannot share # Create a second client so we can make sure clients cannot share
# sessions. # sessions.
cls.client2 = await cls.unmanaged_async_rs_or_single_client() self.client2 = await self.async_rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid". # Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear() monitoring._SENSITIVE_COMMANDS.clear()
@classmethod
async def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
await cls.client2.close()
await super()._tearDown_class()
async def asyncSetUp(self):
self.listener = SessionTestListener() self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener() self.session_checker_listener = SessionTestListener()
self.client = await self.async_rs_or_single_client( self.client = await self.async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener] event_listeners=[self.listener, self.session_checker_listener]
) )
self.addAsyncCleanup(self.client.close)
self.db = self.client.pymongo_test self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)} self.initial_lsids = {s["id"] for s in session_ids(self.client)}
async def asyncTearDown(self): async def asyncTearDown(self):
"""All sessions used in the test must be returned to the pool.""" monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
await self.client.drop_database("pymongo_test") await self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy() used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events: for event in self.session_checker_listener.started_events:
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
current_lsids = {s["id"] for s in session_ids(self.client)} current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids) self.assertLessEqual(used_lsids, current_lsids)
await super().asyncTearDown()
async def _test_ops(self, client, *ops): async def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0] listener = client.options.event_listeners[0]
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener listener: SessionTestListener
client: AsyncMongoClient client: AsyncMongoClient
@classmethod
async def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
async def _tearDown_class(cls):
await cls.client.close()
@async_client_context.require_sessions @async_client_context.require_sessions
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.listener = SessionTestListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
@async_client_context.require_no_standalone @async_client_context.require_no_standalone
async def test_core(self): async def test_core(self):

View File

@ -26,7 +26,7 @@ sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import ( from test.utils import (
OvertCommandListener, OvertCommandListener,
wait_until, async_wait_until,
) )
from typing import List from typing import List
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client( client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000 async_client_context.mongos_seeds(), localThresholdMS=1000
) )
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test coll = client.test.test
# Create the collection. # Create the collection.
await coll.insert_one({}) await coll.insert_one({})
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
client = await self.async_rs_client( client = await self.async_rs_client(
async_client_context.mongos_seeds(), localThresholdMS=1000 async_client_context.mongos_seeds(), localThresholdMS=1000
) )
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
coll = client.test.test coll = client.test.test
# Create the collection. # Create the collection.
await coll.insert_one({}) await coll.insert_one({})
@ -403,21 +403,12 @@ class PatchSessionTimeout:
class TestTransactionsConvenientAPI(AsyncTransactionsBase): class TestTransactionsConvenientAPI(AsyncTransactionsBase):
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class() self.mongos_clients = []
cls.mongos_clients = []
if async_client_context.supports_transactions(): if async_client_context.supports_transactions():
for address in async_client_context.mongoses: for address in async_client_context.mongoses:
cls.mongos_clients.append( self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
@classmethod
async def _tearDown_class(cls):
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
async def _set_fail_point(self, client, command_args): async def _set_fail_point(self, client, command_args):
cmd = {"configureFailPoint": "failCommand"} cmd = {"configureFailPoint": "failCommand"}

View File

@ -50,6 +50,7 @@ from test.unified_format_shared import (
) )
from test.utils import ( from test.utils import (
async_get_pool, async_get_pool,
async_wait_until,
camel_to_snake, camel_to_snake,
camel_to_snake_args, camel_to_snake_args,
parse_spec_options, parse_spec_options,
@ -304,7 +305,6 @@ class EntityMapUtil:
kwargs["h"] = uri kwargs["h"] = uri
client = await self.test.async_rs_or_single_client(**kwargs) client = await self.test.async_rs_or_single_client(**kwargs)
self[spec["id"]] = client self[spec["id"]] = client
self.test.addAsyncCleanup(client.close)
return return
elif entity_type == "database": elif entity_type == "database":
client = self[spec["client"]] client = self[spec["client"]]
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
await db.create_collection(coll_name, write_concern=wc, **opts) await db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod @classmethod
async def _setup_class(cls): def setUpClass(cls) -> None:
# super call creates internal client cls.client
await super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not await cls.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
cls.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
cls.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
cls.mongos_clients.append(
await cls.unmanaged_async_single_client("{}:{}".format(*address))
)
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs( cls.knobs = client_knobs(
heartbeat_frequency=0.1, heartbeat_frequency=0.1,
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
cls.knobs.enable() cls.knobs.enable()
@classmethod @classmethod
async def _tearDown_class(cls): def tearDownClass(cls) -> None:
cls.knobs.disable() cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
async def asyncSetUp(self): async def asyncSetUp(self):
# super call creates internal client cls.client
await super().asyncSetUp() await super().asyncSetUp()
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not await self.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if async_client_context.storage_engine == "mmapv1":
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
self.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
self.mongos_clients = []
if (
async_client_context.supports_transactions()
and not async_client_context.load_balancer
and not async_client_context.serverless
):
for address in async_client_context.mongoses:
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
# process schemaVersion # process schemaVersion
# note: we check major schema version during class generation # note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC["schemaVersion"]) version = Version.from_string(self.TEST_SPEC["schemaVersion"])
self.assertLessEqual( self.assertLessEqual(
version, version,
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
) )
client = await self.async_single_client("{}:{}".format(*session._pinned_address)) client = await self.async_single_client("{}:{}".format(*session._pinned_address))
self.addAsyncCleanup(client.close)
await self.__set_fail_point(client=client, command_args=spec["failPoint"]) await self.__set_fail_point(client=client, command_args=spec["failPoint"])
async def _testOperation_createEntities(self, spec): async def _testOperation_createEntities(self, spec):
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
client, event, count = spec["client"], spec["event"], spec["count"] client, event, count = spec["client"], spec["event"], spec["count"]
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
def _testOperation_waitForEvent(self, spec): async def _testOperation_waitForEvent(self, spec):
"""Run the waitForEvent test operation. """Run the waitForEvent test operation.
Wait for a number of events to be published, or fail. Wait for a number of events to be published, or fail.
""" """
client, event, count = spec["client"], spec["event"], spec["count"] client, event, count = spec["client"], spec["event"], spec["count"]
wait_until( await async_wait_until(
lambda: self._event_count(client, event) >= count, lambda: self._event_count(client, event) >= count,
f"find {count} {event} event(s)", f"find {count} {event} event(s)",
) )

View File

@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
knobs: client_knobs knobs: client_knobs
listener: EventListener listener: EventListener
@classmethod async def asyncSetUp(self) -> None:
async def _setup_class(cls): await super().asyncSetUp()
await super()._setup_class() self.mongos_clients = []
cls.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
@classmethod
async def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
await client.close()
await super()._tearDown_class()
def setUp(self):
super().setUp()
self.targets = {} self.targets = {}
self.listener = None # type: ignore self.listener = None # type: ignore
self.pool_listener = None self.pool_listener = None
self.server_listener = None self.server_listener = None
self.maxDiff = None self.maxDiff = None
async def asyncTearDown(self) -> None:
self.knobs.disable()
async def _set_fail_point(self, client, command_args): async def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")]) cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args) cmd.update(command_args)
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
self.listener = listener self.listener = listener
self.pool_listener = pool_listener self.pool_listener = pool_listener
self.server_listener = server_listener self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.close)
# Create session0 and session1. # Create session0 and session1.
sessions = {} sessions = {}

View File

@ -20,7 +20,7 @@ def event_loop_policy():
return asyncio.get_event_loop_policy() return asyncio.get_event_loop_policy()
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="package", autouse=True)
def test_setup_and_teardown(): def test_setup_and_teardown():
setup() setup()
yield yield

View File

@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
coll: Collection coll: Collection
coll_w0: Collection coll_w0: Collection
@classmethod
def _setup_class(cls):
super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.coll = self.db.test
self.coll.drop() self.coll.drop()
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
def assertEqualResponse(self, expected, actual): def assertEqualResponse(self, expected, actual):
"""Compare response from bulk.execute() to expected response.""" """Compare response from bulk.execute() to expected response."""
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
class BulkAuthorizationTestBase(BulkTestBase): class BulkAuthorizationTestBase(BulkTestBase):
@classmethod
@client_context.require_auth @client_context.require_auth
@client_context.require_no_api_version @client_context.require_no_api_version
def _setup_class(cls):
super()._setup_class()
def setUp(self): def setUp(self):
super().setUp() super().setUp()
client_context.create_user(self.db.name, "readonly", "pw", ["read"]) client_context.create_user(self.db.name, "readonly", "pw", ["read"])
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
w: Optional[int] w: Optional[int]
secondary: MongoClient secondary: MongoClient
@classmethod def setUp(self):
def _setup_class(cls): super().setUp()
super()._setup_class() self.w = client_context.w
cls.w = client_context.w self.secondary = None
cls.secondary = None if self.w is not None and self.w > 1:
if cls.w is not None and cls.w > 1:
for member in (client_context.hello)["hosts"]: for member in (client_context.hello)["hosts"]:
if member != (client_context.hello)["primary"]: if member != (client_context.hello)["primary"]:
cls.secondary = cls.unmanaged_single_client(*partition_node(member)) self.secondary = self.single_client(*partition_node(member))
break break
@classmethod def tearDown(self):
def async_tearDownClass(cls): if self.secondary:
if cls.secondary: self.secondary.close()
cls.secondary.close()
def cause_wtimeout(self, requests, ordered): def cause_wtimeout(self, requests, ordered):
if not client_context.test_commands_enabled: if not client_context.test_commands_enabled:

View File

@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
dbs: list dbs: list
@classmethod
@client_context.require_version_min(4, 0, 0, -1) @client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams @client_context.require_change_streams
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
cls.dbs = [cls.db, cls.client.pymongo_test_2] self.dbs = [self.db, self.client.pymongo_test_2]
@classmethod def tearDown(self):
def _tearDown_class(cls): for db in self.dbs:
for db in cls.dbs: self.client.drop_database(db)
cls.client.drop_database(db) super().tearDown()
super()._tearDown_class()
def change_stream_with_client(self, client, *args, **kwargs): def change_stream_with_client(self, client, *args, **kwargs):
return client.watch(*args, **kwargs) return client.watch(*args, **kwargs)
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
@classmethod
@client_context.require_version_min(4, 0, 0, -1) @client_context.require_version_min(4, 0, 0, -1)
@client_context.require_change_streams @client_context.require_change_streams
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
def change_stream_with_client(self, client, *args, **kwargs): def change_stream_with_client(self, client, *args, **kwargs):
return client[self.db.name].watch(*args, **kwargs) return client[self.db.name].watch(*args, **kwargs)
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
@classmethod
@client_context.require_change_streams @client_context.require_change_streams
def _setup_class(cls):
super()._setup_class()
def setUp(self): def setUp(self):
super().setUp()
# Use a new collection for each test. # Use a new collection for each test.
self.watched_collection().drop() self.watched_collection().drop()
self.watched_collection().insert_one({}) self.watched_collection().insert_one({})
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
RUN_ON_LOAD_BALANCER = True RUN_ON_LOAD_BALANCER = True
listener: AllowListEventListener listener: AllowListEventListener
@classmethod
@client_context.require_connection @client_context.require_connection
def _setup_class(cls):
super()._setup_class()
cls.listener = AllowListEventListener("aggregate", "getMore")
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
cls.client.close()
super()._tearDown_class()
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.listener = AllowListEventListener("aggregate", "getMore")
self.client = self.rs_or_single_client(event_listeners=[self.listener])
self.listener.reset() self.listener.reset()
def setUpCluster(self, scenario_dict): def setUpCluster(self, scenario_dict):

View File

@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
client: MongoClient client: MongoClient
@classmethod def setUp(self) -> None:
def _setup_class(cls): self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@classmethod
def _tearDown_class(cls):
cls.client.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_fixtures(self, caplog): def inject_fixtures(self, caplog):
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
self.assertFalse(client._topology._opened) self.assertFalse(client._topology._opened)
# Ensure kill cursors thread has not been started. # Ensure kill cursors thread has not been started.
kc_thread = client._kill_cursors_executor._thread if _IS_SYNC:
self.assertFalse(kc_thread and kc_thread.is_alive()) kc_thread = client._kill_cursors_executor._thread
self.assertFalse(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertFalse(kc_task and not kc_task.done())
# Using the client should open topology and start the thread. # Using the client should open topology and start the thread.
client.admin.command("ping") client.admin.command("ping")
self.assertTrue(client._topology._opened) self.assertTrue(client._topology._opened)
kc_thread = client._kill_cursors_executor._thread if _IS_SYNC:
self.assertTrue(kc_thread and kc_thread.is_alive()) kc_thread = client._kill_cursors_executor._thread
self.assertTrue(kc_thread and kc_thread.is_alive())
else:
kc_task = client._kill_cursors_executor._task
self.assertTrue(kc_task and not kc_task.done())
def test_close_does_not_open_servers(self): def test_close_does_not_open_servers(self):
client = self.rs_client(connect=False) client = self.rs_client(connect=False)
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
def test_server_selection_timeout(self): def test_server_selection_timeout(self):
client = MongoClient(serverSelectionTimeoutMS=100, connect=False) client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout) self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient(serverSelectionTimeoutMS=0, connect=False) client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
self.assertRaises( self.assertRaises(
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
) )
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
self.assertAlmostEqual(0.1, client.options.server_selection_timeout) self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
self.assertAlmostEqual(0, client.options.server_selection_timeout) self.assertAlmostEqual(0, client.options.server_selection_timeout)
client.close()
# Test invalid timeout in URI ignored and set to default. # Test invalid timeout in URI ignored and set to default.
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout) self.assertAlmostEqual(30, client.options.server_selection_timeout)
client.close()
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
self.assertAlmostEqual(30, client.options.server_selection_timeout) self.assertAlmostEqual(30, client.options.server_selection_timeout)

View File

@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
warn_context: Any warn_context: Any
collation: Collation collation: Collation
@classmethod
@client_context.require_connection @client_context.require_connection
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) self.client = self.rs_or_single_client(event_listeners=[self.listener])
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
cls.collation = Collation("en_US") self.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings() self.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__() self.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
@classmethod def tearDown(self) -> None:
def _tearDown_class(cls): self.warn_context.__exit__()
cls.warn_context.__exit__() self.warn_context = None
cls.warn_context = None
cls.client.close()
super()._tearDown_class()
def tearDown(self):
self.listener.reset() self.listener.reset()
super().tearDown() super().tearDown()

View File

@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
db: Database db: Database
client: MongoClient client: MongoClient
@classmethod def setUp(self) -> None:
def _setup_class(cls): super().setUp()
cls.client = MongoClient(connect=False) self.client = self.simple_client(connect=False)
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.client.close()
def test_collection(self): def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5) self.assertRaises(TypeError, Collection, self.db, 5)
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
class TestCollection(IntegrationTest): class TestCollection(IntegrationTest):
w: int w: int
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.w = client_context.w # type: ignore
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
else:
asyncio.run(cls.async_tearDownClass())
@classmethod
def async_tearDownClass(cls):
cls.db.drop_collection("test_large_limit")
def setUp(self): def setUp(self):
self.db.test.drop() super().setUp()
self.w = client_context.w # type: ignore
def tearDown(self): def tearDown(self):
self.db.test.drop() self.db.test.drop()
self.db.drop_collection("test_large_limit")
super().tearDown()
@contextlib.contextmanager @contextlib.contextmanager
def write_concern_collection(self): def write_concern_collection(self):
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
db.test.insert_one({"y": 1}, bypass_document_validation=True) db.test.insert_one({"y": 1}, bypass_document_validation=True)
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") def predicate():
return db_w0.test.find_one({"x": 1})
wait_until(predicate, "find w:0 replaced document")
def test_update_bypass_document_validation(self): def test_update_bypass_document_validation(self):
db = self.db db = self.db

View File

@ -19,7 +19,12 @@ import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest from test import (
IntegrationTest,
client_context,
reset_client_context,
unittest,
)
from test.helpers import repl_set_step_down from test.helpers import repl_set_step_down
from test.utils import ( from test.utils import (
CMAPListener, CMAPListener,
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
listener: CMAPListener listener: CMAPListener
coll: Collection coll: Collection
@classmethod
@client_context.require_replica_set @client_context.require_replica_set
def _setup_class(cls): def setUp(self):
super()._setup_class() self.listener = CMAPListener()
cls.listener = CMAPListener() self.client = self.rs_or_single_client(
cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
) )
# Ensure connections to all servers in replica set. This is to test # Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for connections that # that the is_writable flag is properly updated for connections that
# survive a replica set election. # survive a replica set election.
ensure_all_connected(cls.client) ensure_all_connected(self.client)
cls.listener.reset() self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
@classmethod
def _tearDown_class(cls):
cls.client.close()
def setUp(self):
# Note that all ops use same write-concern as self.db (majority). # Note that all ops use same write-concern as self.db (majority).
self.db.drop_collection("step-down") self.db.drop_collection("step-down")
self.db.create_collection("step-down") self.db.create_collection("step-down")

View File

@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
self.assertGreater(len(final_entity_map["events1"]), 0) self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]: for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"]) self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
def test_store_all_others_as_entities(self): def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1() self.scenario_runner = UnifiedSpecTestMixinV1()
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5) self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
client.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
class TestRawBatchCommandCursor(IntegrationTest): class TestRawBatchCommandCursor(IntegrationTest):
@classmethod
def _setup_class(cls):
super()._setup_class()
def test_aggregate_raw(self): def test_aggregate_raw(self):
c = self.db.test c = self.db.test
c.drop() c.drop()

View File

@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
class TestCollectionWCustomType(IntegrationTest): class TestCollectionWCustomType(IntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.db.test.drop() self.db.test.drop()
def tearDown(self): def tearDown(self):
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
class TestGridFileCustomType(IntegrationTest): class TestGridFileCustomType(IntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.db.drop_collection("fs.files") self.db.drop_collection("fs.files")
self.db.drop_collection("fs.chunks") self.db.drop_collection("fs.chunks")
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_change_streams @client_context.require_change_streams
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
cls.db.test.delete_many({}) self.db.test.delete_many({})
def tearDown(self): def tearDown(self):
self.input_target.drop() self.input_target.drop()
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0) @client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams @client_context.require_change_streams
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
cls.db.test.delete_many({}) self.db.test.delete_many({})
def tearDown(self): def tearDown(self):
self.input_target.drop() self.input_target.drop()
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
@classmethod
@client_context.require_version_min(4, 0, 0) @client_context.require_version_min(4, 0, 0)
@client_context.require_change_streams @client_context.require_change_streams
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
cls.db.test.delete_many({}) self.db.test.delete_many({})
def tearDown(self): def tearDown(self):
self.input_target.drop() self.input_target.drop()

View File

@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
class TestDatabaseAggregation(IntegrationTest): class TestDatabaseAggregation(IntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.pipeline: List[Mapping[str, Any]] = [ self.pipeline: List[Mapping[str, Any]] = [
{"$listLocalSessions": {}}, {"$listLocalSessions": {}},
{"$limit": 1}, {"$limit": 1},

View File

@ -211,11 +211,10 @@ class TestClientOptions(PyMongoTestCase):
class EncryptionIntegrationTest(IntegrationTest): class EncryptionIntegrationTest(IntegrationTest):
"""Base class for encryption integration tests.""" """Base class for encryption integration tests."""
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@client_context.require_version_min(4, 2, -1) @client_context.require_version_min(4, 2, -1)
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
def assertEncrypted(self, val): def assertEncrypted(self, val):
self.assertIsInstance(val, Binary) self.assertIsInstance(val, Binary)
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
class TestClientMaxWireVersion(IntegrationTest): class TestClientMaxWireVersion(IntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def _setup_class(cls): def setUp(self):
super()._setup_class() super().setUp()
@client_context.require_version_max(4, 0, 99) @client_context.require_version_max(4, 0, 99)
def test_raise_max_wire_version_error(self): def test_raise_max_wire_version_error(self):
@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
"local": None, "local": None,
} }
@classmethod
@unittest.skipUnless( @unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set", "No environment credentials are set",
) )
def _setup_class(cls): def setUp(self):
super()._setup_class() super().setUp()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) self.client = self.rs_or_single_client(event_listeners=[self.listener])
cls.client.db.coll.drop() self.client.db.coll.drop()
cls.vault = create_key_vault(cls.client.keyvault.datakeys) self.vault = create_key_vault(self.client.keyvault.datakeys)
# Configure the encrypted field via the local schema_map option. # Configure the encrypted field via the local schema_map option.
schemas = { schemas = {
@ -844,25 +841,22 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
} }
} }
opts = AutoEncryptionOpts( opts = AutoEncryptionOpts(
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
) )
cls.client_encrypted = cls.unmanaged_rs_or_single_client( self.client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard" auto_encryption_opts=opts, uuidRepresentation="standard"
) )
cls.client_encryption = cls.unmanaged_create_client_encryption( self.client_encryption = self.create_client_encryption(
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
) )
@classmethod
def _tearDown_class(cls):
cls.vault.drop()
cls.client.close()
cls.client_encrypted.close()
cls.client_encryption.close()
def setUp(self):
self.listener.reset() self.listener.reset()
def tearDown(self) -> None:
self.vault.drop()
def run_test(self, provider_name): def run_test(self, provider_name):
# Create data key. # Create data key.
master_key: Any = self.MASTER_KEYS[provider_name] master_key: Any = self.MASTER_KEYS[provider_name]
@ -1007,10 +1001,9 @@ class TestViews(EncryptionIntegrationTest):
class TestCorpus(EncryptionIntegrationTest): class TestCorpus(EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def _setup_class(cls): def setUp(self):
super()._setup_class() super().setUp()
@staticmethod @staticmethod
def kms_providers(): def kms_providers():
@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
client_encrypted: MongoClient client_encrypted: MongoClient
listener: OvertCommandListener listener: OvertCommandListener
@classmethod def setUp(self):
def _setup_class(cls): super().setUp()
super()._setup_class()
db = client_context.client.db db = client_context.client.db
cls.coll = db.coll self.coll = db.coll
cls.coll.drop() self.coll.drop()
# Configure the encrypted 'db.coll' collection via jsonSchema. # Configure the encrypted 'db.coll' collection via jsonSchema.
json_schema = json_data("limits", "limits-schema.json") json_schema = json_data("limits", "limits-schema.json")
db.create_collection( db.create_collection(
@ -1207,17 +1199,14 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
coll.insert_one(json_data("limits", "limits-key.json")) coll.insert_one(json_data("limits", "limits-key.json"))
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client_encrypted = cls.unmanaged_rs_or_single_client( self.client_encrypted = self.rs_or_single_client(
auto_encryption_opts=opts, event_listeners=[cls.listener] auto_encryption_opts=opts, event_listeners=[self.listener]
) )
cls.coll_encrypted = cls.client_encrypted.db.coll self.coll_encrypted = self.client_encrypted.db.coll
@classmethod def tearDown(self) -> None:
def _tearDown_class(cls): self.coll_encrypted.drop()
cls.coll_encrypted.drop()
cls.client_encrypted.close()
super()._tearDown_class()
def test_01_insert_succeeds_under_2MiB(self): def test_01_insert_succeeds_under_2MiB(self):
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
@ -1241,7 +1230,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
self.listener.reset() self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
def test_04_bulk_batch_split(self): def test_04_bulk_batch_split(self):
limits_doc = json_data("limits", "limits-doc.json") limits_doc = json_data("limits", "limits-doc.json")
@ -1251,7 +1242,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
doc2.update(limits_doc) doc2.update(limits_doc)
self.listener.reset() self.listener.reset()
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) self.assertEqual(
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
)
def test_05_insert_succeeds_just_under_16MiB(self): def test_05_insert_succeeds_just_under_16MiB(self):
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
@ -1281,15 +1274,12 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
class TestCustomEndpoint(EncryptionIntegrationTest): class TestCustomEndpoint(EncryptionIntegrationTest):
"""Prose tests for creating data keys with a custom endpoint.""" """Prose tests for creating data keys with a custom endpoint."""
@classmethod
@unittest.skipUnless( @unittest.skipUnless(
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
"No environment credentials are set", "No environment credentials are set",
) )
def _setup_class(cls):
super()._setup_class()
def setUp(self): def setUp(self):
super().setUp()
kms_providers = { kms_providers = {
"aws": AWS_CREDS, "aws": AWS_CREDS,
"azure": AZURE_CREDS, "azure": AZURE_CREDS,
@ -1318,10 +1308,6 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
self._kmip_host_error = None self._kmip_host_error = None
self._invalid_host_error = None self._invalid_host_error = None
def tearDown(self):
self.client_encryption.close()
self.client_encryption_invalid.close()
def run_test_expected_success(self, provider_name, master_key): def run_test_expected_success(self, provider_name, master_key):
data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key)
encrypted = self.client_encryption.encrypt( encrypted = self.client_encryption.encrypt(
@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
KEYVAULT_COLL = "datakeys" KEYVAULT_COLL = "datakeys"
client: MongoClient client: MongoClient
def setUp(self): def _setup(self):
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
create_key_vault(keyvault, self.DEK) create_key_vault(keyvault, self.DEK)
def _test_explicit(self, expectation): def _test_explicit(self, expectation):
self._setup()
client_encryption = self.create_client_encryption( client_encryption = self.create_client_encryption(
self.KMS_PROVIDER_MAP, # type: ignore[arg-type] self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
client_context.client, client_context.client,
OPTS, OPTS,
) )
self.addCleanup(client_encryption.close)
ciphertext = client_encryption.encrypt( ciphertext = client_encryption.encrypt(
"string0", "string0",
@ -1517,6 +1503,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
self.assertEqual(client_encryption.decrypt(ciphertext), "string0") self.assertEqual(client_encryption.decrypt(ciphertext), "string0")
def _test_automatic(self, expectation_extjson, payload): def _test_automatic(self, expectation_extjson, payload):
self._setup()
encrypted_db = "db" encrypted_db = "db"
encrypted_coll = "coll" encrypted_coll = "coll"
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
@ -1531,7 +1518,6 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
client = self.rs_or_single_client( client = self.rs_or_single_client(
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
) )
self.addCleanup(client.close)
coll = client.get_database(encrypted_db).get_collection( coll = client.get_database(encrypted_db).get_collection(
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
@ -1553,13 +1539,12 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
def _setup_class(cls): def setUp(self):
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
cls.DEK = json_data(BASE, "custom", "azure-dek.json") self.DEK = json_data(BASE, "custom", "azure-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super()._setup_class() super().setUp()
def test_explicit(self): def test_explicit(self):
return self._test_explicit( return self._test_explicit(
@ -1579,13 +1564,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
@classmethod
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
def _setup_class(cls): def setUp(self):
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
cls.DEK = json_data(BASE, "custom", "gcp-dek.json") self.DEK = json_data(BASE, "custom", "gcp-dek.json")
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
super()._setup_class() super().setUp()
def test_explicit(self): def test_explicit(self):
return self._test_explicit( return self._test_explicit(
@ -1607,6 +1591,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
class TestDeadlockProse(EncryptionIntegrationTest): class TestDeadlockProse(EncryptionIntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.client_test = self.rs_or_single_client( self.client_test = self.rs_or_single_client(
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
) )
@ -1637,7 +1622,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
self.ciphertext = client_encryption.encrypt( self.ciphertext = client_encryption.encrypt(
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
) )
client_encryption.close()
self.client_listener = OvertCommandListener() self.client_listener = OvertCommandListener()
self.topology_listener = TopologyEventListener() self.topology_listener = TopologyEventListener()
@ -1832,6 +1816,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
class TestDecryptProse(EncryptionIntegrationTest): class TestDecryptProse(EncryptionIntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.client = client_context.client self.client = client_context.client
self.client.db.drop_collection("decryption_events") self.client.db.drop_collection("decryption_events")
create_key_vault(self.client.keyvault.datakeys) create_key_vault(self.client.keyvault.datakeys)
@ -2267,6 +2252,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest): class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.client = client_context.client self.client = client_context.client
create_key_vault(self.client.keyvault.datakeys) create_key_vault(self.client.keyvault.datakeys)
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
@ -2608,8 +2594,6 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_indexed"], Binary)
assert isinstance(res["encrypted_unindexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary)
client_encryption.close()
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
class TestRangeQueryProse(EncryptionIntegrationTest): class TestRangeQueryProse(EncryptionIntegrationTest):
@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
mongocryptd_client: MongoClient mongocryptd_client: MongoClient
MONGOCRYPTD_PORT = 27020 MONGOCRYPTD_PORT = 27020
@classmethod
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
def _setup_class(cls):
super()._setup_class()
start_mongocryptd(cls.MONGOCRYPTD_PORT)
@classmethod
def _tearDown_class(cls):
super()._tearDown_class()
def setUp(self) -> None: def setUp(self) -> None:
super().setUp()
start_mongocryptd(self.MONGOCRYPTD_PORT)
self.listener = OvertCommandListener() self.listener = OvertCommandListener()
self.mongocryptd_client = self.simple_client( self.mongocryptd_client = self.simple_client(
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]

View File

@ -33,19 +33,14 @@ from pymongo.write_concern import WriteConcern
class TestSampleShellCommands(IntegrationTest): class TestSampleShellCommands(IntegrationTest):
@classmethod def setUp(self):
def setUpClass(cls): super().setUp()
super().setUpClass() self.db.inventory.drop()
# Run once before any tests run.
cls.db.inventory.drop()
@classmethod
def tearDownClass(cls):
cls.client.drop_database("pymongo_test")
def tearDown(self): def tearDown(self):
# Run after every test. # Run after every test.
self.db.inventory.drop() self.db.inventory.drop()
self.client.drop_database("pymongo_test")
def test_first_three_examples(self): def test_first_three_examples(self):
db = self.db db = self.db

View File

@ -97,6 +97,7 @@ class TestGridFileNoConnect(UnitTest):
class TestGridFile(IntegrationTest): class TestGridFile(IntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
def test_basic(self): def test_basic(self):

View File

@ -75,9 +75,9 @@ class JustRead(threading.Thread):
class TestGridfsNoConnect(unittest.TestCase): class TestGridfsNoConnect(unittest.TestCase):
db: Database db: Database
@classmethod def setUp(self):
def setUpClass(cls): super().setUp()
cls.db = MongoClient(connect=False).pymongo_test self.db = MongoClient(connect=False).pymongo_test
def test_gridfs(self): def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo") self.assertRaises(TypeError, gridfs.GridFS, "foo")
@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest):
fs: gridfs.GridFS fs: gridfs.GridFS
alt: gridfs.GridFS alt: gridfs.GridFS
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.fs = gridfs.GridFS(cls.db)
cls.alt = gridfs.GridFS(cls.db, "alt")
def setUp(self): def setUp(self):
super().setUp()
self.fs = gridfs.GridFS(self.db)
self.alt = gridfs.GridFS(self.db, "alt")
self.cleanup_colls( self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
) )
@ -509,10 +506,9 @@ class TestGridfs(IntegrationTest):
class TestGridfsReplicaSet(IntegrationTest): class TestGridfsReplicaSet(IntegrationTest):
@classmethod
@client_context.require_secondaries_count(1) @client_context.require_secondaries_count(1)
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

View File

@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest):
fs: gridfs.GridFSBucket fs: gridfs.GridFSBucket
alt: gridfs.GridFSBucket alt: gridfs.GridFSBucket
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.fs = gridfs.GridFSBucket(cls.db)
cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt")
def setUp(self): def setUp(self):
super().setUp()
self.fs = gridfs.GridFSBucket(self.db)
self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt")
self.cleanup_colls( self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
) )
@ -479,10 +476,9 @@ class TestGridfs(IntegrationTest):
class TestGridfsBucketReplicaSet(IntegrationTest): class TestGridfsBucketReplicaSet(IntegrationTest):
@classmethod
@client_context.require_secondaries_count(1) @client_context.require_secondaries_count(1)
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

View File

@ -29,7 +29,7 @@ from test.utils import (
wait_until, wait_until,
) )
from pymongo.synchronous.periodic_executor import _EXECUTORS from pymongo.periodic_executor import _EXECUTORS
def unregistered(ref): def unregistered(ref):

View File

@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest):
listener: EventListener listener: EventListener
@classmethod @classmethod
@client_context.require_connection def setUpClass(cls) -> None:
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener() cls.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(
event_listeners=[cls.listener], retryWrites=False
)
@classmethod @client_context.require_connection
def _tearDown_class(cls): def setUp(self) -> None:
cls.client.close() super().setUp()
super()._tearDown_class()
def tearDown(self):
self.listener.reset() self.listener.reset()
super().tearDown() self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False)
def test_started_simple(self): def test_started_simple(self):
self.client.pymongo_test.command("ping") self.client.pymongo_test.command("ping")
@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest):
saved_listeners: Any saved_listeners: Any
@classmethod @classmethod
@client_context.require_connection def setUpClass(cls) -> None:
def _setup_class(cls):
super()._setup_class()
cls.listener = OvertCommandListener() cls.listener = OvertCommandListener()
# We plan to call register(), which internally modifies _LISTENERS. # We plan to call register(), which internally modifies _LISTENERS.
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
monitoring.register(cls.listener) monitoring.register(cls.listener)
cls.client = cls.unmanaged_single_client()
# Get one (authenticated) socket in the pool.
cls.client.pymongo_test.command("ping")
@classmethod
def _tearDown_class(cls):
monitoring._LISTENERS = cls.saved_listeners
cls.client.close()
super()._tearDown_class()
@client_context.require_connection
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.listener.reset() self.listener.reset()
self.client = self.single_client()
# Get one (authenticated) socket in the pool.
self.client.pymongo_test.command("ping")
@classmethod
def tearDownClass(cls):
monitoring._LISTENERS = cls.saved_listeners
def test_simple(self): def test_simple(self):
self.client.pymongo_test.command("ping") self.client.pymongo_test.command("ping")

View File

@ -31,24 +31,16 @@ from pymongo.read_concern import ReadConcern
class TestReadConcern(IntegrationTest): class TestReadConcern(IntegrationTest):
listener: OvertCommandListener listener: OvertCommandListener
@classmethod
@client_context.require_connection @client_context.require_connection
def setUpClass(cls): def setUp(self):
super().setUpClass() super().setUp()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) self.client = self.rs_or_single_client(event_listeners=[self.listener])
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
client_context.client.pymongo_test.create_collection("coll") client_context.client.pymongo_test.create_collection("coll")
@classmethod
def tearDownClass(cls):
cls.client.close()
client_context.client.pymongo_test.drop_collection("coll")
super().tearDownClass()
def tearDown(self): def tearDown(self):
self.listener.reset() client_context.client.pymongo_test.drop_collection("coll")
super().tearDown()
def test_read_concern(self): def test_read_concern(self):
rc = ReadConcern() rc = ReadConcern()

View File

@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest):
RUN_ON_SERVERLESS = True RUN_ON_SERVERLESS = True
deprecation_filter: DeprecationFilter deprecation_filter: DeprecationFilter
@classmethod def setUp(self) -> None:
def _setup_class(cls): super().setUp()
super()._setup_class() self.deprecation_filter = DeprecationFilter()
cls.deprecation_filter = DeprecationFilter()
@classmethod def tearDown(self) -> None:
def _tearDown_class(cls): self.deprecation_filter.stop()
cls.deprecation_filter.stop()
super()._tearDown_class()
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
knobs: client_knobs knobs: client_knobs
@classmethod def setUp(self) -> None:
def _setup_class(cls): super().setUp()
super()._setup_class()
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) self.client = self.rs_or_single_client(retryWrites=True)
cls.db = cls.client.pymongo_test self.db = self.client.pymongo_test
@classmethod def tearDown(self) -> None:
def _tearDown_class(cls): self.knobs.disable()
cls.knobs.disable()
cls.client.close()
super()._tearDown_class()
@client_context.require_no_standalone @client_context.require_no_standalone
def test_actionable_error_message(self): def test_actionable_error_message(self):
@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
listener: OvertCommandListener listener: OvertCommandListener
knobs: client_knobs knobs: client_knobs
@classmethod
@client_context.require_no_mmap @client_context.require_no_mmap
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
cls.listener = OvertCommandListener() self.listener = OvertCommandListener()
cls.client = cls.unmanaged_rs_or_single_client( self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener])
retryWrites=True, event_listeners=[cls.listener] self.db = self.client.pymongo_test
)
cls.db = cls.client.pymongo_test
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
cls.client.close()
super()._tearDown_class()
def setUp(self):
if client_context.is_rs and client_context.test_commands_enabled: if client_context.is_rs and client_context.test_commands_enabled:
self.client.admin.command( self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
@ -210,6 +193,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
self.client.admin.command( self.client.admin.command(
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
) )
self.knobs.disable()
def test_supported_single_statement_no_retry(self): def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener() listener = OvertCommandListener()
@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest):
RUN_ON_SERVERLESS = True RUN_ON_SERVERLESS = True
fail_insert: dict fail_insert: dict
@classmethod
@client_context.require_replica_set @client_context.require_replica_set
@client_context.require_no_mmap @client_context.require_no_mmap
@client_context.require_failCommand_fail_point @client_context.require_failCommand_fail_point
def _setup_class(cls): def setUp(self) -> None:
super()._setup_class() super().setUp()
cls.fail_insert = { self.fail_insert = {
"configureFailPoint": "failCommand", "configureFailPoint": "failCommand",
"mode": {"times": 2}, "mode": {"times": 2},
"data": { "data": {

View File

@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest):
@classmethod @classmethod
@client_context.require_failCommand_fail_point @client_context.require_failCommand_fail_point
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUp(cls)
# Speed up the tests by decreasing the event publish frequency. # Speed up the tests by decreasing the event publish frequency.
cls.knobs = client_knobs( cls.knobs = client_knobs(
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1

View File

@ -82,36 +82,27 @@ class TestSession(IntegrationTest):
client2: MongoClient client2: MongoClient
sensitive_commands: Set[str] sensitive_commands: Set[str]
@classmethod
@client_context.require_sessions @client_context.require_sessions
def _setup_class(cls): def setUp(self):
super()._setup_class() super().setUp()
# Create a second client so we can make sure clients cannot share # Create a second client so we can make sure clients cannot share
# sessions. # sessions.
cls.client2 = cls.unmanaged_rs_or_single_client() self.client2 = self.rs_or_single_client()
# Redact no commands, so we can test user-admin commands have "lsid". # Redact no commands, so we can test user-admin commands have "lsid".
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
monitoring._SENSITIVE_COMMANDS.clear() monitoring._SENSITIVE_COMMANDS.clear()
@classmethod
def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
cls.client2.close()
super()._tearDown_class()
def setUp(self):
self.listener = SessionTestListener() self.listener = SessionTestListener()
self.session_checker_listener = SessionTestListener() self.session_checker_listener = SessionTestListener()
self.client = self.rs_or_single_client( self.client = self.rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener] event_listeners=[self.listener, self.session_checker_listener]
) )
self.addCleanup(self.client.close)
self.db = self.client.pymongo_test self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)} self.initial_lsids = {s["id"] for s in session_ids(self.client)}
def tearDown(self): def tearDown(self):
"""All sessions used in the test must be returned to the pool.""" monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
self.client.drop_database("pymongo_test") self.client.drop_database("pymongo_test")
used_lsids = self.initial_lsids.copy() used_lsids = self.initial_lsids.copy()
for event in self.session_checker_listener.started_events: for event in self.session_checker_listener.started_events:
@ -121,6 +112,8 @@ class TestSession(IntegrationTest):
current_lsids = {s["id"] for s in session_ids(self.client)} current_lsids = {s["id"] for s in session_ids(self.client)}
self.assertLessEqual(used_lsids, current_lsids) self.assertLessEqual(used_lsids, current_lsids)
super().tearDown()
def _test_ops(self, client, *ops): def _test_ops(self, client, *ops):
listener = client.options.event_listeners[0] listener = client.options.event_listeners[0]
@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest):
listener: SessionTestListener listener: SessionTestListener
client: MongoClient client: MongoClient
@classmethod
def _setup_class(cls):
cls.listener = SessionTestListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def _tearDown_class(cls):
cls.client.close()
@client_context.require_sessions @client_context.require_sessions
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.listener = SessionTestListener()
self.client = self.rs_or_single_client(event_listeners=[self.listener])
@client_context.require_no_standalone @client_context.require_no_standalone
def test_core(self): def test_core(self):

View File

@ -105,6 +105,7 @@ class Update(threading.Thread):
class TestThreads(IntegrationTest): class TestThreads(IntegrationTest):
def setUp(self): def setUp(self):
super().setUp()
self.db = self.client.pymongo_test self.db = self.client.pymongo_test
def test_threading(self): def test_threading(self):

View File

@ -395,19 +395,12 @@ class PatchSessionTimeout:
class TestTransactionsConvenientAPI(TransactionsBase): class TestTransactionsConvenientAPI(TransactionsBase):
@classmethod def setUp(self) -> None:
def _setup_class(cls): super().setUp()
super()._setup_class() self.mongos_clients = []
cls.mongos_clients = []
if client_context.supports_transactions(): if client_context.supports_transactions():
for address in client_context.mongoses: for address in client_context.mongoses:
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
@classmethod
def _tearDown_class(cls):
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def _set_fail_point(self, client, command_args): def _set_fail_point(self, client, command_args):
cmd = {"configureFailPoint": "failCommand"} cmd = {"configureFailPoint": "failCommand"}

View File

@ -114,10 +114,9 @@ class TestMypyFails(unittest.TestCase):
class TestPymongo(IntegrationTest): class TestPymongo(IntegrationTest):
coll: Collection coll: Collection
@classmethod def setUp(self):
def setUpClass(cls): super().setUp()
super().setUpClass() self.coll = self.client.test.test
cls.coll = cls.client.test.test
def test_insert_find(self) -> None: def test_insert_find(self) -> None:
doc = {"my": "doc"} doc = {"my": "doc"}

View File

@ -304,7 +304,6 @@ class EntityMapUtil:
kwargs["h"] = uri kwargs["h"] = uri
client = self.test.rs_or_single_client(**kwargs) client = self.test.rs_or_single_client(**kwargs)
self[spec["id"]] = client self[spec["id"]] = client
self.test.addCleanup(client.close)
return return
elif entity_type == "database": elif entity_type == "database":
client = self[spec["client"]] client = self[spec["client"]]
@ -479,31 +478,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
db.create_collection(coll_name, write_concern=wc, **opts) db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod @classmethod
def _setup_class(cls): def setUpClass(cls) -> None:
# super call creates internal client cls.client
super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not cls.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if client_context.storage_engine == "mmapv1":
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
cls.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
cls.mongos_clients = []
if (
client_context.supports_transactions()
and not client_context.load_balancer
and not client_context.serverless
):
for address in client_context.mongoses:
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs( cls.knobs = client_knobs(
heartbeat_frequency=0.1, heartbeat_frequency=0.1,
@ -514,17 +489,36 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
cls.knobs.enable() cls.knobs.enable()
@classmethod @classmethod
def _tearDown_class(cls): def tearDownClass(cls) -> None:
cls.knobs.disable() cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def setUp(self): def setUp(self):
# super call creates internal client cls.client
super().setUp() super().setUp()
# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not self.should_run_on(run_on_spec):
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
# add any special-casing for skipping tests here
if client_context.storage_engine == "mmapv1":
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
self.TEST_PATH
):
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
# Handle mongos_clients for transactions tests.
self.mongos_clients = []
if (
client_context.supports_transactions()
and not client_context.load_balancer
and not client_context.serverless
):
for address in client_context.mongoses:
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
# process schemaVersion # process schemaVersion
# note: we check major schema version during class generation # note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC["schemaVersion"]) version = Version.from_string(self.TEST_SPEC["schemaVersion"])
self.assertLessEqual( self.assertLessEqual(
version, version,
@ -1026,7 +1020,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
) )
client = self.single_client("{}:{}".format(*session._pinned_address)) client = self.single_client("{}:{}".format(*session._pinned_address))
self.addCleanup(client.close)
self.__set_fail_point(client=client, command_args=spec["failPoint"]) self.__set_fail_point(client=client, command_args=spec["failPoint"])
def _testOperation_createEntities(self, spec): def _testOperation_createEntities(self, spec):

View File

@ -99,6 +99,12 @@ class BaseListener:
"""Wait for a number of events to be published, or fail.""" """Wait for a number of events to be published, or fail."""
wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)")
async def async_wait_for_event(self, event, count):
"""Wait for a number of events to be published, or fail."""
await async_wait_until(
lambda: self.event_count(event) >= count, f"find {count} {event} event(s)"
)
class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
def connection_created(self, event): def connection_created(self, event):
@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10):
start = time.time() start = time.time()
interval = min(float(timeout) / 100, 0.1) interval = min(float(timeout) / 100, 0.1)
while True: while True:
retval = await predicate() if iscoroutinefunction(predicate):
retval = await predicate()
else:
retval = predicate()
if retval: if retval:
return retval return retval

View File

@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest):
knobs: client_knobs knobs: client_knobs
listener: EventListener listener: EventListener
@classmethod def setUp(self) -> None:
def _setup_class(cls): super().setUp()
super()._setup_class() self.mongos_clients = []
cls.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency. # Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
cls.knobs.enable() self.knobs.enable()
@classmethod
def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super()._tearDown_class()
def setUp(self):
super().setUp()
self.targets = {} self.targets = {}
self.listener = None # type: ignore self.listener = None # type: ignore
self.pool_listener = None self.pool_listener = None
self.server_listener = None self.server_listener = None
self.maxDiff = None self.maxDiff = None
def tearDown(self) -> None:
self.knobs.disable()
def _set_fail_point(self, client, command_args): def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")]) cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args) cmd.update(command_args)
@ -697,8 +689,6 @@ class SpecRunner(IntegrationTest):
self.listener = listener self.listener = listener
self.pool_listener = pool_listener self.pool_listener = pool_listener
self.server_listener = server_listener self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addCleanup(client.close)
# Create session0 and session1. # Create session0 and session1.
sessions = {} sessions = {}

View File

@ -110,6 +110,13 @@ replacements = {
"async_set_fail_point": "set_fail_point", "async_set_fail_point": "set_fail_point",
"async_ensure_all_connected": "ensure_all_connected", "async_ensure_all_connected": "ensure_all_connected",
"async_repl_set_step_down": "repl_set_step_down", "async_repl_set_step_down": "repl_set_step_down",
"AsyncPeriodicExecutor": "PeriodicExecutor",
"async_wait_for_event": "wait_for_event",
"pymongo_server_monitor_task": "pymongo_server_monitor_thread",
"pymongo_server_rtt_task": "pymongo_server_rtt_thread",
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
} }
docstring_replacements: dict[tuple[str, str], str] = { docstring_replacements: dict[tuple[str, str], str] = {
@ -130,8 +137,6 @@ docstring_removals: set[str] = {
".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release." ".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release."
} }
type_replacements = {"_Condition": "threading.Condition"}
import_replacements = {"test.synchronous": "test"} import_replacements = {"test.synchronous": "test"}
_pymongo_base = "./pymongo/asynchronous/" _pymongo_base = "./pymongo/asynchronous/"
@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None:
lines = translate_async_sleeps(lines) lines = translate_async_sleeps(lines)
if file in docstring_translate_files: if file in docstring_translate_files:
lines = translate_docstrings(lines) lines = translate_docstrings(lines)
translate_locks(lines)
translate_types(lines)
if file in sync_test_files: if file in sync_test_files:
translate_imports(lines) translate_imports(lines)
f.seek(0) f.seek(0)
@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]:
return lines return lines
def translate_locks(lines: list[str]) -> list[str]:
lock_lines = [line for line in lines if "_Lock(" in line]
cond_lines = [line for line in lines if "_Condition(" in line]
for line in lock_lines:
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)
lines[index] = line.replace(old, res[1])
for line in cond_lines:
res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)
lines[index] = line.replace(old, res[1])
return lines
def translate_types(lines: list[str]) -> list[str]:
for k, v in type_replacements.items():
matches = [line for line in lines if k in line and "import" not in line]
for line in matches:
index = lines.index(line)
lines[index] = line.replace(k, v)
return lines
def translate_imports(lines: list[str]) -> list[str]: def translate_imports(lines: list[str]) -> list[str]:
for k, v in import_replacements.items(): for k, v in import_replacements.items():
matches = [line for line in lines if k in line and "import" in line] matches = [line for line in lines if k in line and "import" in line]