Async client uses tasks instead of threads
PYTHON-4725 - Async client should use tasks for SDAM instead of threads PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition PYTHON-4941 - Synchronous unified test runner being used in asynchronous tests PYTHON-4843 - Async test suite should use a single event loop PYTHON-4945 - Fix test cleanups for mongoses Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
This commit is contained in:
parent
9b5c0981d9
commit
0e8d70457f
@ -281,7 +281,7 @@ functions:
|
||||
"run tests":
|
||||
- command: subprocess.exec
|
||||
params:
|
||||
include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
|
||||
include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
|
||||
binary: bash
|
||||
working_dir: "src"
|
||||
args:
|
||||
|
||||
@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
||||
2) License Notice for _asyncio_lock.py
|
||||
-----------------------------------------
|
||||
|
||||
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||
otherwise using this software ("Python") in source or binary form and
|
||||
its associated documentation.
|
||||
|
||||
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||
distribute, and otherwise use Python alone or in any derivative version,
|
||||
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
|
||||
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||
|
||||
3. In the event Licensee prepares a derivative work that is based on
|
||||
or incorporates Python or any part thereof, and wants to make
|
||||
the derivative work available to others as provided herein, then
|
||||
Licensee hereby agrees to include in any such work a brief summary of
|
||||
the changes made to Python.
|
||||
|
||||
4. PSF is making Python available to Licensee on an "AS IS"
|
||||
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||
|
||||
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||
|
||||
6. This License Agreement will automatically terminate upon a material
|
||||
breach of its terms and conditions.
|
||||
|
||||
7. Nothing in this License Agreement shall be deemed to create any
|
||||
relationship of agency, partnership, or joint venture between PSF and
|
||||
Licensee. This License Agreement does not grant permission to use PSF
|
||||
trademarks or trade name in a trademark sense to endorse or promote
|
||||
products or services of Licensee, or any third party.
|
||||
|
||||
8. By copying, installing or otherwise using Python, Licensee
|
||||
agrees to be bound by the terms and conditions of this License
|
||||
Agreement.
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any
|
||||
purpose with or without fee is hereby granted.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
||||
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
||||
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
||||
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
||||
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
||||
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
||||
PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
309
pymongo/_asyncio_lock.py
Normal file
309
pymongo/_asyncio_lock.py
Normal file
@ -0,0 +1,309 @@
|
||||
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
|
||||
|
||||
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
|
||||
to port 3.13 fixes to older versions of Python.
|
||||
Can be removed once we drop Python 3.12 support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import threading
|
||||
from asyncio import events, exceptions
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
class _LoopBoundMixin:
|
||||
_loop = None
|
||||
|
||||
def _get_loop(self) -> Any:
|
||||
loop = events._get_running_loop()
|
||||
|
||||
if self._loop is None:
|
||||
with _global_lock:
|
||||
if self._loop is None:
|
||||
self._loop = loop
|
||||
if loop is not self._loop:
|
||||
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||
return loop
|
||||
|
||||
|
||||
class _ContextManagerMixin:
|
||||
async def __aenter__(self) -> None:
|
||||
await self.acquire() # type: ignore[attr-defined]
|
||||
# We have no use for the "as ..." clause in the with
|
||||
# statement for locks.
|
||||
return
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class Lock(_ContextManagerMixin, _LoopBoundMixin):
|
||||
"""Primitive lock objects.
|
||||
|
||||
A primitive lock is a synchronization primitive that is not owned
|
||||
by a particular task when locked. A primitive lock is in one
|
||||
of two states, 'locked' or 'unlocked'.
|
||||
|
||||
It is created in the unlocked state. It has two basic methods,
|
||||
acquire() and release(). When the state is unlocked, acquire()
|
||||
changes the state to locked and returns immediately. When the
|
||||
state is locked, acquire() blocks until a call to release() in
|
||||
another task changes it to unlocked, then the acquire() call
|
||||
resets it to locked and returns. The release() method should only
|
||||
be called in the locked state; it changes the state to unlocked
|
||||
and returns immediately. If an attempt is made to release an
|
||||
unlocked lock, a RuntimeError will be raised.
|
||||
|
||||
When more than one task is blocked in acquire() waiting for
|
||||
the state to turn to unlocked, only one task proceeds when a
|
||||
release() call resets the state to unlocked; successive release()
|
||||
calls will unblock tasks in FIFO order.
|
||||
|
||||
Locks also support the asynchronous context management protocol.
|
||||
'async with lock' statement should be used.
|
||||
|
||||
Usage:
|
||||
|
||||
lock = Lock()
|
||||
...
|
||||
await lock.acquire()
|
||||
try:
|
||||
...
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
Context manager usage:
|
||||
|
||||
lock = Lock()
|
||||
...
|
||||
async with lock:
|
||||
...
|
||||
|
||||
Lock objects can be tested for locking state:
|
||||
|
||||
if not lock.locked():
|
||||
await lock.acquire()
|
||||
else:
|
||||
# lock is acquired
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._waiters: Optional[collections.deque] = None
|
||||
self._locked = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res = super().__repr__()
|
||||
extra = "locked" if self._locked else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||
return f"<{res[1:-1]} [{extra}]>"
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Return True if lock is acquired."""
|
||||
return self._locked
|
||||
|
||||
async def acquire(self) -> bool:
|
||||
"""Acquire a lock.
|
||||
|
||||
This method blocks until the lock is unlocked, then sets it to
|
||||
locked and returns True.
|
||||
"""
|
||||
# Implement fair scheduling, where thread always waits
|
||||
# its turn. Jumping the queue if all are cancelled is an optimization.
|
||||
if not self._locked and (
|
||||
self._waiters is None or all(w.cancelled() for w in self._waiters)
|
||||
):
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
if self._waiters is None:
|
||||
self._waiters = collections.deque()
|
||||
fut = self._get_loop().create_future()
|
||||
self._waiters.append(fut)
|
||||
|
||||
try:
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
except exceptions.CancelledError:
|
||||
# Currently the only exception designed be able to occur here.
|
||||
|
||||
# Ensure the lock invariant: If lock is not claimed (or about
|
||||
# to be claimed by us) and there is a Task in waiters,
|
||||
# ensure that the Task at the head will run.
|
||||
if not self._locked:
|
||||
self._wake_up_first()
|
||||
raise
|
||||
|
||||
# assert self._locked is False
|
||||
self._locked = True
|
||||
return True
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release a lock.
|
||||
|
||||
When the lock is locked, reset it to unlocked, and return.
|
||||
If any other tasks are blocked waiting for the lock to become
|
||||
unlocked, allow exactly one of them to proceed.
|
||||
|
||||
When invoked on an unlocked lock, a RuntimeError is raised.
|
||||
|
||||
There is no return value.
|
||||
"""
|
||||
if self._locked:
|
||||
self._locked = False
|
||||
self._wake_up_first()
|
||||
else:
|
||||
raise RuntimeError("Lock is not acquired.")
|
||||
|
||||
def _wake_up_first(self) -> None:
|
||||
"""Ensure that the first waiter will wake up."""
|
||||
if not self._waiters:
|
||||
return
|
||||
try:
|
||||
fut = next(iter(self._waiters))
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
# .done() means that the waiter is already set to wake up.
|
||||
if not fut.done():
|
||||
fut.set_result(True)
|
||||
|
||||
|
||||
class Condition(_ContextManagerMixin, _LoopBoundMixin):
|
||||
"""Asynchronous equivalent to threading.Condition.
|
||||
|
||||
This class implements condition variable objects. A condition variable
|
||||
allows one or more tasks to wait until they are notified by another
|
||||
task.
|
||||
|
||||
A new Lock object is created and used as the underlying lock.
|
||||
"""
|
||||
|
||||
def __init__(self, lock: Optional[Lock] = None) -> None:
|
||||
if lock is None:
|
||||
lock = Lock()
|
||||
|
||||
self._lock = lock
|
||||
# Export the lock's locked(), acquire() and release() methods.
|
||||
self.locked = lock.locked
|
||||
self.acquire = lock.acquire
|
||||
self.release = lock.release
|
||||
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res = super().__repr__()
|
||||
extra = "locked" if self.locked() else "unlocked"
|
||||
if self._waiters:
|
||||
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||
return f"<{res[1:-1]} [{extra}]>"
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Wait until notified.
|
||||
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
if not self.locked():
|
||||
raise RuntimeError("cannot wait on un-acquired lock")
|
||||
|
||||
fut = self._get_loop().create_future()
|
||||
self.release()
|
||||
try:
|
||||
try:
|
||||
self._waiters.append(fut)
|
||||
try:
|
||||
await fut
|
||||
return True
|
||||
finally:
|
||||
self._waiters.remove(fut)
|
||||
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except exceptions.CancelledError as e:
|
||||
err = e
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self._notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Any) -> Coroutine:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""By default, wake up one task waiting on this condition, if any.
|
||||
If the calling task has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up n of the tasks waiting for the condition
|
||||
variable; if fewer than n are waiting, they are all awoken.
|
||||
|
||||
Note: an awakened task does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
if not self.locked():
|
||||
raise RuntimeError("cannot notify on un-acquired lock")
|
||||
self._notify(n)
|
||||
|
||||
def _notify(self, n: int) -> None:
|
||||
idx = 0
|
||||
for fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if not fut.done():
|
||||
idx += 1
|
||||
fut.set_result(False)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Wake up all tasks waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting tasks instead of one. If the
|
||||
calling task has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
49
pymongo/_asyncio_task.py
Normal file
49
pymongo/_asyncio_task.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
|
||||
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
|
||||
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
|
||||
class _Task(asyncio.Task):
|
||||
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
|
||||
super().__init__(coro, name=name)
|
||||
self._cancel_requests = 0
|
||||
asyncio._register_task(self)
|
||||
|
||||
def cancel(self, msg: Optional[str] = None) -> bool:
|
||||
self._cancel_requests += 1
|
||||
return super().cancel(msg=msg)
|
||||
|
||||
def uncancel(self) -> int:
|
||||
if self._cancel_requests > 0:
|
||||
self._cancel_requests -= 1
|
||||
return self._cancel_requests
|
||||
|
||||
def cancelling(self) -> int:
|
||||
return self._cancel_requests
|
||||
|
||||
|
||||
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
|
||||
if sys.version_info >= (3, 11):
|
||||
return asyncio.create_task(coro, name=name)
|
||||
return _Task(coro, name=name)
|
||||
@ -476,7 +476,6 @@ class _AsyncClientBulk:
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
|
||||
@ -45,7 +45,7 @@ from pymongo.common import (
|
||||
)
|
||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
from pymongo.lock import _async_create_lock
|
||||
from pymongo.message import (
|
||||
_CursorAddress,
|
||||
_GetMore,
|
||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
||||
def __init__(self, conn: AsyncConnection, more_to_come: bool):
|
||||
self.conn: Optional[AsyncConnection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self._alock = _ALock(_create_lock())
|
||||
self._lock = _async_create_lock()
|
||||
|
||||
def update_exhaust(self, more_to_come: bool) -> None:
|
||||
self.more_to_come = more_to_come
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
conn.close()
|
||||
except (PyMongoError, MongoCryptError):
|
||||
raise # Propagate pymongo errors directly.
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as error:
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
await database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ access:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
@ -59,8 +60,8 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
||||
from pymongo.asynchronous import client_session, database, periodic_executor
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo.asynchronous import client_session, database
|
||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
@ -82,7 +83,11 @@ from pymongo.errors import (
|
||||
WaitQueueTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
|
||||
from pymongo.lock import (
|
||||
_HAS_REGISTER_AT_FORK,
|
||||
_async_create_lock,
|
||||
_release_locks,
|
||||
)
|
||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _ALock(_create_lock())
|
||||
self._lock = _async_create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
|
||||
self._event_listeners = options.pool_options._event_listeners
|
||||
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
await AsyncMongoClient._process_periodic_tasks(client)
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=common.KILL_CURSOR_FREQUENCY,
|
||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||
target=target,
|
||||
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
address=address,
|
||||
)
|
||||
|
||||
async with operation.conn_mgr._alock:
|
||||
async with operation.conn_mgr._lock:
|
||||
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return await server.run_operation(
|
||||
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
async with conn_mgr._alock:
|
||||
async with conn_mgr._lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
assert address is not None
|
||||
assert conn_mgr.conn is not None
|
||||
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
await self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
await self._process_kill_cursors()
|
||||
await self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -16,20 +16,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _async_create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
@ -76,7 +76,7 @@ class MonitorBase:
|
||||
await monitor._run() # type:ignore[attr-defined]
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=interval, min_interval=min_interval, target=target, name=name
|
||||
)
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
await self._executor.join(timeout)
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_monitor_thread",
|
||||
"pymongo_server_monitor_task",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
finally:
|
||||
if self._executor._stopped:
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
async def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
@ -252,7 +255,9 @@ class Monitor(MonitorBase):
|
||||
except (OperationFailure, NotPrimaryError) as exc:
|
||||
# Update max cluster time even when hello fails.
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
await self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
|
||||
await self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
raise
|
||||
self._rtt_monitor.reset()
|
||||
await self._rtt_monitor.reset()
|
||||
# Server type defaults to Unknown.
|
||||
return ServerDescription(address, error=error)
|
||||
|
||||
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
|
||||
self._conn_id = conn.id
|
||||
response, round_trip_time = await self._check_with_socket(conn)
|
||||
if not response.awaitable:
|
||||
self._rtt_monitor.add_sample(round_trip_time)
|
||||
await self._rtt_monitor.add_sample(round_trip_time)
|
||||
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
avg_rtt, min_rtt = await self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
|
||||
"""
|
||||
super().__init__(
|
||||
topology,
|
||||
"pymongo_server_rtt_thread",
|
||||
"pymongo_server_rtt_task",
|
||||
topology_settings.heartbeat_frequency,
|
||||
common.MIN_HEARTBEAT_INTERVAL,
|
||||
)
|
||||
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
|
||||
self._pool = pool
|
||||
self._moving_average = MovingAverage()
|
||||
self._moving_min = MovingMinimum()
|
||||
self._lock = _create_lock()
|
||||
self._lock = _async_create_lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
await self._pool.reset()
|
||||
|
||||
def add_sample(self, sample: float) -> None:
|
||||
async def add_sample(self, sample: float) -> None:
|
||||
"""Add a RTT sample."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
self._moving_average.add_sample(sample)
|
||||
self._moving_min.add_sample(sample)
|
||||
|
||||
def get(self) -> tuple[Optional[float], float]:
|
||||
async def get(self) -> tuple[Optional[float], float]:
|
||||
"""Get the calculated average, or None if no samples yet and the min."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
return self._moving_average.get(), self._moving_min.get()
|
||||
|
||||
def reset(self) -> None:
|
||||
async def reset(self) -> None:
|
||||
"""Reset the average RTT."""
|
||||
with self._lock:
|
||||
async with self._lock:
|
||||
self._moving_average.reset()
|
||||
self._moving_min.reset()
|
||||
|
||||
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
|
||||
# heartbeat protocol (MongoDB 4.4+).
|
||||
# XXX: Skip check if the server is unknown?
|
||||
rtt = await self._ping()
|
||||
self.add_sample(rtt)
|
||||
await self.add_sample(rtt)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await self._pool.reset()
|
||||
|
||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
if _IS_SYNC:
|
||||
atexit.register(_shutdown_resources)
|
||||
|
||||
@ -1,219 +0,0 @@
|
||||
# Copyright 2014-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Run a target function on a background thread."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo.lock import _ALock, _create_lock
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background thread.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying thread.
|
||||
"""
|
||||
# threading.Event and its internal condition variable are expensive
|
||||
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
|
||||
# The executor's design is constrained by several Python issues, see
|
||||
# "periodic_executor.rst" in this repository.
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
self._thread_will_exit = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
# The default asyncio loop implementation on Windows
|
||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
||||
# We explicitly use a different loop implementation here to prevent that issue
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.SelectorEventLoop()
|
||||
try:
|
||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
Not safe to call from multiple threads at once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._thread_will_exit:
|
||||
# If the background thread has read self._stopped as True
|
||||
# there is a chance that it has not yet exited. The call to
|
||||
# join should not block indefinitely because there is no
|
||||
# other work done outside the while loop in self._run.
|
||||
try:
|
||||
assert self._thread is not None
|
||||
self._thread.join()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
self._thread_will_exit = False
|
||||
self._stopped = False
|
||||
started: Any = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
except ReferenceError:
|
||||
# Thread terminated.
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
# Mitigation to RuntimeError firing when thread starts on shutdown
|
||||
# https://github.com/python/cpython/issues/114570
|
||||
try:
|
||||
thread.start()
|
||||
except RuntimeError as e:
|
||||
if "interpreter shutdown" in str(e) or sys.is_finalizing():
|
||||
self._thread = None
|
||||
return
|
||||
raise
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._thread is not None:
|
||||
try:
|
||||
self._thread.join(timeout)
|
||||
except (ReferenceError, RuntimeError):
|
||||
# Thread already terminated, or not yet started.
|
||||
pass
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _should_stop(self) -> bool:
|
||||
async with self._lock:
|
||||
if self._stopped:
|
||||
self._thread_will_exit = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not await self._should_stop():
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
async with self._lock:
|
||||
self._stopped = True
|
||||
self._thread_will_exit = True
|
||||
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
|
||||
# an executor is kept alive by a strong reference from its thread and perhaps
|
||||
# from other objects. When the thread dies and all other referrers are freed,
|
||||
# the executor is freed and removed from _EXECUTORS. If any threads are
|
||||
# running when the interpreter begins to shut down, we try to halt and join
|
||||
# them to avoid spurious errors.
|
||||
_EXECUTORS = set()
|
||||
|
||||
|
||||
def _register_executor(executor: PeriodicExecutor) -> None:
|
||||
ref = weakref.ref(executor, _on_executor_deleted)
|
||||
_EXECUTORS.add(ref)
|
||||
|
||||
|
||||
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
|
||||
_EXECUTORS.remove(ref)
|
||||
|
||||
|
||||
def _shutdown_executors() -> None:
|
||||
if _EXECUTORS is None:
|
||||
return
|
||||
|
||||
# Copy the set. Stopping threads has the side effect of removing executors.
|
||||
executors = list(_EXECUTORS)
|
||||
|
||||
# First signal all executors to close...
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.close()
|
||||
|
||||
# ...then try to join them.
|
||||
for ref in executors:
|
||||
executor = ref()
|
||||
if executor:
|
||||
executor.join(1)
|
||||
|
||||
executor = None
|
||||
@ -23,7 +23,6 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import (
|
||||
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.lock import (
|
||||
_async_cond_wait,
|
||||
_async_create_condition,
|
||||
_async_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
return await condition.wait(timeout)
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
@ -706,6 +704,8 @@ class AsyncConnection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -992,8 +992,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
_lock = _create_lock()
|
||||
self.lock = _ALock(_lock)
|
||||
self.lock = _async_create_lock()
|
||||
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1019,7 +1019,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = _ACondition(threading.Condition(_lock))
|
||||
self.size_cond = _async_create_condition(self.lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1027,7 +1027,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
|
||||
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
@ -1466,7 +1466,8 @@ class Pool:
|
||||
async with self.size_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||
while not (self.requests < self.max_pool_size):
|
||||
if not await _cond_wait(self.size_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not await _async_cond_wait(self.size_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.requests < self.max_pool_size:
|
||||
@ -1489,7 +1490,8 @@ class Pool:
|
||||
async with self._max_connecting_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||
while not (self.conns or self._pending < self._max_connecting):
|
||||
if not await _cond_wait(self._max_connecting_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not await _async_cond_wait(self._max_connecting_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.conns or self._pending < self._max_connecting:
|
||||
|
||||
@ -27,8 +27,7 @@ import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared
|
||||
from pymongo.asynchronous import periodic_executor
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.asynchronous.monitor import SrvMonitor
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
@ -44,7 +43,11 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
||||
from pymongo.lock import (
|
||||
_async_cond_wait,
|
||||
_async_create_condition,
|
||||
_async_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -170,9 +173,10 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
_lock = _create_lock()
|
||||
self._lock = _ALock(_lock)
|
||||
self._condition = _ACondition(self._settings.condition_class(_lock))
|
||||
self._lock = _async_create_lock()
|
||||
self._condition = _async_create_condition(
|
||||
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||
)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
@ -185,7 +189,7 @@ class Topology:
|
||||
async def target() -> bool:
|
||||
return process_events_queue(weak)
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||
interval=common.EVENTS_QUEUE_FREQUENCY,
|
||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||
target=target,
|
||||
@ -354,7 +358,7 @@ class Topology:
|
||||
# change, or for a timeout. We won't miss any changes that
|
||||
# came after our most recent apply_selector call, since we've
|
||||
# held the lock until now.
|
||||
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
||||
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||
self._description.check_compatible()
|
||||
now = time.monotonic()
|
||||
server_descriptions = self._description.apply_selector(
|
||||
@ -654,7 +658,7 @@ class Topology:
|
||||
"""Wake all monitors, wait for at least one to check its server."""
|
||||
async with self._lock:
|
||||
self._request_check_all()
|
||||
await self._condition.wait(wait_time)
|
||||
await _async_cond_wait(self._condition, wait_time)
|
||||
|
||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||
"""Return a list of all data-bearing servers.
|
||||
@ -742,7 +746,7 @@ class Topology:
|
||||
if self._publish_server or self._publish_tp:
|
||||
# Make sure the events executor thread is fully closed before publishing the remaining events
|
||||
self.__events_executor.close()
|
||||
self.__events_executor.join(1)
|
||||
await self.__events_executor.join(1)
|
||||
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
|
||||
|
||||
@property
|
||||
|
||||
241
pymongo/lock.py
241
pymongo/lock.py
@ -11,15 +11,20 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Internal helpers for lock and condition coordination primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from asyncio import wait_for
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
import pymongo._asyncio_lock
|
||||
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
|
||||
# in older versions of Python
|
||||
if sys.version_info >= (3, 13):
|
||||
Lock = asyncio.Lock
|
||||
Condition = asyncio.Condition
|
||||
else:
|
||||
Lock = pymongo._asyncio_lock.Lock
|
||||
Condition = pymongo._asyncio_lock.Condition
|
||||
|
||||
|
||||
def _create_lock() -> threading.Lock:
|
||||
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
||||
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
def _async_create_lock() -> Lock:
|
||||
"""Represents an asyncio.Lock."""
|
||||
return Lock()
|
||||
|
||||
|
||||
def _create_condition(
|
||||
lock: threading.Lock, condition_class: Optional[Any] = None
|
||||
) -> threading.Condition:
|
||||
"""Represents a threading.Condition."""
|
||||
if condition_class:
|
||||
return condition_class(lock)
|
||||
return threading.Condition(lock)
|
||||
|
||||
|
||||
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
|
||||
"""Represents an asyncio.Condition."""
|
||||
if condition_class:
|
||||
return condition_class(lock)
|
||||
return Condition(lock)
|
||||
|
||||
|
||||
def _release_locks() -> None:
|
||||
# Completed the fork, reset all the locks in the child.
|
||||
for lock in _forkable_locks:
|
||||
@ -46,202 +81,12 @@ def _release_locks() -> None:
|
||||
lock.release()
|
||||
|
||||
|
||||
# Needed only for synchro.py compat.
|
||||
def _Lock(lock: threading.Lock) -> threading.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
class _ALock:
|
||||
__slots__ = ("_lock",)
|
||||
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
return self._lock.acquire(blocking=blocking, timeout=timeout)
|
||||
|
||||
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
async def __aenter__(self) -> _ALock:
|
||||
await self.a_acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ALock:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
|
||||
def _safe_set_result(fut: asyncio.Future) -> None:
|
||||
# Ensure the future hasn't been cancelled before calling set_result.
|
||||
if not fut.done():
|
||||
fut.set_result(False)
|
||||
|
||||
|
||||
class _ACondition:
|
||||
__slots__ = ("_condition", "_waiters")
|
||||
|
||||
def __init__(self, condition: threading.Condition) -> None:
|
||||
self._condition = condition
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
acquired = self._condition.acquire(blocking=False)
|
||||
if acquired:
|
||||
return True
|
||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
if not blocking:
|
||||
return False
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Wait until notified.
|
||||
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
self._waiters.append((loop, fut))
|
||||
self.release()
|
||||
async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
return True
|
||||
return await wait_for(condition.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return False # Return false on timeout for sync pool compat.
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except asyncio.exceptions.CancelledError as e:
|
||||
err = e
|
||||
return False
|
||||
|
||||
self._waiters.remove((loop, fut))
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self.notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
"""By default, wake up one coroutine waiting on this condition, if any.
|
||||
If the calling coroutine has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up at most n of the coroutines waiting for the
|
||||
condition variable; it is a no-op if no coroutines are waiting.
|
||||
|
||||
Note: an awakened coroutine does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
idx = 0
|
||||
to_remove = []
|
||||
for loop, fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if fut.done():
|
||||
continue
|
||||
|
||||
try:
|
||||
loop.call_soon_threadsafe(_safe_set_result, fut)
|
||||
except RuntimeError:
|
||||
# Loop was closed, ignore.
|
||||
to_remove.append((loop, fut))
|
||||
continue
|
||||
|
||||
idx += 1
|
||||
|
||||
for waiter in to_remove:
|
||||
self._waiters.remove(waiter)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
"""Wake up all threads waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting threads instead of one. If the
|
||||
calling thread has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Only needed for tests in test_locks."""
|
||||
return self._condition._lock.locked() # type: ignore[attr-defined]
|
||||
|
||||
def release(self) -> None:
|
||||
self._condition.release()
|
||||
|
||||
async def __aenter__(self) -> _ACondition:
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
def __enter__(self) -> _ACondition:
|
||||
self._condition.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
self.release()
|
||||
def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
|
||||
return condition.wait(timeout)
|
||||
|
||||
@ -29,6 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from pymongo import _csot, ssl_support
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.errors import _OperationCancelled
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
@ -259,18 +260,19 @@ async def async_receive_data(
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
else:
|
||||
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
tasks = [read_task, cancellation_task]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
@ -23,9 +23,102 @@ import time
|
||||
import weakref
|
||||
from typing import Any, Optional
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.lock import _create_lock
|
||||
|
||||
_IS_SYNC = True
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncPeriodicExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
interval: float,
|
||||
min_interval: float,
|
||||
target: Any,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Run a target function periodically on a background task.
|
||||
|
||||
If the target's return value is false, the executor stops.
|
||||
|
||||
:param interval: Seconds between calls to `target`.
|
||||
:param min_interval: Minimum seconds between calls if `wake` is
|
||||
called very often.
|
||||
:param target: A function.
|
||||
:param name: A name to give the underlying task.
|
||||
"""
|
||||
self._event = False
|
||||
self._interval = interval
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect."""
|
||||
self._stopped = False
|
||||
|
||||
if self._task is None or (
|
||||
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
|
||||
):
|
||||
self._task = create_task(self._run(), name=self._name)
|
||||
|
||||
def close(self, dummy: Any = None) -> None:
|
||||
"""Stop. To restart, call open().
|
||||
|
||||
The dummy parameter allows an executor's close method to be a weakref
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
|
||||
except asyncio.TimeoutError:
|
||||
# Task timed out
|
||||
pass
|
||||
except asyncio.exceptions.CancelledError:
|
||||
# Task was already finished, or not yet started.
|
||||
raise
|
||||
|
||||
def wake(self) -> None:
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval: int) -> None:
|
||||
self._interval = new_interval
|
||||
|
||||
def skip_sleep(self) -> None:
|
||||
self._skip_sleep = True
|
||||
|
||||
async def _run(self) -> None:
|
||||
while not self._stopped:
|
||||
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
|
||||
raise asyncio.CancelledError
|
||||
try:
|
||||
if not await self._target():
|
||||
self._stopped = True
|
||||
break
|
||||
except BaseException:
|
||||
self._stopped = True
|
||||
raise
|
||||
|
||||
if self._skip_sleep:
|
||||
self._skip_sleep = False
|
||||
else:
|
||||
deadline = time.monotonic() + self._interval
|
||||
while not self._stopped and time.monotonic() < deadline:
|
||||
await asyncio.sleep(self._min_interval)
|
||||
if self._event:
|
||||
break # Early wake.
|
||||
|
||||
self._event = False
|
||||
|
||||
|
||||
class PeriodicExecutor:
|
||||
@ -64,19 +157,6 @@ class PeriodicExecutor:
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||
|
||||
def _run_async(self) -> None:
|
||||
# The default asyncio loop implementation on Windows
|
||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
||||
# We explicitly use a different loop implementation here to prevent that issue
|
||||
if sys.platform == "win32":
|
||||
loop = asyncio.SelectorEventLoop()
|
||||
try:
|
||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start. Multiple calls have no effect.
|
||||
|
||||
@ -104,10 +184,7 @@ class PeriodicExecutor:
|
||||
pass
|
||||
|
||||
if not started:
|
||||
if _IS_SYNC:
|
||||
thread = threading.Thread(target=self._run, name=self._name)
|
||||
else:
|
||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
||||
thread.daemon = True
|
||||
self._thread = weakref.proxy(thread)
|
||||
_register_executor(self)
|
||||
@ -474,7 +474,6 @@ class _ClientBulk:
|
||||
if op_type == "delete":
|
||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||
full_result[f"{op_type}Results"][original_index] = res
|
||||
|
||||
except Exception as exc:
|
||||
# Attempt to close the cursor, then raise top-level error.
|
||||
if cmd_cursor.alive:
|
||||
|
||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
||||
def __init__(self, conn: Connection, more_to_come: bool):
|
||||
self.conn: Optional[Connection] = conn
|
||||
self.more_to_come = more_to_come
|
||||
self._alock = _create_lock()
|
||||
self._lock = _create_lock()
|
||||
|
||||
def update_exhaust(self, more_to_come: bool) -> None:
|
||||
self.more_to_come = more_to_come
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Support for explicit client-side field level encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import socket
|
||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
||||
# BSON encoding/decoding errors are unrelated to encryption so
|
||||
# we should propagate them unchanged.
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptionError(exc) from exc
|
||||
|
||||
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
conn.close()
|
||||
except (PyMongoError, MongoCryptError):
|
||||
raise # Propagate pymongo errors directly.
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as error:
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
database.create_collection(name=name, **kwargs),
|
||||
encrypted_fields,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ access:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
@ -58,7 +59,7 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
@ -74,7 +75,11 @@ from pymongo.errors import (
|
||||
WaitQueueTimeoutError,
|
||||
WriteConcernError,
|
||||
)
|
||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks
|
||||
from pymongo.lock import (
|
||||
_HAS_REGISTER_AT_FORK,
|
||||
_create_lock,
|
||||
_release_locks,
|
||||
)
|
||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database, periodic_executor
|
||||
from pymongo.synchronous import client_session, database
|
||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
address=address,
|
||||
)
|
||||
|
||||
with operation.conn_mgr._alock:
|
||||
with operation.conn_mgr._lock:
|
||||
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||
return server.run_operation(
|
||||
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
with conn_mgr._alock:
|
||||
with conn_mgr._lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
assert address is not None
|
||||
assert conn_mgr.conn is not None
|
||||
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||
try:
|
||||
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
# Raise the exception when client is closed so that it
|
||||
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
for address, cursor_ids in address_to_cursor_ids.items():
|
||||
try:
|
||||
self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
raise
|
||||
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
try:
|
||||
self._process_kill_cursors()
|
||||
self._topology.update_pool()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||
return
|
||||
|
||||
@ -16,24 +16,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||
|
||||
from pymongo import common
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||
from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
from pymongo.synchronous import periodic_executor
|
||||
from pymongo.synchronous.periodic_executor import _shutdown_executors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
finally:
|
||||
if self._executor._stopped:
|
||||
self._rtt_monitor.close()
|
||||
|
||||
def _check_server(self) -> ServerDescription:
|
||||
"""Call hello or read the next streaming response.
|
||||
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
|
||||
details = cast(Mapping[str, Any], exc.details)
|
||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ReferenceError:
|
||||
raise
|
||||
except Exception as error:
|
||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self._pool.reset()
|
||||
|
||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
||||
shutdown()
|
||||
|
||||
|
||||
atexit.register(_shutdown_resources)
|
||||
if _IS_SYNC:
|
||||
atexit.register(_shutdown_resources)
|
||||
|
||||
@ -23,7 +23,6 @@ import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import (
|
||||
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.lock import (
|
||||
_cond_wait,
|
||||
_create_condition,
|
||||
_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
return condition.wait(timeout)
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
@ -704,6 +702,8 @@ class Connection:
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@ -988,8 +988,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
_lock = _create_lock()
|
||||
self.lock = _Lock(_lock)
|
||||
self.lock = _create_lock()
|
||||
self._max_connecting_cond = _create_condition(self.lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1015,7 +1015,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = threading.Condition(_lock)
|
||||
self.size_cond = _create_condition(self.lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1023,7 +1023,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = threading.Condition(_lock)
|
||||
self._max_connecting_cond = _create_condition(self.lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
@ -1460,7 +1460,8 @@ class Pool:
|
||||
with self.size_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||
while not (self.requests < self.max_pool_size):
|
||||
if not _cond_wait(self.size_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not _cond_wait(self.size_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.requests < self.max_pool_size:
|
||||
@ -1483,7 +1484,8 @@ class Pool:
|
||||
with self._max_connecting_cond:
|
||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||
while not (self.conns or self._pending < self._max_connecting):
|
||||
if not _cond_wait(self._max_connecting_cond, deadline):
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
if not _cond_wait(self._max_connecting_cond, timeout):
|
||||
# Timed out, notify the next thread to ensure a
|
||||
# timeout doesn't consume the condition.
|
||||
if self.conns or self._pending < self._max_connecting:
|
||||
|
||||
@ -27,7 +27,7 @@ import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.errors import (
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
@ -39,7 +39,11 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.lock import (
|
||||
_cond_wait,
|
||||
_create_condition,
|
||||
_create_lock,
|
||||
)
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
|
||||
secondary_server_selector,
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.synchronous import periodic_executor
|
||||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.synchronous.monitor import SrvMonitor
|
||||
from pymongo.synchronous.pool import Pool
|
||||
@ -170,9 +173,10 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
_lock = _create_lock()
|
||||
self._lock = _Lock(_lock)
|
||||
self._condition = self._settings.condition_class(_lock)
|
||||
self._lock = _create_lock()
|
||||
self._condition = _create_condition(
|
||||
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||
)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
@ -354,7 +358,7 @@ class Topology:
|
||||
# change, or for a timeout. We won't miss any changes that
|
||||
# came after our most recent apply_selector call, since we've
|
||||
# held the lock until now.
|
||||
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
||||
_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||
self._description.check_compatible()
|
||||
now = time.monotonic()
|
||||
server_descriptions = self._description.apply_selector(
|
||||
@ -652,7 +656,7 @@ class Topology:
|
||||
"""Wake all monitors, wait for at least one to check its server."""
|
||||
with self._lock:
|
||||
self._request_check_all()
|
||||
self._condition.wait(wait_time)
|
||||
_cond_wait(self._condition, wait_time)
|
||||
|
||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||
"""Return a list of all data-bearing servers.
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
@ -25,6 +26,7 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -191,6 +193,8 @@ class ClientContext:
|
||||
client.close()
|
||||
|
||||
def _init_client(self):
|
||||
self.mongoses = []
|
||||
self.connection_attempts = []
|
||||
self.client = self._connect(host, port)
|
||||
if self.client is not None:
|
||||
# Return early when connected to dataLake as mongohoused does not
|
||||
@ -860,6 +864,16 @@ class ClientContext:
|
||||
client_context = ClientContext()
|
||||
|
||||
|
||||
def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif client_context.client is not None:
|
||||
client_context.client.close()
|
||||
client_context.client = None
|
||||
client_context._init_client()
|
||||
|
||||
|
||||
class PyMongoTestCase(unittest.TestCase):
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
|
||||
class UnitTest(PyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
def setUp(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
def tearDown(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
|
||||
db: Database
|
||||
credentials: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
||||
def setUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
reset_client_context()
|
||||
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
||||
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
raise SkipTest("this test does not support serverless")
|
||||
cls.client = client_context.client
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.client = client_context.client
|
||||
self.db = self.client.pymongo_test
|
||||
if client_context.auth_enabled:
|
||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
||||
self.credentials = {"username": db_user, "password": db_pwd}
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
self.credentials = {}
|
||||
|
||||
def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
|
||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_load_balancer
|
||||
def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||
|
||||
self.client_knobs.enable()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.client_knobs.disable()
|
||||
super().tearDown()
|
||||
|
||||
@ -1253,7 +1211,6 @@ def teardown():
|
||||
c.drop_database("pymongo_test_mike")
|
||||
c.drop_database("pymongo_test_bernie")
|
||||
c.close()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
@ -25,6 +26,7 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
@ -191,6 +193,8 @@ class AsyncClientContext:
|
||||
await client.close()
|
||||
|
||||
async def _init_client(self):
|
||||
self.mongoses = []
|
||||
self.connection_attempts = []
|
||||
self.client = await self._connect(host, port)
|
||||
if self.client is not None:
|
||||
# Return early when connected to dataLake as mongohoused does not
|
||||
@ -862,6 +866,16 @@ class AsyncClientContext:
|
||||
async_client_context = AsyncClientContext()
|
||||
|
||||
|
||||
async def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif async_client_context.client is not None:
|
||||
await async_client_context.client.close()
|
||||
async_client_context.client = None
|
||||
await async_client_context._init_client()
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
async def asyncSetUp(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
async def asyncTearDown(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
db: AsyncDatabase
|
||||
credentials: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
||||
async def asyncSetUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
await reset_client_context()
|
||||
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
||||
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
raise SkipTest("this test does not support serverless")
|
||||
cls.client = async_client_context.client
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.client = async_client_context.client
|
||||
self.db = self.client.pymongo_test
|
||||
if async_client_context.auth_enabled:
|
||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
||||
self.credentials = {"username": db_user, "password": db_pwd}
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
self.credentials = {}
|
||||
|
||||
async def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
|
||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_load_balancer
|
||||
async def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||
|
||||
self.client_knobs.enable()
|
||||
|
||||
def tearDown(self):
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.client_knobs.disable()
|
||||
super().tearDown()
|
||||
await super().asyncTearDown()
|
||||
|
||||
|
||||
async def async_setup():
|
||||
@ -1271,7 +1229,6 @@ async def async_teardown():
|
||||
await c.drop_database("pymongo_test_mike")
|
||||
await c.drop_database("pymongo_test_bernie")
|
||||
await c.close()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ def event_loop_policy():
|
||||
return asyncio.get_event_loop_policy()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
@pytest_asyncio.fixture(scope="package", autouse=True)
|
||||
async def test_setup_and_teardown():
|
||||
await async_setup()
|
||||
yield
|
||||
|
||||
@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
|
||||
coll: AsyncCollection
|
||||
coll_w0: AsyncCollection
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.coll = cls.db.test
|
||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
async def asyncSetUp(self):
|
||||
super().setUp()
|
||||
await super().asyncSetUp()
|
||||
self.coll = self.db.test
|
||||
await self.coll.drop()
|
||||
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def assertEqualResponse(self, expected, actual):
|
||||
"""Compare response from bulk.execute() to expected response."""
|
||||
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
|
||||
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
|
||||
@classmethod
|
||||
@async_client_context.require_auth
|
||||
@async_client_context.require_no_api_version
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
super().setUp()
|
||||
await super().asyncSetUp()
|
||||
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||
await self.db.command(
|
||||
"createRole",
|
||||
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
w: Optional[int]
|
||||
secondary: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.w = async_client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w is not None and cls.w > 1:
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.w = async_client_context.w
|
||||
self.secondary = None
|
||||
if self.w is not None and self.w > 1:
|
||||
for member in (await async_client_context.hello)["hosts"]:
|
||||
if member != (await async_client_context.hello)["primary"]:
|
||||
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
|
||||
self.secondary = await self.async_single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def async_tearDownClass(cls):
|
||||
if cls.secondary:
|
||||
await cls.secondary.close()
|
||||
async def asyncTearDown(self):
|
||||
if self.secondary:
|
||||
await self.secondary.close()
|
||||
|
||||
async def cause_wtimeout(self, requests, ordered):
|
||||
if not async_client_context.test_commands_enabled:
|
||||
|
||||
@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
|
||||
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
dbs: list
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
for db in cls.dbs:
|
||||
await cls.client.drop_database(db)
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self):
|
||||
for db in self.dbs:
|
||||
await self.client.drop_database(db)
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return await client.watch(*args, **kwargs)
|
||||
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||
@classmethod
|
||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return await client[self.db.name].watch(*args, **kwargs)
|
||||
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
|
||||
class TestAsyncCollectionAsyncChangeStream(
|
||||
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
|
||||
):
|
||||
@classmethod
|
||||
@async_client_context.require_change_streams
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# Use a new collection for each test.
|
||||
await self.watched_collection().drop()
|
||||
await self.watched_collection().insert_one({})
|
||||
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
listener: AllowListEventListener
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def asyncSetUp(self):
|
||||
super().asyncSetUp()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
self.listener.reset()
|
||||
|
||||
async def asyncSetUpCluster(self, scenario_dict):
|
||||
|
||||
@ -73,7 +73,6 @@ from test.utils import (
|
||||
is_greenthread_patched,
|
||||
lazy_client_trial,
|
||||
one,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
import bson
|
||||
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
connect=False, serverSelectionTimeoutMS=100
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
# connections could be created and checked into the pool.
|
||||
self.assertGreaterEqual(len(server._pool.conns), 1)
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||
|
||||
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
# maxPoolSize=1 should prevent two connections from being created.
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||
|
||||
async def test_max_idle_time_reaper_removes_stale(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async with server._pool.checkout() as conn_two:
|
||||
pass
|
||||
self.assertIs(conn_one, conn_two)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 0,
|
||||
"stale socket reaped and new one NOT added to the pool",
|
||||
)
|
||||
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
server = await (await client._get_topology()).select_server(
|
||||
readable_server_selector, _Op.TEST
|
||||
)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"pool initialized with 10 connections",
|
||||
)
|
||||
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
async with server._pool.checkout() as conn:
|
||||
conn.close_conn(None)
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"a closed socket gets replaced from the pool",
|
||||
)
|
||||
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
|
||||
async with eval(the_repr) as client_two:
|
||||
self.assertEqual(client_two, client)
|
||||
|
||||
def test_getters(self):
|
||||
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes")
|
||||
async def test_getters(self):
|
||||
await async_wait_until(
|
||||
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
||||
)
|
||||
|
||||
async def test_list_databases(self):
|
||||
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
|
||||
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertFalse(client._topology._opened)
|
||||
|
||||
# Ensure kill cursors thread has not been started.
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertFalse(kc_task and not kc_task.done())
|
||||
# Using the client should open topology and start the thread.
|
||||
await client.admin.command("ping")
|
||||
self.assertTrue(client._topology._opened)
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertTrue(kc_task and not kc_task.done())
|
||||
|
||||
async def test_close_does_not_open_servers(self):
|
||||
client = await self.async_rs_client(connect=False)
|
||||
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_server_selection_timeout(self):
|
||||
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||
|
||||
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertRaises(
|
||||
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||
)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient(
|
||||
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
|
||||
)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
# Test invalid timeout in URI ignored and set to default.
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
await client.close()
|
||||
|
||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await async_client_context.port,
|
||||
)
|
||||
await self.async_single_client(uri, event_listeners=[listener])
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
|
||||
)
|
||||
|
||||
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
|
||||
pool = await async_get_pool(client)
|
||||
original_connect = pool.connect
|
||||
|
||||
def stall_connect(*args, **kwargs):
|
||||
time.sleep(2)
|
||||
return original_connect(*args, **kwargs)
|
||||
async def stall_connect(*args, **kwargs):
|
||||
await asyncio.sleep(2)
|
||||
return await original_connect(*args, **kwargs)
|
||||
|
||||
pool.connect = stall_connect
|
||||
# Un-patch Pool.connect to break the cyclic reference.
|
||||
self.addCleanup(delattr, pool, "connect")
|
||||
|
||||
# Wait for the background thread to start creating connections
|
||||
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||
await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||
|
||||
# Assert that application operations do not block.
|
||||
for _ in range(10):
|
||||
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await client.close()
|
||||
# Add cursor to kill cursors queue
|
||||
del cursor
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: client._kill_cursors_queue,
|
||||
"waited for cursor to be added to queue",
|
||||
)
|
||||
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
|
||||
await cursor.to_list()
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: len(client._kill_cursors_queue) == 0,
|
||||
"waited for all killCursor requests to complete",
|
||||
)
|
||||
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
# Fail over.
|
||||
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
# Total failure.
|
||||
c.kill_host("a:1")
|
||||
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
|
||||
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
|
||||
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||
|
||||
c.kill_host("a:1")
|
||||
|
||||
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
self.assertEqual(await c.arbiters, {("c", 3)})
|
||||
# Assert that we create 2 and only 2 pooled connections.
|
||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
||||
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
|
||||
# Assert that we do not create connections to arbiters.
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
|
||||
)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||
self.assertEqual(await c.address, ("c", 3))
|
||||
# Assert that we create 1 pooled connection.
|
||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
|
||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||
self.assertEqual(len(arbiter.pool.conns), 1)
|
||||
|
||||
@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
cls.warn_context.__enter__()
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
self.collation = Collation("en_US")
|
||||
self.warn_context = warnings.catch_warnings()
|
||||
self.warn_context.__enter__()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.warn_context.__exit__()
|
||||
self.warn_context = None
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
await super().asyncTearDown()
|
||||
|
||||
def last_command_started(self):
|
||||
return self.listener.started_events[-1].command
|
||||
|
||||
@ -40,7 +40,6 @@ from test.utils import (
|
||||
async_get_pool,
|
||||
async_is_mongos,
|
||||
async_wait_until,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import encode
|
||||
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
||||
db: AsyncDatabase
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.client = AsyncMongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.client = self.simple_client(connect=False)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
|
||||
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
||||
class AsyncTestCollection(AsyncIntegrationTest):
|
||||
w: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.w = async_client_context.w # type: ignore
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
||||
else:
|
||||
asyncio.run(cls.async_tearDownClass())
|
||||
|
||||
@classmethod
|
||||
async def async_tearDownClass(cls):
|
||||
await cls.db.drop_collection("test_large_limit")
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await self.db.test.drop()
|
||||
await super().asyncSetUp()
|
||||
self.w = async_client_context.w # type: ignore
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.db.test.drop()
|
||||
await self.db.drop_collection("test_large_limit")
|
||||
await super().asyncTearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_concern_collection(self):
|
||||
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||
|
||||
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
||||
async def predicate():
|
||||
return await db_w0.test.find_one({"x": 1})
|
||||
|
||||
await async_wait_until(predicate, "find w:0 replaced document")
|
||||
|
||||
async def test_update_bypass_document_validation(self):
|
||||
db = self.db
|
||||
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await cur.close()
|
||||
cur = None
|
||||
# Wait until the background thread returns the socket.
|
||||
wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||
# The socket should be discarded.
|
||||
self.assertEqual(0, len(pool.conns))
|
||||
|
||||
|
||||
@ -19,7 +19,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
async_client_context,
|
||||
reset_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.helpers import async_repl_set_step_down
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
|
||||
listener: CMAPListener
|
||||
coll: AsyncCollection
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_replica_set
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = CMAPListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
async def asyncSetUp(self):
|
||||
self.listener = CMAPListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
)
|
||||
|
||||
# Ensure connections to all servers in replica set. This is to test
|
||||
# that the is_writable flag is properly updated for connections that
|
||||
# survive a replica set election.
|
||||
await async_ensure_all_connected(cls.client)
|
||||
cls.listener.reset()
|
||||
|
||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await async_ensure_all_connected(self.client)
|
||||
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
# Note that all ops use same write-concern as self.db (majority).
|
||||
await self.db.drop_collection("step-down")
|
||||
await self.db.create_collection("step-down")
|
||||
|
||||
@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||
for event in final_entity_map["events1"]:
|
||||
self.assertIn("PoolCreatedEvent", event["name"])
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
await client.close()
|
||||
|
||||
async def test_store_all_others_as_entities(self):
|
||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
||||
self.assertEqual(entity_map["failures"], [])
|
||||
self.assertEqual(entity_map["successes"], 2)
|
||||
self.assertEqual(entity_map["iterations"], 5)
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -34,9 +34,9 @@ from test.utils import (
|
||||
AllowListEventListener,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
delay,
|
||||
ignore_deprecations,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import decode_all
|
||||
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
|
||||
with self.assertRaises(ExecutionTimeout):
|
||||
await cursor.next()
|
||||
|
||||
def assertCursorKilled():
|
||||
wait_until(
|
||||
async def assertCursorKilled():
|
||||
await async_wait_until(
|
||||
lambda: len(listener.succeeded_events),
|
||||
"find successful killCursors command",
|
||||
)
|
||||
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
self.assertEqual(1, len(listener.succeeded_events))
|
||||
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
|
||||
|
||||
assertCursorKilled()
|
||||
await assertCursorKilled()
|
||||
listener.reset()
|
||||
|
||||
cursor = await coll.aggregate([], batchSize=1)
|
||||
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
with self.assertRaises(ExecutionTimeout):
|
||||
await cursor.next()
|
||||
|
||||
assertCursorKilled()
|
||||
await assertCursorKilled()
|
||||
|
||||
def test_delete_not_initialized(self):
|
||||
# Creating a cursor with invalid arguments will not run __init__
|
||||
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
|
||||
|
||||
class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
async def test_aggregate_raw(self):
|
||||
c = self.db.test
|
||||
await c.drop()
|
||||
|
||||
@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
|
||||
|
||||
class TestDatabaseAggregation(AsyncIntegrationTest):
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.pipeline: List[Mapping[str, Any]] = [
|
||||
{"$listLocalSessions": {}},
|
||||
{"$limit": 1},
|
||||
|
||||
@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
|
||||
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
|
||||
"""Base class for encryption integration tests."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@async_client_context.require_version_min(4, 2, -1)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
|
||||
def assertEncrypted(self, val):
|
||||
self.assertIsInstance(val, Binary)
|
||||
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@async_client_context.require_version_max(4, 0, 99)
|
||||
async def test_raise_max_wire_version_error(self):
|
||||
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
"local": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
await cls.client.db.coll.drop()
|
||||
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
await self.client.db.coll.drop()
|
||||
self.vault = await create_key_vault(self.client.keyvault.datakeys)
|
||||
|
||||
# Configure the encrypted field via the local schema_map option.
|
||||
schemas = {
|
||||
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
||||
}
|
||||
}
|
||||
opts = AutoEncryptionOpts(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
||||
self.KMS_PROVIDERS,
|
||||
"keyvault.datakeys",
|
||||
schema_map=schemas,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
)
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
self.client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.vault.drop()
|
||||
await cls.client.close()
|
||||
await cls.client_encrypted.close()
|
||||
await cls.client_encryption.close()
|
||||
|
||||
def setUp(self):
|
||||
self.listener.reset()
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.vault.drop()
|
||||
|
||||
async def run_test(self, provider_name):
|
||||
# Create data key.
|
||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestCorpus(AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
|
||||
@staticmethod
|
||||
def kms_providers():
|
||||
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
client_encrypted: AsyncMongoClient
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
db = async_client_context.client.db
|
||||
cls.coll = db.coll
|
||||
await cls.coll.drop()
|
||||
self.coll = db.coll
|
||||
await self.coll.drop()
|
||||
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
||||
json_schema = json_data("limits", "limits-schema.json")
|
||||
await db.create_collection(
|
||||
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
await coll.insert_one(json_data("limits", "limits-key.json"))
|
||||
|
||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
||||
self.listener = OvertCommandListener()
|
||||
self.client_encrypted = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[self.listener]
|
||||
)
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
self.coll_encrypted = self.client_encrypted.db.coll
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.coll_encrypted.drop()
|
||||
await cls.client_encrypted.close()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.coll_encrypted.drop()
|
||||
|
||||
async def test_01_insert_succeeds_under_2MiB(self):
|
||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
||||
self.listener.reset()
|
||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
async def test_04_bulk_batch_split(self):
|
||||
limits_doc = json_data("limits", "limits-doc.json")
|
||||
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
doc2.update(limits_doc)
|
||||
self.listener.reset()
|
||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
async def test_05_insert_succeeds_just_under_16MiB(self):
|
||||
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
||||
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
||||
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
"""Prose tests for creating data keys with a custom endpoint."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
kms_providers = {
|
||||
"aws": AWS_CREDS,
|
||||
"azure": AZURE_CREDS,
|
||||
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||
self._kmip_host_error = None
|
||||
self._invalid_host_error = None
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.client_encryption.close()
|
||||
await self.client_encryption_invalid.close()
|
||||
|
||||
async def run_test_expected_success(self, provider_name, master_key):
|
||||
data_key_id = await self.client_encryption.create_data_key(
|
||||
provider_name, master_key=master_key
|
||||
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
KEYVAULT_COLL = "datakeys"
|
||||
client: AsyncMongoClient
|
||||
|
||||
async def asyncSetUp(self):
|
||||
async def _setup(self):
|
||||
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
||||
await create_key_vault(keyvault, self.DEK)
|
||||
|
||||
async def _test_explicit(self, expectation):
|
||||
await self._setup()
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
async_client_context.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addAsyncCleanup(client_encryption.close)
|
||||
|
||||
ciphertext = await client_encryption.encrypt(
|
||||
"string0",
|
||||
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
|
||||
|
||||
async def _test_automatic(self, expectation_extjson, payload):
|
||||
await self._setup()
|
||||
encrypted_db = "db"
|
||||
encrypted_coll = "coll"
|
||||
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
client = await self.async_rs_or_single_client(
|
||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
|
||||
coll = client.get_database(encrypted_db).get_collection(
|
||||
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
||||
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def test_explicit(self):
|
||||
return await self._test_explicit(
|
||||
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
|
||||
|
||||
|
||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||
async def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
await super().asyncSetUp()
|
||||
|
||||
async def test_explicit(self):
|
||||
return await self._test_explicit(
|
||||
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
||||
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client_test = await self.async_rs_or_single_client(
|
||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||
)
|
||||
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
self.ciphertext = await client_encryption.encrypt(
|
||||
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
||||
)
|
||||
await client_encryption.close()
|
||||
|
||||
self.client_listener = OvertCommandListener()
|
||||
self.topology_listener = TopologyEventListener()
|
||||
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
||||
class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client = async_client_context.client
|
||||
await self.client.db.drop_collection("decryption_events")
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
||||
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.client = async_client_context.client
|
||||
await create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
|
||||
assert isinstance(res["encrypted_indexed"], Binary)
|
||||
assert isinstance(res["encrypted_unindexed"], Binary)
|
||||
|
||||
await client_encryption.close()
|
||||
|
||||
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
||||
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
|
||||
@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
|
||||
mongocryptd_client: AsyncMongoClient
|
||||
MONGOCRYPTD_PORT = 27020
|
||||
|
||||
@classmethod
|
||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
start_mongocryptd(self.MONGOCRYPTD_PORT)
|
||||
|
||||
self.listener = OvertCommandListener()
|
||||
self.mongocryptd_client = self.simple_client(
|
||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||
|
||||
@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
|
||||
|
||||
class AsyncTestGridFile(AsyncIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||
|
||||
async def test_basic(self):
|
||||
|
||||
@ -16,21 +16,20 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
from pymongo.lock import _async_create_condition, _async_create_lock
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.lock import _ACondition
|
||||
|
||||
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
if sys.version_info < (3, 13):
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_wait(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
@ -93,7 +92,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_wait_cancel(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
await cond.acquire()
|
||||
|
||||
wait = asyncio.create_task(cond.wait())
|
||||
@ -104,7 +103,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_contested(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
@ -132,7 +131,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
# See bpo-32841
|
||||
waited = False
|
||||
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
async def wait_on_cond():
|
||||
nonlocal waited
|
||||
@ -155,12 +154,12 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(waited)
|
||||
|
||||
async def test_wait_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait()
|
||||
|
||||
async def test_wait_for(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
presult = False
|
||||
|
||||
def predicate():
|
||||
@ -197,7 +196,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_wait_for_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
# predicate can return true immediately
|
||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||
@ -207,7 +206,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
await cond.wait_for(lambda: False)
|
||||
|
||||
async def test_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
@ -254,7 +253,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
result = []
|
||||
|
||||
@ -287,14 +286,14 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertTrue(t2.result())
|
||||
|
||||
async def test_context_manager(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
self.assertFalse(cond.locked())
|
||||
async with cond:
|
||||
self.assertTrue(cond.locked())
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
async def test_timeout_in_block(self):
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
async with condition:
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||
@ -307,7 +306,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
@ -336,7 +335,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
cond = _async_create_condition(_async_create_lock())
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
@ -369,7 +368,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is awaiting initial
|
||||
# wakeup on the wakeup queue.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
@ -415,7 +414,7 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is acquiring the lock
|
||||
# again.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
condition = _async_create_condition(_async_create_lock())
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
@ -459,55 +458,5 @@ class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
|
||||
class TestCondition(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_multiple_loops_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
def tmain(cond):
|
||||
async def atmain(cond):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
|
||||
asyncio.run(atmain(cond))
|
||||
|
||||
t = threading.Thread(target=tmain, args=(cond,))
|
||||
t.start()
|
||||
|
||||
async with cond:
|
||||
self.assertTrue(await cond.wait(30))
|
||||
t.join()
|
||||
|
||||
async def test_multiple_loops_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
results = []
|
||||
|
||||
def tmain(cond, results):
|
||||
async def atmain(cond, results):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
res = await cond.wait(30)
|
||||
results.append(res)
|
||||
|
||||
asyncio.run(atmain(cond, results))
|
||||
|
||||
nthreads = 5
|
||||
threads = []
|
||||
for _ in range(nthreads):
|
||||
threads.append(threading.Thread(target=tmain, args=(cond, results)))
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
await asyncio.sleep(2)
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(results, [True] * nthreads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.listener.reset()
|
||||
await super().asyncTearDown()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False
|
||||
)
|
||||
|
||||
async def test_started_simple(self):
|
||||
await self.client.pymongo_test.command("ping")
|
||||
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
|
||||
saved_listeners: Any
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
# We plan to call register(), which internally modifies _LISTENERS.
|
||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||
monitoring.register(cls.listener)
|
||||
cls.client = await cls.unmanaged_async_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
await cls.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener.reset()
|
||||
self.client = await self.async_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
await self.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
|
||||
async def test_simple(self):
|
||||
await self.client.pymongo_test.command("ping")
|
||||
|
||||
@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.deprecation_filter = DeprecationFilter()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.deprecation_filter = DeprecationFilter()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.deprecation_filter.stop()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.deprecation_filter.stop()
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.client = await self.async_rs_or_single_client(retryWrites=True)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_actionable_error_message(self):
|
||||
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
listener: OvertCommandListener
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_mmap
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cls.listener]
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[self.listener]
|
||||
)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
await self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
self.knobs.disable()
|
||||
|
||||
async def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
fail_insert: dict
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_no_mmap
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.fail_insert = {
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.fail_insert = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
|
||||
@ -38,7 +38,6 @@ from test.utils import (
|
||||
ExceptionCatchingThread,
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from bson import DBRef
|
||||
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
|
||||
client2: AsyncMongoClient
|
||||
sensitive_commands: Set[str]
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_sessions
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
|
||||
self.client2 = await self.async_rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
monitoring._SENSITIVE_COMMANDS.clear()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
||||
await cls.client2.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.listener = SessionTestListener()
|
||||
self.session_checker_listener = SessionTestListener()
|
||||
self.client = await self.async_rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.client.close)
|
||||
self.db = self.client.pymongo_test
|
||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
|
||||
async def asyncTearDown(self):
|
||||
"""All sessions used in the test must be returned to the pool."""
|
||||
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||
await self.client.drop_database("pymongo_test")
|
||||
used_lsids = self.initial_lsids.copy()
|
||||
for event in self.session_checker_listener.started_events:
|
||||
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
|
||||
current_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
self.assertLessEqual(used_lsids, current_lsids)
|
||||
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def _test_ops(self, client, *ops):
|
||||
listener = client.options.event_listeners[0]
|
||||
|
||||
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
|
||||
listener: SessionTestListener
|
||||
client: AsyncMongoClient
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
cls.listener = SessionTestListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.close()
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = SessionTestListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
|
||||
@async_client_context.require_no_standalone
|
||||
async def test_core(self):
|
||||
|
||||
@ -26,7 +26,7 @@ sys.path[0:0] = [""]
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
wait_until,
|
||||
async_wait_until,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
client = await self.async_rs_client(
|
||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||
)
|
||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
@ -403,21 +403,12 @@ class PatchSessionTimeout:
|
||||
|
||||
|
||||
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.mongos_clients = []
|
||||
if async_client_context.supports_transactions():
|
||||
for address in async_client_context.mongoses:
|
||||
cls.mongos_clients.append(
|
||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||
|
||||
async def _set_fail_point(self, client, command_args):
|
||||
cmd = {"configureFailPoint": "failCommand"}
|
||||
|
||||
@ -50,6 +50,7 @@ from test.unified_format_shared import (
|
||||
)
|
||||
from test.utils import (
|
||||
async_get_pool,
|
||||
async_wait_until,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
parse_spec_options,
|
||||
@ -304,7 +305,6 @@ class EntityMapUtil:
|
||||
kwargs["h"] = uri
|
||||
client = await self.test.async_rs_or_single_client(**kwargs)
|
||||
self[spec["id"]] = client
|
||||
self.test.addAsyncCleanup(client.close)
|
||||
return
|
||||
elif entity_type == "database":
|
||||
client = self[spec["client"]]
|
||||
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
await db.create_collection(coll_name, write_concern=wc, **opts)
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
# super call creates internal client cls.client
|
||||
await super()._setup_class()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not await cls.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if async_client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
cls.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
cls.mongos_clients = []
|
||||
if (
|
||||
async_client_context.supports_transactions()
|
||||
and not async_client_context.load_balancer
|
||||
and not async_client_context.serverless
|
||||
):
|
||||
for address in async_client_context.mongoses:
|
||||
cls.mongos_clients.append(
|
||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
||||
)
|
||||
|
||||
def setUpClass(cls) -> None:
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(
|
||||
heartbeat_frequency=0.1,
|
||||
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
# super call creates internal client cls.client
|
||||
await super().asyncSetUp()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not await self.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if async_client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
self.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
self.mongos_clients = []
|
||||
if (
|
||||
async_client_context.supports_transactions()
|
||||
and not async_client_context.load_balancer
|
||||
and not async_client_context.serverless
|
||||
):
|
||||
for address in async_client_context.mongoses:
|
||||
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||
|
||||
# process schemaVersion
|
||||
# note: we check major schema version during class generation
|
||||
# note: we do this here because we cannot run assertions in setUpClass
|
||||
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
||||
self.assertLessEqual(
|
||||
version,
|
||||
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
|
||||
self.addAsyncCleanup(client.close)
|
||||
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||
|
||||
async def _testOperation_createEntities(self, spec):
|
||||
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
|
||||
|
||||
def _testOperation_waitForEvent(self, spec):
|
||||
async def _testOperation_waitForEvent(self, spec):
|
||||
"""Run the waitForEvent test operation.
|
||||
|
||||
Wait for a number of events to be published, or fail.
|
||||
"""
|
||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||
wait_until(
|
||||
await async_wait_until(
|
||||
lambda: self._event_count(client, event) >= count,
|
||||
f"find {count} {event} event(s)",
|
||||
)
|
||||
|
||||
@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
knobs: client_knobs
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
async def asyncSetUp(self) -> None:
|
||||
await super().asyncSetUp()
|
||||
self.mongos_clients = []
|
||||
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.targets = {}
|
||||
self.listener = None # type: ignore
|
||||
self.pool_listener = None
|
||||
self.server_listener = None
|
||||
self.maxDiff = None
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
async def _set_fail_point(self, client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
self.listener = listener
|
||||
self.pool_listener = pool_listener
|
||||
self.server_listener = server_listener
|
||||
# Close the client explicitly to avoid having too many threads open.
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
# Create session0 and session1.
|
||||
sessions = {}
|
||||
|
||||
@ -20,7 +20,7 @@ def event_loop_policy():
|
||||
return asyncio.get_event_loop_policy()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(scope="package", autouse=True)
|
||||
def test_setup_and_teardown():
|
||||
setup()
|
||||
yield
|
||||
|
||||
@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
|
||||
coll: Collection
|
||||
coll_w0: Collection
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.coll = cls.db.test
|
||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.coll = self.db.test
|
||||
self.coll.drop()
|
||||
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
def assertEqualResponse(self, expected, actual):
|
||||
"""Compare response from bulk.execute() to expected response."""
|
||||
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
|
||||
|
||||
|
||||
class BulkAuthorizationTestBase(BulkTestBase):
|
||||
@classmethod
|
||||
@client_context.require_auth
|
||||
@client_context.require_no_api_version
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
w: Optional[int]
|
||||
secondary: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.w = client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w is not None and cls.w > 1:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.w = client_context.w
|
||||
self.secondary = None
|
||||
if self.w is not None and self.w > 1:
|
||||
for member in (client_context.hello)["hosts"]:
|
||||
if member != (client_context.hello)["primary"]:
|
||||
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
|
||||
self.secondary = self.single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def async_tearDownClass(cls):
|
||||
if cls.secondary:
|
||||
cls.secondary.close()
|
||||
def tearDown(self):
|
||||
if self.secondary:
|
||||
self.secondary.close()
|
||||
|
||||
def cause_wtimeout(self, requests, ordered):
|
||||
if not client_context.test_commands_enabled:
|
||||
|
||||
@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
|
||||
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
dbs: list
|
||||
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0, -1)
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
for db in cls.dbs:
|
||||
cls.client.drop_database(db)
|
||||
super()._tearDown_class()
|
||||
def tearDown(self):
|
||||
for db in self.dbs:
|
||||
self.client.drop_database(db)
|
||||
super().tearDown()
|
||||
|
||||
def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return client.watch(*args, **kwargs)
|
||||
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0, -1)
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
def change_stream_with_client(self, client, *args, **kwargs):
|
||||
return client[self.db.name].watch(*args, **kwargs)
|
||||
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||
|
||||
|
||||
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
|
||||
@classmethod
|
||||
@client_context.require_change_streams
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Use a new collection for each test.
|
||||
self.watched_collection().drop()
|
||||
self.watched_collection().insert_one({})
|
||||
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
listener: AllowListEventListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.listener.reset()
|
||||
|
||||
def setUpCluster(self, scenario_dict):
|
||||
|
||||
@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
|
||||
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
def setUp(self) -> None:
|
||||
self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
|
||||
self.assertFalse(client._topology._opened)
|
||||
|
||||
# Ensure kill cursors thread has not been started.
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertFalse(kc_task and not kc_task.done())
|
||||
# Using the client should open topology and start the thread.
|
||||
client.admin.command("ping")
|
||||
self.assertTrue(client._topology._opened)
|
||||
if _IS_SYNC:
|
||||
kc_thread = client._kill_cursors_executor._thread
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
else:
|
||||
kc_task = client._kill_cursors_executor._task
|
||||
self.assertTrue(kc_task and not kc_task.done())
|
||||
|
||||
def test_close_does_not_open_servers(self):
|
||||
client = self.rs_client(connect=False)
|
||||
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
|
||||
def test_server_selection_timeout(self):
|
||||
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||
|
||||
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
|
||||
self.assertRaises(
|
||||
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||
)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
|
||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
# Test invalid timeout in URI ignored and set to default.
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
client.close()
|
||||
|
||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||
|
||||
@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
cls.warn_context.__enter__()
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
self.collation = Collation("en_US")
|
||||
self.warn_context = warnings.catch_warnings()
|
||||
self.warn_context.__enter__()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.warn_context.__exit__()
|
||||
self.warn_context = None
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
|
||||
|
||||
@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
|
||||
db: Database
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.client = MongoClient(connect=False)
|
||||
cls.db = cls.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.client = self.simple_client(connect=False)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, Collection, self.db, 5)
|
||||
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
|
||||
class TestCollection(IntegrationTest):
|
||||
w: int
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.w = client_context.w # type: ignore
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
||||
else:
|
||||
asyncio.run(cls.async_tearDownClass())
|
||||
|
||||
@classmethod
|
||||
def async_tearDownClass(cls):
|
||||
cls.db.drop_collection("test_large_limit")
|
||||
|
||||
def setUp(self):
|
||||
self.db.test.drop()
|
||||
super().setUp()
|
||||
self.w = client_context.w # type: ignore
|
||||
|
||||
def tearDown(self):
|
||||
self.db.test.drop()
|
||||
self.db.drop_collection("test_large_limit")
|
||||
super().tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_concern_collection(self):
|
||||
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
|
||||
db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||
|
||||
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
||||
def predicate():
|
||||
return db_w0.test.find_one({"x": 1})
|
||||
|
||||
wait_until(predicate, "find w:0 replaced document")
|
||||
|
||||
def test_update_bypass_document_validation(self):
|
||||
db = self.db
|
||||
|
||||
@ -19,7 +19,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
client_context,
|
||||
reset_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.helpers import repl_set_step_down
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
||||
listener: CMAPListener
|
||||
coll: Collection
|
||||
|
||||
@classmethod
|
||||
@client_context.require_replica_set
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = CMAPListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
def setUp(self):
|
||||
self.listener = CMAPListener()
|
||||
self.client = self.rs_or_single_client(
|
||||
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||
)
|
||||
|
||||
# Ensure connections to all servers in replica set. This is to test
|
||||
# that the is_writable flag is properly updated for connections that
|
||||
# survive a replica set election.
|
||||
ensure_all_connected(cls.client)
|
||||
cls.listener.reset()
|
||||
|
||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
|
||||
def setUp(self):
|
||||
ensure_all_connected(self.client)
|
||||
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||
# Note that all ops use same write-concern as self.db (majority).
|
||||
self.db.drop_collection("step-down")
|
||||
self.db.create_collection("step-down")
|
||||
|
||||
@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
|
||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||
for event in final_entity_map["events1"]:
|
||||
self.assertIn("PoolCreatedEvent", event["name"])
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
client.close()
|
||||
|
||||
def test_store_all_others_as_entities(self):
|
||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
|
||||
self.assertEqual(entity_map["failures"], [])
|
||||
self.assertEqual(entity_map["successes"], 2)
|
||||
self.assertEqual(entity_map["iterations"], 5)
|
||||
if self.scenario_runner.mongos_clients:
|
||||
for client in self.scenario_runner.mongos_clients:
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
|
||||
|
||||
class TestRawBatchCommandCursor(IntegrationTest):
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def test_aggregate_raw(self):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
|
||||
@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
|
||||
|
||||
class TestCollectionWCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.drop()
|
||||
|
||||
def tearDown(self):
|
||||
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
|
||||
|
||||
class TestGridFileCustomType(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.drop_collection("fs.files")
|
||||
self.db.drop_collection("fs.chunks")
|
||||
|
||||
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
|
||||
|
||||
|
||||
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
|
||||
|
||||
|
||||
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
|
||||
|
||||
|
||||
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||
@classmethod
|
||||
@client_context.require_version_min(4, 0, 0)
|
||||
@client_context.require_change_streams
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.db.test.delete_many({})
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.delete_many({})
|
||||
|
||||
def tearDown(self):
|
||||
self.input_target.drop()
|
||||
|
||||
@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
|
||||
|
||||
class TestDatabaseAggregation(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.pipeline: List[Mapping[str, Any]] = [
|
||||
{"$listLocalSessions": {}},
|
||||
{"$limit": 1},
|
||||
|
||||
@ -211,11 +211,10 @@ class TestClientOptions(PyMongoTestCase):
|
||||
class EncryptionIntegrationTest(IntegrationTest):
|
||||
"""Base class for encryption integration tests."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@client_context.require_version_min(4, 2, -1)
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
def assertEncrypted(self, val):
|
||||
self.assertIsInstance(val, Binary)
|
||||
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestClientMaxWireVersion(IntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@client_context.require_version_max(4, 0, 99)
|
||||
def test_raise_max_wire_version_error(self):
|
||||
@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
"local": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.client.db.coll.drop()
|
||||
cls.vault = create_key_vault(cls.client.keyvault.datakeys)
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.client.db.coll.drop()
|
||||
self.vault = create_key_vault(self.client.keyvault.datakeys)
|
||||
|
||||
# Configure the encrypted field via the local schema_map option.
|
||||
schemas = {
|
||||
@ -844,25 +841,22 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
||||
}
|
||||
}
|
||||
opts = AutoEncryptionOpts(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
||||
self.KMS_PROVIDERS,
|
||||
"keyvault.datakeys",
|
||||
schema_map=schemas,
|
||||
kms_tls_options=KMS_TLS_OPTS,
|
||||
)
|
||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
||||
self.client_encrypted = self.rs_or_single_client(
|
||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||
)
|
||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
self.client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.vault.drop()
|
||||
cls.client.close()
|
||||
cls.client_encrypted.close()
|
||||
cls.client_encryption.close()
|
||||
|
||||
def setUp(self):
|
||||
self.listener.reset()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.vault.drop()
|
||||
|
||||
def run_test(self, provider_name):
|
||||
# Create data key.
|
||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||
@ -1007,10 +1001,9 @@ class TestViews(EncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestCorpus(EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@staticmethod
|
||||
def kms_providers():
|
||||
@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
client_encrypted: MongoClient
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
db = client_context.client.db
|
||||
cls.coll = db.coll
|
||||
cls.coll.drop()
|
||||
self.coll = db.coll
|
||||
self.coll.drop()
|
||||
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
||||
json_schema = json_data("limits", "limits-schema.json")
|
||||
db.create_collection(
|
||||
@ -1207,17 +1199,14 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
coll.insert_one(json_data("limits", "limits-key.json"))
|
||||
|
||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
||||
self.listener = OvertCommandListener()
|
||||
self.client_encrypted = self.rs_or_single_client(
|
||||
auto_encryption_opts=opts, event_listeners=[self.listener]
|
||||
)
|
||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
||||
self.coll_encrypted = self.client_encrypted.db.coll
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.coll_encrypted.drop()
|
||||
cls.client_encrypted.close()
|
||||
super()._tearDown_class()
|
||||
def tearDown(self) -> None:
|
||||
self.coll_encrypted.drop()
|
||||
|
||||
def test_01_insert_succeeds_under_2MiB(self):
|
||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||
@ -1241,7 +1230,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
||||
self.listener.reset()
|
||||
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
def test_04_bulk_batch_split(self):
|
||||
limits_doc = json_data("limits", "limits-doc.json")
|
||||
@ -1251,7 +1242,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
doc2.update(limits_doc)
|
||||
self.listener.reset()
|
||||
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
||||
self.assertEqual(
|
||||
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||
)
|
||||
|
||||
def test_05_insert_succeeds_just_under_16MiB(self):
|
||||
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
||||
@ -1281,15 +1274,12 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
||||
class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
"""Prose tests for creating data keys with a custom endpoint."""
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(
|
||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||
"No environment credentials are set",
|
||||
)
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
kms_providers = {
|
||||
"aws": AWS_CREDS,
|
||||
"azure": AZURE_CREDS,
|
||||
@ -1318,10 +1308,6 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||
self._kmip_host_error = None
|
||||
self._invalid_host_error = None
|
||||
|
||||
def tearDown(self):
|
||||
self.client_encryption.close()
|
||||
self.client_encryption_invalid.close()
|
||||
|
||||
def run_test_expected_success(self, provider_name, master_key):
|
||||
data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key)
|
||||
encrypted = self.client_encryption.encrypt(
|
||||
@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
||||
KEYVAULT_COLL = "datakeys"
|
||||
client: MongoClient
|
||||
|
||||
def setUp(self):
|
||||
def _setup(self):
|
||||
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
||||
create_key_vault(keyvault, self.DEK)
|
||||
|
||||
def _test_explicit(self, expectation):
|
||||
self._setup()
|
||||
client_encryption = self.create_client_encryption(
|
||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||
client_context.client,
|
||||
OPTS,
|
||||
)
|
||||
self.addCleanup(client_encryption.close)
|
||||
|
||||
ciphertext = client_encryption.encrypt(
|
||||
"string0",
|
||||
@ -1517,6 +1503,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
||||
self.assertEqual(client_encryption.decrypt(ciphertext), "string0")
|
||||
|
||||
def _test_automatic(self, expectation_extjson, payload):
|
||||
self._setup()
|
||||
encrypted_db = "db"
|
||||
encrypted_coll = "coll"
|
||||
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||
@ -1531,7 +1518,6 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
||||
client = self.rs_or_single_client(
|
||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
|
||||
coll = client.get_database(encrypted_db).get_collection(
|
||||
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
||||
@ -1553,13 +1539,12 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
||||
|
||||
|
||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||
def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super().setUp()
|
||||
|
||||
def test_explicit(self):
|
||||
return self._test_explicit(
|
||||
@ -1579,13 +1564,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest
|
||||
|
||||
|
||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||
def _setup_class(cls):
|
||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||
super().setUp()
|
||||
|
||||
def test_explicit(self):
|
||||
return self._test_explicit(
|
||||
@ -1607,6 +1591,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
||||
class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.client_test = self.rs_or_single_client(
|
||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||
)
|
||||
@ -1637,7 +1622,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
self.ciphertext = client_encryption.encrypt(
|
||||
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
||||
)
|
||||
client_encryption.close()
|
||||
|
||||
self.client_listener = OvertCommandListener()
|
||||
self.topology_listener = TopologyEventListener()
|
||||
@ -1832,6 +1816,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
||||
class TestDecryptProse(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.client = client_context.client
|
||||
self.client.db.drop_collection("decryption_events")
|
||||
create_key_vault(self.client.keyvault.datakeys)
|
||||
@ -2267,6 +2252,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
||||
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.client = client_context.client
|
||||
create_key_vault(self.client.keyvault.datakeys)
|
||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||
@ -2608,8 +2594,6 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
|
||||
assert isinstance(res["encrypted_indexed"], Binary)
|
||||
assert isinstance(res["encrypted_unindexed"], Binary)
|
||||
|
||||
client_encryption.close()
|
||||
|
||||
|
||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
||||
class TestRangeQueryProse(EncryptionIntegrationTest):
|
||||
@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
|
||||
mongocryptd_client: MongoClient
|
||||
MONGOCRYPTD_PORT = 27020
|
||||
|
||||
@classmethod
|
||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
start_mongocryptd(self.MONGOCRYPTD_PORT)
|
||||
|
||||
self.listener = OvertCommandListener()
|
||||
self.mongocryptd_client = self.simple_client(
|
||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||
|
||||
@ -33,19 +33,14 @@ from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
class TestSampleShellCommands(IntegrationTest):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
# Run once before any tests run.
|
||||
cls.db.inventory.drop()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.client.drop_database("pymongo_test")
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.inventory.drop()
|
||||
|
||||
def tearDown(self):
|
||||
# Run after every test.
|
||||
self.db.inventory.drop()
|
||||
self.client.drop_database("pymongo_test")
|
||||
|
||||
def test_first_three_examples(self):
|
||||
db = self.db
|
||||
|
||||
@ -97,6 +97,7 @@ class TestGridFileNoConnect(UnitTest):
|
||||
|
||||
class TestGridFile(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||
|
||||
def test_basic(self):
|
||||
|
||||
@ -75,9 +75,9 @@ class JustRead(threading.Thread):
|
||||
class TestGridfsNoConnect(unittest.TestCase):
|
||||
db: Database
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.db = MongoClient(connect=False).pymongo_test
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db = MongoClient(connect=False).pymongo_test
|
||||
|
||||
def test_gridfs(self):
|
||||
self.assertRaises(TypeError, gridfs.GridFS, "foo")
|
||||
@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest):
|
||||
fs: gridfs.GridFS
|
||||
alt: gridfs.GridFS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.fs = gridfs.GridFS(cls.db)
|
||||
cls.alt = gridfs.GridFS(cls.db, "alt")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.fs = gridfs.GridFS(self.db)
|
||||
self.alt = gridfs.GridFS(self.db, "alt")
|
||||
self.cleanup_colls(
|
||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||
)
|
||||
@ -509,10 +506,9 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
|
||||
class TestGridfsReplicaSet(IntegrationTest):
|
||||
@classmethod
|
||||
@client_context.require_secondaries_count(1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest):
|
||||
fs: gridfs.GridFSBucket
|
||||
alt: gridfs.GridFSBucket
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.fs = gridfs.GridFSBucket(cls.db)
|
||||
cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.fs = gridfs.GridFSBucket(self.db)
|
||||
self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt")
|
||||
self.cleanup_colls(
|
||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||
)
|
||||
@ -479,10 +476,9 @@ class TestGridfs(IntegrationTest):
|
||||
|
||||
|
||||
class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||
@classmethod
|
||||
@client_context.require_secondaries_count(1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@ -29,7 +29,7 @@ from test.utils import (
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from pymongo.synchronous.periodic_executor import _EXECUTORS
|
||||
from pymongo.periodic_executor import _EXECUTORS
|
||||
|
||||
|
||||
def unregistered(ref):
|
||||
|
||||
@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
event_listeners=[cls.listener], retryWrites=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
@client_context.require_connection
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False)
|
||||
|
||||
def test_started_simple(self):
|
||||
self.client.pymongo_test.command("ping")
|
||||
@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest):
|
||||
saved_listeners: Any
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUpClass(cls) -> None:
|
||||
cls.listener = OvertCommandListener()
|
||||
# We plan to call register(), which internally modifies _LISTENERS.
|
||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||
monitoring.register(cls.listener)
|
||||
cls.client = cls.unmanaged_single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
cls.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
@client_context.require_connection
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener.reset()
|
||||
self.client = self.single_client()
|
||||
# Get one (authenticated) socket in the pool.
|
||||
self.client.pymongo_test.command("ping")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
monitoring._LISTENERS = cls.saved_listeners
|
||||
|
||||
def test_simple(self):
|
||||
self.client.pymongo_test.command("ping")
|
||||
|
||||
@ -31,24 +31,16 @@ from pymongo.read_concern import ReadConcern
|
||||
class TestReadConcern(IntegrationTest):
|
||||
listener: OvertCommandListener
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
client_context.client.pymongo_test.create_collection("coll")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.client.close()
|
||||
client_context.client.pymongo_test.drop_collection("coll")
|
||||
super().tearDownClass()
|
||||
|
||||
def tearDown(self):
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
client_context.client.pymongo_test.drop_collection("coll")
|
||||
|
||||
def test_read_concern(self):
|
||||
rc = ReadConcern()
|
||||
|
||||
@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
deprecation_filter: DeprecationFilter
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.deprecation_filter = DeprecationFilter()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.deprecation_filter = DeprecationFilter()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.deprecation_filter.stop()
|
||||
super()._tearDown_class()
|
||||
def tearDown(self) -> None:
|
||||
self.deprecation_filter.stop()
|
||||
|
||||
|
||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.client = self.rs_or_single_client(retryWrites=True)
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
def tearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
@client_context.require_no_standalone
|
||||
def test_actionable_error_message(self):
|
||||
@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
listener: OvertCommandListener
|
||||
knobs: client_knobs
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_mmap
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
cls.listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(
|
||||
retryWrites=True, event_listeners=[cls.listener]
|
||||
)
|
||||
cls.db = cls.client.pymongo_test
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
cls.client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
if client_context.is_rs and client_context.test_commands_enabled:
|
||||
self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||
@ -210,6 +193,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
self.client.admin.command(
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
self.knobs.disable()
|
||||
|
||||
def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest):
|
||||
RUN_ON_SERVERLESS = True
|
||||
fail_insert: dict
|
||||
|
||||
@classmethod
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_no_mmap
|
||||
@client_context.require_failCommand_fail_point
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.fail_insert = {
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.fail_insert = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
|
||||
@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest):
|
||||
@classmethod
|
||||
@client_context.require_failCommand_fail_point
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
super().setUp(cls)
|
||||
# Speed up the tests by decreasing the event publish frequency.
|
||||
cls.knobs = client_knobs(
|
||||
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1
|
||||
|
||||
@ -82,36 +82,27 @@ class TestSession(IntegrationTest):
|
||||
client2: MongoClient
|
||||
sensitive_commands: Set[str]
|
||||
|
||||
@classmethod
|
||||
@client_context.require_sessions
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Create a second client so we can make sure clients cannot share
|
||||
# sessions.
|
||||
cls.client2 = cls.unmanaged_rs_or_single_client()
|
||||
self.client2 = self.rs_or_single_client()
|
||||
|
||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||
monitoring._SENSITIVE_COMMANDS.clear()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
||||
cls.client2.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
self.listener = SessionTestListener()
|
||||
self.session_checker_listener = SessionTestListener()
|
||||
self.client = self.rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addCleanup(self.client.close)
|
||||
self.db = self.client.pymongo_test
|
||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
|
||||
def tearDown(self):
|
||||
"""All sessions used in the test must be returned to the pool."""
|
||||
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||
self.client.drop_database("pymongo_test")
|
||||
used_lsids = self.initial_lsids.copy()
|
||||
for event in self.session_checker_listener.started_events:
|
||||
@ -121,6 +112,8 @@ class TestSession(IntegrationTest):
|
||||
current_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
self.assertLessEqual(used_lsids, current_lsids)
|
||||
|
||||
super().tearDown()
|
||||
|
||||
def _test_ops(self, client, *ops):
|
||||
listener = client.options.event_listeners[0]
|
||||
|
||||
@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest):
|
||||
listener: SessionTestListener
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls.listener = SessionTestListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
|
||||
@client_context.require_sessions
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.listener = SessionTestListener()
|
||||
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||
|
||||
@client_context.require_no_standalone
|
||||
def test_core(self):
|
||||
|
||||
@ -105,6 +105,7 @@ class Update(threading.Thread):
|
||||
|
||||
class TestThreads(IntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db = self.client.pymongo_test
|
||||
|
||||
def test_threading(self):
|
||||
|
||||
@ -395,19 +395,12 @@ class PatchSessionTimeout:
|
||||
|
||||
|
||||
class TestTransactionsConvenientAPI(TransactionsBase):
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.mongos_clients = []
|
||||
if client_context.supports_transactions():
|
||||
for address in client_context.mongoses:
|
||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
for client in cls.mongos_clients:
|
||||
client.close()
|
||||
super()._tearDown_class()
|
||||
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
|
||||
|
||||
def _set_fail_point(self, client, command_args):
|
||||
cmd = {"configureFailPoint": "failCommand"}
|
||||
|
||||
@ -114,10 +114,9 @@ class TestMypyFails(unittest.TestCase):
|
||||
class TestPymongo(IntegrationTest):
|
||||
coll: Collection
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.coll = cls.client.test.test
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.coll = self.client.test.test
|
||||
|
||||
def test_insert_find(self) -> None:
|
||||
doc = {"my": "doc"}
|
||||
|
||||
@ -304,7 +304,6 @@ class EntityMapUtil:
|
||||
kwargs["h"] = uri
|
||||
client = self.test.rs_or_single_client(**kwargs)
|
||||
self[spec["id"]] = client
|
||||
self.test.addCleanup(client.close)
|
||||
return
|
||||
elif entity_type == "database":
|
||||
client = self[spec["client"]]
|
||||
@ -479,31 +478,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
db.create_collection(coll_name, write_concern=wc, **opts)
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
# super call creates internal client cls.client
|
||||
super()._setup_class()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not cls.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
cls.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
cls.mongos_clients = []
|
||||
if (
|
||||
client_context.supports_transactions()
|
||||
and not client_context.load_balancer
|
||||
and not client_context.serverless
|
||||
):
|
||||
for address in client_context.mongoses:
|
||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
||||
|
||||
def setUpClass(cls) -> None:
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(
|
||||
heartbeat_frequency=0.1,
|
||||
@ -514,17 +489,36 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
# super call creates internal client cls.client
|
||||
super().setUp()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not self.should_run_on(run_on_spec):
|
||||
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
|
||||
|
||||
# add any special-casing for skipping tests here
|
||||
if client_context.storage_engine == "mmapv1":
|
||||
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||
self.TEST_PATH
|
||||
):
|
||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||
|
||||
# Handle mongos_clients for transactions tests.
|
||||
self.mongos_clients = []
|
||||
if (
|
||||
client_context.supports_transactions()
|
||||
and not client_context.load_balancer
|
||||
and not client_context.serverless
|
||||
):
|
||||
for address in client_context.mongoses:
|
||||
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
|
||||
|
||||
# process schemaVersion
|
||||
# note: we check major schema version during class generation
|
||||
# note: we do this here because we cannot run assertions in setUpClass
|
||||
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
||||
self.assertLessEqual(
|
||||
version,
|
||||
@ -1026,7 +1020,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
)
|
||||
|
||||
client = self.single_client("{}:{}".format(*session._pinned_address))
|
||||
self.addCleanup(client.close)
|
||||
self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||
|
||||
def _testOperation_createEntities(self, spec):
|
||||
|
||||
@ -99,6 +99,12 @@ class BaseListener:
|
||||
"""Wait for a number of events to be published, or fail."""
|
||||
wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)")
|
||||
|
||||
async def async_wait_for_event(self, event, count):
|
||||
"""Wait for a number of events to be published, or fail."""
|
||||
await async_wait_until(
|
||||
lambda: self.event_count(event) >= count, f"find {count} {event} event(s)"
|
||||
)
|
||||
|
||||
|
||||
class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
|
||||
def connection_created(self, event):
|
||||
@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10):
|
||||
start = time.time()
|
||||
interval = min(float(timeout) / 100, 0.1)
|
||||
while True:
|
||||
if iscoroutinefunction(predicate):
|
||||
retval = await predicate()
|
||||
else:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
|
||||
@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest):
|
||||
knobs: client_knobs
|
||||
listener: EventListener
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.mongos_clients = []
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.mongos_clients = []
|
||||
|
||||
# Speed up the tests by decreasing the heartbeat frequency.
|
||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
client.close()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||
self.knobs.enable()
|
||||
self.targets = {}
|
||||
self.listener = None # type: ignore
|
||||
self.pool_listener = None
|
||||
self.server_listener = None
|
||||
self.maxDiff = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.knobs.disable()
|
||||
|
||||
def _set_fail_point(self, client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
@ -697,8 +689,6 @@ class SpecRunner(IntegrationTest):
|
||||
self.listener = listener
|
||||
self.pool_listener = pool_listener
|
||||
self.server_listener = server_listener
|
||||
# Close the client explicitly to avoid having too many threads open.
|
||||
self.addCleanup(client.close)
|
||||
|
||||
# Create session0 and session1.
|
||||
sessions = {}
|
||||
|
||||
@ -110,6 +110,13 @@ replacements = {
|
||||
"async_set_fail_point": "set_fail_point",
|
||||
"async_ensure_all_connected": "ensure_all_connected",
|
||||
"async_repl_set_step_down": "repl_set_step_down",
|
||||
"AsyncPeriodicExecutor": "PeriodicExecutor",
|
||||
"async_wait_for_event": "wait_for_event",
|
||||
"pymongo_server_monitor_task": "pymongo_server_monitor_thread",
|
||||
"pymongo_server_rtt_task": "pymongo_server_rtt_thread",
|
||||
"_async_create_lock": "_create_lock",
|
||||
"_async_create_condition": "_create_condition",
|
||||
"_async_cond_wait": "_cond_wait",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -130,8 +137,6 @@ docstring_removals: set[str] = {
|
||||
".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release."
|
||||
}
|
||||
|
||||
type_replacements = {"_Condition": "threading.Condition"}
|
||||
|
||||
import_replacements = {"test.synchronous": "test"}
|
||||
|
||||
_pymongo_base = "./pymongo/asynchronous/"
|
||||
@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None:
|
||||
lines = translate_async_sleeps(lines)
|
||||
if file in docstring_translate_files:
|
||||
lines = translate_docstrings(lines)
|
||||
translate_locks(lines)
|
||||
translate_types(lines)
|
||||
if file in sync_test_files:
|
||||
translate_imports(lines)
|
||||
f.seek(0)
|
||||
@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]:
|
||||
return lines
|
||||
|
||||
|
||||
def translate_locks(lines: list[str]) -> list[str]:
|
||||
lock_lines = [line for line in lines if "_Lock(" in line]
|
||||
cond_lines = [line for line in lines if "_Condition(" in line]
|
||||
for line in lock_lines:
|
||||
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
|
||||
if res:
|
||||
old = res[0]
|
||||
index = lines.index(line)
|
||||
lines[index] = line.replace(old, res[1])
|
||||
for line in cond_lines:
|
||||
res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line)
|
||||
if res:
|
||||
old = res[0]
|
||||
index = lines.index(line)
|
||||
lines[index] = line.replace(old, res[1])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def translate_types(lines: list[str]) -> list[str]:
|
||||
for k, v in type_replacements.items():
|
||||
matches = [line for line in lines if k in line and "import" not in line]
|
||||
for line in matches:
|
||||
index = lines.index(line)
|
||||
lines[index] = line.replace(k, v)
|
||||
return lines
|
||||
|
||||
|
||||
def translate_imports(lines: list[str]) -> list[str]:
|
||||
for k, v in import_replacements.items():
|
||||
matches = [line for line in lines if k in line and "import" in line]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user