Async client uses tasks instead of threads
PYTHON-4725 - Async client should use tasks for SDAM instead of threads PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition PYTHON-4941 - Synchronous unified test runner being used in asynchronous tests PYTHON-4843 - Async test suite should use a single event loop PYTHON-4945 - Fix test cleanups for mongoses Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
This commit is contained in:
parent
9b5c0981d9
commit
0e8d70457f
@ -281,7 +281,7 @@ functions:
|
|||||||
"run tests":
|
"run tests":
|
||||||
- command: subprocess.exec
|
- command: subprocess.exec
|
||||||
params:
|
params:
|
||||||
include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
|
include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"]
|
||||||
binary: bash
|
binary: bash
|
||||||
working_dir: "src"
|
working_dir: "src"
|
||||||
args:
|
args:
|
||||||
|
|||||||
@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
|
||||||
|
2) License Notice for _asyncio_lock.py
|
||||||
|
-----------------------------------------
|
||||||
|
|
||||||
|
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||||
|
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||||
|
otherwise using this software ("Python") in source or binary form and
|
||||||
|
its associated documentation.
|
||||||
|
|
||||||
|
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||||
|
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||||
|
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||||
|
distribute, and otherwise use Python alone or in any derivative version,
|
||||||
|
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||||
|
i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved"
|
||||||
|
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||||
|
|
||||||
|
3. In the event Licensee prepares a derivative work that is based on
|
||||||
|
or incorporates Python or any part thereof, and wants to make
|
||||||
|
the derivative work available to others as provided herein, then
|
||||||
|
Licensee hereby agrees to include in any such work a brief summary of
|
||||||
|
the changes made to Python.
|
||||||
|
|
||||||
|
4. PSF is making Python available to Licensee on an "AS IS"
|
||||||
|
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||||
|
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||||
|
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||||
|
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||||
|
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||||
|
|
||||||
|
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||||
|
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||||
|
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||||
|
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||||
|
|
||||||
|
6. This License Agreement will automatically terminate upon a material
|
||||||
|
breach of its terms and conditions.
|
||||||
|
|
||||||
|
7. Nothing in this License Agreement shall be deemed to create any
|
||||||
|
relationship of agency, partnership, or joint venture between PSF and
|
||||||
|
Licensee. This License Agreement does not grant permission to use PSF
|
||||||
|
trademarks or trade name in a trademark sense to endorse or promote
|
||||||
|
products or services of Licensee, or any third party.
|
||||||
|
|
||||||
|
8. By copying, installing or otherwise using Python, Licensee
|
||||||
|
agrees to be bound by the terms and conditions of this License
|
||||||
|
Agreement.
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and/or distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
||||||
|
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
||||||
|
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
||||||
|
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
||||||
|
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
||||||
|
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
||||||
|
PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|||||||
309
pymongo/_asyncio_lock.py
Normal file
309
pymongo/_asyncio_lock.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
|
||||||
|
|
||||||
|
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
|
||||||
|
to port 3.13 fixes to older versions of Python.
|
||||||
|
Can be removed once we drop Python 3.12 support."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import threading
|
||||||
|
from asyncio import events, exceptions
|
||||||
|
from typing import Any, Coroutine, Optional
|
||||||
|
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class _LoopBoundMixin:
|
||||||
|
_loop = None
|
||||||
|
|
||||||
|
def _get_loop(self) -> Any:
|
||||||
|
loop = events._get_running_loop()
|
||||||
|
|
||||||
|
if self._loop is None:
|
||||||
|
with _global_lock:
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = loop
|
||||||
|
if loop is not self._loop:
|
||||||
|
raise RuntimeError(f"{self!r} is bound to a different event loop")
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextManagerMixin:
|
||||||
|
async def __aenter__(self) -> None:
|
||||||
|
await self.acquire() # type: ignore[attr-defined]
|
||||||
|
# We have no use for the "as ..." clause in the with
|
||||||
|
# statement for locks.
|
||||||
|
return
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||||
|
self.release() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
class Lock(_ContextManagerMixin, _LoopBoundMixin):
|
||||||
|
"""Primitive lock objects.
|
||||||
|
|
||||||
|
A primitive lock is a synchronization primitive that is not owned
|
||||||
|
by a particular task when locked. A primitive lock is in one
|
||||||
|
of two states, 'locked' or 'unlocked'.
|
||||||
|
|
||||||
|
It is created in the unlocked state. It has two basic methods,
|
||||||
|
acquire() and release(). When the state is unlocked, acquire()
|
||||||
|
changes the state to locked and returns immediately. When the
|
||||||
|
state is locked, acquire() blocks until a call to release() in
|
||||||
|
another task changes it to unlocked, then the acquire() call
|
||||||
|
resets it to locked and returns. The release() method should only
|
||||||
|
be called in the locked state; it changes the state to unlocked
|
||||||
|
and returns immediately. If an attempt is made to release an
|
||||||
|
unlocked lock, a RuntimeError will be raised.
|
||||||
|
|
||||||
|
When more than one task is blocked in acquire() waiting for
|
||||||
|
the state to turn to unlocked, only one task proceeds when a
|
||||||
|
release() call resets the state to unlocked; successive release()
|
||||||
|
calls will unblock tasks in FIFO order.
|
||||||
|
|
||||||
|
Locks also support the asynchronous context management protocol.
|
||||||
|
'async with lock' statement should be used.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
...
|
||||||
|
await lock.acquire()
|
||||||
|
try:
|
||||||
|
...
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
Context manager usage:
|
||||||
|
|
||||||
|
lock = Lock()
|
||||||
|
...
|
||||||
|
async with lock:
|
||||||
|
...
|
||||||
|
|
||||||
|
Lock objects can be tested for locking state:
|
||||||
|
|
||||||
|
if not lock.locked():
|
||||||
|
await lock.acquire()
|
||||||
|
else:
|
||||||
|
# lock is acquired
|
||||||
|
...
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._waiters: Optional[collections.deque] = None
|
||||||
|
self._locked = False
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
res = super().__repr__()
|
||||||
|
extra = "locked" if self._locked else "unlocked"
|
||||||
|
if self._waiters:
|
||||||
|
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||||
|
return f"<{res[1:-1]} [{extra}]>"
|
||||||
|
|
||||||
|
def locked(self) -> bool:
|
||||||
|
"""Return True if lock is acquired."""
|
||||||
|
return self._locked
|
||||||
|
|
||||||
|
async def acquire(self) -> bool:
|
||||||
|
"""Acquire a lock.
|
||||||
|
|
||||||
|
This method blocks until the lock is unlocked, then sets it to
|
||||||
|
locked and returns True.
|
||||||
|
"""
|
||||||
|
# Implement fair scheduling, where thread always waits
|
||||||
|
# its turn. Jumping the queue if all are cancelled is an optimization.
|
||||||
|
if not self._locked and (
|
||||||
|
self._waiters is None or all(w.cancelled() for w in self._waiters)
|
||||||
|
):
|
||||||
|
self._locked = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._waiters is None:
|
||||||
|
self._waiters = collections.deque()
|
||||||
|
fut = self._get_loop().create_future()
|
||||||
|
self._waiters.append(fut)
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
await fut
|
||||||
|
finally:
|
||||||
|
self._waiters.remove(fut)
|
||||||
|
except exceptions.CancelledError:
|
||||||
|
# Currently the only exception designed be able to occur here.
|
||||||
|
|
||||||
|
# Ensure the lock invariant: If lock is not claimed (or about
|
||||||
|
# to be claimed by us) and there is a Task in waiters,
|
||||||
|
# ensure that the Task at the head will run.
|
||||||
|
if not self._locked:
|
||||||
|
self._wake_up_first()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# assert self._locked is False
|
||||||
|
self._locked = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
"""Release a lock.
|
||||||
|
|
||||||
|
When the lock is locked, reset it to unlocked, and return.
|
||||||
|
If any other tasks are blocked waiting for the lock to become
|
||||||
|
unlocked, allow exactly one of them to proceed.
|
||||||
|
|
||||||
|
When invoked on an unlocked lock, a RuntimeError is raised.
|
||||||
|
|
||||||
|
There is no return value.
|
||||||
|
"""
|
||||||
|
if self._locked:
|
||||||
|
self._locked = False
|
||||||
|
self._wake_up_first()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Lock is not acquired.")
|
||||||
|
|
||||||
|
def _wake_up_first(self) -> None:
|
||||||
|
"""Ensure that the first waiter will wake up."""
|
||||||
|
if not self._waiters:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
fut = next(iter(self._waiters))
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
|
||||||
|
# .done() means that the waiter is already set to wake up.
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(True)
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(_ContextManagerMixin, _LoopBoundMixin):
|
||||||
|
"""Asynchronous equivalent to threading.Condition.
|
||||||
|
|
||||||
|
This class implements condition variable objects. A condition variable
|
||||||
|
allows one or more tasks to wait until they are notified by another
|
||||||
|
task.
|
||||||
|
|
||||||
|
A new Lock object is created and used as the underlying lock.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock: Optional[Lock] = None) -> None:
|
||||||
|
if lock is None:
|
||||||
|
lock = Lock()
|
||||||
|
|
||||||
|
self._lock = lock
|
||||||
|
# Export the lock's locked(), acquire() and release() methods.
|
||||||
|
self.locked = lock.locked
|
||||||
|
self.acquire = lock.acquire
|
||||||
|
self.release = lock.release
|
||||||
|
|
||||||
|
self._waiters: collections.deque = collections.deque()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
res = super().__repr__()
|
||||||
|
extra = "locked" if self.locked() else "unlocked"
|
||||||
|
if self._waiters:
|
||||||
|
extra = f"{extra}, waiters:{len(self._waiters)}"
|
||||||
|
return f"<{res[1:-1]} [{extra}]>"
|
||||||
|
|
||||||
|
async def wait(self) -> bool:
|
||||||
|
"""Wait until notified.
|
||||||
|
|
||||||
|
If the calling task has not acquired the lock when this
|
||||||
|
method is called, a RuntimeError is raised.
|
||||||
|
|
||||||
|
This method releases the underlying lock, and then blocks
|
||||||
|
until it is awakened by a notify() or notify_all() call for
|
||||||
|
the same condition variable in another task. Once
|
||||||
|
awakened, it re-acquires the lock and returns True.
|
||||||
|
|
||||||
|
This method may return spuriously,
|
||||||
|
which is why the caller should always
|
||||||
|
re-check the state and be prepared to wait() again.
|
||||||
|
"""
|
||||||
|
if not self.locked():
|
||||||
|
raise RuntimeError("cannot wait on un-acquired lock")
|
||||||
|
|
||||||
|
fut = self._get_loop().create_future()
|
||||||
|
self.release()
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
self._waiters.append(fut)
|
||||||
|
try:
|
||||||
|
await fut
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
self._waiters.remove(fut)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Must re-acquire lock even if wait is cancelled.
|
||||||
|
# We only catch CancelledError here, since we don't want any
|
||||||
|
# other (fatal) errors with the future to cause us to spin.
|
||||||
|
err = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self.acquire()
|
||||||
|
break
|
||||||
|
except exceptions.CancelledError as e:
|
||||||
|
err = e
|
||||||
|
|
||||||
|
if err is not None:
|
||||||
|
try:
|
||||||
|
raise err # Re-raise most recent exception instance.
|
||||||
|
finally:
|
||||||
|
err = None # Break reference cycles.
|
||||||
|
except BaseException:
|
||||||
|
# Any error raised out of here _may_ have occurred after this Task
|
||||||
|
# believed to have been successfully notified.
|
||||||
|
# Make sure to notify another Task instead. This may result
|
||||||
|
# in a "spurious wakeup", which is allowed as part of the
|
||||||
|
# Condition Variable protocol.
|
||||||
|
self._notify(1)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def wait_for(self, predicate: Any) -> Coroutine:
|
||||||
|
"""Wait until a predicate becomes true.
|
||||||
|
|
||||||
|
The predicate should be a callable whose result will be
|
||||||
|
interpreted as a boolean value. The method will repeatedly
|
||||||
|
wait() until it evaluates to true. The final predicate value is
|
||||||
|
the return value.
|
||||||
|
"""
|
||||||
|
result = predicate()
|
||||||
|
while not result:
|
||||||
|
await self.wait()
|
||||||
|
result = predicate()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def notify(self, n: int = 1) -> None:
|
||||||
|
"""By default, wake up one task waiting on this condition, if any.
|
||||||
|
If the calling task has not acquired the lock when this method
|
||||||
|
is called, a RuntimeError is raised.
|
||||||
|
|
||||||
|
This method wakes up n of the tasks waiting for the condition
|
||||||
|
variable; if fewer than n are waiting, they are all awoken.
|
||||||
|
|
||||||
|
Note: an awakened task does not actually return from its
|
||||||
|
wait() call until it can reacquire the lock. Since notify() does
|
||||||
|
not release the lock, its caller should.
|
||||||
|
"""
|
||||||
|
if not self.locked():
|
||||||
|
raise RuntimeError("cannot notify on un-acquired lock")
|
||||||
|
self._notify(n)
|
||||||
|
|
||||||
|
def _notify(self, n: int) -> None:
|
||||||
|
idx = 0
|
||||||
|
for fut in self._waiters:
|
||||||
|
if idx >= n:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not fut.done():
|
||||||
|
idx += 1
|
||||||
|
fut.set_result(False)
|
||||||
|
|
||||||
|
def notify_all(self) -> None:
|
||||||
|
"""Wake up all tasks waiting on this condition. This method acts
|
||||||
|
like notify(), but wakes up all waiting tasks instead of one. If the
|
||||||
|
calling task has not acquired the lock when this method is called,
|
||||||
|
a RuntimeError is raised.
|
||||||
|
"""
|
||||||
|
self.notify(len(self._waiters))
|
||||||
49
pymongo/_asyncio_task.py
Normal file
49
pymongo/_asyncio_task.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2024-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
|
||||||
|
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from typing import Any, Coroutine, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
|
||||||
|
class _Task(asyncio.Task):
|
||||||
|
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
|
||||||
|
super().__init__(coro, name=name)
|
||||||
|
self._cancel_requests = 0
|
||||||
|
asyncio._register_task(self)
|
||||||
|
|
||||||
|
def cancel(self, msg: Optional[str] = None) -> bool:
|
||||||
|
self._cancel_requests += 1
|
||||||
|
return super().cancel(msg=msg)
|
||||||
|
|
||||||
|
def uncancel(self) -> int:
|
||||||
|
if self._cancel_requests > 0:
|
||||||
|
self._cancel_requests -= 1
|
||||||
|
return self._cancel_requests
|
||||||
|
|
||||||
|
def cancelling(self) -> int:
|
||||||
|
return self._cancel_requests
|
||||||
|
|
||||||
|
|
||||||
|
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
return asyncio.create_task(coro, name=name)
|
||||||
|
return _Task(coro, name=name)
|
||||||
@ -476,7 +476,6 @@ class _AsyncClientBulk:
|
|||||||
if op_type == "delete":
|
if op_type == "delete":
|
||||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||||
full_result[f"{op_type}Results"][original_index] = res
|
full_result[f"{op_type}Results"][original_index] = res
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Attempt to close the cursor, then raise top-level error.
|
# Attempt to close the cursor, then raise top-level error.
|
||||||
if cmd_cursor.alive:
|
if cmd_cursor.alive:
|
||||||
|
|||||||
@ -45,7 +45,7 @@ from pymongo.common import (
|
|||||||
)
|
)
|
||||||
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
|
||||||
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
|
||||||
from pymongo.lock import _ALock, _create_lock
|
from pymongo.lock import _async_create_lock
|
||||||
from pymongo.message import (
|
from pymongo.message import (
|
||||||
_CursorAddress,
|
_CursorAddress,
|
||||||
_GetMore,
|
_GetMore,
|
||||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
|||||||
def __init__(self, conn: AsyncConnection, more_to_come: bool):
|
def __init__(self, conn: AsyncConnection, more_to_come: bool):
|
||||||
self.conn: Optional[AsyncConnection] = conn
|
self.conn: Optional[AsyncConnection] = conn
|
||||||
self.more_to_come = more_to_come
|
self.more_to_come = more_to_come
|
||||||
self._alock = _ALock(_create_lock())
|
self._lock = _async_create_lock()
|
||||||
|
|
||||||
def update_exhaust(self, more_to_come: bool) -> None:
|
def update_exhaust(self, more_to_come: bool) -> None:
|
||||||
self.more_to_come = more_to_come
|
self.more_to_come = more_to_come
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
"""Support for explicit client-side field level encryption."""
|
"""Support for explicit client-side field level encryption."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
import socket
|
import socket
|
||||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
|||||||
# BSON encoding/decoding errors are unrelated to encryption so
|
# BSON encoding/decoding errors are unrelated to encryption so
|
||||||
# we should propagate them unchanged.
|
# we should propagate them unchanged.
|
||||||
raise
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise EncryptionError(exc) from exc
|
raise EncryptionError(exc) from exc
|
||||||
|
|
||||||
@ -200,6 +203,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
|||||||
conn.close()
|
conn.close()
|
||||||
except (PyMongoError, MongoCryptError):
|
except (PyMongoError, MongoCryptError):
|
||||||
raise # Propagate pymongo errors directly.
|
raise # Propagate pymongo errors directly.
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
# Wrap I/O errors in PyMongo exceptions.
|
# Wrap I/O errors in PyMongo exceptions.
|
||||||
_raise_connection_failure((host, port), error)
|
_raise_connection_failure((host, port), error)
|
||||||
@ -722,6 +727,8 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
|||||||
await database.create_collection(name=name, **kwargs),
|
await database.create_collection(name=name, **kwargs),
|
||||||
encrypted_fields,
|
encrypted_fields,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ access:
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -59,8 +60,8 @@ from typing import (
|
|||||||
|
|
||||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||||
from bson.timestamp import Timestamp
|
from bson.timestamp import Timestamp
|
||||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||||
from pymongo.asynchronous import client_session, database, periodic_executor
|
from pymongo.asynchronous import client_session, database
|
||||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||||
@ -82,7 +83,11 @@ from pymongo.errors import (
|
|||||||
WaitQueueTimeoutError,
|
WaitQueueTimeoutError,
|
||||||
WriteConcernError,
|
WriteConcernError,
|
||||||
)
|
)
|
||||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
|
from pymongo.lock import (
|
||||||
|
_HAS_REGISTER_AT_FORK,
|
||||||
|
_async_create_lock,
|
||||||
|
_release_locks,
|
||||||
|
)
|
||||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||||
from pymongo.monitoring import ConnectionClosedReason
|
from pymongo.monitoring import ConnectionClosedReason
|
||||||
@ -842,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||||
|
|
||||||
self._default_database_name = dbase
|
self._default_database_name = dbase
|
||||||
self._lock = _ALock(_create_lock())
|
self._lock = _async_create_lock()
|
||||||
self._kill_cursors_queue: list = []
|
self._kill_cursors_queue: list = []
|
||||||
|
|
||||||
self._event_listeners = options.pool_options._event_listeners
|
self._event_listeners = options.pool_options._event_listeners
|
||||||
@ -908,7 +913,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
await AsyncMongoClient._process_periodic_tasks(client)
|
await AsyncMongoClient._process_periodic_tasks(client)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
executor = periodic_executor.PeriodicExecutor(
|
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||||
interval=common.KILL_CURSOR_FREQUENCY,
|
interval=common.KILL_CURSOR_FREQUENCY,
|
||||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||||
target=target,
|
target=target,
|
||||||
@ -1722,7 +1727,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
address=address,
|
address=address,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with operation.conn_mgr._alock:
|
async with operation.conn_mgr._lock:
|
||||||
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||||
return await server.run_operation(
|
return await server.run_operation(
|
||||||
@ -1970,7 +1975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if conn_mgr:
|
if conn_mgr:
|
||||||
async with conn_mgr._alock:
|
async with conn_mgr._lock:
|
||||||
# Cursor is pinned to LB outside of a transaction.
|
# Cursor is pinned to LB outside of a transaction.
|
||||||
assert address is not None
|
assert address is not None
|
||||||
assert conn_mgr.conn is not None
|
assert conn_mgr.conn is not None
|
||||||
@ -2033,6 +2038,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||||
try:
|
try:
|
||||||
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
# Raise the exception when client is closed so that it
|
# Raise the exception when client is closed so that it
|
||||||
@ -2047,6 +2054,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
for address, cursor_ids in address_to_cursor_ids.items():
|
for address, cursor_ids in address_to_cursor_ids.items():
|
||||||
try:
|
try:
|
||||||
await self._kill_cursors(cursor_ids, address, topology, session=None)
|
await self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
raise
|
raise
|
||||||
@ -2061,6 +2070,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
try:
|
try:
|
||||||
await self._process_kill_cursors()
|
await self._process_kill_cursors()
|
||||||
await self._topology.update_pool()
|
await self._topology.update_pool()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -16,20 +16,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||||
|
|
||||||
from pymongo import common
|
from pymongo import common, periodic_executor
|
||||||
from pymongo._csot import MovingMinimum
|
from pymongo._csot import MovingMinimum
|
||||||
from pymongo.asynchronous import periodic_executor
|
|
||||||
from pymongo.asynchronous.periodic_executor import _shutdown_executors
|
|
||||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||||
from pymongo.hello import Hello
|
from pymongo.hello import Hello
|
||||||
from pymongo.lock import _create_lock
|
from pymongo.lock import _async_create_lock
|
||||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||||
|
from pymongo.periodic_executor import _shutdown_executors
|
||||||
from pymongo.pool_options import _is_faas
|
from pymongo.pool_options import _is_faas
|
||||||
from pymongo.read_preferences import MovingAverage
|
from pymongo.read_preferences import MovingAverage
|
||||||
from pymongo.server_description import ServerDescription
|
from pymongo.server_description import ServerDescription
|
||||||
@ -76,7 +76,7 @@ class MonitorBase:
|
|||||||
await monitor._run() # type:ignore[attr-defined]
|
await monitor._run() # type:ignore[attr-defined]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
executor = periodic_executor.PeriodicExecutor(
|
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||||
interval=interval, min_interval=min_interval, target=target, name=name
|
interval=interval, min_interval=min_interval, target=target, name=name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,9 +112,9 @@ class MonitorBase:
|
|||||||
"""
|
"""
|
||||||
self.gc_safe_close()
|
self.gc_safe_close()
|
||||||
|
|
||||||
def join(self, timeout: Optional[int] = None) -> None:
|
async def join(self, timeout: Optional[int] = None) -> None:
|
||||||
"""Wait for the monitor to stop."""
|
"""Wait for the monitor to stop."""
|
||||||
self._executor.join(timeout)
|
await self._executor.join(timeout)
|
||||||
|
|
||||||
def request_check(self) -> None:
|
def request_check(self) -> None:
|
||||||
"""If the monitor is sleeping, wake it soon."""
|
"""If the monitor is sleeping, wake it soon."""
|
||||||
@ -139,7 +139,7 @@ class Monitor(MonitorBase):
|
|||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
topology,
|
topology,
|
||||||
"pymongo_server_monitor_thread",
|
"pymongo_server_monitor_task",
|
||||||
topology_settings.heartbeat_frequency,
|
topology_settings.heartbeat_frequency,
|
||||||
common.MIN_HEARTBEAT_INTERVAL,
|
common.MIN_HEARTBEAT_INTERVAL,
|
||||||
)
|
)
|
||||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
|||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
# Topology was garbage-collected.
|
# Topology was garbage-collected.
|
||||||
await self.close()
|
await self.close()
|
||||||
|
finally:
|
||||||
|
if self._executor._stopped:
|
||||||
|
await self._rtt_monitor.close()
|
||||||
|
|
||||||
async def _check_server(self) -> ServerDescription:
|
async def _check_server(self) -> ServerDescription:
|
||||||
"""Call hello or read the next streaming response.
|
"""Call hello or read the next streaming response.
|
||||||
@ -252,8 +255,10 @@ class Monitor(MonitorBase):
|
|||||||
except (OperationFailure, NotPrimaryError) as exc:
|
except (OperationFailure, NotPrimaryError) as exc:
|
||||||
# Update max cluster time even when hello fails.
|
# Update max cluster time even when hello fails.
|
||||||
details = cast(Mapping[str, Any], exc.details)
|
details = cast(Mapping[str, Any], exc.details)
|
||||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
await self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||||
raise
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
raise
|
raise
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
@ -280,7 +285,7 @@ class Monitor(MonitorBase):
|
|||||||
await self._reset_connection()
|
await self._reset_connection()
|
||||||
if isinstance(error, _OperationCancelled):
|
if isinstance(error, _OperationCancelled):
|
||||||
raise
|
raise
|
||||||
self._rtt_monitor.reset()
|
await self._rtt_monitor.reset()
|
||||||
# Server type defaults to Unknown.
|
# Server type defaults to Unknown.
|
||||||
return ServerDescription(address, error=error)
|
return ServerDescription(address, error=error)
|
||||||
|
|
||||||
@ -321,9 +326,9 @@ class Monitor(MonitorBase):
|
|||||||
self._conn_id = conn.id
|
self._conn_id = conn.id
|
||||||
response, round_trip_time = await self._check_with_socket(conn)
|
response, round_trip_time = await self._check_with_socket(conn)
|
||||||
if not response.awaitable:
|
if not response.awaitable:
|
||||||
self._rtt_monitor.add_sample(round_trip_time)
|
await self._rtt_monitor.add_sample(round_trip_time)
|
||||||
|
|
||||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
avg_rtt, min_rtt = await self._rtt_monitor.get()
|
||||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||||
if self._publish:
|
if self._publish:
|
||||||
assert self._listeners is not None
|
assert self._listeners is not None
|
||||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
|||||||
if len(seedlist) == 0:
|
if len(seedlist) == 0:
|
||||||
# As per the spec: this should be treated as a failure.
|
# As per the spec: this should be treated as a failure.
|
||||||
raise Exception
|
raise Exception
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
# As per the spec, upon encountering an error:
|
# As per the spec, upon encountering an error:
|
||||||
# - An error must not be raised
|
# - An error must not be raised
|
||||||
@ -439,7 +446,7 @@ class _RttMonitor(MonitorBase):
|
|||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
topology,
|
topology,
|
||||||
"pymongo_server_rtt_thread",
|
"pymongo_server_rtt_task",
|
||||||
topology_settings.heartbeat_frequency,
|
topology_settings.heartbeat_frequency,
|
||||||
common.MIN_HEARTBEAT_INTERVAL,
|
common.MIN_HEARTBEAT_INTERVAL,
|
||||||
)
|
)
|
||||||
@ -447,7 +454,7 @@ class _RttMonitor(MonitorBase):
|
|||||||
self._pool = pool
|
self._pool = pool
|
||||||
self._moving_average = MovingAverage()
|
self._moving_average = MovingAverage()
|
||||||
self._moving_min = MovingMinimum()
|
self._moving_min = MovingMinimum()
|
||||||
self._lock = _create_lock()
|
self._lock = _async_create_lock()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
self.gc_safe_close()
|
self.gc_safe_close()
|
||||||
@ -455,20 +462,20 @@ class _RttMonitor(MonitorBase):
|
|||||||
# thread has the socket checked out, it will be closed when checked in.
|
# thread has the socket checked out, it will be closed when checked in.
|
||||||
await self._pool.reset()
|
await self._pool.reset()
|
||||||
|
|
||||||
def add_sample(self, sample: float) -> None:
|
async def add_sample(self, sample: float) -> None:
|
||||||
"""Add a RTT sample."""
|
"""Add a RTT sample."""
|
||||||
with self._lock:
|
async with self._lock:
|
||||||
self._moving_average.add_sample(sample)
|
self._moving_average.add_sample(sample)
|
||||||
self._moving_min.add_sample(sample)
|
self._moving_min.add_sample(sample)
|
||||||
|
|
||||||
def get(self) -> tuple[Optional[float], float]:
|
async def get(self) -> tuple[Optional[float], float]:
|
||||||
"""Get the calculated average, or None if no samples yet and the min."""
|
"""Get the calculated average, or None if no samples yet and the min."""
|
||||||
with self._lock:
|
async with self._lock:
|
||||||
return self._moving_average.get(), self._moving_min.get()
|
return self._moving_average.get(), self._moving_min.get()
|
||||||
|
|
||||||
def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""Reset the average RTT."""
|
"""Reset the average RTT."""
|
||||||
with self._lock:
|
async with self._lock:
|
||||||
self._moving_average.reset()
|
self._moving_average.reset()
|
||||||
self._moving_min.reset()
|
self._moving_min.reset()
|
||||||
|
|
||||||
@ -478,10 +485,12 @@ class _RttMonitor(MonitorBase):
|
|||||||
# heartbeat protocol (MongoDB 4.4+).
|
# heartbeat protocol (MongoDB 4.4+).
|
||||||
# XXX: Skip check if the server is unknown?
|
# XXX: Skip check if the server is unknown?
|
||||||
rtt = await self._ping()
|
rtt = await self._ping()
|
||||||
self.add_sample(rtt)
|
await self.add_sample(rtt)
|
||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
# Topology was garbage-collected.
|
# Topology was garbage-collected.
|
||||||
await self.close()
|
await self.close()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
await self._pool.reset()
|
await self._pool.reset()
|
||||||
|
|
||||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
|||||||
shutdown()
|
shutdown()
|
||||||
|
|
||||||
|
|
||||||
atexit.register(_shutdown_resources)
|
if _IS_SYNC:
|
||||||
|
atexit.register(_shutdown_resources)
|
||||||
|
|||||||
@ -1,219 +0,0 @@
|
|||||||
# Copyright 2014-present MongoDB, Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
|
||||||
# may not use this file except in compliance with the License. You
|
|
||||||
# may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
# implied. See the License for the specific language governing
|
|
||||||
# permissions and limitations under the License.
|
|
||||||
|
|
||||||
"""Run a target function on a background thread."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import weakref
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pymongo.lock import _ALock, _create_lock
|
|
||||||
|
|
||||||
_IS_SYNC = False
|
|
||||||
|
|
||||||
|
|
||||||
class PeriodicExecutor:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
interval: float,
|
|
||||||
min_interval: float,
|
|
||||||
target: Any,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""Run a target function periodically on a background thread.
|
|
||||||
|
|
||||||
If the target's return value is false, the executor stops.
|
|
||||||
|
|
||||||
:param interval: Seconds between calls to `target`.
|
|
||||||
:param min_interval: Minimum seconds between calls if `wake` is
|
|
||||||
called very often.
|
|
||||||
:param target: A function.
|
|
||||||
:param name: A name to give the underlying thread.
|
|
||||||
"""
|
|
||||||
# threading.Event and its internal condition variable are expensive
|
|
||||||
# in Python 2, see PYTHON-983. Use a boolean to know when to wake.
|
|
||||||
# The executor's design is constrained by several Python issues, see
|
|
||||||
# "periodic_executor.rst" in this repository.
|
|
||||||
self._event = False
|
|
||||||
self._interval = interval
|
|
||||||
self._min_interval = min_interval
|
|
||||||
self._target = target
|
|
||||||
self._stopped = False
|
|
||||||
self._thread: Optional[threading.Thread] = None
|
|
||||||
self._name = name
|
|
||||||
self._skip_sleep = False
|
|
||||||
self._thread_will_exit = False
|
|
||||||
self._lock = _ALock(_create_lock())
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
|
||||||
|
|
||||||
def _run_async(self) -> None:
|
|
||||||
# The default asyncio loop implementation on Windows
|
|
||||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
|
||||||
# We explicitly use a different loop implementation here to prevent that issue
|
|
||||||
if sys.platform == "win32":
|
|
||||||
loop = asyncio.SelectorEventLoop()
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
else:
|
|
||||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
|
||||||
|
|
||||||
def open(self) -> None:
|
|
||||||
"""Start. Multiple calls have no effect.
|
|
||||||
|
|
||||||
Not safe to call from multiple threads at once.
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
if self._thread_will_exit:
|
|
||||||
# If the background thread has read self._stopped as True
|
|
||||||
# there is a chance that it has not yet exited. The call to
|
|
||||||
# join should not block indefinitely because there is no
|
|
||||||
# other work done outside the while loop in self._run.
|
|
||||||
try:
|
|
||||||
assert self._thread is not None
|
|
||||||
self._thread.join()
|
|
||||||
except ReferenceError:
|
|
||||||
# Thread terminated.
|
|
||||||
pass
|
|
||||||
self._thread_will_exit = False
|
|
||||||
self._stopped = False
|
|
||||||
started: Any = False
|
|
||||||
try:
|
|
||||||
started = self._thread and self._thread.is_alive()
|
|
||||||
except ReferenceError:
|
|
||||||
# Thread terminated.
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
if _IS_SYNC:
|
|
||||||
thread = threading.Thread(target=self._run, name=self._name)
|
|
||||||
else:
|
|
||||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
|
||||||
thread.daemon = True
|
|
||||||
self._thread = weakref.proxy(thread)
|
|
||||||
_register_executor(self)
|
|
||||||
# Mitigation to RuntimeError firing when thread starts on shutdown
|
|
||||||
# https://github.com/python/cpython/issues/114570
|
|
||||||
try:
|
|
||||||
thread.start()
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "interpreter shutdown" in str(e) or sys.is_finalizing():
|
|
||||||
self._thread = None
|
|
||||||
return
|
|
||||||
raise
|
|
||||||
|
|
||||||
def close(self, dummy: Any = None) -> None:
|
|
||||||
"""Stop. To restart, call open().
|
|
||||||
|
|
||||||
The dummy parameter allows an executor's close method to be a weakref
|
|
||||||
callback; see monitor.py.
|
|
||||||
"""
|
|
||||||
self._stopped = True
|
|
||||||
|
|
||||||
def join(self, timeout: Optional[int] = None) -> None:
|
|
||||||
if self._thread is not None:
|
|
||||||
try:
|
|
||||||
self._thread.join(timeout)
|
|
||||||
except (ReferenceError, RuntimeError):
|
|
||||||
# Thread already terminated, or not yet started.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def wake(self) -> None:
|
|
||||||
"""Execute the target function soon."""
|
|
||||||
self._event = True
|
|
||||||
|
|
||||||
def update_interval(self, new_interval: int) -> None:
|
|
||||||
self._interval = new_interval
|
|
||||||
|
|
||||||
def skip_sleep(self) -> None:
|
|
||||||
self._skip_sleep = True
|
|
||||||
|
|
||||||
async def _should_stop(self) -> bool:
|
|
||||||
async with self._lock:
|
|
||||||
if self._stopped:
|
|
||||||
self._thread_will_exit = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _run(self) -> None:
|
|
||||||
while not await self._should_stop():
|
|
||||||
try:
|
|
||||||
if not await self._target():
|
|
||||||
self._stopped = True
|
|
||||||
break
|
|
||||||
except BaseException:
|
|
||||||
async with self._lock:
|
|
||||||
self._stopped = True
|
|
||||||
self._thread_will_exit = True
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
if self._skip_sleep:
|
|
||||||
self._skip_sleep = False
|
|
||||||
else:
|
|
||||||
deadline = time.monotonic() + self._interval
|
|
||||||
while not self._stopped and time.monotonic() < deadline:
|
|
||||||
await asyncio.sleep(self._min_interval)
|
|
||||||
if self._event:
|
|
||||||
break # Early wake.
|
|
||||||
|
|
||||||
self._event = False
|
|
||||||
|
|
||||||
|
|
||||||
# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started,
|
|
||||||
# an executor is kept alive by a strong reference from its thread and perhaps
|
|
||||||
# from other objects. When the thread dies and all other referrers are freed,
|
|
||||||
# the executor is freed and removed from _EXECUTORS. If any threads are
|
|
||||||
# running when the interpreter begins to shut down, we try to halt and join
|
|
||||||
# them to avoid spurious errors.
|
|
||||||
_EXECUTORS = set()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_executor(executor: PeriodicExecutor) -> None:
|
|
||||||
ref = weakref.ref(executor, _on_executor_deleted)
|
|
||||||
_EXECUTORS.add(ref)
|
|
||||||
|
|
||||||
|
|
||||||
def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None:
|
|
||||||
_EXECUTORS.remove(ref)
|
|
||||||
|
|
||||||
|
|
||||||
def _shutdown_executors() -> None:
|
|
||||||
if _EXECUTORS is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Copy the set. Stopping threads has the side effect of removing executors.
|
|
||||||
executors = list(_EXECUTORS)
|
|
||||||
|
|
||||||
# First signal all executors to close...
|
|
||||||
for ref in executors:
|
|
||||||
executor = ref()
|
|
||||||
if executor:
|
|
||||||
executor.close()
|
|
||||||
|
|
||||||
# ...then try to join them.
|
|
||||||
for ref in executors:
|
|
||||||
executor = ref()
|
|
||||||
if executor:
|
|
||||||
executor.join(1)
|
|
||||||
|
|
||||||
executor = None
|
|
||||||
@ -23,7 +23,6 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -65,7 +64,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
|||||||
_CertificateError,
|
_CertificateError,
|
||||||
)
|
)
|
||||||
from pymongo.hello import Hello, HelloCompat
|
from pymongo.hello import Hello, HelloCompat
|
||||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
from pymongo.lock import (
|
||||||
|
_async_cond_wait,
|
||||||
|
_async_create_condition,
|
||||||
|
_async_create_lock,
|
||||||
|
)
|
||||||
from pymongo.logger import (
|
from pymongo.logger import (
|
||||||
_CONNECTION_LOGGER,
|
_CONNECTION_LOGGER,
|
||||||
_ConnectionStatusMessage,
|
_ConnectionStatusMessage,
|
||||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
|||||||
raise AutoReconnect(msg) from error
|
raise AutoReconnect(msg) from error
|
||||||
|
|
||||||
|
|
||||||
async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
|
|
||||||
timeout = deadline - time.monotonic() if deadline else None
|
|
||||||
return await condition.wait(timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||||
details = {}
|
details = {}
|
||||||
timeout = _csot.get_timeout()
|
timeout = _csot.get_timeout()
|
||||||
@ -706,6 +704,8 @@ class AsyncConnection:
|
|||||||
# shutdown.
|
# shutdown.
|
||||||
try:
|
try:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception: # noqa: S110
|
except Exception: # noqa: S110
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -992,8 +992,8 @@ class Pool:
|
|||||||
# from the right side.
|
# from the right side.
|
||||||
self.conns: collections.deque = collections.deque()
|
self.conns: collections.deque = collections.deque()
|
||||||
self.active_contexts: set[_CancellationContext] = set()
|
self.active_contexts: set[_CancellationContext] = set()
|
||||||
_lock = _create_lock()
|
self.lock = _async_create_lock()
|
||||||
self.lock = _ALock(_lock)
|
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||||
self.active_sockets = 0
|
self.active_sockets = 0
|
||||||
# Monotonically increasing connection ID required for CMAP Events.
|
# Monotonically increasing connection ID required for CMAP Events.
|
||||||
self.next_connection_id = 1
|
self.next_connection_id = 1
|
||||||
@ -1019,7 +1019,7 @@ class Pool:
|
|||||||
# The first portion of the wait queue.
|
# The first portion of the wait queue.
|
||||||
# Enforces: maxPoolSize
|
# Enforces: maxPoolSize
|
||||||
# Also used for: clearing the wait queue
|
# Also used for: clearing the wait queue
|
||||||
self.size_cond = _ACondition(threading.Condition(_lock))
|
self.size_cond = _async_create_condition(self.lock)
|
||||||
self.requests = 0
|
self.requests = 0
|
||||||
self.max_pool_size = self.opts.max_pool_size
|
self.max_pool_size = self.opts.max_pool_size
|
||||||
if not self.max_pool_size:
|
if not self.max_pool_size:
|
||||||
@ -1027,7 +1027,7 @@ class Pool:
|
|||||||
# The second portion of the wait queue.
|
# The second portion of the wait queue.
|
||||||
# Enforces: maxConnecting
|
# Enforces: maxConnecting
|
||||||
# Also used for: clearing the wait queue
|
# Also used for: clearing the wait queue
|
||||||
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
|
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||||
self._max_connecting = self.opts.max_connecting
|
self._max_connecting = self.opts.max_connecting
|
||||||
self._pending = 0
|
self._pending = 0
|
||||||
self._client_id = client_id
|
self._client_id = client_id
|
||||||
@ -1466,7 +1466,8 @@ class Pool:
|
|||||||
async with self.size_cond:
|
async with self.size_cond:
|
||||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||||
while not (self.requests < self.max_pool_size):
|
while not (self.requests < self.max_pool_size):
|
||||||
if not await _cond_wait(self.size_cond, deadline):
|
timeout = deadline - time.monotonic() if deadline else None
|
||||||
|
if not await _async_cond_wait(self.size_cond, timeout):
|
||||||
# Timed out, notify the next thread to ensure a
|
# Timed out, notify the next thread to ensure a
|
||||||
# timeout doesn't consume the condition.
|
# timeout doesn't consume the condition.
|
||||||
if self.requests < self.max_pool_size:
|
if self.requests < self.max_pool_size:
|
||||||
@ -1489,7 +1490,8 @@ class Pool:
|
|||||||
async with self._max_connecting_cond:
|
async with self._max_connecting_cond:
|
||||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||||
while not (self.conns or self._pending < self._max_connecting):
|
while not (self.conns or self._pending < self._max_connecting):
|
||||||
if not await _cond_wait(self._max_connecting_cond, deadline):
|
timeout = deadline - time.monotonic() if deadline else None
|
||||||
|
if not await _async_cond_wait(self._max_connecting_cond, timeout):
|
||||||
# Timed out, notify the next thread to ensure a
|
# Timed out, notify the next thread to ensure a
|
||||||
# timeout doesn't consume the condition.
|
# timeout doesn't consume the condition.
|
||||||
if self.conns or self._pending < self._max_connecting:
|
if self.conns or self._pending < self._max_connecting:
|
||||||
|
|||||||
@ -27,8 +27,7 @@ import weakref
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||||
|
|
||||||
from pymongo import _csot, common, helpers_shared
|
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||||
from pymongo.asynchronous import periodic_executor
|
|
||||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||||
from pymongo.asynchronous.monitor import SrvMonitor
|
from pymongo.asynchronous.monitor import SrvMonitor
|
||||||
from pymongo.asynchronous.pool import Pool
|
from pymongo.asynchronous.pool import Pool
|
||||||
@ -44,7 +43,11 @@ from pymongo.errors import (
|
|||||||
WriteError,
|
WriteError,
|
||||||
)
|
)
|
||||||
from pymongo.hello import Hello
|
from pymongo.hello import Hello
|
||||||
from pymongo.lock import _ACondition, _ALock, _create_lock
|
from pymongo.lock import (
|
||||||
|
_async_cond_wait,
|
||||||
|
_async_create_condition,
|
||||||
|
_async_create_lock,
|
||||||
|
)
|
||||||
from pymongo.logger import (
|
from pymongo.logger import (
|
||||||
_SDAM_LOGGER,
|
_SDAM_LOGGER,
|
||||||
_SERVER_SELECTION_LOGGER,
|
_SERVER_SELECTION_LOGGER,
|
||||||
@ -170,9 +173,10 @@ class Topology:
|
|||||||
self._seed_addresses = list(topology_description.server_descriptions())
|
self._seed_addresses = list(topology_description.server_descriptions())
|
||||||
self._opened = False
|
self._opened = False
|
||||||
self._closed = False
|
self._closed = False
|
||||||
_lock = _create_lock()
|
self._lock = _async_create_lock()
|
||||||
self._lock = _ALock(_lock)
|
self._condition = _async_create_condition(
|
||||||
self._condition = _ACondition(self._settings.condition_class(_lock))
|
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||||
|
)
|
||||||
self._servers: dict[_Address, Server] = {}
|
self._servers: dict[_Address, Server] = {}
|
||||||
self._pid: Optional[int] = None
|
self._pid: Optional[int] = None
|
||||||
self._max_cluster_time: Optional[ClusterTime] = None
|
self._max_cluster_time: Optional[ClusterTime] = None
|
||||||
@ -185,7 +189,7 @@ class Topology:
|
|||||||
async def target() -> bool:
|
async def target() -> bool:
|
||||||
return process_events_queue(weak)
|
return process_events_queue(weak)
|
||||||
|
|
||||||
executor = periodic_executor.PeriodicExecutor(
|
executor = periodic_executor.AsyncPeriodicExecutor(
|
||||||
interval=common.EVENTS_QUEUE_FREQUENCY,
|
interval=common.EVENTS_QUEUE_FREQUENCY,
|
||||||
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
min_interval=common.MIN_HEARTBEAT_INTERVAL,
|
||||||
target=target,
|
target=target,
|
||||||
@ -354,7 +358,7 @@ class Topology:
|
|||||||
# change, or for a timeout. We won't miss any changes that
|
# change, or for a timeout. We won't miss any changes that
|
||||||
# came after our most recent apply_selector call, since we've
|
# came after our most recent apply_selector call, since we've
|
||||||
# held the lock until now.
|
# held the lock until now.
|
||||||
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||||
self._description.check_compatible()
|
self._description.check_compatible()
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
server_descriptions = self._description.apply_selector(
|
server_descriptions = self._description.apply_selector(
|
||||||
@ -654,7 +658,7 @@ class Topology:
|
|||||||
"""Wake all monitors, wait for at least one to check its server."""
|
"""Wake all monitors, wait for at least one to check its server."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._request_check_all()
|
self._request_check_all()
|
||||||
await self._condition.wait(wait_time)
|
await _async_cond_wait(self._condition, wait_time)
|
||||||
|
|
||||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||||
"""Return a list of all data-bearing servers.
|
"""Return a list of all data-bearing servers.
|
||||||
@ -742,7 +746,7 @@ class Topology:
|
|||||||
if self._publish_server or self._publish_tp:
|
if self._publish_server or self._publish_tp:
|
||||||
# Make sure the events executor thread is fully closed before publishing the remaining events
|
# Make sure the events executor thread is fully closed before publishing the remaining events
|
||||||
self.__events_executor.close()
|
self.__events_executor.close()
|
||||||
self.__events_executor.join(1)
|
await self.__events_executor.join(1)
|
||||||
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
|
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
245
pymongo/lock.py
245
pymongo/lock.py
@ -11,15 +11,20 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Internal helpers for lock and condition coordination primitives."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from asyncio import wait_for
|
||||||
|
from typing import Any, Optional, TypeVar
|
||||||
|
|
||||||
|
import pymongo._asyncio_lock
|
||||||
|
|
||||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||||
|
|
||||||
@ -28,6 +33,15 @@ _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
|||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202)
|
||||||
|
# in older versions of Python
|
||||||
|
if sys.version_info >= (3, 13):
|
||||||
|
Lock = asyncio.Lock
|
||||||
|
Condition = asyncio.Condition
|
||||||
|
else:
|
||||||
|
Lock = pymongo._asyncio_lock.Lock
|
||||||
|
Condition = pymongo._asyncio_lock.Condition
|
||||||
|
|
||||||
|
|
||||||
def _create_lock() -> threading.Lock:
|
def _create_lock() -> threading.Lock:
|
||||||
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
||||||
@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock:
|
|||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def _async_create_lock() -> Lock:
|
||||||
|
"""Represents an asyncio.Lock."""
|
||||||
|
return Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_condition(
|
||||||
|
lock: threading.Lock, condition_class: Optional[Any] = None
|
||||||
|
) -> threading.Condition:
|
||||||
|
"""Represents a threading.Condition."""
|
||||||
|
if condition_class:
|
||||||
|
return condition_class(lock)
|
||||||
|
return threading.Condition(lock)
|
||||||
|
|
||||||
|
|
||||||
|
def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition:
|
||||||
|
"""Represents an asyncio.Condition."""
|
||||||
|
if condition_class:
|
||||||
|
return condition_class(lock)
|
||||||
|
return Condition(lock)
|
||||||
|
|
||||||
|
|
||||||
def _release_locks() -> None:
|
def _release_locks() -> None:
|
||||||
# Completed the fork, reset all the locks in the child.
|
# Completed the fork, reset all the locks in the child.
|
||||||
for lock in _forkable_locks:
|
for lock in _forkable_locks:
|
||||||
@ -46,202 +81,12 @@ def _release_locks() -> None:
|
|||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
|
|
||||||
# Needed only for synchro.py compat.
|
async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool:
|
||||||
def _Lock(lock: threading.Lock) -> threading.Lock:
|
try:
|
||||||
return lock
|
return await wait_for(condition.wait(), timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _ALock:
|
def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool:
|
||||||
__slots__ = ("_lock",)
|
return condition.wait(timeout)
|
||||||
|
|
||||||
def __init__(self, lock: threading.Lock) -> None:
|
|
||||||
self._lock = lock
|
|
||||||
|
|
||||||
def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
|
||||||
return self._lock.acquire(blocking=blocking, timeout=timeout)
|
|
||||||
|
|
||||||
async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
|
||||||
if timeout > 0:
|
|
||||||
tstart = time.monotonic()
|
|
||||||
while True:
|
|
||||||
acquired = self._lock.acquire(blocking=False)
|
|
||||||
if acquired:
|
|
||||||
return True
|
|
||||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
|
||||||
return False
|
|
||||||
if not blocking:
|
|
||||||
return False
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
def release(self) -> None:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> _ALock:
|
|
||||||
await self.a_acquire()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __enter__(self) -> _ALock:
|
|
||||||
self._lock.acquire()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
|
||||||
self.release()
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
|
||||||
self.release()
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_set_result(fut: asyncio.Future) -> None:
|
|
||||||
# Ensure the future hasn't been cancelled before calling set_result.
|
|
||||||
if not fut.done():
|
|
||||||
fut.set_result(False)
|
|
||||||
|
|
||||||
|
|
||||||
class _ACondition:
|
|
||||||
__slots__ = ("_condition", "_waiters")
|
|
||||||
|
|
||||||
def __init__(self, condition: threading.Condition) -> None:
|
|
||||||
self._condition = condition
|
|
||||||
self._waiters: collections.deque = collections.deque()
|
|
||||||
|
|
||||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
|
||||||
if timeout > 0:
|
|
||||||
tstart = time.monotonic()
|
|
||||||
while True:
|
|
||||||
acquired = self._condition.acquire(blocking=False)
|
|
||||||
if acquired:
|
|
||||||
return True
|
|
||||||
if timeout > 0 and (time.monotonic() - tstart) > timeout:
|
|
||||||
return False
|
|
||||||
if not blocking:
|
|
||||||
return False
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
|
||||||
"""Wait until notified.
|
|
||||||
|
|
||||||
If the calling task has not acquired the lock when this
|
|
||||||
method is called, a RuntimeError is raised.
|
|
||||||
|
|
||||||
This method releases the underlying lock, and then blocks
|
|
||||||
until it is awakened by a notify() or notify_all() call for
|
|
||||||
the same condition variable in another task. Once
|
|
||||||
awakened, it re-acquires the lock and returns True.
|
|
||||||
|
|
||||||
This method may return spuriously,
|
|
||||||
which is why the caller should always
|
|
||||||
re-check the state and be prepared to wait() again.
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
fut = loop.create_future()
|
|
||||||
self._waiters.append((loop, fut))
|
|
||||||
self.release()
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(fut, timeout)
|
|
||||||
return True
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return False # Return false on timeout for sync pool compat.
|
|
||||||
finally:
|
|
||||||
# Must re-acquire lock even if wait is cancelled.
|
|
||||||
# We only catch CancelledError here, since we don't want any
|
|
||||||
# other (fatal) errors with the future to cause us to spin.
|
|
||||||
err = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await self.acquire()
|
|
||||||
break
|
|
||||||
except asyncio.exceptions.CancelledError as e:
|
|
||||||
err = e
|
|
||||||
|
|
||||||
self._waiters.remove((loop, fut))
|
|
||||||
if err is not None:
|
|
||||||
try:
|
|
||||||
raise err # Re-raise most recent exception instance.
|
|
||||||
finally:
|
|
||||||
err = None # Break reference cycles.
|
|
||||||
except BaseException:
|
|
||||||
# Any error raised out of here _may_ have occurred after this Task
|
|
||||||
# believed to have been successfully notified.
|
|
||||||
# Make sure to notify another Task instead. This may result
|
|
||||||
# in a "spurious wakeup", which is allowed as part of the
|
|
||||||
# Condition Variable protocol.
|
|
||||||
self.notify(1)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
|
|
||||||
"""Wait until a predicate becomes true.
|
|
||||||
|
|
||||||
The predicate should be a callable whose result will be
|
|
||||||
interpreted as a boolean value. The method will repeatedly
|
|
||||||
wait() until it evaluates to true. The final predicate value is
|
|
||||||
the return value.
|
|
||||||
"""
|
|
||||||
result = predicate()
|
|
||||||
while not result:
|
|
||||||
await self.wait()
|
|
||||||
result = predicate()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def notify(self, n: int = 1) -> None:
|
|
||||||
"""By default, wake up one coroutine waiting on this condition, if any.
|
|
||||||
If the calling coroutine has not acquired the lock when this method
|
|
||||||
is called, a RuntimeError is raised.
|
|
||||||
|
|
||||||
This method wakes up at most n of the coroutines waiting for the
|
|
||||||
condition variable; it is a no-op if no coroutines are waiting.
|
|
||||||
|
|
||||||
Note: an awakened coroutine does not actually return from its
|
|
||||||
wait() call until it can reacquire the lock. Since notify() does
|
|
||||||
not release the lock, its caller should.
|
|
||||||
"""
|
|
||||||
idx = 0
|
|
||||||
to_remove = []
|
|
||||||
for loop, fut in self._waiters:
|
|
||||||
if idx >= n:
|
|
||||||
break
|
|
||||||
|
|
||||||
if fut.done():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.call_soon_threadsafe(_safe_set_result, fut)
|
|
||||||
except RuntimeError:
|
|
||||||
# Loop was closed, ignore.
|
|
||||||
to_remove.append((loop, fut))
|
|
||||||
continue
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
for waiter in to_remove:
|
|
||||||
self._waiters.remove(waiter)
|
|
||||||
|
|
||||||
def notify_all(self) -> None:
|
|
||||||
"""Wake up all threads waiting on this condition. This method acts
|
|
||||||
like notify(), but wakes up all waiting threads instead of one. If the
|
|
||||||
calling thread has not acquired the lock when this method is called,
|
|
||||||
a RuntimeError is raised.
|
|
||||||
"""
|
|
||||||
self.notify(len(self._waiters))
|
|
||||||
|
|
||||||
def locked(self) -> bool:
|
|
||||||
"""Only needed for tests in test_locks."""
|
|
||||||
return self._condition._lock.locked() # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def release(self) -> None:
|
|
||||||
self._condition.release()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> _ACondition:
|
|
||||||
await self.acquire()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __enter__(self) -> _ACondition:
|
|
||||||
self._condition.acquire()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
|
||||||
self.release()
|
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
|
||||||
self.release()
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pymongo import _csot, ssl_support
|
from pymongo import _csot, ssl_support
|
||||||
|
from pymongo._asyncio_task import create_task
|
||||||
from pymongo.errors import _OperationCancelled
|
from pymongo.errors import _OperationCancelled
|
||||||
from pymongo.socket_checker import _errno_from_exception
|
from pymongo.socket_checker import _errno_from_exception
|
||||||
|
|
||||||
@ -259,19 +260,20 @@ async def async_receive_data(
|
|||||||
|
|
||||||
sock.settimeout(0.0)
|
sock.settimeout(0.0)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
|
cancellation_task = create_task(_poll_cancellation(conn))
|
||||||
try:
|
try:
|
||||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||||
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||||
tasks = [read_task, cancellation_task]
|
tasks = [read_task, cancellation_task]
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait(
|
||||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||||
)
|
)
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
await asyncio.wait(pending)
|
if pending:
|
||||||
|
await asyncio.wait(pending)
|
||||||
if len(done) == 0:
|
if len(done) == 0:
|
||||||
raise socket.timeout("timed out")
|
raise socket.timeout("timed out")
|
||||||
if read_task in done:
|
if read_task in done:
|
||||||
|
|||||||
@ -23,9 +23,102 @@ import time
|
|||||||
import weakref
|
import weakref
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pymongo._asyncio_task import create_task
|
||||||
from pymongo.lock import _create_lock
|
from pymongo.lock import _create_lock
|
||||||
|
|
||||||
_IS_SYNC = True
|
_IS_SYNC = False
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncPeriodicExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
interval: float,
|
||||||
|
min_interval: float,
|
||||||
|
target: Any,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Run a target function periodically on a background task.
|
||||||
|
|
||||||
|
If the target's return value is false, the executor stops.
|
||||||
|
|
||||||
|
:param interval: Seconds between calls to `target`.
|
||||||
|
:param min_interval: Minimum seconds between calls if `wake` is
|
||||||
|
called very often.
|
||||||
|
:param target: A function.
|
||||||
|
:param name: A name to give the underlying task.
|
||||||
|
"""
|
||||||
|
self._event = False
|
||||||
|
self._interval = interval
|
||||||
|
self._min_interval = min_interval
|
||||||
|
self._target = target
|
||||||
|
self._stopped = False
|
||||||
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
self._name = name
|
||||||
|
self._skip_sleep = False
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
"""Start. Multiple calls have no effect."""
|
||||||
|
self._stopped = False
|
||||||
|
|
||||||
|
if self._task is None or (
|
||||||
|
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
|
||||||
|
):
|
||||||
|
self._task = create_task(self._run(), name=self._name)
|
||||||
|
|
||||||
|
def close(self, dummy: Any = None) -> None:
|
||||||
|
"""Stop. To restart, call open().
|
||||||
|
|
||||||
|
The dummy parameter allows an executor's close method to be a weakref
|
||||||
|
callback; see monitor.py.
|
||||||
|
"""
|
||||||
|
self._stopped = True
|
||||||
|
|
||||||
|
async def join(self, timeout: Optional[int] = None) -> None:
|
||||||
|
if self._task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type]
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Task timed out
|
||||||
|
pass
|
||||||
|
except asyncio.exceptions.CancelledError:
|
||||||
|
# Task was already finished, or not yet started.
|
||||||
|
raise
|
||||||
|
|
||||||
|
def wake(self) -> None:
|
||||||
|
"""Execute the target function soon."""
|
||||||
|
self._event = True
|
||||||
|
|
||||||
|
def update_interval(self, new_interval: int) -> None:
|
||||||
|
self._interval = new_interval
|
||||||
|
|
||||||
|
def skip_sleep(self) -> None:
|
||||||
|
self._skip_sleep = True
|
||||||
|
|
||||||
|
async def _run(self) -> None:
|
||||||
|
while not self._stopped:
|
||||||
|
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
try:
|
||||||
|
if not await self._target():
|
||||||
|
self._stopped = True
|
||||||
|
break
|
||||||
|
except BaseException:
|
||||||
|
self._stopped = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
if self._skip_sleep:
|
||||||
|
self._skip_sleep = False
|
||||||
|
else:
|
||||||
|
deadline = time.monotonic() + self._interval
|
||||||
|
while not self._stopped and time.monotonic() < deadline:
|
||||||
|
await asyncio.sleep(self._min_interval)
|
||||||
|
if self._event:
|
||||||
|
break # Early wake.
|
||||||
|
|
||||||
|
self._event = False
|
||||||
|
|
||||||
|
|
||||||
class PeriodicExecutor:
|
class PeriodicExecutor:
|
||||||
@ -64,19 +157,6 @@ class PeriodicExecutor:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>"
|
||||||
|
|
||||||
def _run_async(self) -> None:
|
|
||||||
# The default asyncio loop implementation on Windows
|
|
||||||
# has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240)
|
|
||||||
# We explicitly use a different loop implementation here to prevent that issue
|
|
||||||
if sys.platform == "win32":
|
|
||||||
loop = asyncio.SelectorEventLoop()
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self._run()) # type: ignore[func-returns-value]
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
else:
|
|
||||||
asyncio.run(self._run()) # type: ignore[func-returns-value]
|
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
"""Start. Multiple calls have no effect.
|
"""Start. Multiple calls have no effect.
|
||||||
|
|
||||||
@ -104,10 +184,7 @@ class PeriodicExecutor:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if not started:
|
if not started:
|
||||||
if _IS_SYNC:
|
thread = threading.Thread(target=self._run, name=self._name)
|
||||||
thread = threading.Thread(target=self._run, name=self._name)
|
|
||||||
else:
|
|
||||||
thread = threading.Thread(target=self._run_async, name=self._name)
|
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
self._thread = weakref.proxy(thread)
|
self._thread = weakref.proxy(thread)
|
||||||
_register_executor(self)
|
_register_executor(self)
|
||||||
@ -474,7 +474,6 @@ class _ClientBulk:
|
|||||||
if op_type == "delete":
|
if op_type == "delete":
|
||||||
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
|
||||||
full_result[f"{op_type}Results"][original_index] = res
|
full_result[f"{op_type}Results"][original_index] = res
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Attempt to close the cursor, then raise top-level error.
|
# Attempt to close the cursor, then raise top-level error.
|
||||||
if cmd_cursor.alive:
|
if cmd_cursor.alive:
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class _ConnectionManager:
|
|||||||
def __init__(self, conn: Connection, more_to_come: bool):
|
def __init__(self, conn: Connection, more_to_come: bool):
|
||||||
self.conn: Optional[Connection] = conn
|
self.conn: Optional[Connection] = conn
|
||||||
self.more_to_come = more_to_come
|
self.more_to_come = more_to_come
|
||||||
self._alock = _create_lock()
|
self._lock = _create_lock()
|
||||||
|
|
||||||
def update_exhaust(self, more_to_come: bool) -> None:
|
def update_exhaust(self, more_to_come: bool) -> None:
|
||||||
self.more_to_come = more_to_come
|
self.more_to_come = more_to_come
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
"""Support for explicit client-side field level encryption."""
|
"""Support for explicit client-side field level encryption."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
import socket
|
import socket
|
||||||
@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
|
|||||||
# BSON encoding/decoding errors are unrelated to encryption so
|
# BSON encoding/decoding errors are unrelated to encryption so
|
||||||
# we should propagate them unchanged.
|
# we should propagate them unchanged.
|
||||||
raise
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise EncryptionError(exc) from exc
|
raise EncryptionError(exc) from exc
|
||||||
|
|
||||||
@ -200,6 +203,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
|||||||
conn.close()
|
conn.close()
|
||||||
except (PyMongoError, MongoCryptError):
|
except (PyMongoError, MongoCryptError):
|
||||||
raise # Propagate pymongo errors directly.
|
raise # Propagate pymongo errors directly.
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
# Wrap I/O errors in PyMongo exceptions.
|
# Wrap I/O errors in PyMongo exceptions.
|
||||||
_raise_connection_failure((host, port), error)
|
_raise_connection_failure((host, port), error)
|
||||||
@ -716,6 +721,8 @@ class ClientEncryption(Generic[_DocumentType]):
|
|||||||
database.create_collection(name=name, **kwargs),
|
database.create_collection(name=name, **kwargs),
|
||||||
encrypted_fields,
|
encrypted_fields,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
raise EncryptedCollectionError(exc, encrypted_fields) from exc
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ access:
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -58,7 +59,7 @@ from typing import (
|
|||||||
|
|
||||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||||
from bson.timestamp import Timestamp
|
from bson.timestamp import Timestamp
|
||||||
from pymongo import _csot, common, helpers_shared, uri_parser
|
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||||
from pymongo.client_options import ClientOptions
|
from pymongo.client_options import ClientOptions
|
||||||
from pymongo.errors import (
|
from pymongo.errors import (
|
||||||
AutoReconnect,
|
AutoReconnect,
|
||||||
@ -74,7 +75,11 @@ from pymongo.errors import (
|
|||||||
WaitQueueTimeoutError,
|
WaitQueueTimeoutError,
|
||||||
WriteConcernError,
|
WriteConcernError,
|
||||||
)
|
)
|
||||||
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks
|
from pymongo.lock import (
|
||||||
|
_HAS_REGISTER_AT_FORK,
|
||||||
|
_create_lock,
|
||||||
|
_release_locks,
|
||||||
|
)
|
||||||
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
|
||||||
from pymongo.message import _CursorAddress, _GetMore, _Query
|
from pymongo.message import _CursorAddress, _GetMore, _Query
|
||||||
from pymongo.monitoring import ConnectionClosedReason
|
from pymongo.monitoring import ConnectionClosedReason
|
||||||
@ -91,7 +96,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
|||||||
from pymongo.results import ClientBulkWriteResult
|
from pymongo.results import ClientBulkWriteResult
|
||||||
from pymongo.server_selectors import writable_server_selector
|
from pymongo.server_selectors import writable_server_selector
|
||||||
from pymongo.server_type import SERVER_TYPE
|
from pymongo.server_type import SERVER_TYPE
|
||||||
from pymongo.synchronous import client_session, database, periodic_executor
|
from pymongo.synchronous import client_session, database
|
||||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||||
@ -1716,7 +1721,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
address=address,
|
address=address,
|
||||||
)
|
)
|
||||||
|
|
||||||
with operation.conn_mgr._alock:
|
with operation.conn_mgr._lock:
|
||||||
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
|
||||||
err_handler.contribute_socket(operation.conn_mgr.conn)
|
err_handler.contribute_socket(operation.conn_mgr.conn)
|
||||||
return server.run_operation(
|
return server.run_operation(
|
||||||
@ -1964,7 +1969,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if conn_mgr:
|
if conn_mgr:
|
||||||
with conn_mgr._alock:
|
with conn_mgr._lock:
|
||||||
# Cursor is pinned to LB outside of a transaction.
|
# Cursor is pinned to LB outside of a transaction.
|
||||||
assert address is not None
|
assert address is not None
|
||||||
assert conn_mgr.conn is not None
|
assert conn_mgr.conn is not None
|
||||||
@ -2027,6 +2032,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
for address, cursor_id, conn_mgr in pinned_cursors:
|
for address, cursor_id, conn_mgr in pinned_cursors:
|
||||||
try:
|
try:
|
||||||
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
# Raise the exception when client is closed so that it
|
# Raise the exception when client is closed so that it
|
||||||
@ -2041,6 +2048,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
for address, cursor_ids in address_to_cursor_ids.items():
|
for address, cursor_ids in address_to_cursor_ids.items():
|
||||||
try:
|
try:
|
||||||
self._kill_cursors(cursor_ids, address, topology, session=None)
|
self._kill_cursors(cursor_ids, address, topology, session=None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
raise
|
raise
|
||||||
@ -2055,6 +2064,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
try:
|
try:
|
||||||
self._process_kill_cursors()
|
self._process_kill_cursors()
|
||||||
self._topology.update_pool()
|
self._topology.update_pool()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
if isinstance(exc, InvalidOperation) and self._topology._closed:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -16,24 +16,24 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
from typing import TYPE_CHECKING, Any, Mapping, Optional, cast
|
||||||
|
|
||||||
from pymongo import common
|
from pymongo import common, periodic_executor
|
||||||
from pymongo._csot import MovingMinimum
|
from pymongo._csot import MovingMinimum
|
||||||
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
|
||||||
from pymongo.hello import Hello
|
from pymongo.hello import Hello
|
||||||
from pymongo.lock import _create_lock
|
from pymongo.lock import _create_lock
|
||||||
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
|
||||||
|
from pymongo.periodic_executor import _shutdown_executors
|
||||||
from pymongo.pool_options import _is_faas
|
from pymongo.pool_options import _is_faas
|
||||||
from pymongo.read_preferences import MovingAverage
|
from pymongo.read_preferences import MovingAverage
|
||||||
from pymongo.server_description import ServerDescription
|
from pymongo.server_description import ServerDescription
|
||||||
from pymongo.srv_resolver import _SrvResolver
|
from pymongo.srv_resolver import _SrvResolver
|
||||||
from pymongo.synchronous import periodic_executor
|
|
||||||
from pymongo.synchronous.periodic_executor import _shutdown_executors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||||
@ -238,6 +238,9 @@ class Monitor(MonitorBase):
|
|||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
# Topology was garbage-collected.
|
# Topology was garbage-collected.
|
||||||
self.close()
|
self.close()
|
||||||
|
finally:
|
||||||
|
if self._executor._stopped:
|
||||||
|
self._rtt_monitor.close()
|
||||||
|
|
||||||
def _check_server(self) -> ServerDescription:
|
def _check_server(self) -> ServerDescription:
|
||||||
"""Call hello or read the next streaming response.
|
"""Call hello or read the next streaming response.
|
||||||
@ -254,6 +257,8 @@ class Monitor(MonitorBase):
|
|||||||
details = cast(Mapping[str, Any], exc.details)
|
details = cast(Mapping[str, Any], exc.details)
|
||||||
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
self._topology.receive_cluster_time(details.get("$clusterTime"))
|
||||||
raise
|
raise
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
raise
|
raise
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
@ -419,6 +424,8 @@ class SrvMonitor(MonitorBase):
|
|||||||
if len(seedlist) == 0:
|
if len(seedlist) == 0:
|
||||||
# As per the spec: this should be treated as a failure.
|
# As per the spec: this should be treated as a failure.
|
||||||
raise Exception
|
raise Exception
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
# As per the spec, upon encountering an error:
|
# As per the spec, upon encountering an error:
|
||||||
# - An error must not be raised
|
# - An error must not be raised
|
||||||
@ -482,6 +489,8 @@ class _RttMonitor(MonitorBase):
|
|||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
# Topology was garbage-collected.
|
# Topology was garbage-collected.
|
||||||
self.close()
|
self.close()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
self._pool.reset()
|
self._pool.reset()
|
||||||
|
|
||||||
@ -536,4 +545,5 @@ def _shutdown_resources() -> None:
|
|||||||
shutdown()
|
shutdown()
|
||||||
|
|
||||||
|
|
||||||
atexit.register(_shutdown_resources)
|
if _IS_SYNC:
|
||||||
|
atexit.register(_shutdown_resources)
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -62,7 +61,11 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
|||||||
_CertificateError,
|
_CertificateError,
|
||||||
)
|
)
|
||||||
from pymongo.hello import Hello, HelloCompat
|
from pymongo.hello import Hello, HelloCompat
|
||||||
from pymongo.lock import _create_lock, _Lock
|
from pymongo.lock import (
|
||||||
|
_cond_wait,
|
||||||
|
_create_condition,
|
||||||
|
_create_lock,
|
||||||
|
)
|
||||||
from pymongo.logger import (
|
from pymongo.logger import (
|
||||||
_CONNECTION_LOGGER,
|
_CONNECTION_LOGGER,
|
||||||
_ConnectionStatusMessage,
|
_ConnectionStatusMessage,
|
||||||
@ -208,11 +211,6 @@ def _raise_connection_failure(
|
|||||||
raise AutoReconnect(msg) from error
|
raise AutoReconnect(msg) from error
|
||||||
|
|
||||||
|
|
||||||
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
|
|
||||||
timeout = deadline - time.monotonic() if deadline else None
|
|
||||||
return condition.wait(timeout)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||||
details = {}
|
details = {}
|
||||||
timeout = _csot.get_timeout()
|
timeout = _csot.get_timeout()
|
||||||
@ -704,6 +702,8 @@ class Connection:
|
|||||||
# shutdown.
|
# shutdown.
|
||||||
try:
|
try:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
except Exception: # noqa: S110
|
except Exception: # noqa: S110
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -988,8 +988,8 @@ class Pool:
|
|||||||
# from the right side.
|
# from the right side.
|
||||||
self.conns: collections.deque = collections.deque()
|
self.conns: collections.deque = collections.deque()
|
||||||
self.active_contexts: set[_CancellationContext] = set()
|
self.active_contexts: set[_CancellationContext] = set()
|
||||||
_lock = _create_lock()
|
self.lock = _create_lock()
|
||||||
self.lock = _Lock(_lock)
|
self._max_connecting_cond = _create_condition(self.lock)
|
||||||
self.active_sockets = 0
|
self.active_sockets = 0
|
||||||
# Monotonically increasing connection ID required for CMAP Events.
|
# Monotonically increasing connection ID required for CMAP Events.
|
||||||
self.next_connection_id = 1
|
self.next_connection_id = 1
|
||||||
@ -1015,7 +1015,7 @@ class Pool:
|
|||||||
# The first portion of the wait queue.
|
# The first portion of the wait queue.
|
||||||
# Enforces: maxPoolSize
|
# Enforces: maxPoolSize
|
||||||
# Also used for: clearing the wait queue
|
# Also used for: clearing the wait queue
|
||||||
self.size_cond = threading.Condition(_lock)
|
self.size_cond = _create_condition(self.lock)
|
||||||
self.requests = 0
|
self.requests = 0
|
||||||
self.max_pool_size = self.opts.max_pool_size
|
self.max_pool_size = self.opts.max_pool_size
|
||||||
if not self.max_pool_size:
|
if not self.max_pool_size:
|
||||||
@ -1023,7 +1023,7 @@ class Pool:
|
|||||||
# The second portion of the wait queue.
|
# The second portion of the wait queue.
|
||||||
# Enforces: maxConnecting
|
# Enforces: maxConnecting
|
||||||
# Also used for: clearing the wait queue
|
# Also used for: clearing the wait queue
|
||||||
self._max_connecting_cond = threading.Condition(_lock)
|
self._max_connecting_cond = _create_condition(self.lock)
|
||||||
self._max_connecting = self.opts.max_connecting
|
self._max_connecting = self.opts.max_connecting
|
||||||
self._pending = 0
|
self._pending = 0
|
||||||
self._client_id = client_id
|
self._client_id = client_id
|
||||||
@ -1460,7 +1460,8 @@ class Pool:
|
|||||||
with self.size_cond:
|
with self.size_cond:
|
||||||
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
self._raise_if_not_ready(checkout_started_time, emit_event=True)
|
||||||
while not (self.requests < self.max_pool_size):
|
while not (self.requests < self.max_pool_size):
|
||||||
if not _cond_wait(self.size_cond, deadline):
|
timeout = deadline - time.monotonic() if deadline else None
|
||||||
|
if not _cond_wait(self.size_cond, timeout):
|
||||||
# Timed out, notify the next thread to ensure a
|
# Timed out, notify the next thread to ensure a
|
||||||
# timeout doesn't consume the condition.
|
# timeout doesn't consume the condition.
|
||||||
if self.requests < self.max_pool_size:
|
if self.requests < self.max_pool_size:
|
||||||
@ -1483,7 +1484,8 @@ class Pool:
|
|||||||
with self._max_connecting_cond:
|
with self._max_connecting_cond:
|
||||||
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
self._raise_if_not_ready(checkout_started_time, emit_event=False)
|
||||||
while not (self.conns or self._pending < self._max_connecting):
|
while not (self.conns or self._pending < self._max_connecting):
|
||||||
if not _cond_wait(self._max_connecting_cond, deadline):
|
timeout = deadline - time.monotonic() if deadline else None
|
||||||
|
if not _cond_wait(self._max_connecting_cond, timeout):
|
||||||
# Timed out, notify the next thread to ensure a
|
# Timed out, notify the next thread to ensure a
|
||||||
# timeout doesn't consume the condition.
|
# timeout doesn't consume the condition.
|
||||||
if self.conns or self._pending < self._max_connecting:
|
if self.conns or self._pending < self._max_connecting:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ import weakref
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||||
|
|
||||||
from pymongo import _csot, common, helpers_shared
|
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||||
from pymongo.errors import (
|
from pymongo.errors import (
|
||||||
ConnectionFailure,
|
ConnectionFailure,
|
||||||
InvalidOperation,
|
InvalidOperation,
|
||||||
@ -39,7 +39,11 @@ from pymongo.errors import (
|
|||||||
WriteError,
|
WriteError,
|
||||||
)
|
)
|
||||||
from pymongo.hello import Hello
|
from pymongo.hello import Hello
|
||||||
from pymongo.lock import _create_lock, _Lock
|
from pymongo.lock import (
|
||||||
|
_cond_wait,
|
||||||
|
_create_condition,
|
||||||
|
_create_lock,
|
||||||
|
)
|
||||||
from pymongo.logger import (
|
from pymongo.logger import (
|
||||||
_SDAM_LOGGER,
|
_SDAM_LOGGER,
|
||||||
_SERVER_SELECTION_LOGGER,
|
_SERVER_SELECTION_LOGGER,
|
||||||
@ -56,7 +60,6 @@ from pymongo.server_selectors import (
|
|||||||
secondary_server_selector,
|
secondary_server_selector,
|
||||||
writable_server_selector,
|
writable_server_selector,
|
||||||
)
|
)
|
||||||
from pymongo.synchronous import periodic_executor
|
|
||||||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
||||||
from pymongo.synchronous.monitor import SrvMonitor
|
from pymongo.synchronous.monitor import SrvMonitor
|
||||||
from pymongo.synchronous.pool import Pool
|
from pymongo.synchronous.pool import Pool
|
||||||
@ -170,9 +173,10 @@ class Topology:
|
|||||||
self._seed_addresses = list(topology_description.server_descriptions())
|
self._seed_addresses = list(topology_description.server_descriptions())
|
||||||
self._opened = False
|
self._opened = False
|
||||||
self._closed = False
|
self._closed = False
|
||||||
_lock = _create_lock()
|
self._lock = _create_lock()
|
||||||
self._lock = _Lock(_lock)
|
self._condition = _create_condition(
|
||||||
self._condition = self._settings.condition_class(_lock)
|
self._lock, self._settings.condition_class if _IS_SYNC else None
|
||||||
|
)
|
||||||
self._servers: dict[_Address, Server] = {}
|
self._servers: dict[_Address, Server] = {}
|
||||||
self._pid: Optional[int] = None
|
self._pid: Optional[int] = None
|
||||||
self._max_cluster_time: Optional[ClusterTime] = None
|
self._max_cluster_time: Optional[ClusterTime] = None
|
||||||
@ -354,7 +358,7 @@ class Topology:
|
|||||||
# change, or for a timeout. We won't miss any changes that
|
# change, or for a timeout. We won't miss any changes that
|
||||||
# came after our most recent apply_selector call, since we've
|
# came after our most recent apply_selector call, since we've
|
||||||
# held the lock until now.
|
# held the lock until now.
|
||||||
self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
|
_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
|
||||||
self._description.check_compatible()
|
self._description.check_compatible()
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
server_descriptions = self._description.apply_selector(
|
server_descriptions = self._description.apply_selector(
|
||||||
@ -652,7 +656,7 @@ class Topology:
|
|||||||
"""Wake all monitors, wait for at least one to check its server."""
|
"""Wake all monitors, wait for at least one to check its server."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._request_check_all()
|
self._request_check_all()
|
||||||
self._condition.wait(wait_time)
|
_cond_wait(self._condition, wait_time)
|
||||||
|
|
||||||
def data_bearing_servers(self) -> list[ServerDescription]:
|
def data_bearing_servers(self) -> list[ServerDescription]:
|
||||||
"""Return a list of all data-bearing servers.
|
"""Return a list of all data-bearing servers.
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@ -25,6 +26,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
@ -191,6 +193,8 @@ class ClientContext:
|
|||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
|
self.mongoses = []
|
||||||
|
self.connection_attempts = []
|
||||||
self.client = self._connect(host, port)
|
self.client = self._connect(host, port)
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
# Return early when connected to dataLake as mongohoused does not
|
# Return early when connected to dataLake as mongohoused does not
|
||||||
@ -860,6 +864,16 @@ class ClientContext:
|
|||||||
client_context = ClientContext()
|
client_context = ClientContext()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_client_context():
|
||||||
|
if _IS_SYNC:
|
||||||
|
# sync tests don't need to reset a client context
|
||||||
|
return
|
||||||
|
elif client_context.client is not None:
|
||||||
|
client_context.client.close()
|
||||||
|
client_context.client = None
|
||||||
|
client_context._init_client()
|
||||||
|
|
||||||
|
|
||||||
class PyMongoTestCase(unittest.TestCase):
|
class PyMongoTestCase(unittest.TestCase):
|
||||||
def assertEqualCommand(self, expected, actual, msg=None):
|
def assertEqualCommand(self, expected, actual, msg=None):
|
||||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||||
@ -1106,26 +1120,10 @@ class PyMongoTestCase(unittest.TestCase):
|
|||||||
class UnitTest(PyMongoTestCase):
|
class UnitTest(PyMongoTestCase):
|
||||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _setup_class(cls):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self) -> None:
|
||||||
def _tearDown_class(cls):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase):
|
|||||||
db: Database
|
db: Database
|
||||||
credentials: Dict[str, str]
|
credentials: Dict[str, str]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
if not _IS_SYNC:
|
||||||
|
reset_client_context()
|
||||||
|
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||||
raise SkipTest("this test does not support load balancers")
|
raise SkipTest("this test does not support load balancers")
|
||||||
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||||
raise SkipTest("this test does not support serverless")
|
raise SkipTest("this test does not support serverless")
|
||||||
cls.client = client_context.client
|
self.client = client_context.client
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
if client_context.auth_enabled:
|
if client_context.auth_enabled:
|
||||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
self.credentials = {"username": db_user, "password": db_pwd}
|
||||||
else:
|
else:
|
||||||
cls.credentials = {}
|
self.credentials = {}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def cleanup_colls(self, *collections):
|
def cleanup_colls(self, *collections):
|
||||||
"""Cleanup collections faster than drop_collection."""
|
"""Cleanup collections faster than drop_collection."""
|
||||||
@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest):
|
|||||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||||
# with loadBalanced=True.
|
# with loadBalanced=True.
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_no_load_balancer
|
@client_context.require_no_load_balancer
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||||
|
|
||||||
self.client_knobs.enable()
|
self.client_knobs.enable()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
self.client_knobs.disable()
|
self.client_knobs.disable()
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
@ -1253,7 +1211,6 @@ def teardown():
|
|||||||
c.drop_database("pymongo_test_mike")
|
c.drop_database("pymongo_test_mike")
|
||||||
c.drop_database("pymongo_test_bernie")
|
c.drop_database("pymongo_test_bernie")
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
print_running_clients()
|
print_running_clients()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@ -25,6 +26,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
@ -191,6 +193,8 @@ class AsyncClientContext:
|
|||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
async def _init_client(self):
|
async def _init_client(self):
|
||||||
|
self.mongoses = []
|
||||||
|
self.connection_attempts = []
|
||||||
self.client = await self._connect(host, port)
|
self.client = await self._connect(host, port)
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
# Return early when connected to dataLake as mongohoused does not
|
# Return early when connected to dataLake as mongohoused does not
|
||||||
@ -862,6 +866,16 @@ class AsyncClientContext:
|
|||||||
async_client_context = AsyncClientContext()
|
async_client_context = AsyncClientContext()
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_client_context():
|
||||||
|
if _IS_SYNC:
|
||||||
|
# sync tests don't need to reset a client context
|
||||||
|
return
|
||||||
|
elif async_client_context.client is not None:
|
||||||
|
await async_client_context.client.close()
|
||||||
|
async_client_context.client = None
|
||||||
|
await async_client_context._init_client()
|
||||||
|
|
||||||
|
|
||||||
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||||
def assertEqualCommand(self, expected, actual, msg=None):
|
def assertEqualCommand(self, expected, actual, msg=None):
|
||||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||||
@ -1124,26 +1138,10 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
|||||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _setup_class(cls):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self) -> None:
|
||||||
async def _tearDown_class(cls):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
|||||||
db: AsyncDatabase
|
db: AsyncDatabase
|
||||||
credentials: Dict[str, str]
|
credentials: Dict[str, str]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_connection
|
@async_client_context.require_connection
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
|
if not _IS_SYNC:
|
||||||
|
await reset_client_context()
|
||||||
|
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||||
raise SkipTest("this test does not support load balancers")
|
raise SkipTest("this test does not support load balancers")
|
||||||
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
|
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||||
raise SkipTest("this test does not support serverless")
|
raise SkipTest("this test does not support serverless")
|
||||||
cls.client = async_client_context.client
|
self.client = async_client_context.client
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
if async_client_context.auth_enabled:
|
if async_client_context.auth_enabled:
|
||||||
cls.credentials = {"username": db_user, "password": db_pwd}
|
self.credentials = {"username": db_user, "password": db_pwd}
|
||||||
else:
|
else:
|
||||||
cls.credentials = {}
|
self.credentials = {}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def cleanup_colls(self, *collections):
|
async def cleanup_colls(self, *collections):
|
||||||
"""Cleanup collections faster than drop_collection."""
|
"""Cleanup collections faster than drop_collection."""
|
||||||
@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest):
|
|||||||
# MockClients tests that use replicaSet, directConnection=True, pass
|
# MockClients tests that use replicaSet, directConnection=True, pass
|
||||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||||
# with loadBalanced=True.
|
# with loadBalanced=True.
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._setup_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._setup_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls._tearDown_class()
|
|
||||||
else:
|
|
||||||
asyncio.run(cls._tearDown_class())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_no_load_balancer
|
@async_client_context.require_no_load_balancer
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
pass
|
await super().asyncSetUp()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
|
||||||
|
|
||||||
self.client_knobs.enable()
|
self.client_knobs.enable()
|
||||||
|
|
||||||
def tearDown(self):
|
async def asyncTearDown(self) -> None:
|
||||||
self.client_knobs.disable()
|
self.client_knobs.disable()
|
||||||
super().tearDown()
|
await super().asyncTearDown()
|
||||||
|
|
||||||
|
|
||||||
async def async_setup():
|
async def async_setup():
|
||||||
@ -1271,7 +1229,6 @@ async def async_teardown():
|
|||||||
await c.drop_database("pymongo_test_mike")
|
await c.drop_database("pymongo_test_mike")
|
||||||
await c.drop_database("pymongo_test_bernie")
|
await c.drop_database("pymongo_test_bernie")
|
||||||
await c.close()
|
await c.close()
|
||||||
|
|
||||||
print_running_clients()
|
print_running_clients()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def event_loop_policy():
|
|||||||
return asyncio.get_event_loop_policy()
|
return asyncio.get_event_loop_policy()
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="package", autouse=True)
|
||||||
async def test_setup_and_teardown():
|
async def test_setup_and_teardown():
|
||||||
await async_setup()
|
await async_setup()
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
|
|||||||
coll: AsyncCollection
|
coll: AsyncCollection
|
||||||
coll_w0: AsyncCollection
|
coll_w0: AsyncCollection
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
cls.coll = cls.db.test
|
|
||||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
super().setUp()
|
await super().asyncSetUp()
|
||||||
|
self.coll = self.db.test
|
||||||
await self.coll.drop()
|
await self.coll.drop()
|
||||||
|
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||||
|
|
||||||
def assertEqualResponse(self, expected, actual):
|
def assertEqualResponse(self, expected, actual):
|
||||||
"""Compare response from bulk.execute() to expected response."""
|
"""Compare response from bulk.execute() to expected response."""
|
||||||
@ -787,14 +783,10 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
|
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_auth
|
@async_client_context.require_auth
|
||||||
@async_client_context.require_no_api_version
|
@async_client_context.require_no_api_version
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
super().setUp()
|
await super().asyncSetUp()
|
||||||
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||||
await self.db.command(
|
await self.db.command(
|
||||||
"createRole",
|
"createRole",
|
||||||
@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
|||||||
w: Optional[int]
|
w: Optional[int]
|
||||||
secondary: AsyncMongoClient
|
secondary: AsyncMongoClient
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self):
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
self.w = async_client_context.w
|
||||||
cls.w = async_client_context.w
|
self.secondary = None
|
||||||
cls.secondary = None
|
if self.w is not None and self.w > 1:
|
||||||
if cls.w is not None and cls.w > 1:
|
|
||||||
for member in (await async_client_context.hello)["hosts"]:
|
for member in (await async_client_context.hello)["hosts"]:
|
||||||
if member != (await async_client_context.hello)["primary"]:
|
if member != (await async_client_context.hello)["primary"]:
|
||||||
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
|
self.secondary = await self.async_single_client(*partition_node(member))
|
||||||
break
|
break
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self):
|
||||||
async def async_tearDownClass(cls):
|
if self.secondary:
|
||||||
if cls.secondary:
|
await self.secondary.close()
|
||||||
await cls.secondary.close()
|
|
||||||
|
|
||||||
async def cause_wtimeout(self, requests, ordered):
|
async def cause_wtimeout(self, requests, ordered):
|
||||||
if not async_client_context.test_commands_enabled:
|
if not async_client_context.test_commands_enabled:
|
||||||
|
|||||||
@ -836,18 +836,16 @@ class ProseSpecTestsMixin:
|
|||||||
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||||
dbs: list
|
dbs: list
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||||
@async_client_context.require_change_streams
|
@async_client_context.require_change_streams
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self):
|
||||||
async def _tearDown_class(cls):
|
for db in self.dbs:
|
||||||
for db in cls.dbs:
|
await self.client.drop_database(db)
|
||||||
await cls.client.drop_database(db)
|
await super().asyncTearDown()
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||||
return await client.watch(*args, **kwargs)
|
return await client.watch(*args, **kwargs)
|
||||||
@ -898,11 +896,10 @@ class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
|||||||
|
|
||||||
|
|
||||||
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_version_min(4, 0, 0, -1)
|
@async_client_context.require_version_min(4, 0, 0, -1)
|
||||||
@async_client_context.require_change_streams
|
@async_client_context.require_change_streams
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
async def change_stream_with_client(self, client, *args, **kwargs):
|
async def change_stream_with_client(self, client, *args, **kwargs):
|
||||||
return await client[self.db.name].watch(*args, **kwargs)
|
return await client[self.db.name].watch(*args, **kwargs)
|
||||||
@ -988,12 +985,9 @@ class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixi
|
|||||||
class TestAsyncCollectionAsyncChangeStream(
|
class TestAsyncCollectionAsyncChangeStream(
|
||||||
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
|
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
|
||||||
):
|
):
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_change_streams
|
@async_client_context.require_change_streams
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
# Use a new collection for each test.
|
# Use a new collection for each test.
|
||||||
await self.watched_collection().drop()
|
await self.watched_collection().drop()
|
||||||
await self.watched_collection().insert_one({})
|
await self.watched_collection().insert_one({})
|
||||||
@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
|
|||||||
RUN_ON_LOAD_BALANCER = True
|
RUN_ON_LOAD_BALANCER = True
|
||||||
listener: AllowListEventListener
|
listener: AllowListEventListener
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_connection
|
@async_client_context.require_connection
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
def asyncSetUp(self):
|
|
||||||
super().asyncSetUp()
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
|
||||||
async def asyncSetUpCluster(self, scenario_dict):
|
async def asyncSetUpCluster(self, scenario_dict):
|
||||||
|
|||||||
@ -73,7 +73,6 @@ from test.utils import (
|
|||||||
is_greenthread_patched,
|
is_greenthread_patched,
|
||||||
lazy_client_trial,
|
lazy_client_trial,
|
||||||
one,
|
one,
|
||||||
wait_until,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import bson
|
import bson
|
||||||
@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
|||||||
|
|
||||||
client: AsyncMongoClient
|
client: AsyncMongoClient
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
self.client = await self.async_rs_or_single_client(
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
|
||||||
connect=False, serverSelectionTimeoutMS=100
|
connect=False, serverSelectionTimeoutMS=100
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.client.close()
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, caplog):
|
def inject_fixtures(self, caplog):
|
||||||
self._caplog = caplog
|
self._caplog = caplog
|
||||||
@ -693,8 +687,8 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
# When the reaper runs at the same time as the get_socket, two
|
# When the reaper runs at the same time as the get_socket, two
|
||||||
# connections could be created and checked into the pool.
|
# connections could be created and checked into the pool.
|
||||||
self.assertGreaterEqual(len(server._pool.conns), 1)
|
self.assertGreaterEqual(len(server._pool.conns), 1)
|
||||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||||
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||||
|
|
||||||
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
@ -710,8 +704,8 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
# When the reaper runs at the same time as the get_socket,
|
# When the reaper runs at the same time as the get_socket,
|
||||||
# maxPoolSize=1 should prevent two connections from being created.
|
# maxPoolSize=1 should prevent two connections from being created.
|
||||||
self.assertEqual(1, len(server._pool.conns))
|
self.assertEqual(1, len(server._pool.conns))
|
||||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||||
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||||
|
|
||||||
async def test_max_idle_time_reaper_removes_stale(self):
|
async def test_max_idle_time_reaper_removes_stale(self):
|
||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
@ -727,7 +721,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
async with server._pool.checkout() as conn_two:
|
async with server._pool.checkout() as conn_two:
|
||||||
pass
|
pass
|
||||||
self.assertIs(conn_one, conn_two)
|
self.assertIs(conn_one, conn_two)
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(server._pool.conns) == 0,
|
lambda: len(server._pool.conns) == 0,
|
||||||
"stale socket reaped and new one NOT added to the pool",
|
"stale socket reaped and new one NOT added to the pool",
|
||||||
)
|
)
|
||||||
@ -745,7 +739,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
server = await (await client._get_topology()).select_server(
|
server = await (await client._get_topology()).select_server(
|
||||||
readable_server_selector, _Op.TEST
|
readable_server_selector, _Op.TEST
|
||||||
)
|
)
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(server._pool.conns) == 10,
|
lambda: len(server._pool.conns) == 10,
|
||||||
"pool initialized with 10 connections",
|
"pool initialized with 10 connections",
|
||||||
)
|
)
|
||||||
@ -753,7 +747,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
# Assert that if a socket is closed, a new one takes its place
|
# Assert that if a socket is closed, a new one takes its place
|
||||||
async with server._pool.checkout() as conn:
|
async with server._pool.checkout() as conn:
|
||||||
conn.close_conn(None)
|
conn.close_conn(None)
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(server._pool.conns) == 10,
|
lambda: len(server._pool.conns) == 10,
|
||||||
"a closed socket gets replaced from the pool",
|
"a closed socket gets replaced from the pool",
|
||||||
)
|
)
|
||||||
@ -939,8 +933,10 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
async with eval(the_repr) as client_two:
|
async with eval(the_repr) as client_two:
|
||||||
self.assertEqual(client_two, client)
|
self.assertEqual(client_two, client)
|
||||||
|
|
||||||
def test_getters(self):
|
async def test_getters(self):
|
||||||
wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes")
|
await async_wait_until(
|
||||||
|
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
||||||
|
)
|
||||||
|
|
||||||
async def test_list_databases(self):
|
async def test_list_databases(self):
|
||||||
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
|
cmd_docs = (await self.client.admin.command("listDatabases"))["databases"]
|
||||||
@ -1065,14 +1061,21 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
self.assertFalse(client._topology._opened)
|
self.assertFalse(client._topology._opened)
|
||||||
|
|
||||||
# Ensure kill cursors thread has not been started.
|
# Ensure kill cursors thread has not been started.
|
||||||
kc_thread = client._kill_cursors_executor._thread
|
if _IS_SYNC:
|
||||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
kc_thread = client._kill_cursors_executor._thread
|
||||||
|
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||||
|
else:
|
||||||
|
kc_task = client._kill_cursors_executor._task
|
||||||
|
self.assertFalse(kc_task and not kc_task.done())
|
||||||
# Using the client should open topology and start the thread.
|
# Using the client should open topology and start the thread.
|
||||||
await client.admin.command("ping")
|
await client.admin.command("ping")
|
||||||
self.assertTrue(client._topology._opened)
|
self.assertTrue(client._topology._opened)
|
||||||
kc_thread = client._kill_cursors_executor._thread
|
if _IS_SYNC:
|
||||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
kc_thread = client._kill_cursors_executor._thread
|
||||||
|
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||||
|
else:
|
||||||
|
kc_task = client._kill_cursors_executor._task
|
||||||
|
self.assertTrue(kc_task and not kc_task.done())
|
||||||
|
|
||||||
async def test_close_does_not_open_servers(self):
|
async def test_close_does_not_open_servers(self):
|
||||||
client = await self.async_rs_client(connect=False)
|
client = await self.async_rs_client(connect=False)
|
||||||
@ -1277,6 +1280,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
async def test_server_selection_timeout(self):
|
async def test_server_selection_timeout(self):
|
||||||
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
|
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
|
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||||
|
|
||||||
@ -1289,18 +1293,22 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
|
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||||
)
|
)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
client = AsyncMongoClient(
|
client = AsyncMongoClient(
|
||||||
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
|
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
|
||||||
)
|
)
|
||||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
# Test invalid timeout in URI ignored and set to default.
|
# Test invalid timeout in URI ignored and set to default.
|
||||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||||
@ -1608,7 +1616,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
await async_client_context.port,
|
await async_client_context.port,
|
||||||
)
|
)
|
||||||
await self.async_single_client(uri, event_listeners=[listener])
|
await self.async_single_client(uri, event_listeners=[listener])
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
|
lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1766,16 +1774,16 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
pool = await async_get_pool(client)
|
pool = await async_get_pool(client)
|
||||||
original_connect = pool.connect
|
original_connect = pool.connect
|
||||||
|
|
||||||
def stall_connect(*args, **kwargs):
|
async def stall_connect(*args, **kwargs):
|
||||||
time.sleep(2)
|
await asyncio.sleep(2)
|
||||||
return original_connect(*args, **kwargs)
|
return await original_connect(*args, **kwargs)
|
||||||
|
|
||||||
pool.connect = stall_connect
|
pool.connect = stall_connect
|
||||||
# Un-patch Pool.connect to break the cyclic reference.
|
# Un-patch Pool.connect to break the cyclic reference.
|
||||||
self.addCleanup(delattr, pool, "connect")
|
self.addCleanup(delattr, pool, "connect")
|
||||||
|
|
||||||
# Wait for the background thread to start creating connections
|
# Wait for the background thread to start creating connections
|
||||||
wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections")
|
||||||
|
|
||||||
# Assert that application operations do not block.
|
# Assert that application operations do not block.
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -1858,7 +1866,7 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
await client.close()
|
await client.close()
|
||||||
# Add cursor to kill cursors queue
|
# Add cursor to kill cursors queue
|
||||||
del cursor
|
del cursor
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: client._kill_cursors_queue,
|
lambda: client._kill_cursors_queue,
|
||||||
"waited for cursor to be added to queue",
|
"waited for cursor to be added to queue",
|
||||||
)
|
)
|
||||||
@ -2232,7 +2240,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
|
|||||||
await cursor.to_list()
|
await cursor.to_list()
|
||||||
self.assertTrue(conn.closed)
|
self.assertTrue(conn.closed)
|
||||||
|
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(client._kill_cursors_queue) == 0,
|
lambda: len(client._kill_cursors_queue) == 0,
|
||||||
"waited for all killCursor requests to complete",
|
"waited for all killCursor requests to complete",
|
||||||
)
|
)
|
||||||
@ -2403,7 +2411,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
|||||||
)
|
)
|
||||||
self.addAsyncCleanup(c.close)
|
self.addAsyncCleanup(c.close)
|
||||||
|
|
||||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||||
|
|
||||||
self.assertEqual(await c.address, ("a", 1))
|
self.assertEqual(await c.address, ("a", 1))
|
||||||
# Fail over.
|
# Fail over.
|
||||||
@ -2430,7 +2438,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
|||||||
)
|
)
|
||||||
self.addAsyncCleanup(c.close)
|
self.addAsyncCleanup(c.close)
|
||||||
|
|
||||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||||
|
|
||||||
# Total failure.
|
# Total failure.
|
||||||
c.kill_host("a:1")
|
c.kill_host("a:1")
|
||||||
@ -2472,7 +2480,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
|||||||
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
|
c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION)
|
||||||
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
|
c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1)
|
||||||
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||||
wait_until(lambda: len(c.nodes) == 2, "connect")
|
await async_wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||||
|
|
||||||
c.kill_host("a:1")
|
c.kill_host("a:1")
|
||||||
|
|
||||||
@ -2544,11 +2552,11 @@ class TestClientPool(AsyncMockClientTest):
|
|||||||
)
|
)
|
||||||
self.addAsyncCleanup(c.close)
|
self.addAsyncCleanup(c.close)
|
||||||
|
|
||||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
await async_wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||||
self.assertEqual(await c.address, ("a", 1))
|
self.assertEqual(await c.address, ("a", 1))
|
||||||
self.assertEqual(await c.arbiters, {("c", 3)})
|
self.assertEqual(await c.arbiters, {("c", 3)})
|
||||||
# Assert that we create 2 and only 2 pooled connections.
|
# Assert that we create 2 and only 2 pooled connections.
|
||||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2)
|
||||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
|
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2)
|
||||||
# Assert that we do not create connections to arbiters.
|
# Assert that we do not create connections to arbiters.
|
||||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||||
@ -2574,10 +2582,10 @@ class TestClientPool(AsyncMockClientTest):
|
|||||||
)
|
)
|
||||||
self.addAsyncCleanup(c.close)
|
self.addAsyncCleanup(c.close)
|
||||||
|
|
||||||
wait_until(lambda: len(c.nodes) == 1, "connect")
|
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||||
self.assertEqual(await c.address, ("c", 3))
|
self.assertEqual(await c.address, ("c", 3))
|
||||||
# Assert that we create 1 pooled connection.
|
# Assert that we create 1 pooled connection.
|
||||||
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
|
||||||
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
|
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
|
||||||
arbiter = c._topology.get_server_by_address(("c", 3))
|
arbiter = c._topology.get_server_by_address(("c", 3))
|
||||||
self.assertEqual(len(arbiter.pool.conns), 1)
|
self.assertEqual(len(arbiter.pool.conns), 1)
|
||||||
|
|||||||
@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest):
|
|||||||
warn_context: Any
|
warn_context: Any
|
||||||
collation: Collation
|
collation: Collation
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_connection
|
@async_client_context.require_connection
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
cls.collation = Collation("en_US")
|
self.collation = Collation("en_US")
|
||||||
cls.warn_context = warnings.catch_warnings()
|
self.warn_context = warnings.catch_warnings()
|
||||||
cls.warn_context.__enter__()
|
self.warn_context.__enter__()
|
||||||
warnings.simplefilter("ignore", DeprecationWarning)
|
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self) -> None:
|
||||||
async def _tearDown_class(cls):
|
self.warn_context.__exit__()
|
||||||
cls.warn_context.__exit__()
|
self.warn_context = None
|
||||||
cls.warn_context = None
|
|
||||||
await cls.client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
super().tearDown()
|
await super().asyncTearDown()
|
||||||
|
|
||||||
def last_command_started(self):
|
def last_command_started(self):
|
||||||
return self.listener.started_events[-1].command
|
return self.listener.started_events[-1].command
|
||||||
|
|||||||
@ -40,7 +40,6 @@ from test.utils import (
|
|||||||
async_get_pool,
|
async_get_pool,
|
||||||
async_is_mongos,
|
async_is_mongos,
|
||||||
async_wait_until,
|
async_wait_until,
|
||||||
wait_until,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from bson import encode
|
from bson import encode
|
||||||
@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
|||||||
db: AsyncDatabase
|
db: AsyncDatabase
|
||||||
client: AsyncMongoClient
|
client: AsyncMongoClient
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
cls.client = AsyncMongoClient(connect=False)
|
self.client = self.simple_client(connect=False)
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.client.close()
|
|
||||||
|
|
||||||
def test_collection(self):
|
def test_collection(self):
|
||||||
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
|
self.assertRaises(TypeError, AsyncCollection, self.db, 5)
|
||||||
@ -165,27 +160,14 @@ class TestCollectionNoConnect(AsyncUnitTest):
|
|||||||
class AsyncTestCollection(AsyncIntegrationTest):
|
class AsyncTestCollection(AsyncIntegrationTest):
|
||||||
w: int
|
w: int
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.w = async_client_context.w # type: ignore
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
|
||||||
else:
|
|
||||||
asyncio.run(cls.async_tearDownClass())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def async_tearDownClass(cls):
|
|
||||||
await cls.db.drop_collection("test_large_limit")
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await self.db.test.drop()
|
await super().asyncSetUp()
|
||||||
|
self.w = async_client_context.w # type: ignore
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
await self.db.test.drop()
|
await self.db.test.drop()
|
||||||
|
await self.db.drop_collection("test_large_limit")
|
||||||
|
await super().asyncTearDown()
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def write_concern_collection(self):
|
def write_concern_collection(self):
|
||||||
@ -1023,7 +1005,10 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
|||||||
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
await db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||||
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||||
|
|
||||||
await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
async def predicate():
|
||||||
|
return await db_w0.test.find_one({"x": 1})
|
||||||
|
|
||||||
|
await async_wait_until(predicate, "find w:0 replaced document")
|
||||||
|
|
||||||
async def test_update_bypass_document_validation(self):
|
async def test_update_bypass_document_validation(self):
|
||||||
db = self.db
|
db = self.db
|
||||||
@ -1871,7 +1856,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
|||||||
await cur.close()
|
await cur.close()
|
||||||
cur = None
|
cur = None
|
||||||
# Wait until the background thread returns the socket.
|
# Wait until the background thread returns the socket.
|
||||||
wait_until(lambda: pool.active_sockets == 0, "return socket")
|
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
|
||||||
# The socket should be discarded.
|
# The socket should be discarded.
|
||||||
self.assertEqual(0, len(pool.conns))
|
self.assertEqual(0, len(pool.conns))
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,12 @@ import sys
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
from test.asynchronous import (
|
||||||
|
AsyncIntegrationTest,
|
||||||
|
async_client_context,
|
||||||
|
reset_client_context,
|
||||||
|
unittest,
|
||||||
|
)
|
||||||
from test.asynchronous.helpers import async_repl_set_step_down
|
from test.asynchronous.helpers import async_repl_set_step_down
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
CMAPListener,
|
CMAPListener,
|
||||||
@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
|
|||||||
listener: CMAPListener
|
listener: CMAPListener
|
||||||
coll: AsyncCollection
|
coll: AsyncCollection
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_replica_set
|
@async_client_context.require_replica_set
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
self.listener = CMAPListener()
|
||||||
cls.listener = CMAPListener()
|
self.client = await self.async_rs_or_single_client(
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure connections to all servers in replica set. This is to test
|
# Ensure connections to all servers in replica set. This is to test
|
||||||
# that the is_writable flag is properly updated for connections that
|
# that the is_writable flag is properly updated for connections that
|
||||||
# survive a replica set election.
|
# survive a replica set election.
|
||||||
await async_ensure_all_connected(cls.client)
|
await async_ensure_all_connected(self.client)
|
||||||
cls.listener.reset()
|
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||||
|
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
|
||||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.client.close()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
# Note that all ops use same write-concern as self.db (majority).
|
# Note that all ops use same write-concern as self.db (majority).
|
||||||
await self.db.drop_collection("step-down")
|
await self.db.drop_collection("step-down")
|
||||||
await self.db.create_collection("step-down")
|
await self.db.create_collection("step-down")
|
||||||
|
|||||||
@ -56,6 +56,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
|||||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||||
for event in final_entity_map["events1"]:
|
for event in final_entity_map["events1"]:
|
||||||
self.assertIn("PoolCreatedEvent", event["name"])
|
self.assertIn("PoolCreatedEvent", event["name"])
|
||||||
|
if self.scenario_runner.mongos_clients:
|
||||||
|
for client in self.scenario_runner.mongos_clients:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
async def test_store_all_others_as_entities(self):
|
async def test_store_all_others_as_entities(self):
|
||||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||||
@ -122,6 +125,9 @@ class TestCreateEntities(AsyncIntegrationTest):
|
|||||||
self.assertEqual(entity_map["failures"], [])
|
self.assertEqual(entity_map["failures"], [])
|
||||||
self.assertEqual(entity_map["successes"], 2)
|
self.assertEqual(entity_map["successes"], 2)
|
||||||
self.assertEqual(entity_map["iterations"], 5)
|
self.assertEqual(entity_map["iterations"], 5)
|
||||||
|
if self.scenario_runner.mongos_clients:
|
||||||
|
for client in self.scenario_runner.mongos_clients:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -34,9 +34,9 @@ from test.utils import (
|
|||||||
AllowListEventListener,
|
AllowListEventListener,
|
||||||
EventListener,
|
EventListener,
|
||||||
OvertCommandListener,
|
OvertCommandListener,
|
||||||
|
async_wait_until,
|
||||||
delay,
|
delay,
|
||||||
ignore_deprecations,
|
ignore_deprecations,
|
||||||
wait_until,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from bson import decode_all
|
from bson import decode_all
|
||||||
@ -1324,8 +1324,8 @@ class TestCursor(AsyncIntegrationTest):
|
|||||||
with self.assertRaises(ExecutionTimeout):
|
with self.assertRaises(ExecutionTimeout):
|
||||||
await cursor.next()
|
await cursor.next()
|
||||||
|
|
||||||
def assertCursorKilled():
|
async def assertCursorKilled():
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: len(listener.succeeded_events),
|
lambda: len(listener.succeeded_events),
|
||||||
"find successful killCursors command",
|
"find successful killCursors command",
|
||||||
)
|
)
|
||||||
@ -1335,7 +1335,7 @@ class TestCursor(AsyncIntegrationTest):
|
|||||||
self.assertEqual(1, len(listener.succeeded_events))
|
self.assertEqual(1, len(listener.succeeded_events))
|
||||||
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
|
self.assertEqual("killCursors", listener.succeeded_events[0].command_name)
|
||||||
|
|
||||||
assertCursorKilled()
|
await assertCursorKilled()
|
||||||
listener.reset()
|
listener.reset()
|
||||||
|
|
||||||
cursor = await coll.aggregate([], batchSize=1)
|
cursor = await coll.aggregate([], batchSize=1)
|
||||||
@ -1345,7 +1345,7 @@ class TestCursor(AsyncIntegrationTest):
|
|||||||
with self.assertRaises(ExecutionTimeout):
|
with self.assertRaises(ExecutionTimeout):
|
||||||
await cursor.next()
|
await cursor.next()
|
||||||
|
|
||||||
assertCursorKilled()
|
await assertCursorKilled()
|
||||||
|
|
||||||
def test_delete_not_initialized(self):
|
def test_delete_not_initialized(self):
|
||||||
# Creating a cursor with invalid arguments will not run __init__
|
# Creating a cursor with invalid arguments will not run __init__
|
||||||
@ -1647,10 +1647,6 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||||
@classmethod
|
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
|
|
||||||
async def test_aggregate_raw(self):
|
async def test_aggregate_raw(self):
|
||||||
c = self.db.test
|
c = self.db.test
|
||||||
await c.drop()
|
await c.drop()
|
||||||
|
|||||||
@ -717,7 +717,8 @@ class TestDatabase(AsyncIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestDatabaseAggregation(AsyncIntegrationTest):
|
class TestDatabaseAggregation(AsyncIntegrationTest):
|
||||||
def setUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
self.pipeline: List[Mapping[str, Any]] = [
|
self.pipeline: List[Mapping[str, Any]] = [
|
||||||
{"$listLocalSessions": {}},
|
{"$listLocalSessions": {}},
|
||||||
{"$limit": 1},
|
{"$limit": 1},
|
||||||
|
|||||||
@ -211,11 +211,10 @@ class TestClientOptions(AsyncPyMongoTestCase):
|
|||||||
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
|
class AsyncEncryptionIntegrationTest(AsyncIntegrationTest):
|
||||||
"""Base class for encryption integration tests."""
|
"""Base class for encryption integration tests."""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||||
@async_client_context.require_version_min(4, 2, -1)
|
@async_client_context.require_version_min(4, 2, -1)
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
def assertEncrypted(self, val):
|
def assertEncrypted(self, val):
|
||||||
self.assertIsInstance(val, Binary)
|
self.assertIsInstance(val, Binary)
|
||||||
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestClientMaxWireVersion(AsyncIntegrationTest):
|
class TestClientMaxWireVersion(AsyncIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
@async_client_context.require_version_max(4, 0, 99)
|
@async_client_context.require_version_max(4, 0, 99)
|
||||||
async def test_raise_max_wire_version_error(self):
|
async def test_raise_max_wire_version_error(self):
|
||||||
@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
|||||||
"local": None,
|
"local": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||||
"No environment credentials are set",
|
"No environment credentials are set",
|
||||||
)
|
)
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||||
await cls.client.db.coll.drop()
|
await self.client.db.coll.drop()
|
||||||
cls.vault = await create_key_vault(cls.client.keyvault.datakeys)
|
self.vault = await create_key_vault(self.client.keyvault.datakeys)
|
||||||
|
|
||||||
# Configure the encrypted field via the local schema_map option.
|
# Configure the encrypted field via the local schema_map option.
|
||||||
schemas = {
|
schemas = {
|
||||||
@ -846,25 +843,22 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
opts = AutoEncryptionOpts(
|
opts = AutoEncryptionOpts(
|
||||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
self.KMS_PROVIDERS,
|
||||||
|
"keyvault.datakeys",
|
||||||
|
schema_map=schemas,
|
||||||
|
kms_tls_options=KMS_TLS_OPTS,
|
||||||
)
|
)
|
||||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
self.client_encrypted = await self.async_rs_or_single_client(
|
||||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||||
)
|
)
|
||||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
self.client_encryption = self.create_client_encryption(
|
||||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.vault.drop()
|
|
||||||
await cls.client.close()
|
|
||||||
await cls.client_encrypted.close()
|
|
||||||
await cls.client_encryption.close()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
|
||||||
|
async def asyncTearDown(self) -> None:
|
||||||
|
await self.vault.drop()
|
||||||
|
|
||||||
async def run_test(self, provider_name):
|
async def run_test(self, provider_name):
|
||||||
# Create data key.
|
# Create data key.
|
||||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||||
@ -1011,10 +1005,9 @@ class TestViews(AsyncEncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestCorpus(AsyncEncryptionIntegrationTest):
|
class TestCorpus(AsyncEncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def kms_providers():
|
def kms_providers():
|
||||||
@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
|||||||
client_encrypted: AsyncMongoClient
|
client_encrypted: AsyncMongoClient
|
||||||
listener: OvertCommandListener
|
listener: OvertCommandListener
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self):
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
|
||||||
db = async_client_context.client.db
|
db = async_client_context.client.db
|
||||||
cls.coll = db.coll
|
self.coll = db.coll
|
||||||
await cls.coll.drop()
|
await self.coll.drop()
|
||||||
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
||||||
json_schema = json_data("limits", "limits-schema.json")
|
json_schema = json_data("limits", "limits-schema.json")
|
||||||
await db.create_collection(
|
await db.create_collection(
|
||||||
@ -1211,17 +1203,14 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
|||||||
await coll.insert_one(json_data("limits", "limits-key.json"))
|
await coll.insert_one(json_data("limits", "limits-key.json"))
|
||||||
|
|
||||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client(
|
self.client_encrypted = await self.async_rs_or_single_client(
|
||||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
auto_encryption_opts=opts, event_listeners=[self.listener]
|
||||||
)
|
)
|
||||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
self.coll_encrypted = self.client_encrypted.db.coll
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self) -> None:
|
||||||
async def _tearDown_class(cls):
|
await self.coll_encrypted.drop()
|
||||||
await cls.coll_encrypted.drop()
|
|
||||||
await cls.client_encrypted.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def test_01_insert_succeeds_under_2MiB(self):
|
async def test_01_insert_succeeds_under_2MiB(self):
|
||||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||||
@ -1245,7 +1234,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
|||||||
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
self.assertEqual(
|
||||||
|
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||||
|
)
|
||||||
|
|
||||||
async def test_04_bulk_batch_split(self):
|
async def test_04_bulk_batch_split(self):
|
||||||
limits_doc = json_data("limits", "limits-doc.json")
|
limits_doc = json_data("limits", "limits-doc.json")
|
||||||
@ -1255,7 +1246,9 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
|||||||
doc2.update(limits_doc)
|
doc2.update(limits_doc)
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
self.assertEqual(
|
||||||
|
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||||
|
)
|
||||||
|
|
||||||
async def test_05_insert_succeeds_just_under_16MiB(self):
|
async def test_05_insert_succeeds_just_under_16MiB(self):
|
||||||
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
||||||
@ -1285,15 +1278,12 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest):
|
|||||||
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
||||||
"""Prose tests for creating data keys with a custom endpoint."""
|
"""Prose tests for creating data keys with a custom endpoint."""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||||
"No environment credentials are set",
|
"No environment credentials are set",
|
||||||
)
|
)
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
kms_providers = {
|
kms_providers = {
|
||||||
"aws": AWS_CREDS,
|
"aws": AWS_CREDS,
|
||||||
"azure": AZURE_CREDS,
|
"azure": AZURE_CREDS,
|
||||||
@ -1322,10 +1312,6 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
|
|||||||
self._kmip_host_error = None
|
self._kmip_host_error = None
|
||||||
self._invalid_host_error = None
|
self._invalid_host_error = None
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
await self.client_encryption.close()
|
|
||||||
await self.client_encryption_invalid.close()
|
|
||||||
|
|
||||||
async def run_test_expected_success(self, provider_name, master_key):
|
async def run_test_expected_success(self, provider_name, master_key):
|
||||||
data_key_id = await self.client_encryption.create_data_key(
|
data_key_id = await self.client_encryption.create_data_key(
|
||||||
provider_name, master_key=master_key
|
provider_name, master_key=master_key
|
||||||
@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
|||||||
KEYVAULT_COLL = "datakeys"
|
KEYVAULT_COLL = "datakeys"
|
||||||
client: AsyncMongoClient
|
client: AsyncMongoClient
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def _setup(self):
|
||||||
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
||||||
await create_key_vault(keyvault, self.DEK)
|
await create_key_vault(keyvault, self.DEK)
|
||||||
|
|
||||||
async def _test_explicit(self, expectation):
|
async def _test_explicit(self, expectation):
|
||||||
|
await self._setup()
|
||||||
client_encryption = self.create_client_encryption(
|
client_encryption = self.create_client_encryption(
|
||||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||||
async_client_context.client,
|
async_client_context.client,
|
||||||
OPTS,
|
OPTS,
|
||||||
)
|
)
|
||||||
self.addAsyncCleanup(client_encryption.close)
|
|
||||||
|
|
||||||
ciphertext = await client_encryption.encrypt(
|
ciphertext = await client_encryption.encrypt(
|
||||||
"string0",
|
"string0",
|
||||||
@ -1523,6 +1509,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
|||||||
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
|
self.assertEqual(await client_encryption.decrypt(ciphertext), "string0")
|
||||||
|
|
||||||
async def _test_automatic(self, expectation_extjson, payload):
|
async def _test_automatic(self, expectation_extjson, payload):
|
||||||
|
await self._setup()
|
||||||
encrypted_db = "db"
|
encrypted_db = "db"
|
||||||
encrypted_coll = "coll"
|
encrypted_coll = "coll"
|
||||||
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||||
@ -1537,7 +1524,6 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
|||||||
client = await self.async_rs_or_single_client(
|
client = await self.async_rs_or_single_client(
|
||||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||||
)
|
)
|
||||||
self.addAsyncCleanup(client.aclose)
|
|
||||||
|
|
||||||
coll = client.get_database(encrypted_db).get_collection(
|
coll = client.get_database(encrypted_db).get_collection(
|
||||||
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
||||||
@ -1559,13 +1545,12 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
self.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
async def test_explicit(self):
|
async def test_explicit(self):
|
||||||
return await self._test_explicit(
|
return await self._test_explicit(
|
||||||
@ -1585,13 +1570,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegratio
|
|||||||
|
|
||||||
|
|
||||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
|
|
||||||
async def test_explicit(self):
|
async def test_explicit(self):
|
||||||
return await self._test_explicit(
|
return await self._test_explicit(
|
||||||
@ -1613,6 +1597,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT
|
|||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
||||||
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
self.client_test = await self.async_rs_or_single_client(
|
self.client_test = await self.async_rs_or_single_client(
|
||||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||||
)
|
)
|
||||||
@ -1645,7 +1630,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
|||||||
self.ciphertext = await client_encryption.encrypt(
|
self.ciphertext = await client_encryption.encrypt(
|
||||||
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
||||||
)
|
)
|
||||||
await client_encryption.close()
|
|
||||||
|
|
||||||
self.client_listener = OvertCommandListener()
|
self.client_listener = OvertCommandListener()
|
||||||
self.topology_listener = TopologyEventListener()
|
self.topology_listener = TopologyEventListener()
|
||||||
@ -1840,6 +1824,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest):
|
|||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
||||||
class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
class TestDecryptProse(AsyncEncryptionIntegrationTest):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
self.client = async_client_context.client
|
self.client = async_client_context.client
|
||||||
await self.client.db.drop_collection("decryption_events")
|
await self.client.db.drop_collection("decryption_events")
|
||||||
await create_key_vault(self.client.keyvault.datakeys)
|
await create_key_vault(self.client.keyvault.datakeys)
|
||||||
@ -2275,6 +2260,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
|||||||
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
||||||
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
|
class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
self.client = async_client_context.client
|
self.client = async_client_context.client
|
||||||
await create_key_vault(self.client.keyvault.datakeys)
|
await create_key_vault(self.client.keyvault.datakeys)
|
||||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||||
@ -2624,8 +2610,6 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest):
|
|||||||
assert isinstance(res["encrypted_indexed"], Binary)
|
assert isinstance(res["encrypted_indexed"], Binary)
|
||||||
assert isinstance(res["encrypted_unindexed"], Binary)
|
assert isinstance(res["encrypted_unindexed"], Binary)
|
||||||
|
|
||||||
await client_encryption.close()
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
||||||
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
|
class TestRangeQueryProse(AsyncEncryptionIntegrationTest):
|
||||||
@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest):
|
|||||||
mongocryptd_client: AsyncMongoClient
|
mongocryptd_client: AsyncMongoClient
|
||||||
MONGOCRYPTD_PORT = 27020
|
MONGOCRYPTD_PORT = 27020
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self) -> None:
|
async def asyncSetUp(self) -> None:
|
||||||
|
await super().asyncSetUp()
|
||||||
|
start_mongocryptd(self.MONGOCRYPTD_PORT)
|
||||||
|
|
||||||
self.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
self.mongocryptd_client = self.simple_client(
|
self.mongocryptd_client = self.simple_client(
|
||||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||||
|
|||||||
@ -97,6 +97,7 @@ class AsyncTestGridFileNoConnect(AsyncUnitTest):
|
|||||||
|
|
||||||
class AsyncTestGridFile(AsyncIntegrationTest):
|
class AsyncTestGridFile(AsyncIntegrationTest):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
await super().asyncSetUp()
|
||||||
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||||
|
|
||||||
async def test_basic(self):
|
async def test_basic(self):
|
||||||
|
|||||||
@ -16,498 +16,447 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from pymongo.lock import _async_create_condition, _async_create_lock
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from pymongo.lock import _ACondition
|
if sys.version_info < (3, 13):
|
||||||
|
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||||
|
# Includes tests for:
|
||||||
|
# - https://github.com/python/cpython/issues/111693
|
||||||
|
# - https://github.com/python/cpython/issues/112202
|
||||||
|
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_wait(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def c1(result):
|
||||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
await cond.acquire()
|
||||||
# Includes tests for:
|
|
||||||
# - https://github.com/python/cpython/issues/111693
|
|
||||||
# - https://github.com/python/cpython/issues/112202
|
|
||||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
|
||||||
async def test_wait(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
result = []
|
|
||||||
|
|
||||||
async def c1(result):
|
|
||||||
await cond.acquire()
|
|
||||||
if await cond.wait():
|
|
||||||
result.append(1)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def c2(result):
|
|
||||||
await cond.acquire()
|
|
||||||
if await cond.wait():
|
|
||||||
result.append(2)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def c3(result):
|
|
||||||
await cond.acquire()
|
|
||||||
if await cond.wait():
|
|
||||||
result.append(3)
|
|
||||||
return True
|
|
||||||
|
|
||||||
t1 = asyncio.create_task(c1(result))
|
|
||||||
t2 = asyncio.create_task(c2(result))
|
|
||||||
t3 = asyncio.create_task(c3(result))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
self.assertFalse(cond.locked())
|
|
||||||
|
|
||||||
self.assertTrue(await cond.acquire())
|
|
||||||
cond.notify()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([1], result)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
cond.notify(2)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([1], result)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([1, 2], result)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([1, 2, 3], result)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
self.assertTrue(t1.done())
|
|
||||||
self.assertTrue(t1.result())
|
|
||||||
self.assertTrue(t2.done())
|
|
||||||
self.assertTrue(t2.result())
|
|
||||||
self.assertTrue(t3.done())
|
|
||||||
self.assertTrue(t3.result())
|
|
||||||
|
|
||||||
async def test_wait_cancel(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
await cond.acquire()
|
|
||||||
|
|
||||||
wait = asyncio.create_task(cond.wait())
|
|
||||||
asyncio.get_running_loop().call_soon(wait.cancel)
|
|
||||||
with self.assertRaises(asyncio.CancelledError):
|
|
||||||
await wait
|
|
||||||
self.assertFalse(cond._waiters)
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
async def test_wait_cancel_contested(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
await cond.acquire()
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
wait_task = asyncio.create_task(cond.wait())
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertFalse(cond.locked())
|
|
||||||
|
|
||||||
# Notify, but contest the lock before cancelling
|
|
||||||
await cond.acquire()
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
cond.notify()
|
|
||||||
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
|
||||||
asyncio.get_running_loop().call_soon(cond.release)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await wait_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Should not happen, since no cancellation points
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
|
|
||||||
async def test_wait_cancel_after_notify(self):
|
|
||||||
# See bpo-32841
|
|
||||||
waited = False
|
|
||||||
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
async def wait_on_cond():
|
|
||||||
nonlocal waited
|
|
||||||
async with cond:
|
|
||||||
waited = True # Make sure this area was reached
|
|
||||||
await cond.wait()
|
|
||||||
|
|
||||||
waiter = asyncio.create_task(wait_on_cond())
|
|
||||||
await asyncio.sleep(0) # Start waiting
|
|
||||||
|
|
||||||
await cond.acquire()
|
|
||||||
cond.notify()
|
|
||||||
await asyncio.sleep(0) # Get to acquire()
|
|
||||||
waiter.cancel()
|
|
||||||
await asyncio.sleep(0) # Activate cancellation
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0) # Cancellation should occur
|
|
||||||
|
|
||||||
self.assertTrue(waiter.cancelled())
|
|
||||||
self.assertTrue(waited)
|
|
||||||
|
|
||||||
async def test_wait_unacquired(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
await cond.wait()
|
|
||||||
|
|
||||||
async def test_wait_for(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
presult = False
|
|
||||||
|
|
||||||
def predicate():
|
|
||||||
return presult
|
|
||||||
|
|
||||||
result = []
|
|
||||||
|
|
||||||
async def c1(result):
|
|
||||||
await cond.acquire()
|
|
||||||
if await cond.wait_for(predicate):
|
|
||||||
result.append(1)
|
|
||||||
cond.release()
|
|
||||||
return True
|
|
||||||
|
|
||||||
t = asyncio.create_task(c1(result))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
|
|
||||||
await cond.acquire()
|
|
||||||
cond.notify()
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
|
|
||||||
presult = True
|
|
||||||
await cond.acquire()
|
|
||||||
cond.notify()
|
|
||||||
cond.release()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([1], result)
|
|
||||||
|
|
||||||
self.assertTrue(t.done())
|
|
||||||
self.assertTrue(t.result())
|
|
||||||
|
|
||||||
async def test_wait_for_unacquired(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
# predicate can return true immediately
|
|
||||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
|
||||||
self.assertEqual([1, 2, 3], res)
|
|
||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
await cond.wait_for(lambda: False)
|
|
||||||
|
|
||||||
async def test_notify(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
result = []
|
|
||||||
|
|
||||||
async def c1(result):
|
|
||||||
async with cond:
|
|
||||||
if await cond.wait():
|
if await cond.wait():
|
||||||
result.append(1)
|
result.append(1)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def c2(result):
|
async def c2(result):
|
||||||
async with cond:
|
await cond.acquire()
|
||||||
if await cond.wait():
|
if await cond.wait():
|
||||||
result.append(2)
|
result.append(2)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def c3(result):
|
async def c3(result):
|
||||||
async with cond:
|
await cond.acquire()
|
||||||
if await cond.wait():
|
if await cond.wait():
|
||||||
result.append(3)
|
result.append(3)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
t1 = asyncio.create_task(c1(result))
|
t1 = asyncio.create_task(c1(result))
|
||||||
t2 = asyncio.create_task(c2(result))
|
t2 = asyncio.create_task(c2(result))
|
||||||
t3 = asyncio.create_task(c3(result))
|
t3 = asyncio.create_task(c3(result))
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
|
|
||||||
async with cond:
|
|
||||||
cond.notify(1)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
self.assertEqual([1], result)
|
|
||||||
|
|
||||||
async with cond:
|
|
||||||
cond.notify(1)
|
|
||||||
cond.notify(2048)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
self.assertEqual([1, 2, 3], result)
|
|
||||||
|
|
||||||
self.assertTrue(t1.done())
|
|
||||||
self.assertTrue(t1.result())
|
|
||||||
self.assertTrue(t2.done())
|
|
||||||
self.assertTrue(t2.result())
|
|
||||||
self.assertTrue(t3.done())
|
|
||||||
self.assertTrue(t3.result())
|
|
||||||
|
|
||||||
async def test_notify_all(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
result = []
|
|
||||||
|
|
||||||
async def c1(result):
|
|
||||||
async with cond:
|
|
||||||
if await cond.wait():
|
|
||||||
result.append(1)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def c2(result):
|
|
||||||
async with cond:
|
|
||||||
if await cond.wait():
|
|
||||||
result.append(2)
|
|
||||||
return True
|
|
||||||
|
|
||||||
t1 = asyncio.create_task(c1(result))
|
|
||||||
t2 = asyncio.create_task(c2(result))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
self.assertEqual([], result)
|
|
||||||
|
|
||||||
async with cond:
|
|
||||||
cond.notify_all()
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
self.assertEqual([1, 2], result)
|
|
||||||
|
|
||||||
self.assertTrue(t1.done())
|
|
||||||
self.assertTrue(t1.result())
|
|
||||||
self.assertTrue(t2.done())
|
|
||||||
self.assertTrue(t2.result())
|
|
||||||
|
|
||||||
async def test_context_manager(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
self.assertFalse(cond.locked())
|
|
||||||
async with cond:
|
|
||||||
self.assertTrue(cond.locked())
|
|
||||||
self.assertFalse(cond.locked())
|
|
||||||
|
|
||||||
async def test_timeout_in_block(self):
|
|
||||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
async with condition:
|
|
||||||
with self.assertRaises(asyncio.TimeoutError):
|
|
||||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
|
||||||
)
|
|
||||||
async def test_cancelled_error_wakeup(self):
|
|
||||||
# Test that a cancelled error, received when awaiting wakeup,
|
|
||||||
# will be re-raised un-modified.
|
|
||||||
wake = False
|
|
||||||
raised = None
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
async def func():
|
|
||||||
nonlocal raised
|
|
||||||
async with cond:
|
|
||||||
with self.assertRaises(asyncio.CancelledError) as err:
|
|
||||||
await cond.wait_for(lambda: wake)
|
|
||||||
raised = err.exception
|
|
||||||
raise raised
|
|
||||||
|
|
||||||
task = asyncio.create_task(func())
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
# Task is waiting on the condition, cancel it there.
|
|
||||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
|
||||||
with self.assertRaises(asyncio.CancelledError) as err:
|
|
||||||
await task
|
|
||||||
self.assertEqual(err.exception.args, ("foo",))
|
|
||||||
# We should have got the _same_ exception instance as the one
|
|
||||||
# originally raised.
|
|
||||||
self.assertIs(err.exception, raised)
|
|
||||||
|
|
||||||
@unittest.skipIf(
|
|
||||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
|
||||||
)
|
|
||||||
async def test_cancelled_error_re_aquire(self):
|
|
||||||
# Test that a cancelled error, received when re-aquiring lock,
|
|
||||||
# will be re-raised un-modified.
|
|
||||||
wake = False
|
|
||||||
raised = None
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
async def func():
|
|
||||||
nonlocal raised
|
|
||||||
async with cond:
|
|
||||||
with self.assertRaises(asyncio.CancelledError) as err:
|
|
||||||
await cond.wait_for(lambda: wake)
|
|
||||||
raised = err.exception
|
|
||||||
raise raised
|
|
||||||
|
|
||||||
task = asyncio.create_task(func())
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
# Task is waiting on the condition
|
|
||||||
await cond.acquire()
|
|
||||||
wake = True
|
|
||||||
cond.notify()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
# Task is now trying to re-acquire the lock, cancel it there.
|
|
||||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
|
||||||
cond.release()
|
|
||||||
with self.assertRaises(asyncio.CancelledError) as err:
|
|
||||||
await task
|
|
||||||
self.assertEqual(err.exception.args, ("foo",))
|
|
||||||
# We should have got the _same_ exception instance as the one
|
|
||||||
# originally raised.
|
|
||||||
self.assertIs(err.exception, raised)
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
|
||||||
async def test_cancelled_wakeup(self):
|
|
||||||
# Test that a task cancelled at the "same" time as it is woken
|
|
||||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
|
||||||
# This test simulates a cancel while the target task is awaiting initial
|
|
||||||
# wakeup on the wakeup queue.
|
|
||||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
state = 0
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
nonlocal state
|
|
||||||
async with condition:
|
|
||||||
while True:
|
|
||||||
await condition.wait_for(lambda: state != 0)
|
|
||||||
if state < 0:
|
|
||||||
return
|
|
||||||
state -= 1
|
|
||||||
|
|
||||||
# create two consumers
|
|
||||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
|
||||||
# wait for them to settle
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
async with condition:
|
|
||||||
# produce one item and wake up one
|
|
||||||
state += 1
|
|
||||||
condition.notify(1)
|
|
||||||
|
|
||||||
# Cancel it while it is awaiting to be run.
|
|
||||||
# This cancellation could come from the outside
|
|
||||||
c[0].cancel()
|
|
||||||
|
|
||||||
# now wait for the item to be consumed
|
|
||||||
# if it doesn't means that our "notify" didn"t take hold.
|
|
||||||
# because it raced with a cancel()
|
|
||||||
try:
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await condition.wait_for(lambda: state == 0)
|
|
||||||
except TimeoutError:
|
|
||||||
pass
|
|
||||||
self.assertEqual(state, 0)
|
|
||||||
|
|
||||||
# clean up
|
|
||||||
state = -1
|
|
||||||
condition.notify_all()
|
|
||||||
await c[1]
|
|
||||||
|
|
||||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
|
||||||
async def test_cancelled_wakeup_relock(self):
|
|
||||||
# Test that a task cancelled at the "same" time as it is woken
|
|
||||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
|
||||||
# This test simulates a cancel while the target task is acquiring the lock
|
|
||||||
# again.
|
|
||||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
state = 0
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
nonlocal state
|
|
||||||
async with condition:
|
|
||||||
while True:
|
|
||||||
await condition.wait_for(lambda: state != 0)
|
|
||||||
if state < 0:
|
|
||||||
return
|
|
||||||
state -= 1
|
|
||||||
|
|
||||||
# create two consumers
|
|
||||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
|
||||||
# wait for them to settle
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
async with condition:
|
|
||||||
# produce one item and wake up one
|
|
||||||
state += 1
|
|
||||||
condition.notify(1)
|
|
||||||
|
|
||||||
# now we sleep for a bit. This allows the target task to wake up and
|
|
||||||
# settle on re-aquiring the lock
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([], result)
|
||||||
|
self.assertFalse(cond.locked())
|
||||||
|
|
||||||
# Cancel it while awaiting the lock
|
self.assertTrue(await cond.acquire())
|
||||||
# This cancel could come the outside.
|
cond.notify()
|
||||||
c[0].cancel()
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([], result)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([1], result)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
cond.notify(2)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([1], result)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([1, 2], result)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([1, 2, 3], result)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
self.assertTrue(t1.done())
|
||||||
|
self.assertTrue(t1.result())
|
||||||
|
self.assertTrue(t2.done())
|
||||||
|
self.assertTrue(t2.result())
|
||||||
|
self.assertTrue(t3.done())
|
||||||
|
self.assertTrue(t3.result())
|
||||||
|
|
||||||
|
async def test_wait_cancel(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
await cond.acquire()
|
||||||
|
|
||||||
|
wait = asyncio.create_task(cond.wait())
|
||||||
|
asyncio.get_running_loop().call_soon(wait.cancel)
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
await wait
|
||||||
|
self.assertFalse(cond._waiters)
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
async def test_wait_cancel_contested(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
|
||||||
|
await cond.acquire()
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
|
||||||
|
wait_task = asyncio.create_task(cond.wait())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertFalse(cond.locked())
|
||||||
|
|
||||||
|
# Notify, but contest the lock before cancelling
|
||||||
|
await cond.acquire()
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
cond.notify()
|
||||||
|
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
||||||
|
asyncio.get_running_loop().call_soon(cond.release)
|
||||||
|
|
||||||
# now wait for the item to be consumed
|
|
||||||
# if it doesn't means that our "notify" didn"t take hold.
|
|
||||||
# because it raced with a cancel()
|
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(1):
|
await wait_task
|
||||||
await condition.wait_for(lambda: state == 0)
|
except asyncio.CancelledError:
|
||||||
except TimeoutError:
|
# Should not happen, since no cancellation points
|
||||||
pass
|
pass
|
||||||
self.assertEqual(state, 0)
|
|
||||||
|
|
||||||
# clean up
|
self.assertTrue(cond.locked())
|
||||||
state = -1
|
|
||||||
condition.notify_all()
|
|
||||||
await c[1]
|
|
||||||
|
|
||||||
|
async def test_wait_cancel_after_notify(self):
|
||||||
|
# See bpo-32841
|
||||||
|
waited = False
|
||||||
|
|
||||||
class TestCondition(unittest.IsolatedAsyncioTestCase):
|
cond = _async_create_condition(_async_create_lock())
|
||||||
async def test_multiple_loops_notify(self):
|
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
|
||||||
|
|
||||||
def tmain(cond):
|
async def wait_on_cond():
|
||||||
async def atmain(cond):
|
nonlocal waited
|
||||||
await asyncio.sleep(1)
|
|
||||||
async with cond:
|
async with cond:
|
||||||
cond.notify(1)
|
waited = True # Make sure this area was reached
|
||||||
|
await cond.wait()
|
||||||
|
|
||||||
asyncio.run(atmain(cond))
|
waiter = asyncio.create_task(wait_on_cond())
|
||||||
|
await asyncio.sleep(0) # Start waiting
|
||||||
|
|
||||||
t = threading.Thread(target=tmain, args=(cond,))
|
await cond.acquire()
|
||||||
t.start()
|
cond.notify()
|
||||||
|
await asyncio.sleep(0) # Get to acquire()
|
||||||
|
waiter.cancel()
|
||||||
|
await asyncio.sleep(0) # Activate cancellation
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0) # Cancellation should occur
|
||||||
|
|
||||||
async with cond:
|
self.assertTrue(waiter.cancelled())
|
||||||
self.assertTrue(await cond.wait(30))
|
self.assertTrue(waited)
|
||||||
t.join()
|
|
||||||
|
|
||||||
async def test_multiple_loops_notify_all(self):
|
async def test_wait_unacquired(self):
|
||||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
cond = _async_create_condition(_async_create_lock())
|
||||||
results = []
|
with self.assertRaises(RuntimeError):
|
||||||
|
await cond.wait()
|
||||||
|
|
||||||
def tmain(cond, results):
|
async def test_wait_for(self):
|
||||||
async def atmain(cond, results):
|
cond = _async_create_condition(_async_create_lock())
|
||||||
await asyncio.sleep(1)
|
presult = False
|
||||||
|
|
||||||
|
def predicate():
|
||||||
|
return presult
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def c1(result):
|
||||||
|
await cond.acquire()
|
||||||
|
if await cond.wait_for(predicate):
|
||||||
|
result.append(1)
|
||||||
|
cond.release()
|
||||||
|
return True
|
||||||
|
|
||||||
|
t = asyncio.create_task(c1(result))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([], result)
|
||||||
|
|
||||||
|
await cond.acquire()
|
||||||
|
cond.notify()
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([], result)
|
||||||
|
|
||||||
|
presult = True
|
||||||
|
await cond.acquire()
|
||||||
|
cond.notify()
|
||||||
|
cond.release()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([1], result)
|
||||||
|
|
||||||
|
self.assertTrue(t.done())
|
||||||
|
self.assertTrue(t.result())
|
||||||
|
|
||||||
|
async def test_wait_for_unacquired(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
|
||||||
|
# predicate can return true immediately
|
||||||
|
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||||
|
self.assertEqual([1, 2, 3], res)
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
await cond.wait_for(lambda: False)
|
||||||
|
|
||||||
|
async def test_notify(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def c1(result):
|
||||||
async with cond:
|
async with cond:
|
||||||
res = await cond.wait(30)
|
if await cond.wait():
|
||||||
results.append(res)
|
result.append(1)
|
||||||
|
return True
|
||||||
|
|
||||||
asyncio.run(atmain(cond, results))
|
async def c2(result):
|
||||||
|
async with cond:
|
||||||
|
if await cond.wait():
|
||||||
|
result.append(2)
|
||||||
|
return True
|
||||||
|
|
||||||
nthreads = 5
|
async def c3(result):
|
||||||
threads = []
|
async with cond:
|
||||||
for _ in range(nthreads):
|
if await cond.wait():
|
||||||
threads.append(threading.Thread(target=tmain, args=(cond, results)))
|
result.append(3)
|
||||||
for t in threads:
|
return True
|
||||||
t.start()
|
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
t1 = asyncio.create_task(c1(result))
|
||||||
async with cond:
|
t2 = asyncio.create_task(c2(result))
|
||||||
cond.notify_all()
|
t3 = asyncio.create_task(c3(result))
|
||||||
|
|
||||||
for t in threads:
|
await asyncio.sleep(0)
|
||||||
t.join()
|
self.assertEqual([], result)
|
||||||
|
|
||||||
self.assertEqual(results, [True] * nthreads)
|
async with cond:
|
||||||
|
cond.notify(1)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
self.assertEqual([1], result)
|
||||||
|
|
||||||
|
async with cond:
|
||||||
|
cond.notify(1)
|
||||||
|
cond.notify(2048)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
self.assertEqual([1, 2, 3], result)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
self.assertTrue(t1.done())
|
||||||
unittest.main()
|
self.assertTrue(t1.result())
|
||||||
|
self.assertTrue(t2.done())
|
||||||
|
self.assertTrue(t2.result())
|
||||||
|
self.assertTrue(t3.done())
|
||||||
|
self.assertTrue(t3.result())
|
||||||
|
|
||||||
|
async def test_notify_all(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def c1(result):
|
||||||
|
async with cond:
|
||||||
|
if await cond.wait():
|
||||||
|
result.append(1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def c2(result):
|
||||||
|
async with cond:
|
||||||
|
if await cond.wait():
|
||||||
|
result.append(2)
|
||||||
|
return True
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(c1(result))
|
||||||
|
t2 = asyncio.create_task(c2(result))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
self.assertEqual([], result)
|
||||||
|
|
||||||
|
async with cond:
|
||||||
|
cond.notify_all()
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
self.assertEqual([1, 2], result)
|
||||||
|
|
||||||
|
self.assertTrue(t1.done())
|
||||||
|
self.assertTrue(t1.result())
|
||||||
|
self.assertTrue(t2.done())
|
||||||
|
self.assertTrue(t2.result())
|
||||||
|
|
||||||
|
async def test_context_manager(self):
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
self.assertFalse(cond.locked())
|
||||||
|
async with cond:
|
||||||
|
self.assertTrue(cond.locked())
|
||||||
|
self.assertFalse(cond.locked())
|
||||||
|
|
||||||
|
async def test_timeout_in_block(self):
|
||||||
|
condition = _async_create_condition(_async_create_lock())
|
||||||
|
async with condition:
|
||||||
|
with self.assertRaises(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||||
|
)
|
||||||
|
async def test_cancelled_error_wakeup(self):
|
||||||
|
# Test that a cancelled error, received when awaiting wakeup,
|
||||||
|
# will be re-raised un-modified.
|
||||||
|
wake = False
|
||||||
|
raised = None
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
nonlocal raised
|
||||||
|
async with cond:
|
||||||
|
with self.assertRaises(asyncio.CancelledError) as err:
|
||||||
|
await cond.wait_for(lambda: wake)
|
||||||
|
raised = err.exception
|
||||||
|
raise raised
|
||||||
|
|
||||||
|
task = asyncio.create_task(func())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Task is waiting on the condition, cancel it there.
|
||||||
|
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||||
|
with self.assertRaises(asyncio.CancelledError) as err:
|
||||||
|
await task
|
||||||
|
self.assertEqual(err.exception.args, ("foo",))
|
||||||
|
# We should have got the _same_ exception instance as the one
|
||||||
|
# originally raised.
|
||||||
|
self.assertIs(err.exception, raised)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||||
|
)
|
||||||
|
async def test_cancelled_error_re_aquire(self):
|
||||||
|
# Test that a cancelled error, received when re-aquiring lock,
|
||||||
|
# will be re-raised un-modified.
|
||||||
|
wake = False
|
||||||
|
raised = None
|
||||||
|
cond = _async_create_condition(_async_create_lock())
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
nonlocal raised
|
||||||
|
async with cond:
|
||||||
|
with self.assertRaises(asyncio.CancelledError) as err:
|
||||||
|
await cond.wait_for(lambda: wake)
|
||||||
|
raised = err.exception
|
||||||
|
raise raised
|
||||||
|
|
||||||
|
task = asyncio.create_task(func())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Task is waiting on the condition
|
||||||
|
await cond.acquire()
|
||||||
|
wake = True
|
||||||
|
cond.notify()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
# Task is now trying to re-acquire the lock, cancel it there.
|
||||||
|
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||||
|
cond.release()
|
||||||
|
with self.assertRaises(asyncio.CancelledError) as err:
|
||||||
|
await task
|
||||||
|
self.assertEqual(err.exception.args, ("foo",))
|
||||||
|
# We should have got the _same_ exception instance as the one
|
||||||
|
# originally raised.
|
||||||
|
self.assertIs(err.exception, raised)
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||||
|
async def test_cancelled_wakeup(self):
|
||||||
|
# Test that a task cancelled at the "same" time as it is woken
|
||||||
|
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||||
|
# This test simulates a cancel while the target task is awaiting initial
|
||||||
|
# wakeup on the wakeup queue.
|
||||||
|
condition = _async_create_condition(_async_create_lock())
|
||||||
|
state = 0
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
nonlocal state
|
||||||
|
async with condition:
|
||||||
|
while True:
|
||||||
|
await condition.wait_for(lambda: state != 0)
|
||||||
|
if state < 0:
|
||||||
|
return
|
||||||
|
state -= 1
|
||||||
|
|
||||||
|
# create two consumers
|
||||||
|
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||||
|
# wait for them to settle
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
async with condition:
|
||||||
|
# produce one item and wake up one
|
||||||
|
state += 1
|
||||||
|
condition.notify(1)
|
||||||
|
|
||||||
|
# Cancel it while it is awaiting to be run.
|
||||||
|
# This cancellation could come from the outside
|
||||||
|
c[0].cancel()
|
||||||
|
|
||||||
|
# now wait for the item to be consumed
|
||||||
|
# if it doesn't means that our "notify" didn"t take hold.
|
||||||
|
# because it raced with a cancel()
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await condition.wait_for(lambda: state == 0)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
self.assertEqual(state, 0)
|
||||||
|
|
||||||
|
# clean up
|
||||||
|
state = -1
|
||||||
|
condition.notify_all()
|
||||||
|
await c[1]
|
||||||
|
|
||||||
|
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||||
|
async def test_cancelled_wakeup_relock(self):
|
||||||
|
# Test that a task cancelled at the "same" time as it is woken
|
||||||
|
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||||
|
# This test simulates a cancel while the target task is acquiring the lock
|
||||||
|
# again.
|
||||||
|
condition = _async_create_condition(_async_create_lock())
|
||||||
|
state = 0
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
nonlocal state
|
||||||
|
async with condition:
|
||||||
|
while True:
|
||||||
|
await condition.wait_for(lambda: state != 0)
|
||||||
|
if state < 0:
|
||||||
|
return
|
||||||
|
state -= 1
|
||||||
|
|
||||||
|
# create two consumers
|
||||||
|
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||||
|
# wait for them to settle
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
async with condition:
|
||||||
|
# produce one item and wake up one
|
||||||
|
state += 1
|
||||||
|
condition.notify(1)
|
||||||
|
|
||||||
|
# now we sleep for a bit. This allows the target task to wake up and
|
||||||
|
# settle on re-aquiring the lock
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# Cancel it while awaiting the lock
|
||||||
|
# This cancel could come the outside.
|
||||||
|
c[0].cancel()
|
||||||
|
|
||||||
|
# now wait for the item to be consumed
|
||||||
|
# if it doesn't means that our "notify" didn"t take hold.
|
||||||
|
# because it raced with a cancel()
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await condition.wait_for(lambda: state == 0)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
self.assertEqual(state, 0)
|
||||||
|
|
||||||
|
# clean up
|
||||||
|
state = -1
|
||||||
|
condition.notify_all()
|
||||||
|
await c[1]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest):
|
|||||||
listener: EventListener
|
listener: EventListener
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_client_context.require_connection
|
def setUpClass(cls) -> None:
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
cls.listener = OvertCommandListener()
|
cls.listener = OvertCommandListener()
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
|
||||||
event_listeners=[cls.listener], retryWrites=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@async_client_context.require_connection
|
||||||
async def _tearDown_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await cls.client.close()
|
await super().asyncSetUp()
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
await super().asyncTearDown()
|
self.client = await self.async_rs_or_single_client(
|
||||||
|
event_listeners=[self.listener], retryWrites=False
|
||||||
|
)
|
||||||
|
|
||||||
async def test_started_simple(self):
|
async def test_started_simple(self):
|
||||||
await self.client.pymongo_test.command("ping")
|
await self.client.pymongo_test.command("ping")
|
||||||
@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest):
|
|||||||
saved_listeners: Any
|
saved_listeners: Any
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_client_context.require_connection
|
def setUpClass(cls) -> None:
|
||||||
async def _setup_class(cls):
|
|
||||||
await super()._setup_class()
|
|
||||||
cls.listener = OvertCommandListener()
|
cls.listener = OvertCommandListener()
|
||||||
# We plan to call register(), which internally modifies _LISTENERS.
|
# We plan to call register(), which internally modifies _LISTENERS.
|
||||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||||
monitoring.register(cls.listener)
|
monitoring.register(cls.listener)
|
||||||
cls.client = await cls.unmanaged_async_single_client()
|
|
||||||
# Get one (authenticated) socket in the pool.
|
|
||||||
await cls.client.pymongo_test.command("ping")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
monitoring._LISTENERS = cls.saved_listeners
|
|
||||||
await cls.client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
|
@async_client_context.require_connection
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
self.client = await self.async_single_client()
|
||||||
|
# Get one (authenticated) socket in the pool.
|
||||||
|
await self.client.pymongo_test.command("ping")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
monitoring._LISTENERS = cls.saved_listeners
|
||||||
|
|
||||||
async def test_simple(self):
|
async def test_simple(self):
|
||||||
await self.client.pymongo_test.command("ping")
|
await self.client.pymongo_test.command("ping")
|
||||||
|
|||||||
@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
|
|||||||
RUN_ON_SERVERLESS = True
|
RUN_ON_SERVERLESS = True
|
||||||
deprecation_filter: DeprecationFilter
|
deprecation_filter: DeprecationFilter
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
self.deprecation_filter = DeprecationFilter()
|
||||||
cls.deprecation_filter = DeprecationFilter()
|
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self) -> None:
|
||||||
async def _tearDown_class(cls):
|
self.deprecation_filter.stop()
|
||||||
cls.deprecation_filter.stop()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
|
|
||||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True)
|
self.client = await self.async_rs_or_single_client(retryWrites=True)
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
@classmethod
|
async def asyncTearDown(self) -> None:
|
||||||
async def _tearDown_class(cls):
|
self.knobs.disable()
|
||||||
cls.knobs.disable()
|
|
||||||
await cls.client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
@async_client_context.require_no_standalone
|
@async_client_context.require_no_standalone
|
||||||
async def test_actionable_error_message(self):
|
async def test_actionable_error_message(self):
|
||||||
@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
|||||||
listener: OvertCommandListener
|
listener: OvertCommandListener
|
||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_no_mmap
|
@async_client_context.require_no_mmap
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(
|
self.client = await self.async_rs_or_single_client(
|
||||||
retryWrites=True, event_listeners=[cls.listener]
|
retryWrites=True, event_listeners=[self.listener]
|
||||||
)
|
)
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
cls.knobs.disable()
|
|
||||||
await cls.client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
if async_client_context.is_rs and async_client_context.test_commands_enabled:
|
||||||
await self.client.admin.command(
|
await self.client.admin.command(
|
||||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||||
@ -210,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
|||||||
await self.client.admin.command(
|
await self.client.admin.command(
|
||||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||||
)
|
)
|
||||||
|
self.knobs.disable()
|
||||||
|
|
||||||
async def test_supported_single_statement_no_retry(self):
|
async def test_supported_single_statement_no_retry(self):
|
||||||
listener = OvertCommandListener()
|
listener = OvertCommandListener()
|
||||||
@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest):
|
|||||||
RUN_ON_SERVERLESS = True
|
RUN_ON_SERVERLESS = True
|
||||||
fail_insert: dict
|
fail_insert: dict
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_replica_set
|
@async_client_context.require_replica_set
|
||||||
@async_client_context.require_no_mmap
|
@async_client_context.require_no_mmap
|
||||||
@async_client_context.require_failCommand_fail_point
|
@async_client_context.require_failCommand_fail_point
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self) -> None:
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
cls.fail_insert = {
|
self.fail_insert = {
|
||||||
"configureFailPoint": "failCommand",
|
"configureFailPoint": "failCommand",
|
||||||
"mode": {"times": 2},
|
"mode": {"times": 2},
|
||||||
"data": {
|
"data": {
|
||||||
|
|||||||
@ -38,7 +38,6 @@ from test.utils import (
|
|||||||
ExceptionCatchingThread,
|
ExceptionCatchingThread,
|
||||||
OvertCommandListener,
|
OvertCommandListener,
|
||||||
async_wait_until,
|
async_wait_until,
|
||||||
wait_until,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from bson import DBRef
|
from bson import DBRef
|
||||||
@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest):
|
|||||||
client2: AsyncMongoClient
|
client2: AsyncMongoClient
|
||||||
sensitive_commands: Set[str]
|
sensitive_commands: Set[str]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@async_client_context.require_sessions
|
@async_client_context.require_sessions
|
||||||
async def _setup_class(cls):
|
async def asyncSetUp(self):
|
||||||
await super()._setup_class()
|
await super().asyncSetUp()
|
||||||
# Create a second client so we can make sure clients cannot share
|
# Create a second client so we can make sure clients cannot share
|
||||||
# sessions.
|
# sessions.
|
||||||
cls.client2 = await cls.unmanaged_async_rs_or_single_client()
|
self.client2 = await self.async_rs_or_single_client()
|
||||||
|
|
||||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||||
monitoring._SENSITIVE_COMMANDS.clear()
|
monitoring._SENSITIVE_COMMANDS.clear()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
|
||||||
await cls.client2.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
self.listener = SessionTestListener()
|
self.listener = SessionTestListener()
|
||||||
self.session_checker_listener = SessionTestListener()
|
self.session_checker_listener = SessionTestListener()
|
||||||
self.client = await self.async_rs_or_single_client(
|
self.client = await self.async_rs_or_single_client(
|
||||||
event_listeners=[self.listener, self.session_checker_listener]
|
event_listeners=[self.listener, self.session_checker_listener]
|
||||||
)
|
)
|
||||||
self.addAsyncCleanup(self.client.close)
|
|
||||||
self.db = self.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
"""All sessions used in the test must be returned to the pool."""
|
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||||
await self.client.drop_database("pymongo_test")
|
await self.client.drop_database("pymongo_test")
|
||||||
used_lsids = self.initial_lsids.copy()
|
used_lsids = self.initial_lsids.copy()
|
||||||
for event in self.session_checker_listener.started_events:
|
for event in self.session_checker_listener.started_events:
|
||||||
@ -122,6 +112,8 @@ class TestSession(AsyncIntegrationTest):
|
|||||||
current_lsids = {s["id"] for s in session_ids(self.client)}
|
current_lsids = {s["id"] for s in session_ids(self.client)}
|
||||||
self.assertLessEqual(used_lsids, current_lsids)
|
self.assertLessEqual(used_lsids, current_lsids)
|
||||||
|
|
||||||
|
await super().asyncTearDown()
|
||||||
|
|
||||||
async def _test_ops(self, client, *ops):
|
async def _test_ops(self, client, *ops):
|
||||||
listener = client.options.event_listeners[0]
|
listener = client.options.event_listeners[0]
|
||||||
|
|
||||||
@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest):
|
|||||||
listener: SessionTestListener
|
listener: SessionTestListener
|
||||||
client: AsyncMongoClient
|
client: AsyncMongoClient
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _setup_class(cls):
|
|
||||||
cls.listener = SessionTestListener()
|
|
||||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
await cls.client.close()
|
|
||||||
|
|
||||||
@async_client_context.require_sessions
|
@async_client_context.require_sessions
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
|
self.listener = SessionTestListener()
|
||||||
|
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||||
|
|
||||||
@async_client_context.require_no_standalone
|
@async_client_context.require_no_standalone
|
||||||
async def test_core(self):
|
async def test_core(self):
|
||||||
|
|||||||
@ -26,7 +26,7 @@ sys.path[0:0] = [""]
|
|||||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
OvertCommandListener,
|
OvertCommandListener,
|
||||||
wait_until,
|
async_wait_until,
|
||||||
)
|
)
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class TestTransactions(AsyncTransactionsBase):
|
|||||||
client = await self.async_rs_client(
|
client = await self.async_rs_client(
|
||||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||||
)
|
)
|
||||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||||
coll = client.test.test
|
coll = client.test.test
|
||||||
# Create the collection.
|
# Create the collection.
|
||||||
await coll.insert_one({})
|
await coll.insert_one({})
|
||||||
@ -191,7 +191,7 @@ class TestTransactions(AsyncTransactionsBase):
|
|||||||
client = await self.async_rs_client(
|
client = await self.async_rs_client(
|
||||||
async_client_context.mongos_seeds(), localThresholdMS=1000
|
async_client_context.mongos_seeds(), localThresholdMS=1000
|
||||||
)
|
)
|
||||||
wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses")
|
||||||
coll = client.test.test
|
coll = client.test.test
|
||||||
# Create the collection.
|
# Create the collection.
|
||||||
await coll.insert_one({})
|
await coll.insert_one({})
|
||||||
@ -403,21 +403,12 @@ class PatchSessionTimeout:
|
|||||||
|
|
||||||
|
|
||||||
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
self.mongos_clients = []
|
||||||
cls.mongos_clients = []
|
|
||||||
if async_client_context.supports_transactions():
|
if async_client_context.supports_transactions():
|
||||||
for address in async_client_context.mongoses:
|
for address in async_client_context.mongoses:
|
||||||
cls.mongos_clients.append(
|
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
for client in cls.mongos_clients:
|
|
||||||
await client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def _set_fail_point(self, client, command_args):
|
async def _set_fail_point(self, client, command_args):
|
||||||
cmd = {"configureFailPoint": "failCommand"}
|
cmd = {"configureFailPoint": "failCommand"}
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from test.unified_format_shared import (
|
|||||||
)
|
)
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
async_get_pool,
|
async_get_pool,
|
||||||
|
async_wait_until,
|
||||||
camel_to_snake,
|
camel_to_snake,
|
||||||
camel_to_snake_args,
|
camel_to_snake_args,
|
||||||
parse_spec_options,
|
parse_spec_options,
|
||||||
@ -304,7 +305,6 @@ class EntityMapUtil:
|
|||||||
kwargs["h"] = uri
|
kwargs["h"] = uri
|
||||||
client = await self.test.async_rs_or_single_client(**kwargs)
|
client = await self.test.async_rs_or_single_client(**kwargs)
|
||||||
self[spec["id"]] = client
|
self[spec["id"]] = client
|
||||||
self.test.addAsyncCleanup(client.close)
|
|
||||||
return
|
return
|
||||||
elif entity_type == "database":
|
elif entity_type == "database":
|
||||||
client = self[spec["client"]]
|
client = self[spec["client"]]
|
||||||
@ -479,33 +479,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
|||||||
await db.create_collection(coll_name, write_concern=wc, **opts)
|
await db.create_collection(coll_name, write_concern=wc, **opts)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _setup_class(cls):
|
def setUpClass(cls) -> None:
|
||||||
# super call creates internal client cls.client
|
|
||||||
await super()._setup_class()
|
|
||||||
# process file-level runOnRequirements
|
|
||||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
|
||||||
if not await cls.should_run_on(run_on_spec):
|
|
||||||
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
|
|
||||||
|
|
||||||
# add any special-casing for skipping tests here
|
|
||||||
if async_client_context.storage_engine == "mmapv1":
|
|
||||||
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
|
|
||||||
cls.TEST_PATH
|
|
||||||
):
|
|
||||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
|
||||||
|
|
||||||
# Handle mongos_clients for transactions tests.
|
|
||||||
cls.mongos_clients = []
|
|
||||||
if (
|
|
||||||
async_client_context.supports_transactions()
|
|
||||||
and not async_client_context.load_balancer
|
|
||||||
and not async_client_context.serverless
|
|
||||||
):
|
|
||||||
for address in async_client_context.mongoses:
|
|
||||||
cls.mongos_clients.append(
|
|
||||||
await cls.unmanaged_async_single_client("{}:{}".format(*address))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(
|
cls.knobs = client_knobs(
|
||||||
heartbeat_frequency=0.1,
|
heartbeat_frequency=0.1,
|
||||||
@ -516,17 +490,36 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
|||||||
cls.knobs.enable()
|
cls.knobs.enable()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _tearDown_class(cls):
|
def tearDownClass(cls) -> None:
|
||||||
cls.knobs.disable()
|
cls.knobs.disable()
|
||||||
for client in cls.mongos_clients:
|
|
||||||
await client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
|
# super call creates internal client cls.client
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
|
# process file-level runOnRequirements
|
||||||
|
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
|
||||||
|
if not await self.should_run_on(run_on_spec):
|
||||||
|
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
|
||||||
|
|
||||||
|
# add any special-casing for skipping tests here
|
||||||
|
if async_client_context.storage_engine == "mmapv1":
|
||||||
|
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||||
|
self.TEST_PATH
|
||||||
|
):
|
||||||
|
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||||
|
|
||||||
|
# Handle mongos_clients for transactions tests.
|
||||||
|
self.mongos_clients = []
|
||||||
|
if (
|
||||||
|
async_client_context.supports_transactions()
|
||||||
|
and not async_client_context.load_balancer
|
||||||
|
and not async_client_context.serverless
|
||||||
|
):
|
||||||
|
for address in async_client_context.mongoses:
|
||||||
|
self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address)))
|
||||||
|
|
||||||
# process schemaVersion
|
# process schemaVersion
|
||||||
# note: we check major schema version during class generation
|
# note: we check major schema version during class generation
|
||||||
# note: we do this here because we cannot run assertions in setUpClass
|
|
||||||
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
version,
|
version,
|
||||||
@ -1036,7 +1029,6 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
|
client = await self.async_single_client("{}:{}".format(*session._pinned_address))
|
||||||
self.addAsyncCleanup(client.close)
|
|
||||||
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
await self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||||
|
|
||||||
async def _testOperation_createEntities(self, spec):
|
async def _testOperation_createEntities(self, spec):
|
||||||
@ -1137,13 +1129,13 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
|||||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||||
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
|
self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}")
|
||||||
|
|
||||||
def _testOperation_waitForEvent(self, spec):
|
async def _testOperation_waitForEvent(self, spec):
|
||||||
"""Run the waitForEvent test operation.
|
"""Run the waitForEvent test operation.
|
||||||
|
|
||||||
Wait for a number of events to be published, or fail.
|
Wait for a number of events to be published, or fail.
|
||||||
"""
|
"""
|
||||||
client, event, count = spec["client"], spec["event"], spec["count"]
|
client, event, count = spec["client"], spec["event"], spec["count"]
|
||||||
wait_until(
|
await async_wait_until(
|
||||||
lambda: self._event_count(client, event) >= count,
|
lambda: self._event_count(client, event) >= count,
|
||||||
f"find {count} {event} event(s)",
|
f"find {count} {event} event(s)",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
|||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
listener: EventListener
|
listener: EventListener
|
||||||
|
|
||||||
@classmethod
|
async def asyncSetUp(self) -> None:
|
||||||
async def _setup_class(cls):
|
await super().asyncSetUp()
|
||||||
await super()._setup_class()
|
self.mongos_clients = []
|
||||||
cls.mongos_clients = []
|
|
||||||
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _tearDown_class(cls):
|
|
||||||
cls.knobs.disable()
|
|
||||||
for client in cls.mongos_clients:
|
|
||||||
await client.close()
|
|
||||||
await super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self.targets = {}
|
self.targets = {}
|
||||||
self.listener = None # type: ignore
|
self.listener = None # type: ignore
|
||||||
self.pool_listener = None
|
self.pool_listener = None
|
||||||
self.server_listener = None
|
self.server_listener = None
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
|
||||||
|
async def asyncTearDown(self) -> None:
|
||||||
|
self.knobs.disable()
|
||||||
|
|
||||||
async def _set_fail_point(self, client, command_args):
|
async def _set_fail_point(self, client, command_args):
|
||||||
cmd = SON([("configureFailPoint", "failCommand")])
|
cmd = SON([("configureFailPoint", "failCommand")])
|
||||||
cmd.update(command_args)
|
cmd.update(command_args)
|
||||||
@ -700,8 +692,6 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
|||||||
self.listener = listener
|
self.listener = listener
|
||||||
self.pool_listener = pool_listener
|
self.pool_listener = pool_listener
|
||||||
self.server_listener = server_listener
|
self.server_listener = server_listener
|
||||||
# Close the client explicitly to avoid having too many threads open.
|
|
||||||
self.addAsyncCleanup(client.close)
|
|
||||||
|
|
||||||
# Create session0 and session1.
|
# Create session0 and session1.
|
||||||
sessions = {}
|
sessions = {}
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def event_loop_policy():
|
|||||||
return asyncio.get_event_loop_policy()
|
return asyncio.get_event_loop_policy()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="package", autouse=True)
|
||||||
def test_setup_and_teardown():
|
def test_setup_and_teardown():
|
||||||
setup()
|
setup()
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest):
|
|||||||
coll: Collection
|
coll: Collection
|
||||||
coll_w0: Collection
|
coll_w0: Collection
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
cls.coll = cls.db.test
|
|
||||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.coll = self.db.test
|
||||||
self.coll.drop()
|
self.coll.drop()
|
||||||
|
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
|
||||||
|
|
||||||
def assertEqualResponse(self, expected, actual):
|
def assertEqualResponse(self, expected, actual):
|
||||||
"""Compare response from bulk.execute() to expected response."""
|
"""Compare response from bulk.execute() to expected response."""
|
||||||
@ -785,12 +781,8 @@ class TestBulk(BulkTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class BulkAuthorizationTestBase(BulkTestBase):
|
class BulkAuthorizationTestBase(BulkTestBase):
|
||||||
@classmethod
|
|
||||||
@client_context.require_auth
|
@client_context.require_auth
|
||||||
@client_context.require_no_api_version
|
@client_context.require_no_api_version
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
client_context.create_user(self.db.name, "readonly", "pw", ["read"])
|
||||||
@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase):
|
|||||||
w: Optional[int]
|
w: Optional[int]
|
||||||
secondary: MongoClient
|
secondary: MongoClient
|
||||||
|
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
self.w = client_context.w
|
||||||
cls.w = client_context.w
|
self.secondary = None
|
||||||
cls.secondary = None
|
if self.w is not None and self.w > 1:
|
||||||
if cls.w is not None and cls.w > 1:
|
|
||||||
for member in (client_context.hello)["hosts"]:
|
for member in (client_context.hello)["hosts"]:
|
||||||
if member != (client_context.hello)["primary"]:
|
if member != (client_context.hello)["primary"]:
|
||||||
cls.secondary = cls.unmanaged_single_client(*partition_node(member))
|
self.secondary = self.single_client(*partition_node(member))
|
||||||
break
|
break
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self):
|
||||||
def async_tearDownClass(cls):
|
if self.secondary:
|
||||||
if cls.secondary:
|
self.secondary.close()
|
||||||
cls.secondary.close()
|
|
||||||
|
|
||||||
def cause_wtimeout(self, requests, ordered):
|
def cause_wtimeout(self, requests, ordered):
|
||||||
if not client_context.test_commands_enabled:
|
if not client_context.test_commands_enabled:
|
||||||
|
|||||||
@ -820,18 +820,16 @@ class ProseSpecTestsMixin:
|
|||||||
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||||
dbs: list
|
dbs: list
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_version_min(4, 0, 0, -1)
|
@client_context.require_version_min(4, 0, 0, -1)
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
cls.dbs = [cls.db, cls.client.pymongo_test_2]
|
self.dbs = [self.db, self.client.pymongo_test_2]
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self):
|
||||||
def _tearDown_class(cls):
|
for db in self.dbs:
|
||||||
for db in cls.dbs:
|
self.client.drop_database(db)
|
||||||
cls.client.drop_database(db)
|
super().tearDown()
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def change_stream_with_client(self, client, *args, **kwargs):
|
def change_stream_with_client(self, client, *args, **kwargs):
|
||||||
return client.watch(*args, **kwargs)
|
return client.watch(*args, **kwargs)
|
||||||
@ -882,11 +880,10 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
|
|||||||
|
|
||||||
|
|
||||||
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
||||||
@classmethod
|
|
||||||
@client_context.require_version_min(4, 0, 0, -1)
|
@client_context.require_version_min(4, 0, 0, -1)
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
def change_stream_with_client(self, client, *args, **kwargs):
|
def change_stream_with_client(self, client, *args, **kwargs):
|
||||||
return client[self.db.name].watch(*args, **kwargs)
|
return client[self.db.name].watch(*args, **kwargs)
|
||||||
@ -968,12 +965,9 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
|
|||||||
|
|
||||||
|
|
||||||
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
|
class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin):
|
||||||
@classmethod
|
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
# Use a new collection for each test.
|
# Use a new collection for each test.
|
||||||
self.watched_collection().drop()
|
self.watched_collection().drop()
|
||||||
self.watched_collection().insert_one({})
|
self.watched_collection().insert_one({})
|
||||||
@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest):
|
|||||||
RUN_ON_LOAD_BALANCER = True
|
RUN_ON_LOAD_BALANCER = True
|
||||||
listener: AllowListEventListener
|
listener: AllowListEventListener
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
cls.listener = AllowListEventListener("aggregate", "getMore")
|
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.listener = AllowListEventListener("aggregate", "getMore")
|
||||||
|
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
|
||||||
def setUpCluster(self, scenario_dict):
|
def setUpCluster(self, scenario_dict):
|
||||||
|
|||||||
@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest):
|
|||||||
|
|
||||||
client: MongoClient
|
client: MongoClient
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.client.close()
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, caplog):
|
def inject_fixtures(self, caplog):
|
||||||
@ -1039,14 +1034,21 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertFalse(client._topology._opened)
|
self.assertFalse(client._topology._opened)
|
||||||
|
|
||||||
# Ensure kill cursors thread has not been started.
|
# Ensure kill cursors thread has not been started.
|
||||||
kc_thread = client._kill_cursors_executor._thread
|
if _IS_SYNC:
|
||||||
self.assertFalse(kc_thread and kc_thread.is_alive())
|
kc_thread = client._kill_cursors_executor._thread
|
||||||
|
self.assertFalse(kc_thread and kc_thread.is_alive())
|
||||||
|
else:
|
||||||
|
kc_task = client._kill_cursors_executor._task
|
||||||
|
self.assertFalse(kc_task and not kc_task.done())
|
||||||
# Using the client should open topology and start the thread.
|
# Using the client should open topology and start the thread.
|
||||||
client.admin.command("ping")
|
client.admin.command("ping")
|
||||||
self.assertTrue(client._topology._opened)
|
self.assertTrue(client._topology._opened)
|
||||||
kc_thread = client._kill_cursors_executor._thread
|
if _IS_SYNC:
|
||||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
kc_thread = client._kill_cursors_executor._thread
|
||||||
|
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||||
|
else:
|
||||||
|
kc_task = client._kill_cursors_executor._task
|
||||||
|
self.assertTrue(kc_task and not kc_task.done())
|
||||||
|
|
||||||
def test_close_does_not_open_servers(self):
|
def test_close_does_not_open_servers(self):
|
||||||
client = self.rs_client(connect=False)
|
client = self.rs_client(connect=False)
|
||||||
@ -1241,6 +1243,7 @@ class TestClient(IntegrationTest):
|
|||||||
def test_server_selection_timeout(self):
|
def test_server_selection_timeout(self):
|
||||||
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||||
|
client.close()
|
||||||
|
|
||||||
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
|
client = MongoClient(serverSelectionTimeoutMS=0, connect=False)
|
||||||
|
|
||||||
@ -1251,16 +1254,20 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
|
ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False
|
||||||
)
|
)
|
||||||
|
client.close()
|
||||||
|
|
||||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
|
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False)
|
||||||
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
|
||||||
|
client.close()
|
||||||
|
|
||||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
|
||||||
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
self.assertAlmostEqual(0, client.options.server_selection_timeout)
|
||||||
|
client.close()
|
||||||
|
|
||||||
# Test invalid timeout in URI ignored and set to default.
|
# Test invalid timeout in URI ignored and set to default.
|
||||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
|
||||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||||
|
client.close()
|
||||||
|
|
||||||
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
|
||||||
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
self.assertAlmostEqual(30, client.options.server_selection_timeout)
|
||||||
|
|||||||
@ -97,26 +97,19 @@ class TestCollation(IntegrationTest):
|
|||||||
warn_context: Any
|
warn_context: Any
|
||||||
collation: Collation
|
collation: Collation
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
cls.collation = Collation("en_US")
|
self.collation = Collation("en_US")
|
||||||
cls.warn_context = warnings.catch_warnings()
|
self.warn_context = warnings.catch_warnings()
|
||||||
cls.warn_context.__enter__()
|
self.warn_context.__enter__()
|
||||||
warnings.simplefilter("ignore", DeprecationWarning)
|
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self) -> None:
|
||||||
def _tearDown_class(cls):
|
self.warn_context.__exit__()
|
||||||
cls.warn_context.__exit__()
|
self.warn_context = None
|
||||||
cls.warn_context = None
|
|
||||||
cls.client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
|
|||||||
@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest):
|
|||||||
db: Database
|
db: Database
|
||||||
client: MongoClient
|
client: MongoClient
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
cls.client = MongoClient(connect=False)
|
self.client = self.simple_client(connect=False)
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.client.close()
|
|
||||||
|
|
||||||
def test_collection(self):
|
def test_collection(self):
|
||||||
self.assertRaises(TypeError, Collection, self.db, 5)
|
self.assertRaises(TypeError, Collection, self.db, 5)
|
||||||
@ -164,27 +160,14 @@ class TestCollectionNoConnect(UnitTest):
|
|||||||
class TestCollection(IntegrationTest):
|
class TestCollection(IntegrationTest):
|
||||||
w: int
|
w: int
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.w = client_context.w # type: ignore
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
if _IS_SYNC:
|
|
||||||
cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine]
|
|
||||||
else:
|
|
||||||
asyncio.run(cls.async_tearDownClass())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def async_tearDownClass(cls):
|
|
||||||
cls.db.drop_collection("test_large_limit")
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.db.test.drop()
|
super().setUp()
|
||||||
|
self.w = client_context.w # type: ignore
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.db.test.drop()
|
self.db.test.drop()
|
||||||
|
self.db.drop_collection("test_large_limit")
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def write_concern_collection(self):
|
def write_concern_collection(self):
|
||||||
@ -1010,7 +993,10 @@ class TestCollection(IntegrationTest):
|
|||||||
db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
db.test.insert_one({"y": 1}, bypass_document_validation=True)
|
||||||
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True)
|
||||||
|
|
||||||
wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document")
|
def predicate():
|
||||||
|
return db_w0.test.find_one({"x": 1})
|
||||||
|
|
||||||
|
wait_until(predicate, "find w:0 replaced document")
|
||||||
|
|
||||||
def test_update_bypass_document_validation(self):
|
def test_update_bypass_document_validation(self):
|
||||||
db = self.db
|
db = self.db
|
||||||
|
|||||||
@ -19,7 +19,12 @@ import sys
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, client_context, unittest
|
from test import (
|
||||||
|
IntegrationTest,
|
||||||
|
client_context,
|
||||||
|
reset_client_context,
|
||||||
|
unittest,
|
||||||
|
)
|
||||||
from test.helpers import repl_set_step_down
|
from test.helpers import repl_set_step_down
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
CMAPListener,
|
CMAPListener,
|
||||||
@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
|
|||||||
listener: CMAPListener
|
listener: CMAPListener
|
||||||
coll: Collection
|
coll: Collection
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_replica_set
|
@client_context.require_replica_set
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
super()._setup_class()
|
self.listener = CMAPListener()
|
||||||
cls.listener = CMAPListener()
|
self.client = self.rs_or_single_client(
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(
|
event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500
|
||||||
event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure connections to all servers in replica set. This is to test
|
# Ensure connections to all servers in replica set. This is to test
|
||||||
# that the is_writable flag is properly updated for connections that
|
# that the is_writable flag is properly updated for connections that
|
||||||
# survive a replica set election.
|
# survive a replica set election.
|
||||||
ensure_all_connected(cls.client)
|
ensure_all_connected(self.client)
|
||||||
cls.listener.reset()
|
self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
||||||
|
self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
||||||
cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority"))
|
|
||||||
cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.client.close()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# Note that all ops use same write-concern as self.db (majority).
|
# Note that all ops use same write-concern as self.db (majority).
|
||||||
self.db.drop_collection("step-down")
|
self.db.drop_collection("step-down")
|
||||||
self.db.create_collection("step-down")
|
self.db.create_collection("step-down")
|
||||||
|
|||||||
@ -56,6 +56,9 @@ class TestCreateEntities(IntegrationTest):
|
|||||||
self.assertGreater(len(final_entity_map["events1"]), 0)
|
self.assertGreater(len(final_entity_map["events1"]), 0)
|
||||||
for event in final_entity_map["events1"]:
|
for event in final_entity_map["events1"]:
|
||||||
self.assertIn("PoolCreatedEvent", event["name"])
|
self.assertIn("PoolCreatedEvent", event["name"])
|
||||||
|
if self.scenario_runner.mongos_clients:
|
||||||
|
for client in self.scenario_runner.mongos_clients:
|
||||||
|
client.close()
|
||||||
|
|
||||||
def test_store_all_others_as_entities(self):
|
def test_store_all_others_as_entities(self):
|
||||||
self.scenario_runner = UnifiedSpecTestMixinV1()
|
self.scenario_runner = UnifiedSpecTestMixinV1()
|
||||||
@ -122,6 +125,9 @@ class TestCreateEntities(IntegrationTest):
|
|||||||
self.assertEqual(entity_map["failures"], [])
|
self.assertEqual(entity_map["failures"], [])
|
||||||
self.assertEqual(entity_map["successes"], 2)
|
self.assertEqual(entity_map["successes"], 2)
|
||||||
self.assertEqual(entity_map["iterations"], 5)
|
self.assertEqual(entity_map["iterations"], 5)
|
||||||
|
if self.scenario_runner.mongos_clients:
|
||||||
|
for client in self.scenario_runner.mongos_clients:
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1636,10 +1636,6 @@ class TestRawBatchCursor(IntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestRawBatchCommandCursor(IntegrationTest):
|
class TestRawBatchCommandCursor(IntegrationTest):
|
||||||
@classmethod
|
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
|
|
||||||
def test_aggregate_raw(self):
|
def test_aggregate_raw(self):
|
||||||
c = self.db.test
|
c = self.db.test
|
||||||
c.drop()
|
c.drop()
|
||||||
|
|||||||
@ -633,6 +633,7 @@ class TestTypeRegistry(unittest.TestCase):
|
|||||||
|
|
||||||
class TestCollectionWCustomType(IntegrationTest):
|
class TestCollectionWCustomType(IntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.db.test.drop()
|
self.db.test.drop()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -754,6 +755,7 @@ class TestCollectionWCustomType(IntegrationTest):
|
|||||||
|
|
||||||
class TestGridFileCustomType(IntegrationTest):
|
class TestGridFileCustomType(IntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.db.drop_collection("fs.files")
|
self.db.drop_collection("fs.files")
|
||||||
self.db.drop_collection("fs.chunks")
|
self.db.drop_collection("fs.chunks")
|
||||||
|
|
||||||
@ -917,11 +919,10 @@ class ChangeStreamsWCustomTypesTestMixin:
|
|||||||
|
|
||||||
|
|
||||||
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||||
@classmethod
|
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
cls.db.test.delete_many({})
|
self.db.test.delete_many({})
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.input_target.drop()
|
self.input_target.drop()
|
||||||
@ -935,12 +936,11 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus
|
|||||||
|
|
||||||
|
|
||||||
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||||
@classmethod
|
|
||||||
@client_context.require_version_min(4, 0, 0)
|
@client_context.require_version_min(4, 0, 0)
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
cls.db.test.delete_many({})
|
self.db.test.delete_many({})
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.input_target.drop()
|
self.input_target.drop()
|
||||||
@ -954,12 +954,11 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto
|
|||||||
|
|
||||||
|
|
||||||
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin):
|
||||||
@classmethod
|
|
||||||
@client_context.require_version_min(4, 0, 0)
|
@client_context.require_version_min(4, 0, 0)
|
||||||
@client_context.require_change_streams
|
@client_context.require_change_streams
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
cls.db.test.delete_many({})
|
self.db.test.delete_many({})
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.input_target.drop()
|
self.input_target.drop()
|
||||||
|
|||||||
@ -709,6 +709,7 @@ class TestDatabase(IntegrationTest):
|
|||||||
|
|
||||||
class TestDatabaseAggregation(IntegrationTest):
|
class TestDatabaseAggregation(IntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.pipeline: List[Mapping[str, Any]] = [
|
self.pipeline: List[Mapping[str, Any]] = [
|
||||||
{"$listLocalSessions": {}},
|
{"$listLocalSessions": {}},
|
||||||
{"$limit": 1},
|
{"$limit": 1},
|
||||||
|
|||||||
@ -211,11 +211,10 @@ class TestClientOptions(PyMongoTestCase):
|
|||||||
class EncryptionIntegrationTest(IntegrationTest):
|
class EncryptionIntegrationTest(IntegrationTest):
|
||||||
"""Base class for encryption integration tests."""
|
"""Base class for encryption integration tests."""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||||
@client_context.require_version_min(4, 2, -1)
|
@client_context.require_version_min(4, 2, -1)
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
def assertEncrypted(self, val):
|
def assertEncrypted(self, val):
|
||||||
self.assertIsInstance(val, Binary)
|
self.assertIsInstance(val, Binary)
|
||||||
@ -430,10 +429,9 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestClientMaxWireVersion(IntegrationTest):
|
class TestClientMaxWireVersion(IntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
@client_context.require_version_max(4, 0, 99)
|
@client_context.require_version_max(4, 0, 99)
|
||||||
def test_raise_max_wire_version_error(self):
|
def test_raise_max_wire_version_error(self):
|
||||||
@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
|||||||
"local": None,
|
"local": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||||
"No environment credentials are set",
|
"No environment credentials are set",
|
||||||
)
|
)
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||||
cls.client.db.coll.drop()
|
self.client.db.coll.drop()
|
||||||
cls.vault = create_key_vault(cls.client.keyvault.datakeys)
|
self.vault = create_key_vault(self.client.keyvault.datakeys)
|
||||||
|
|
||||||
# Configure the encrypted field via the local schema_map option.
|
# Configure the encrypted field via the local schema_map option.
|
||||||
schemas = {
|
schemas = {
|
||||||
@ -844,25 +841,22 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
opts = AutoEncryptionOpts(
|
opts = AutoEncryptionOpts(
|
||||||
cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS
|
self.KMS_PROVIDERS,
|
||||||
|
"keyvault.datakeys",
|
||||||
|
schema_map=schemas,
|
||||||
|
kms_tls_options=KMS_TLS_OPTS,
|
||||||
)
|
)
|
||||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
self.client_encrypted = self.rs_or_single_client(
|
||||||
auto_encryption_opts=opts, uuidRepresentation="standard"
|
auto_encryption_opts=opts, uuidRepresentation="standard"
|
||||||
)
|
)
|
||||||
cls.client_encryption = cls.unmanaged_create_client_encryption(
|
self.client_encryption = self.create_client_encryption(
|
||||||
cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.vault.drop()
|
|
||||||
cls.client.close()
|
|
||||||
cls.client_encrypted.close()
|
|
||||||
cls.client_encryption.close()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.vault.drop()
|
||||||
|
|
||||||
def run_test(self, provider_name):
|
def run_test(self, provider_name):
|
||||||
# Create data key.
|
# Create data key.
|
||||||
master_key: Any = self.MASTER_KEYS[provider_name]
|
master_key: Any = self.MASTER_KEYS[provider_name]
|
||||||
@ -1007,10 +1001,9 @@ class TestViews(EncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestCorpus(EncryptionIntegrationTest):
|
class TestCorpus(EncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def kms_providers():
|
def kms_providers():
|
||||||
@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
|||||||
client_encrypted: MongoClient
|
client_encrypted: MongoClient
|
||||||
listener: OvertCommandListener
|
listener: OvertCommandListener
|
||||||
|
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
|
||||||
db = client_context.client.db
|
db = client_context.client.db
|
||||||
cls.coll = db.coll
|
self.coll = db.coll
|
||||||
cls.coll.drop()
|
self.coll.drop()
|
||||||
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
# Configure the encrypted 'db.coll' collection via jsonSchema.
|
||||||
json_schema = json_data("limits", "limits-schema.json")
|
json_schema = json_data("limits", "limits-schema.json")
|
||||||
db.create_collection(
|
db.create_collection(
|
||||||
@ -1207,17 +1199,14 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
|||||||
coll.insert_one(json_data("limits", "limits-key.json"))
|
coll.insert_one(json_data("limits", "limits-key.json"))
|
||||||
|
|
||||||
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys")
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client_encrypted = cls.unmanaged_rs_or_single_client(
|
self.client_encrypted = self.rs_or_single_client(
|
||||||
auto_encryption_opts=opts, event_listeners=[cls.listener]
|
auto_encryption_opts=opts, event_listeners=[self.listener]
|
||||||
)
|
)
|
||||||
cls.coll_encrypted = cls.client_encrypted.db.coll
|
self.coll_encrypted = self.client_encrypted.db.coll
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self) -> None:
|
||||||
def _tearDown_class(cls):
|
self.coll_encrypted.drop()
|
||||||
cls.coll_encrypted.drop()
|
|
||||||
cls.client_encrypted.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def test_01_insert_succeeds_under_2MiB(self):
|
def test_01_insert_succeeds_under_2MiB(self):
|
||||||
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB}
|
||||||
@ -1241,7 +1230,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
|||||||
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB}
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
self.assertEqual(
|
||||||
|
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||||
|
)
|
||||||
|
|
||||||
def test_04_bulk_batch_split(self):
|
def test_04_bulk_batch_split(self):
|
||||||
limits_doc = json_data("limits", "limits-doc.json")
|
limits_doc = json_data("limits", "limits-doc.json")
|
||||||
@ -1251,7 +1242,9 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
|||||||
doc2.update(limits_doc)
|
doc2.update(limits_doc)
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)])
|
||||||
self.assertEqual(self.listener.started_command_names(), ["insert", "insert"])
|
self.assertEqual(
|
||||||
|
len([c for c in self.listener.started_command_names() if c == "insert"]), 2
|
||||||
|
)
|
||||||
|
|
||||||
def test_05_insert_succeeds_just_under_16MiB(self):
|
def test_05_insert_succeeds_just_under_16MiB(self):
|
||||||
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)}
|
||||||
@ -1281,15 +1274,12 @@ class TestBsonSizeBatches(EncryptionIntegrationTest):
|
|||||||
class TestCustomEndpoint(EncryptionIntegrationTest):
|
class TestCustomEndpoint(EncryptionIntegrationTest):
|
||||||
"""Prose tests for creating data keys with a custom endpoint."""
|
"""Prose tests for creating data keys with a custom endpoint."""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]),
|
||||||
"No environment credentials are set",
|
"No environment credentials are set",
|
||||||
)
|
)
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
kms_providers = {
|
kms_providers = {
|
||||||
"aws": AWS_CREDS,
|
"aws": AWS_CREDS,
|
||||||
"azure": AZURE_CREDS,
|
"azure": AZURE_CREDS,
|
||||||
@ -1318,10 +1308,6 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
|
|||||||
self._kmip_host_error = None
|
self._kmip_host_error = None
|
||||||
self._invalid_host_error = None
|
self._invalid_host_error = None
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.client_encryption.close()
|
|
||||||
self.client_encryption_invalid.close()
|
|
||||||
|
|
||||||
def run_test_expected_success(self, provider_name, master_key):
|
def run_test_expected_success(self, provider_name, master_key):
|
||||||
data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key)
|
data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key)
|
||||||
encrypted = self.client_encryption.encrypt(
|
encrypted = self.client_encryption.encrypt(
|
||||||
@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
|||||||
KEYVAULT_COLL = "datakeys"
|
KEYVAULT_COLL = "datakeys"
|
||||||
client: MongoClient
|
client: MongoClient
|
||||||
|
|
||||||
def setUp(self):
|
def _setup(self):
|
||||||
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL)
|
||||||
create_key_vault(keyvault, self.DEK)
|
create_key_vault(keyvault, self.DEK)
|
||||||
|
|
||||||
def _test_explicit(self, expectation):
|
def _test_explicit(self, expectation):
|
||||||
|
self._setup()
|
||||||
client_encryption = self.create_client_encryption(
|
client_encryption = self.create_client_encryption(
|
||||||
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
self.KMS_PROVIDER_MAP, # type: ignore[arg-type]
|
||||||
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]),
|
||||||
client_context.client,
|
client_context.client,
|
||||||
OPTS,
|
OPTS,
|
||||||
)
|
)
|
||||||
self.addCleanup(client_encryption.close)
|
|
||||||
|
|
||||||
ciphertext = client_encryption.encrypt(
|
ciphertext = client_encryption.encrypt(
|
||||||
"string0",
|
"string0",
|
||||||
@ -1517,6 +1503,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
|||||||
self.assertEqual(client_encryption.decrypt(ciphertext), "string0")
|
self.assertEqual(client_encryption.decrypt(ciphertext), "string0")
|
||||||
|
|
||||||
def _test_automatic(self, expectation_extjson, payload):
|
def _test_automatic(self, expectation_extjson, payload):
|
||||||
|
self._setup()
|
||||||
encrypted_db = "db"
|
encrypted_db = "db"
|
||||||
encrypted_coll = "coll"
|
encrypted_coll = "coll"
|
||||||
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL])
|
||||||
@ -1531,7 +1518,6 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
|||||||
client = self.rs_or_single_client(
|
client = self.rs_or_single_client(
|
||||||
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]
|
||||||
)
|
)
|
||||||
self.addCleanup(client.close)
|
|
||||||
|
|
||||||
coll = client.get_database(encrypted_db).get_collection(
|
coll = client.get_database(encrypted_db).get_collection(
|
||||||
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority")
|
||||||
@ -1553,13 +1539,12 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
@unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set")
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS}
|
||||||
cls.DEK = json_data(BASE, "custom", "azure-dek.json")
|
self.DEK = json_data(BASE, "custom", "azure-dek.json")
|
||||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
def test_explicit(self):
|
def test_explicit(self):
|
||||||
return self._test_explicit(
|
return self._test_explicit(
|
||||||
@ -1579,13 +1564,12 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest
|
|||||||
|
|
||||||
|
|
||||||
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
||||||
@classmethod
|
|
||||||
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
@unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set")
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS}
|
||||||
cls.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
self.DEK = json_data(BASE, "custom", "gcp-dek.json")
|
||||||
cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json")
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
|
|
||||||
def test_explicit(self):
|
def test_explicit(self):
|
||||||
return self._test_explicit(
|
return self._test_explicit(
|
||||||
@ -1607,6 +1591,7 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest):
|
|||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests
|
||||||
class TestDeadlockProse(EncryptionIntegrationTest):
|
class TestDeadlockProse(EncryptionIntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.client_test = self.rs_or_single_client(
|
self.client_test = self.rs_or_single_client(
|
||||||
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard"
|
||||||
)
|
)
|
||||||
@ -1637,7 +1622,6 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
|||||||
self.ciphertext = client_encryption.encrypt(
|
self.ciphertext = client_encryption.encrypt(
|
||||||
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
"string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local"
|
||||||
)
|
)
|
||||||
client_encryption.close()
|
|
||||||
|
|
||||||
self.client_listener = OvertCommandListener()
|
self.client_listener = OvertCommandListener()
|
||||||
self.topology_listener = TopologyEventListener()
|
self.topology_listener = TopologyEventListener()
|
||||||
@ -1832,6 +1816,7 @@ class TestDeadlockProse(EncryptionIntegrationTest):
|
|||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events
|
||||||
class TestDecryptProse(EncryptionIntegrationTest):
|
class TestDecryptProse(EncryptionIntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.client = client_context.client
|
self.client = client_context.client
|
||||||
self.client.db.drop_collection("decryption_events")
|
self.client.db.drop_collection("decryption_events")
|
||||||
create_key_vault(self.client.keyvault.datakeys)
|
create_key_vault(self.client.keyvault.datakeys)
|
||||||
@ -2267,6 +2252,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
|||||||
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
# https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames
|
||||||
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
|
class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.client = client_context.client
|
self.client = client_context.client
|
||||||
create_key_vault(self.client.keyvault.datakeys)
|
create_key_vault(self.client.keyvault.datakeys)
|
||||||
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}}
|
||||||
@ -2608,8 +2594,6 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest):
|
|||||||
assert isinstance(res["encrypted_indexed"], Binary)
|
assert isinstance(res["encrypted_indexed"], Binary)
|
||||||
assert isinstance(res["encrypted_unindexed"], Binary)
|
assert isinstance(res["encrypted_unindexed"], Binary)
|
||||||
|
|
||||||
client_encryption.close()
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption
|
||||||
class TestRangeQueryProse(EncryptionIntegrationTest):
|
class TestRangeQueryProse(EncryptionIntegrationTest):
|
||||||
@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest):
|
|||||||
mongocryptd_client: MongoClient
|
mongocryptd_client: MongoClient
|
||||||
MONGOCRYPTD_PORT = 27020
|
MONGOCRYPTD_PORT = 27020
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
@unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed")
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
start_mongocryptd(cls.MONGOCRYPTD_PORT)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
start_mongocryptd(self.MONGOCRYPTD_PORT)
|
||||||
|
|
||||||
self.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
self.mongocryptd_client = self.simple_client(
|
self.mongocryptd_client = self.simple_client(
|
||||||
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener]
|
||||||
|
|||||||
@ -33,19 +33,14 @@ from pymongo.write_concern import WriteConcern
|
|||||||
|
|
||||||
|
|
||||||
class TestSampleShellCommands(IntegrationTest):
|
class TestSampleShellCommands(IntegrationTest):
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def setUpClass(cls):
|
super().setUp()
|
||||||
super().setUpClass()
|
self.db.inventory.drop()
|
||||||
# Run once before any tests run.
|
|
||||||
cls.db.inventory.drop()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
cls.client.drop_database("pymongo_test")
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# Run after every test.
|
# Run after every test.
|
||||||
self.db.inventory.drop()
|
self.db.inventory.drop()
|
||||||
|
self.client.drop_database("pymongo_test")
|
||||||
|
|
||||||
def test_first_three_examples(self):
|
def test_first_three_examples(self):
|
||||||
db = self.db
|
db = self.db
|
||||||
|
|||||||
@ -97,6 +97,7 @@ class TestGridFileNoConnect(UnitTest):
|
|||||||
|
|
||||||
class TestGridFile(IntegrationTest):
|
class TestGridFile(IntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|||||||
@ -75,9 +75,9 @@ class JustRead(threading.Thread):
|
|||||||
class TestGridfsNoConnect(unittest.TestCase):
|
class TestGridfsNoConnect(unittest.TestCase):
|
||||||
db: Database
|
db: Database
|
||||||
|
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def setUpClass(cls):
|
super().setUp()
|
||||||
cls.db = MongoClient(connect=False).pymongo_test
|
self.db = MongoClient(connect=False).pymongo_test
|
||||||
|
|
||||||
def test_gridfs(self):
|
def test_gridfs(self):
|
||||||
self.assertRaises(TypeError, gridfs.GridFS, "foo")
|
self.assertRaises(TypeError, gridfs.GridFS, "foo")
|
||||||
@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest):
|
|||||||
fs: gridfs.GridFS
|
fs: gridfs.GridFS
|
||||||
alt: gridfs.GridFS
|
alt: gridfs.GridFS
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.fs = gridfs.GridFS(cls.db)
|
|
||||||
cls.alt = gridfs.GridFS(cls.db, "alt")
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.fs = gridfs.GridFS(self.db)
|
||||||
|
self.alt = gridfs.GridFS(self.db, "alt")
|
||||||
self.cleanup_colls(
|
self.cleanup_colls(
|
||||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||||
)
|
)
|
||||||
@ -509,10 +506,9 @@ class TestGridfs(IntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestGridfsReplicaSet(IntegrationTest):
|
class TestGridfsReplicaSet(IntegrationTest):
|
||||||
@classmethod
|
|
||||||
@client_context.require_secondaries_count(1)
|
@client_context.require_secondaries_count(1)
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|||||||
@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest):
|
|||||||
fs: gridfs.GridFSBucket
|
fs: gridfs.GridFSBucket
|
||||||
alt: gridfs.GridFSBucket
|
alt: gridfs.GridFSBucket
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.fs = gridfs.GridFSBucket(cls.db)
|
|
||||||
cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt")
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.fs = gridfs.GridFSBucket(self.db)
|
||||||
|
self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt")
|
||||||
self.cleanup_colls(
|
self.cleanup_colls(
|
||||||
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
|
||||||
)
|
)
|
||||||
@ -479,10 +476,9 @@ class TestGridfs(IntegrationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TestGridfsBucketReplicaSet(IntegrationTest):
|
class TestGridfsBucketReplicaSet(IntegrationTest):
|
||||||
@classmethod
|
|
||||||
@client_context.require_secondaries_count(1)
|
@client_context.require_secondaries_count(1)
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from test.utils import (
|
|||||||
wait_until,
|
wait_until,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pymongo.synchronous.periodic_executor import _EXECUTORS
|
from pymongo.periodic_executor import _EXECUTORS
|
||||||
|
|
||||||
|
|
||||||
def unregistered(ref):
|
def unregistered(ref):
|
||||||
|
|||||||
@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest):
|
|||||||
listener: EventListener
|
listener: EventListener
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_connection
|
def setUpClass(cls) -> None:
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
cls.listener = OvertCommandListener()
|
cls.listener = OvertCommandListener()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(
|
|
||||||
event_listeners=[cls.listener], retryWrites=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@client_context.require_connection
|
||||||
def _tearDown_class(cls):
|
def setUp(self) -> None:
|
||||||
cls.client.close()
|
super().setUp()
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
super().tearDown()
|
self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False)
|
||||||
|
|
||||||
def test_started_simple(self):
|
def test_started_simple(self):
|
||||||
self.client.pymongo_test.command("ping")
|
self.client.pymongo_test.command("ping")
|
||||||
@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest):
|
|||||||
saved_listeners: Any
|
saved_listeners: Any
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_connection
|
def setUpClass(cls) -> None:
|
||||||
def _setup_class(cls):
|
|
||||||
super()._setup_class()
|
|
||||||
cls.listener = OvertCommandListener()
|
cls.listener = OvertCommandListener()
|
||||||
# We plan to call register(), which internally modifies _LISTENERS.
|
# We plan to call register(), which internally modifies _LISTENERS.
|
||||||
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS)
|
||||||
monitoring.register(cls.listener)
|
monitoring.register(cls.listener)
|
||||||
cls.client = cls.unmanaged_single_client()
|
|
||||||
# Get one (authenticated) socket in the pool.
|
|
||||||
cls.client.pymongo_test.command("ping")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
monitoring._LISTENERS = cls.saved_listeners
|
|
||||||
cls.client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
|
@client_context.require_connection
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.listener.reset()
|
self.listener.reset()
|
||||||
|
self.client = self.single_client()
|
||||||
|
# Get one (authenticated) socket in the pool.
|
||||||
|
self.client.pymongo_test.command("ping")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
monitoring._LISTENERS = cls.saved_listeners
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
self.client.pymongo_test.command("ping")
|
self.client.pymongo_test.command("ping")
|
||||||
|
|||||||
@ -31,24 +31,16 @@ from pymongo.read_concern import ReadConcern
|
|||||||
class TestReadConcern(IntegrationTest):
|
class TestReadConcern(IntegrationTest):
|
||||||
listener: OvertCommandListener
|
listener: OvertCommandListener
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def setUpClass(cls):
|
def setUp(self):
|
||||||
super().setUpClass()
|
super().setUp()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
client_context.client.pymongo_test.create_collection("coll")
|
client_context.client.pymongo_test.create_collection("coll")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
cls.client.close()
|
|
||||||
client_context.client.pymongo_test.drop_collection("coll")
|
|
||||||
super().tearDownClass()
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.listener.reset()
|
client_context.client.pymongo_test.drop_collection("coll")
|
||||||
super().tearDown()
|
|
||||||
|
|
||||||
def test_read_concern(self):
|
def test_read_concern(self):
|
||||||
rc = ReadConcern()
|
rc = ReadConcern()
|
||||||
|
|||||||
@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest):
|
|||||||
RUN_ON_SERVERLESS = True
|
RUN_ON_SERVERLESS = True
|
||||||
deprecation_filter: DeprecationFilter
|
deprecation_filter: DeprecationFilter
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
self.deprecation_filter = DeprecationFilter()
|
||||||
cls.deprecation_filter = DeprecationFilter()
|
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self) -> None:
|
||||||
def _tearDown_class(cls):
|
self.deprecation_filter.stop()
|
||||||
cls.deprecation_filter.stop()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
|
|
||||||
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest):
|
||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True)
|
self.client = self.rs_or_single_client(retryWrites=True)
|
||||||
cls.db = cls.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
@classmethod
|
def tearDown(self) -> None:
|
||||||
def _tearDown_class(cls):
|
self.knobs.disable()
|
||||||
cls.knobs.disable()
|
|
||||||
cls.client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
@client_context.require_no_standalone
|
@client_context.require_no_standalone
|
||||||
def test_actionable_error_message(self):
|
def test_actionable_error_message(self):
|
||||||
@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
|||||||
listener: OvertCommandListener
|
listener: OvertCommandListener
|
||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_no_mmap
|
@client_context.require_no_mmap
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
cls.listener = OvertCommandListener()
|
self.listener = OvertCommandListener()
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(
|
self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener])
|
||||||
retryWrites=True, event_listeners=[cls.listener]
|
self.db = self.client.pymongo_test
|
||||||
)
|
|
||||||
cls.db = cls.client.pymongo_test
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.knobs.disable()
|
|
||||||
cls.client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
if client_context.is_rs and client_context.test_commands_enabled:
|
if client_context.is_rs and client_context.test_commands_enabled:
|
||||||
self.client.admin.command(
|
self.client.admin.command(
|
||||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")])
|
||||||
@ -210,6 +193,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
|||||||
self.client.admin.command(
|
self.client.admin.command(
|
||||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||||
)
|
)
|
||||||
|
self.knobs.disable()
|
||||||
|
|
||||||
def test_supported_single_statement_no_retry(self):
|
def test_supported_single_statement_no_retry(self):
|
||||||
listener = OvertCommandListener()
|
listener = OvertCommandListener()
|
||||||
@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest):
|
|||||||
RUN_ON_SERVERLESS = True
|
RUN_ON_SERVERLESS = True
|
||||||
fail_insert: dict
|
fail_insert: dict
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_replica_set
|
@client_context.require_replica_set
|
||||||
@client_context.require_no_mmap
|
@client_context.require_no_mmap
|
||||||
@client_context.require_failCommand_fail_point
|
@client_context.require_failCommand_fail_point
|
||||||
def _setup_class(cls):
|
def setUp(self) -> None:
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
cls.fail_insert = {
|
self.fail_insert = {
|
||||||
"configureFailPoint": "failCommand",
|
"configureFailPoint": "failCommand",
|
||||||
"mode": {"times": 2},
|
"mode": {"times": 2},
|
||||||
"data": {
|
"data": {
|
||||||
|
|||||||
@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_failCommand_fail_point
|
@client_context.require_failCommand_fail_point
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super().setUpClass()
|
super().setUp(cls)
|
||||||
# Speed up the tests by decreasing the event publish frequency.
|
# Speed up the tests by decreasing the event publish frequency.
|
||||||
cls.knobs = client_knobs(
|
cls.knobs = client_knobs(
|
||||||
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1
|
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1
|
||||||
|
|||||||
@ -82,36 +82,27 @@ class TestSession(IntegrationTest):
|
|||||||
client2: MongoClient
|
client2: MongoClient
|
||||||
sensitive_commands: Set[str]
|
sensitive_commands: Set[str]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@client_context.require_sessions
|
@client_context.require_sessions
|
||||||
def _setup_class(cls):
|
def setUp(self):
|
||||||
super()._setup_class()
|
super().setUp()
|
||||||
# Create a second client so we can make sure clients cannot share
|
# Create a second client so we can make sure clients cannot share
|
||||||
# sessions.
|
# sessions.
|
||||||
cls.client2 = cls.unmanaged_rs_or_single_client()
|
self.client2 = self.rs_or_single_client()
|
||||||
|
|
||||||
# Redact no commands, so we can test user-admin commands have "lsid".
|
# Redact no commands, so we can test user-admin commands have "lsid".
|
||||||
cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy()
|
||||||
monitoring._SENSITIVE_COMMANDS.clear()
|
monitoring._SENSITIVE_COMMANDS.clear()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
|
||||||
cls.client2.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.listener = SessionTestListener()
|
self.listener = SessionTestListener()
|
||||||
self.session_checker_listener = SessionTestListener()
|
self.session_checker_listener = SessionTestListener()
|
||||||
self.client = self.rs_or_single_client(
|
self.client = self.rs_or_single_client(
|
||||||
event_listeners=[self.listener, self.session_checker_listener]
|
event_listeners=[self.listener, self.session_checker_listener]
|
||||||
)
|
)
|
||||||
self.addCleanup(self.client.close)
|
|
||||||
self.db = self.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""All sessions used in the test must be returned to the pool."""
|
monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands)
|
||||||
self.client.drop_database("pymongo_test")
|
self.client.drop_database("pymongo_test")
|
||||||
used_lsids = self.initial_lsids.copy()
|
used_lsids = self.initial_lsids.copy()
|
||||||
for event in self.session_checker_listener.started_events:
|
for event in self.session_checker_listener.started_events:
|
||||||
@ -121,6 +112,8 @@ class TestSession(IntegrationTest):
|
|||||||
current_lsids = {s["id"] for s in session_ids(self.client)}
|
current_lsids = {s["id"] for s in session_ids(self.client)}
|
||||||
self.assertLessEqual(used_lsids, current_lsids)
|
self.assertLessEqual(used_lsids, current_lsids)
|
||||||
|
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def _test_ops(self, client, *ops):
|
def _test_ops(self, client, *ops):
|
||||||
listener = client.options.event_listeners[0]
|
listener = client.options.event_listeners[0]
|
||||||
|
|
||||||
@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest):
|
|||||||
listener: SessionTestListener
|
listener: SessionTestListener
|
||||||
client: MongoClient
|
client: MongoClient
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _setup_class(cls):
|
|
||||||
cls.listener = SessionTestListener()
|
|
||||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.client.close()
|
|
||||||
|
|
||||||
@client_context.require_sessions
|
@client_context.require_sessions
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
self.listener = SessionTestListener()
|
||||||
|
self.client = self.rs_or_single_client(event_listeners=[self.listener])
|
||||||
|
|
||||||
@client_context.require_no_standalone
|
@client_context.require_no_standalone
|
||||||
def test_core(self):
|
def test_core(self):
|
||||||
|
|||||||
@ -105,6 +105,7 @@ class Update(threading.Thread):
|
|||||||
|
|
||||||
class TestThreads(IntegrationTest):
|
class TestThreads(IntegrationTest):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
self.db = self.client.pymongo_test
|
self.db = self.client.pymongo_test
|
||||||
|
|
||||||
def test_threading(self):
|
def test_threading(self):
|
||||||
|
|||||||
@ -395,19 +395,12 @@ class PatchSessionTimeout:
|
|||||||
|
|
||||||
|
|
||||||
class TestTransactionsConvenientAPI(TransactionsBase):
|
class TestTransactionsConvenientAPI(TransactionsBase):
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
self.mongos_clients = []
|
||||||
cls.mongos_clients = []
|
|
||||||
if client_context.supports_transactions():
|
if client_context.supports_transactions():
|
||||||
for address in client_context.mongoses:
|
for address in client_context.mongoses:
|
||||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
for client in cls.mongos_clients:
|
|
||||||
client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def _set_fail_point(self, client, command_args):
|
def _set_fail_point(self, client, command_args):
|
||||||
cmd = {"configureFailPoint": "failCommand"}
|
cmd = {"configureFailPoint": "failCommand"}
|
||||||
|
|||||||
@ -114,10 +114,9 @@ class TestMypyFails(unittest.TestCase):
|
|||||||
class TestPymongo(IntegrationTest):
|
class TestPymongo(IntegrationTest):
|
||||||
coll: Collection
|
coll: Collection
|
||||||
|
|
||||||
@classmethod
|
def setUp(self):
|
||||||
def setUpClass(cls):
|
super().setUp()
|
||||||
super().setUpClass()
|
self.coll = self.client.test.test
|
||||||
cls.coll = cls.client.test.test
|
|
||||||
|
|
||||||
def test_insert_find(self) -> None:
|
def test_insert_find(self) -> None:
|
||||||
doc = {"my": "doc"}
|
doc = {"my": "doc"}
|
||||||
|
|||||||
@ -304,7 +304,6 @@ class EntityMapUtil:
|
|||||||
kwargs["h"] = uri
|
kwargs["h"] = uri
|
||||||
client = self.test.rs_or_single_client(**kwargs)
|
client = self.test.rs_or_single_client(**kwargs)
|
||||||
self[spec["id"]] = client
|
self[spec["id"]] = client
|
||||||
self.test.addCleanup(client.close)
|
|
||||||
return
|
return
|
||||||
elif entity_type == "database":
|
elif entity_type == "database":
|
||||||
client = self[spec["client"]]
|
client = self[spec["client"]]
|
||||||
@ -479,31 +478,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
|||||||
db.create_collection(coll_name, write_concern=wc, **opts)
|
db.create_collection(coll_name, write_concern=wc, **opts)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _setup_class(cls):
|
def setUpClass(cls) -> None:
|
||||||
# super call creates internal client cls.client
|
|
||||||
super()._setup_class()
|
|
||||||
# process file-level runOnRequirements
|
|
||||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
|
||||||
if not cls.should_run_on(run_on_spec):
|
|
||||||
raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied")
|
|
||||||
|
|
||||||
# add any special-casing for skipping tests here
|
|
||||||
if client_context.storage_engine == "mmapv1":
|
|
||||||
if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str(
|
|
||||||
cls.TEST_PATH
|
|
||||||
):
|
|
||||||
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
|
||||||
|
|
||||||
# Handle mongos_clients for transactions tests.
|
|
||||||
cls.mongos_clients = []
|
|
||||||
if (
|
|
||||||
client_context.supports_transactions()
|
|
||||||
and not client_context.load_balancer
|
|
||||||
and not client_context.serverless
|
|
||||||
):
|
|
||||||
for address in client_context.mongoses:
|
|
||||||
cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address)))
|
|
||||||
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(
|
cls.knobs = client_knobs(
|
||||||
heartbeat_frequency=0.1,
|
heartbeat_frequency=0.1,
|
||||||
@ -514,17 +489,36 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
|||||||
cls.knobs.enable()
|
cls.knobs.enable()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _tearDown_class(cls):
|
def tearDownClass(cls) -> None:
|
||||||
cls.knobs.disable()
|
cls.knobs.disable()
|
||||||
for client in cls.mongos_clients:
|
|
||||||
client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
# super call creates internal client cls.client
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
# process file-level runOnRequirements
|
||||||
|
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
|
||||||
|
if not self.should_run_on(run_on_spec):
|
||||||
|
raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied")
|
||||||
|
|
||||||
|
# add any special-casing for skipping tests here
|
||||||
|
if client_context.storage_engine == "mmapv1":
|
||||||
|
if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str(
|
||||||
|
self.TEST_PATH
|
||||||
|
):
|
||||||
|
raise unittest.SkipTest("MMAPv1 does not support retryWrites=True")
|
||||||
|
|
||||||
|
# Handle mongos_clients for transactions tests.
|
||||||
|
self.mongos_clients = []
|
||||||
|
if (
|
||||||
|
client_context.supports_transactions()
|
||||||
|
and not client_context.load_balancer
|
||||||
|
and not client_context.serverless
|
||||||
|
):
|
||||||
|
for address in client_context.mongoses:
|
||||||
|
self.mongos_clients.append(self.single_client("{}:{}".format(*address)))
|
||||||
|
|
||||||
# process schemaVersion
|
# process schemaVersion
|
||||||
# note: we check major schema version during class generation
|
# note: we check major schema version during class generation
|
||||||
# note: we do this here because we cannot run assertions in setUpClass
|
|
||||||
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
version = Version.from_string(self.TEST_SPEC["schemaVersion"])
|
||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
version,
|
version,
|
||||||
@ -1026,7 +1020,6 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
client = self.single_client("{}:{}".format(*session._pinned_address))
|
client = self.single_client("{}:{}".format(*session._pinned_address))
|
||||||
self.addCleanup(client.close)
|
|
||||||
self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
self.__set_fail_point(client=client, command_args=spec["failPoint"])
|
||||||
|
|
||||||
def _testOperation_createEntities(self, spec):
|
def _testOperation_createEntities(self, spec):
|
||||||
|
|||||||
@ -99,6 +99,12 @@ class BaseListener:
|
|||||||
"""Wait for a number of events to be published, or fail."""
|
"""Wait for a number of events to be published, or fail."""
|
||||||
wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)")
|
wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)")
|
||||||
|
|
||||||
|
async def async_wait_for_event(self, event, count):
|
||||||
|
"""Wait for a number of events to be published, or fail."""
|
||||||
|
await async_wait_until(
|
||||||
|
lambda: self.event_count(event) >= count, f"find {count} {event} event(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
|
class CMAPListener(BaseListener, monitoring.ConnectionPoolListener):
|
||||||
def connection_created(self, event):
|
def connection_created(self, event):
|
||||||
@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
interval = min(float(timeout) / 100, 0.1)
|
interval = min(float(timeout) / 100, 0.1)
|
||||||
while True:
|
while True:
|
||||||
retval = await predicate()
|
if iscoroutinefunction(predicate):
|
||||||
|
retval = await predicate()
|
||||||
|
else:
|
||||||
|
retval = predicate()
|
||||||
if retval:
|
if retval:
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|||||||
@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest):
|
|||||||
knobs: client_knobs
|
knobs: client_knobs
|
||||||
listener: EventListener
|
listener: EventListener
|
||||||
|
|
||||||
@classmethod
|
def setUp(self) -> None:
|
||||||
def _setup_class(cls):
|
super().setUp()
|
||||||
super()._setup_class()
|
self.mongos_clients = []
|
||||||
cls.mongos_clients = []
|
|
||||||
|
|
||||||
# Speed up the tests by decreasing the heartbeat frequency.
|
# Speed up the tests by decreasing the heartbeat frequency.
|
||||||
cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1)
|
||||||
cls.knobs.enable()
|
self.knobs.enable()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _tearDown_class(cls):
|
|
||||||
cls.knobs.disable()
|
|
||||||
for client in cls.mongos_clients:
|
|
||||||
client.close()
|
|
||||||
super()._tearDown_class()
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
super().setUp()
|
|
||||||
self.targets = {}
|
self.targets = {}
|
||||||
self.listener = None # type: ignore
|
self.listener = None # type: ignore
|
||||||
self.pool_listener = None
|
self.pool_listener = None
|
||||||
self.server_listener = None
|
self.server_listener = None
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.knobs.disable()
|
||||||
|
|
||||||
def _set_fail_point(self, client, command_args):
|
def _set_fail_point(self, client, command_args):
|
||||||
cmd = SON([("configureFailPoint", "failCommand")])
|
cmd = SON([("configureFailPoint", "failCommand")])
|
||||||
cmd.update(command_args)
|
cmd.update(command_args)
|
||||||
@ -697,8 +689,6 @@ class SpecRunner(IntegrationTest):
|
|||||||
self.listener = listener
|
self.listener = listener
|
||||||
self.pool_listener = pool_listener
|
self.pool_listener = pool_listener
|
||||||
self.server_listener = server_listener
|
self.server_listener = server_listener
|
||||||
# Close the client explicitly to avoid having too many threads open.
|
|
||||||
self.addCleanup(client.close)
|
|
||||||
|
|
||||||
# Create session0 and session1.
|
# Create session0 and session1.
|
||||||
sessions = {}
|
sessions = {}
|
||||||
|
|||||||
@ -110,6 +110,13 @@ replacements = {
|
|||||||
"async_set_fail_point": "set_fail_point",
|
"async_set_fail_point": "set_fail_point",
|
||||||
"async_ensure_all_connected": "ensure_all_connected",
|
"async_ensure_all_connected": "ensure_all_connected",
|
||||||
"async_repl_set_step_down": "repl_set_step_down",
|
"async_repl_set_step_down": "repl_set_step_down",
|
||||||
|
"AsyncPeriodicExecutor": "PeriodicExecutor",
|
||||||
|
"async_wait_for_event": "wait_for_event",
|
||||||
|
"pymongo_server_monitor_task": "pymongo_server_monitor_thread",
|
||||||
|
"pymongo_server_rtt_task": "pymongo_server_rtt_thread",
|
||||||
|
"_async_create_lock": "_create_lock",
|
||||||
|
"_async_create_condition": "_create_condition",
|
||||||
|
"_async_cond_wait": "_cond_wait",
|
||||||
}
|
}
|
||||||
|
|
||||||
docstring_replacements: dict[tuple[str, str], str] = {
|
docstring_replacements: dict[tuple[str, str], str] = {
|
||||||
@ -130,8 +137,6 @@ docstring_removals: set[str] = {
|
|||||||
".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release."
|
".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release."
|
||||||
}
|
}
|
||||||
|
|
||||||
type_replacements = {"_Condition": "threading.Condition"}
|
|
||||||
|
|
||||||
import_replacements = {"test.synchronous": "test"}
|
import_replacements = {"test.synchronous": "test"}
|
||||||
|
|
||||||
_pymongo_base = "./pymongo/asynchronous/"
|
_pymongo_base = "./pymongo/asynchronous/"
|
||||||
@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None:
|
|||||||
lines = translate_async_sleeps(lines)
|
lines = translate_async_sleeps(lines)
|
||||||
if file in docstring_translate_files:
|
if file in docstring_translate_files:
|
||||||
lines = translate_docstrings(lines)
|
lines = translate_docstrings(lines)
|
||||||
translate_locks(lines)
|
|
||||||
translate_types(lines)
|
|
||||||
if file in sync_test_files:
|
if file in sync_test_files:
|
||||||
translate_imports(lines)
|
translate_imports(lines)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]:
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
def translate_locks(lines: list[str]) -> list[str]:
|
|
||||||
lock_lines = [line for line in lines if "_Lock(" in line]
|
|
||||||
cond_lines = [line for line in lines if "_Condition(" in line]
|
|
||||||
for line in lock_lines:
|
|
||||||
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
|
|
||||||
if res:
|
|
||||||
old = res[0]
|
|
||||||
index = lines.index(line)
|
|
||||||
lines[index] = line.replace(old, res[1])
|
|
||||||
for line in cond_lines:
|
|
||||||
res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line)
|
|
||||||
if res:
|
|
||||||
old = res[0]
|
|
||||||
index = lines.index(line)
|
|
||||||
lines[index] = line.replace(old, res[1])
|
|
||||||
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def translate_types(lines: list[str]) -> list[str]:
|
|
||||||
for k, v in type_replacements.items():
|
|
||||||
matches = [line for line in lines if k in line and "import" not in line]
|
|
||||||
for line in matches:
|
|
||||||
index = lines.index(line)
|
|
||||||
lines[index] = line.replace(k, v)
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def translate_imports(lines: list[str]) -> list[str]:
|
def translate_imports(lines: list[str]) -> list[str]:
|
||||||
for k, v in import_replacements.items():
|
for k, v in import_replacements.items():
|
||||||
matches = [line for line in lines if k in line and "import" in line]
|
matches = [line for line in lines if k in line and "import" in line]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user