Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
1418c90be2
@ -1324,7 +1324,7 @@ def decode_iter(
|
||||
elements = data[position : position + obj_size]
|
||||
position += obj_size
|
||||
|
||||
yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var]
|
||||
yield _bson_to_dict(elements, opts) # type:ignore[misc]
|
||||
|
||||
|
||||
@overload
|
||||
@ -1370,7 +1370,7 @@ def decode_file_iter(
|
||||
raise InvalidBSON("cut off in middle of objsize")
|
||||
obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4
|
||||
elements = size_data + file_obj.read(max(0, obj_size))
|
||||
yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc]
|
||||
yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc]
|
||||
|
||||
|
||||
def is_valid(bson: bytes) -> bool:
|
||||
|
||||
@ -223,7 +223,7 @@ class Decimal128:
|
||||
"from list or tuple. Must have exactly 2 "
|
||||
"elements."
|
||||
)
|
||||
self.__high, self.__low = value # type: ignore
|
||||
self.__high, self.__low = value
|
||||
else:
|
||||
raise TypeError(f"Cannot convert {value!r} to Decimal128")
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ class JSONOptions(_BASE_CLASS):
|
||||
"JSONOptions.datetime_representation must be one of LEGACY, "
|
||||
"NUMBERLONG, or ISO8601 from DatetimeRepresentation."
|
||||
)
|
||||
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type]
|
||||
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs))
|
||||
if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL):
|
||||
raise ValueError(
|
||||
"JSONOptions.json_mode must be one of LEGACY, RELAXED, "
|
||||
|
||||
@ -68,7 +68,7 @@ class SON(Dict[_Key, _Value]):
|
||||
self.update(kwargs)
|
||||
|
||||
def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]:
|
||||
instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var]
|
||||
instance = super().__new__(cls, *args, **kwargs)
|
||||
instance.__keys = []
|
||||
return instance
|
||||
|
||||
|
||||
@ -13,8 +13,9 @@ features = ["docs","test"]
|
||||
test = "sphinx-build -E -b doctest doc ./doc/_build/doctest"
|
||||
|
||||
[envs.typing]
|
||||
features = ["encryption", "ocsp", "zstd", "aws"]
|
||||
dependencies = ["mypy==1.2.0","pyright==1.1.290", "certifi", "typing_extensions"]
|
||||
pre-install-commands = [
|
||||
"pip install -q -r requirements/typing.txt",
|
||||
]
|
||||
[envs.typing.scripts]
|
||||
check-mypy = [
|
||||
"mypy --install-types --non-interactive bson gridfs tools pymongo",
|
||||
|
||||
@ -88,7 +88,7 @@ TEXT = "text"
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo._version import __version__, get_version_string, version_tuple
|
||||
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
|
||||
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c
|
||||
from pymongo.cursor import CursorType
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
@ -116,16 +116,6 @@ version = __version__
|
||||
"""Current version of PyMongo."""
|
||||
|
||||
|
||||
def has_c() -> bool:
|
||||
"""Is the C extension installed?"""
|
||||
try:
|
||||
from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def timeout(seconds: Optional[float]) -> ContextManager[None]:
|
||||
"""**(Provisional)** Apply the given timeout for a block of operations.
|
||||
|
||||
|
||||
@ -75,14 +75,13 @@ class _TimeoutContext(AbstractContextManager):
|
||||
self._timeout = timeout
|
||||
self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None
|
||||
|
||||
def __enter__(self) -> _TimeoutContext:
|
||||
def __enter__(self) -> None:
|
||||
timeout_token = TIMEOUT.set(self._timeout)
|
||||
prev_deadline = DEADLINE.get()
|
||||
next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf")
|
||||
deadline_token = DEADLINE.set(min(prev_deadline, next_deadline))
|
||||
rtt_token = RTT.set(0.0)
|
||||
self._tokens = (timeout_token, deadline_token, rtt_token)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
if self._tokens:
|
||||
|
||||
@ -35,6 +35,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
@ -332,13 +333,33 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return self._database
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: None = None,
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> AsyncCollection[_DocumentType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: bson.CodecOptions[_DocumentTypeArg],
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> AsyncCollection[_DocumentTypeArg]:
|
||||
...
|
||||
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
) -> AsyncCollection[_DocumentType]:
|
||||
) -> AsyncCollection[_DocumentType] | AsyncCollection[_DocumentTypeArg]:
|
||||
"""Get a clone of this collection changing the specified settings.
|
||||
|
||||
>>> coll1.read_preference
|
||||
|
||||
@ -146,13 +146,33 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
"""The name of this :class:`AsyncDatabase`."""
|
||||
return self._name
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: None = None,
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> AsyncDatabase[_DocumentType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: bson.CodecOptions[_DocumentTypeArg],
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> AsyncDatabase[_DocumentTypeArg]:
|
||||
...
|
||||
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
) -> AsyncDatabase[_DocumentType]:
|
||||
) -> AsyncDatabase[_DocumentType] | AsyncDatabase[_DocumentTypeArg]:
|
||||
"""Get a clone of this database changing the specified settings.
|
||||
|
||||
>>> db1.read_preference
|
||||
|
||||
@ -913,7 +913,7 @@ async def _configured_socket(
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
@ -992,7 +992,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _ALock(_create_lock())
|
||||
_lock = _create_lock()
|
||||
self.lock = _ALock(_lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1018,7 +1019,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
|
||||
self.size_cond = _ACondition(threading.Condition(_lock))
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1026,7 +1027,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
|
||||
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
|
||||
@ -170,8 +170,9 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._lock = _ALock(_create_lock())
|
||||
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
|
||||
_lock = _create_lock()
|
||||
self._lock = _ALock(_lock)
|
||||
self._condition = _ACondition(self._settings.condition_class(_lock))
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
|
||||
@ -850,7 +850,7 @@ def get_validated_options(
|
||||
return x
|
||||
|
||||
def get_setter_key(x: str) -> str:
|
||||
return options.cased_key(x) # type: ignore[attr-defined]
|
||||
return options.cased_key(x)
|
||||
|
||||
else:
|
||||
validated_options = {}
|
||||
@ -1060,3 +1060,13 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
|
||||
|
||||
def cased_key(self, key: str) -> Any:
|
||||
return self.__casedkeys[key.lower()]
|
||||
|
||||
|
||||
def has_c() -> bool:
|
||||
"""Is the C extension installed?"""
|
||||
try:
|
||||
from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@ -26,7 +26,7 @@ _NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
|
||||
|
||||
def _have_snappy() -> bool:
|
||||
try:
|
||||
import snappy # type:ignore[import] # noqa: F401
|
||||
import snappy # type:ignore[import-not-found] # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
|
||||
@ -21,7 +21,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import] # noqa: F401
|
||||
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
|
||||
|
||||
# Check for pymongocrypt>=1.10.
|
||||
from pymongocrypt import synchronous as _ # noqa: F401
|
||||
|
||||
145
pymongo/lock.py
145
pymongo/lock.py
@ -14,17 +14,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
|
||||
|
||||
# References to instances of _create_lock
|
||||
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _create_lock() -> threading.Lock:
|
||||
"""Represents a lock that is tracked upon instantiation using a WeakSet and
|
||||
@ -43,7 +46,14 @@ def _release_locks() -> None:
|
||||
lock.release()
|
||||
|
||||
|
||||
# Needed only for synchro.py compat.
|
||||
def _Lock(lock: threading.Lock) -> threading.Lock:
|
||||
return lock
|
||||
|
||||
|
||||
class _ALock:
|
||||
__slots__ = ("_lock",)
|
||||
|
||||
def __init__(self, lock: threading.Lock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
@ -81,9 +91,18 @@ class _ALock:
|
||||
self.release()
|
||||
|
||||
|
||||
def _safe_set_result(fut: asyncio.Future) -> None:
|
||||
# Ensure the future hasn't been cancelled before calling set_result.
|
||||
if not fut.done():
|
||||
fut.set_result(False)
|
||||
|
||||
|
||||
class _ACondition:
|
||||
__slots__ = ("_condition", "_waiters")
|
||||
|
||||
def __init__(self, condition: threading.Condition) -> None:
|
||||
self._condition = condition
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
|
||||
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
|
||||
if timeout > 0:
|
||||
@ -99,30 +118,116 @@ class _ACondition:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait(0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
"""Wait until notified.
|
||||
|
||||
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
|
||||
if timeout is not None:
|
||||
tstart = time.monotonic()
|
||||
while True:
|
||||
notified = self._condition.wait_for(predicate, 0.001)
|
||||
if notified:
|
||||
return True
|
||||
if timeout is not None and (time.monotonic() - tstart) > timeout:
|
||||
return False
|
||||
If the calling task has not acquired the lock when this
|
||||
method is called, a RuntimeError is raised.
|
||||
|
||||
This method releases the underlying lock, and then blocks
|
||||
until it is awakened by a notify() or notify_all() call for
|
||||
the same condition variable in another task. Once
|
||||
awakened, it re-acquires the lock and returns True.
|
||||
|
||||
This method may return spuriously,
|
||||
which is why the caller should always
|
||||
re-check the state and be prepared to wait() again.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
self._waiters.append((loop, fut))
|
||||
self.release()
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False # Return false on timeout for sync pool compat.
|
||||
finally:
|
||||
# Must re-acquire lock even if wait is cancelled.
|
||||
# We only catch CancelledError here, since we don't want any
|
||||
# other (fatal) errors with the future to cause us to spin.
|
||||
err = None
|
||||
while True:
|
||||
try:
|
||||
await self.acquire()
|
||||
break
|
||||
except asyncio.exceptions.CancelledError as e:
|
||||
err = e
|
||||
|
||||
self._waiters.remove((loop, fut))
|
||||
if err is not None:
|
||||
try:
|
||||
raise err # Re-raise most recent exception instance.
|
||||
finally:
|
||||
err = None # Break reference cycles.
|
||||
except BaseException:
|
||||
# Any error raised out of here _may_ have occurred after this Task
|
||||
# believed to have been successfully notified.
|
||||
# Make sure to notify another Task instead. This may result
|
||||
# in a "spurious wakeup", which is allowed as part of the
|
||||
# Condition Variable protocol.
|
||||
self.notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
interpreted as a boolean value. The method will repeatedly
|
||||
wait() until it evaluates to true. The final predicate value is
|
||||
the return value.
|
||||
"""
|
||||
result = predicate()
|
||||
while not result:
|
||||
await self.wait()
|
||||
result = predicate()
|
||||
return result
|
||||
|
||||
def notify(self, n: int = 1) -> None:
|
||||
self._condition.notify(n)
|
||||
"""By default, wake up one coroutine waiting on this condition, if any.
|
||||
If the calling coroutine has not acquired the lock when this method
|
||||
is called, a RuntimeError is raised.
|
||||
|
||||
This method wakes up at most n of the coroutines waiting for the
|
||||
condition variable; it is a no-op if no coroutines are waiting.
|
||||
|
||||
Note: an awakened coroutine does not actually return from its
|
||||
wait() call until it can reacquire the lock. Since notify() does
|
||||
not release the lock, its caller should.
|
||||
"""
|
||||
idx = 0
|
||||
to_remove = []
|
||||
for loop, fut in self._waiters:
|
||||
if idx >= n:
|
||||
break
|
||||
|
||||
if fut.done():
|
||||
continue
|
||||
|
||||
try:
|
||||
loop.call_soon_threadsafe(_safe_set_result, fut)
|
||||
except RuntimeError:
|
||||
# Loop was closed, ignore.
|
||||
to_remove.append((loop, fut))
|
||||
continue
|
||||
|
||||
idx += 1
|
||||
|
||||
for waiter in to_remove:
|
||||
self._waiters.remove(waiter)
|
||||
|
||||
def notify_all(self) -> None:
|
||||
self._condition.notify_all()
|
||||
"""Wake up all threads waiting on this condition. This method acts
|
||||
like notify(), but wakes up all waiting threads instead of one. If the
|
||||
calling thread has not acquired the lock when this method is called,
|
||||
a RuntimeError is raised.
|
||||
"""
|
||||
self.notify(len(self._waiters))
|
||||
|
||||
def locked(self) -> bool:
|
||||
"""Only needed for tests in test_locks."""
|
||||
return self._condition._lock.locked() # type: ignore[attr-defined]
|
||||
|
||||
def release(self) -> None:
|
||||
self._condition.release()
|
||||
|
||||
@ -33,6 +33,7 @@ from pymongo.common import (
|
||||
MAX_POOL_SIZE,
|
||||
MIN_POOL_SIZE,
|
||||
WAIT_QUEUE_TIMEOUT,
|
||||
has_c,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -363,6 +364,11 @@ class PoolOptions:
|
||||
# },
|
||||
# 'platform': 'CPython 3.8.0|MyPlatform'
|
||||
# }
|
||||
if has_c():
|
||||
self.__metadata["driver"]["name"] = "{}|{}".format(
|
||||
self.__metadata["driver"]["name"],
|
||||
"c",
|
||||
)
|
||||
if not is_sync:
|
||||
self.__metadata["driver"]["name"] = "{}|{}".format(
|
||||
self.__metadata["driver"]["name"],
|
||||
|
||||
@ -34,6 +34,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
|
||||
@ -333,13 +334,33 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return self._database
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: None = None,
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> Collection[_DocumentType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: bson.CodecOptions[_DocumentTypeArg],
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> Collection[_DocumentTypeArg]:
|
||||
...
|
||||
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
) -> Collection[_DocumentType]:
|
||||
) -> Collection[_DocumentType] | Collection[_DocumentTypeArg]:
|
||||
"""Get a clone of this collection changing the specified settings.
|
||||
|
||||
>>> coll1.read_preference
|
||||
|
||||
@ -146,13 +146,33 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
"""The name of this :class:`Database`."""
|
||||
return self._name
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: None = None,
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> Database[_DocumentType]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: bson.CodecOptions[_DocumentTypeArg],
|
||||
read_preference: Optional[_ServerMode] = ...,
|
||||
write_concern: Optional[WriteConcern] = ...,
|
||||
read_concern: Optional[ReadConcern] = ...,
|
||||
) -> Database[_DocumentTypeArg]:
|
||||
...
|
||||
|
||||
def with_options(
|
||||
self,
|
||||
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
) -> Database[_DocumentType]:
|
||||
) -> Database[_DocumentType] | Database[_DocumentTypeArg]:
|
||||
"""Get a clone of this database changing the specified settings.
|
||||
|
||||
>>> db1.read_preference
|
||||
|
||||
@ -62,7 +62,7 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.logger import (
|
||||
_CONNECTION_LOGGER,
|
||||
_ConnectionStatusMessage,
|
||||
@ -909,7 +909,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
@ -988,7 +988,8 @@ class Pool:
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _create_lock()
|
||||
_lock = _create_lock()
|
||||
self.lock = _Lock(_lock)
|
||||
self.active_sockets = 0
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
@ -1014,7 +1015,7 @@ class Pool:
|
||||
# The first portion of the wait queue.
|
||||
# Enforces: maxPoolSize
|
||||
# Also used for: clearing the wait queue
|
||||
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
|
||||
self.size_cond = threading.Condition(_lock)
|
||||
self.requests = 0
|
||||
self.max_pool_size = self.opts.max_pool_size
|
||||
if not self.max_pool_size:
|
||||
@ -1022,7 +1023,7 @@ class Pool:
|
||||
# The second portion of the wait queue.
|
||||
# Enforces: maxConnecting
|
||||
# Also used for: clearing the wait queue
|
||||
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
|
||||
self._max_connecting_cond = threading.Condition(_lock)
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
self._client_id = client_id
|
||||
|
||||
@ -39,7 +39,7 @@ from pymongo.errors import (
|
||||
WriteError,
|
||||
)
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.lock import _create_lock, _Lock
|
||||
from pymongo.logger import (
|
||||
_SDAM_LOGGER,
|
||||
_SERVER_SELECTION_LOGGER,
|
||||
@ -170,8 +170,9 @@ class Topology:
|
||||
self._seed_addresses = list(topology_description.server_descriptions())
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._lock = _create_lock()
|
||||
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
|
||||
_lock = _create_lock()
|
||||
self._lock = _Lock(_lock)
|
||||
self._condition = self._settings.condition_class(_lock)
|
||||
self._servers: dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
|
||||
7
requirements/typing.txt
Normal file
7
requirements/typing.txt
Normal file
@ -0,0 +1,7 @@
|
||||
mypy==1.11.2
|
||||
pyright==1.1.382.post1
|
||||
typing_extensions
|
||||
-r ./encryption.txt
|
||||
-r ./ocsp.txt
|
||||
-r ./zstd.txt
|
||||
-r ./aws.txt
|
||||
@ -313,7 +313,7 @@ class ClientContext:
|
||||
params = self.cmd_line["parsed"].get("setParameter", {})
|
||||
if params.get("enableTestCommands") == "1":
|
||||
self.test_commands_enabled = True
|
||||
self.has_ipv6 = self._server_started_with_ipv6()
|
||||
self.has_ipv6 = self._server_started_with_ipv6()
|
||||
|
||||
self.is_mongos = (self.hello).get("msg") == "isdbgrid"
|
||||
if self.is_mongos:
|
||||
|
||||
@ -313,7 +313,7 @@ class AsyncClientContext:
|
||||
params = self.cmd_line["parsed"].get("setParameter", {})
|
||||
if params.get("enableTestCommands") == "1":
|
||||
self.test_commands_enabled = True
|
||||
self.has_ipv6 = await self._server_started_with_ipv6()
|
||||
self.has_ipv6 = await self._server_started_with_ipv6()
|
||||
|
||||
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
|
||||
if self.is_mongos:
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import _thread as thread
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
@ -31,13 +32,15 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Iterable, Type, no_type_check
|
||||
import uuid
|
||||
from typing import Any, Iterable, Type, no_type_check
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation
|
||||
from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -57,6 +60,7 @@ from test.asynchronous import (
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.pymongo_mocks import AsyncMockClient
|
||||
from test.test_binary import BinaryData
|
||||
from test.utils import (
|
||||
NTHREADS,
|
||||
CMAPListener,
|
||||
@ -95,7 +99,7 @@ from pymongo.asynchronous.pool import (
|
||||
from pymongo.asynchronous.settings import TOPOLOGY_TYPE
|
||||
from pymongo.asynchronous.topology import _ErrorContext
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT
|
||||
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c
|
||||
from pymongo.compression_support import _have_snappy, _have_zstd
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import (
|
||||
@ -343,7 +347,10 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
|
||||
async def test_metadata(self):
|
||||
metadata = copy.deepcopy(_METADATA)
|
||||
metadata["driver"]["name"] = "PyMongo|async"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c|async"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo|async"
|
||||
metadata["application"] = {"name": "foobar"}
|
||||
client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false")
|
||||
options = client.options
|
||||
@ -366,7 +373,10 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
with self.assertRaises(TypeError):
|
||||
self.simple_client(driver=("Foo", "1", "a"))
|
||||
# Test appending to driver info.
|
||||
metadata["driver"]["name"] = "PyMongo|async|FooDriver"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c|async|FooDriver"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo|async|FooDriver"
|
||||
metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"])
|
||||
client = self.simple_client(
|
||||
"foo",
|
||||
@ -1927,7 +1937,10 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def _test_handshake(self, env_vars, expected_env):
|
||||
with patch.dict("os.environ", env_vars):
|
||||
metadata = copy.deepcopy(_METADATA)
|
||||
metadata["driver"]["name"] = "PyMongo|async"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c|async"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo|async"
|
||||
if expected_env is not None:
|
||||
metadata["env"] = expected_env
|
||||
|
||||
@ -2020,6 +2033,75 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_dict_hints_create_index(self):
|
||||
await self.db.t.create_index({"x": pymongo.ASCENDING})
|
||||
|
||||
async def test_legacy_java_uuid_roundtrip(self):
|
||||
data = BinaryData.java_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
|
||||
|
||||
await async_client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
db = async_client_context.client.pymongo_test
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
|
||||
|
||||
await coll.insert_many(docs)
|
||||
self.assertEqual(5, await coll.count_documents({}))
|
||||
async for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
async for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
await async_client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
|
||||
async def test_legacy_csharp_uuid_roundtrip(self):
|
||||
data = BinaryData.csharp_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
|
||||
|
||||
await async_client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
db = async_client_context.client.pymongo_test
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
|
||||
|
||||
await coll.insert_many(docs)
|
||||
self.assertEqual(5, await coll.count_documents({}))
|
||||
async for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
async for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
await async_client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
|
||||
async def test_uri_to_uuid(self):
|
||||
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
|
||||
client = await self.async_single_client(uri, connect=False)
|
||||
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
|
||||
|
||||
async def test_uuid_queries(self):
|
||||
db = async_client_context.client.pymongo_test
|
||||
coll = db.test
|
||||
await coll.drop()
|
||||
|
||||
uu = uuid.uuid4()
|
||||
await coll.insert_one({"uuid": Binary(uu.bytes, 3)})
|
||||
self.assertEqual(1, await coll.count_documents({}))
|
||||
|
||||
# Test regular UUID queries (using subtype 4).
|
||||
coll = db.get_collection(
|
||||
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
)
|
||||
self.assertEqual(0, await coll.count_documents({"uuid": uu}))
|
||||
await coll.insert_one({"uuid": uu})
|
||||
self.assertEqual(2, await coll.count_documents({}))
|
||||
docs = await coll.find({"uuid": uu}).to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(uu, docs[0]["uuid"])
|
||||
|
||||
# Test both.
|
||||
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
|
||||
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
|
||||
self.assertEqual(2, await coll.count_documents(predicate))
|
||||
docs = await coll.find(predicate).to_list()
|
||||
self.assertEqual(2, len(docs))
|
||||
await coll.drop()
|
||||
|
||||
|
||||
class TestExhaustCursor(AsyncIntegrationTest):
|
||||
"""Test that clients properly handle errors from exhaust cursors."""
|
||||
@ -2351,7 +2433,9 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
|
||||
# But it can reconnect.
|
||||
c.revive_host("a:1")
|
||||
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
await (await c._get_topology()).select_servers(
|
||||
writable_server_selector, _Op.TEST, server_selection_timeout=10
|
||||
)
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
|
||||
async def _test_network_error(self, operation_callback):
|
||||
|
||||
@ -30,6 +30,7 @@ from test.utils import (
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pymongo
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
|
||||
from pymongo.errors import (
|
||||
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
await client.admin.command("ping") # Init the client first.
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
await client.bulk_write(models=models)
|
||||
self.assertIsInstance(context.exception.error, NetworkTimeout)
|
||||
|
||||
@ -1414,7 +1414,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
@ -1456,7 +1456,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_command_cursor_to_list_csot_applied(self):
|
||||
client = await self.async_single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
await client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
|
||||
@ -711,7 +711,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
"write_concern": WriteConcern(w=1),
|
||||
"read_concern": ReadConcern(level="local"),
|
||||
}
|
||||
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
|
||||
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
|
||||
for opt in newopts:
|
||||
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
|
||||
|
||||
|
||||
513
test/asynchronous/test_locks.py
Normal file
513
test/asynchronous/test_locks.py
Normal file
@ -0,0 +1,513 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for lock.py"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.lock import _ACondition
|
||||
|
||||
|
||||
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
|
||||
# Includes tests for:
|
||||
# - https://github.com/python/cpython/issues/111693
|
||||
# - https://github.com/python/cpython/issues/112202
|
||||
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_wait(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
self.assertTrue(await cond.acquire())
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.notify(2)
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_wait_cancel(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
await cond.acquire()
|
||||
|
||||
wait = asyncio.create_task(cond.wait())
|
||||
asyncio.get_running_loop().call_soon(wait.cancel)
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await wait
|
||||
self.assertFalse(cond._waiters)
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_contested(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
wait_task = asyncio.create_task(cond.wait())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
# Notify, but contest the lock before cancelling
|
||||
await cond.acquire()
|
||||
self.assertTrue(cond.locked())
|
||||
cond.notify()
|
||||
asyncio.get_running_loop().call_soon(wait_task.cancel)
|
||||
asyncio.get_running_loop().call_soon(cond.release)
|
||||
|
||||
try:
|
||||
await wait_task
|
||||
except asyncio.CancelledError:
|
||||
# Should not happen, since no cancellation points
|
||||
pass
|
||||
|
||||
self.assertTrue(cond.locked())
|
||||
|
||||
async def test_wait_cancel_after_notify(self):
|
||||
# See bpo-32841
|
||||
waited = False
|
||||
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def wait_on_cond():
|
||||
nonlocal waited
|
||||
async with cond:
|
||||
waited = True # Make sure this area was reached
|
||||
await cond.wait()
|
||||
|
||||
waiter = asyncio.create_task(wait_on_cond())
|
||||
await asyncio.sleep(0) # Start waiting
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
await asyncio.sleep(0) # Get to acquire()
|
||||
waiter.cancel()
|
||||
await asyncio.sleep(0) # Activate cancellation
|
||||
cond.release()
|
||||
await asyncio.sleep(0) # Cancellation should occur
|
||||
|
||||
self.assertTrue(waiter.cancelled())
|
||||
self.assertTrue(waited)
|
||||
|
||||
async def test_wait_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait()
|
||||
|
||||
async def test_wait_for(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
presult = False
|
||||
|
||||
def predicate():
|
||||
return presult
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
await cond.acquire()
|
||||
if await cond.wait_for(predicate):
|
||||
result.append(1)
|
||||
cond.release()
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(c1(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
presult = True
|
||||
await cond.acquire()
|
||||
cond.notify()
|
||||
cond.release()
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_wait_for_unacquired(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
# predicate can return true immediately
|
||||
res = await cond.wait_for(lambda: [1, 2, 3])
|
||||
self.assertEqual([1, 2, 3], res)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
await cond.wait_for(lambda: False)
|
||||
|
||||
async def test_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
async def c3(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(3)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
t3 = asyncio.create_task(c3(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
cond.notify(2048)
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2, 3], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
self.assertTrue(t3.done())
|
||||
self.assertTrue(t3.result())
|
||||
|
||||
async def test_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
result = []
|
||||
|
||||
async def c1(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(1)
|
||||
return True
|
||||
|
||||
async def c2(result):
|
||||
async with cond:
|
||||
if await cond.wait():
|
||||
result.append(2)
|
||||
return True
|
||||
|
||||
t1 = asyncio.create_task(c1(result))
|
||||
t2 = asyncio.create_task(c2(result))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([], result)
|
||||
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
await asyncio.sleep(1)
|
||||
self.assertEqual([1, 2], result)
|
||||
|
||||
self.assertTrue(t1.done())
|
||||
self.assertTrue(t1.result())
|
||||
self.assertTrue(t2.done())
|
||||
self.assertTrue(t2.result())
|
||||
|
||||
async def test_context_manager(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
self.assertFalse(cond.locked())
|
||||
async with cond:
|
||||
self.assertTrue(cond.locked())
|
||||
self.assertFalse(cond.locked())
|
||||
|
||||
async def test_timeout_in_block(self):
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
async with condition:
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(condition.wait(), timeout=0.5)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_wakeup(self):
|
||||
# Test that a cancelled error, received when awaiting wakeup,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
|
||||
)
|
||||
async def test_cancelled_error_re_aquire(self):
|
||||
# Test that a cancelled error, received when re-aquiring lock,
|
||||
# will be re-raised un-modified.
|
||||
wake = False
|
||||
raised = None
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
async def func():
|
||||
nonlocal raised
|
||||
async with cond:
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await cond.wait_for(lambda: wake)
|
||||
raised = err.exception
|
||||
raise raised
|
||||
|
||||
task = asyncio.create_task(func())
|
||||
await asyncio.sleep(0)
|
||||
# Task is waiting on the condition
|
||||
await cond.acquire()
|
||||
wake = True
|
||||
cond.notify()
|
||||
await asyncio.sleep(0)
|
||||
# Task is now trying to re-acquire the lock, cancel it there.
|
||||
task.cancel(msg="foo") # type: ignore[call-arg]
|
||||
cond.release()
|
||||
with self.assertRaises(asyncio.CancelledError) as err:
|
||||
await task
|
||||
self.assertEqual(err.exception.args, ("foo",))
|
||||
# We should have got the _same_ exception instance as the one
|
||||
# originally raised.
|
||||
self.assertIs(err.exception, raised)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is awaiting initial
|
||||
# wakeup on the wakeup queue.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# Cancel it while it is awaiting to be run.
|
||||
# This cancellation could come from the outside
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
|
||||
async def test_cancelled_wakeup_relock(self):
|
||||
# Test that a task cancelled at the "same" time as it is woken
|
||||
# up as part of a Condition.notify() does not result in a lost wakeup.
|
||||
# This test simulates a cancel while the target task is acquiring the lock
|
||||
# again.
|
||||
condition = _ACondition(threading.Condition(threading.Lock()))
|
||||
state = 0
|
||||
|
||||
async def consumer():
|
||||
nonlocal state
|
||||
async with condition:
|
||||
while True:
|
||||
await condition.wait_for(lambda: state != 0)
|
||||
if state < 0:
|
||||
return
|
||||
state -= 1
|
||||
|
||||
# create two consumers
|
||||
c = [asyncio.create_task(consumer()) for _ in range(2)]
|
||||
# wait for them to settle
|
||||
await asyncio.sleep(0.1)
|
||||
async with condition:
|
||||
# produce one item and wake up one
|
||||
state += 1
|
||||
condition.notify(1)
|
||||
|
||||
# now we sleep for a bit. This allows the target task to wake up and
|
||||
# settle on re-aquiring the lock
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Cancel it while awaiting the lock
|
||||
# This cancel could come the outside.
|
||||
c[0].cancel()
|
||||
|
||||
# now wait for the item to be consumed
|
||||
# if it doesn't means that our "notify" didn"t take hold.
|
||||
# because it raced with a cancel()
|
||||
try:
|
||||
async with asyncio.timeout(1):
|
||||
await condition.wait_for(lambda: state == 0)
|
||||
except TimeoutError:
|
||||
pass
|
||||
self.assertEqual(state, 0)
|
||||
|
||||
# clean up
|
||||
state = -1
|
||||
condition.notify_all()
|
||||
await c[1]
|
||||
|
||||
|
||||
class TestCondition(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_multiple_loops_notify(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
|
||||
def tmain(cond):
|
||||
async def atmain(cond):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
cond.notify(1)
|
||||
|
||||
asyncio.run(atmain(cond))
|
||||
|
||||
t = threading.Thread(target=tmain, args=(cond,))
|
||||
t.start()
|
||||
|
||||
async with cond:
|
||||
self.assertTrue(await cond.wait(30))
|
||||
t.join()
|
||||
|
||||
async def test_multiple_loops_notify_all(self):
|
||||
cond = _ACondition(threading.Condition(threading.Lock()))
|
||||
results = []
|
||||
|
||||
def tmain(cond, results):
|
||||
async def atmain(cond, results):
|
||||
await asyncio.sleep(1)
|
||||
async with cond:
|
||||
res = await cond.wait(30)
|
||||
results.append(res)
|
||||
|
||||
asyncio.run(atmain(cond, results))
|
||||
|
||||
nthreads = 5
|
||||
threads = []
|
||||
for _ in range(nthreads):
|
||||
threads.append(threading.Thread(target=tmain, args=(cond, results)))
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
await asyncio.sleep(2)
|
||||
async with cond:
|
||||
cond.notify_all()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
self.assertEqual(results, [True] * nthreads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
219
test/asynchronous/test_raw_bson.py
Normal file
219
test/asynchronous/test_raw_bson.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
|
||||
from bson import Code, DBRef, decode, encode
|
||||
from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.errors import InvalidBSON
|
||||
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
|
||||
from bson.son import SON
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestRawBSONDocument(AsyncIntegrationTest):
|
||||
# {'_id': ObjectId('556df68b6e32ab21a95e0785'),
|
||||
# 'name': 'Sherlock',
|
||||
# 'addresses': [{'street': 'Baker Street'}]}
|
||||
bson_string = (
|
||||
b"Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t"
|
||||
b"\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e"
|
||||
b"\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00"
|
||||
)
|
||||
document = RawBSONDocument(bson_string)
|
||||
|
||||
async def asyncTearDown(self):
|
||||
if async_client_context.connected:
|
||||
await self.client.pymongo_test.test_raw.drop()
|
||||
|
||||
def test_decode(self):
|
||||
self.assertEqual("Sherlock", self.document["name"])
|
||||
first_address = self.document["addresses"][0]
|
||||
self.assertIsInstance(first_address, RawBSONDocument)
|
||||
self.assertEqual("Baker Street", first_address["street"])
|
||||
|
||||
def test_raw(self):
|
||||
self.assertEqual(self.bson_string, self.document.raw)
|
||||
|
||||
def test_empty_doc(self):
|
||||
doc = RawBSONDocument(encode({}))
|
||||
with self.assertRaises(KeyError):
|
||||
doc["does-not-exist"]
|
||||
|
||||
def test_invalid_bson_sequence(self):
|
||||
bson_byte_sequence = encode({"a": 1}) + encode({})
|
||||
with self.assertRaisesRegex(InvalidBSON, "invalid object length"):
|
||||
RawBSONDocument(bson_byte_sequence)
|
||||
|
||||
def test_invalid_bson_eoo(self):
|
||||
invalid_bson_eoo = encode({"a": 1})[:-1] + b"\x01"
|
||||
with self.assertRaisesRegex(InvalidBSON, "bad eoo"):
|
||||
RawBSONDocument(invalid_bson_eoo)
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def test_round_trip(self):
|
||||
db = self.client.get_database(
|
||||
"pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument)
|
||||
)
|
||||
await db.test_raw.insert_one(self.document)
|
||||
result = await db.test_raw.find_one(self.document["_id"])
|
||||
assert result is not None
|
||||
self.assertIsInstance(result, RawBSONDocument)
|
||||
self.assertEqual(dict(self.document.items()), dict(result.items()))
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def test_round_trip_raw_uuid(self):
|
||||
coll = self.client.get_database("pymongo_test").test_raw
|
||||
uid = uuid.uuid4()
|
||||
doc = {"_id": 1, "bin4": Binary(uid.bytes, 4), "bin3": Binary(uid.bytes, 3)}
|
||||
raw = RawBSONDocument(encode(doc))
|
||||
await coll.insert_one(raw)
|
||||
self.assertEqual(await coll.find_one(), doc)
|
||||
uuid_coll = coll.with_options(
|
||||
codec_options=coll.codec_options.with_options(
|
||||
uuid_representation=UuidRepresentation.STANDARD
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
await uuid_coll.find_one(), {"_id": 1, "bin4": uid, "bin3": Binary(uid.bytes, 3)}
|
||||
)
|
||||
|
||||
# Test that the raw bytes haven't changed.
|
||||
raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS)
|
||||
self.assertEqual(await raw_coll.find_one(), raw)
|
||||
|
||||
def test_with_codec_options(self):
|
||||
# {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
|
||||
# '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')}
|
||||
# encoded with JAVA_LEGACY uuid representation.
|
||||
bson_string = (
|
||||
b"-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02"
|
||||
b"\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM"
|
||||
b"\x01\x00\x00\x00"
|
||||
)
|
||||
document = RawBSONDocument(
|
||||
bson_string,
|
||||
codec_options=CodecOptions(
|
||||
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), document["_id"])
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def test_round_trip_codec_options(self):
|
||||
doc = {
|
||||
"date": datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
|
||||
"_id": uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"),
|
||||
}
|
||||
db = self.client.pymongo_test
|
||||
coll = db.get_collection(
|
||||
"test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)
|
||||
)
|
||||
await coll.insert_one(doc)
|
||||
raw_java_legacy = CodecOptions(
|
||||
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
|
||||
)
|
||||
coll = db.get_collection("test_raw", codec_options=raw_java_legacy)
|
||||
self.assertEqual(
|
||||
RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), await coll.find_one()
|
||||
)
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def test_raw_bson_document_embedded(self):
|
||||
doc = {"embedded": self.document}
|
||||
db = self.client.pymongo_test
|
||||
await db.test_raw.insert_one(doc)
|
||||
result = await db.test_raw.find_one()
|
||||
assert result is not None
|
||||
self.assertEqual(decode(self.document.raw), result["embedded"])
|
||||
|
||||
# Make sure that CodecOptions are preserved.
|
||||
# {'embedded': [
|
||||
# {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
|
||||
# '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')}
|
||||
# ]}
|
||||
# encoded with JAVA_LEGACY uuid representation.
|
||||
bson_string = (
|
||||
b"D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00"
|
||||
b"\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00"
|
||||
b"\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00"
|
||||
b"\x00"
|
||||
)
|
||||
rbd = RawBSONDocument(
|
||||
bson_string,
|
||||
codec_options=CodecOptions(
|
||||
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
|
||||
),
|
||||
)
|
||||
|
||||
await db.test_raw.drop()
|
||||
await db.test_raw.insert_one(rbd)
|
||||
result = await db.get_collection(
|
||||
"test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)
|
||||
).find_one()
|
||||
assert result is not None
|
||||
self.assertEqual(rbd["embedded"][0]["_id"], result["embedded"][0]["_id"])
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def test_write_response_raw_bson(self):
|
||||
coll = self.client.get_database(
|
||||
"pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument)
|
||||
).test_raw
|
||||
|
||||
# No Exceptions raised while handling write response.
|
||||
await coll.insert_one(self.document)
|
||||
await coll.delete_one(self.document)
|
||||
await coll.insert_many([self.document])
|
||||
await coll.delete_many(self.document)
|
||||
await coll.update_one(self.document, {"$set": {"a": "b"}}, upsert=True)
|
||||
await coll.update_many(self.document, {"$set": {"b": "c"}})
|
||||
|
||||
def test_preserve_key_ordering(self):
|
||||
keyvaluepairs = [
|
||||
("a", 1),
|
||||
("b", 2),
|
||||
("c", 3),
|
||||
]
|
||||
rawdoc = RawBSONDocument(encode(SON(keyvaluepairs)))
|
||||
|
||||
for rkey, elt in zip(rawdoc, keyvaluepairs):
|
||||
self.assertEqual(rkey, elt[0])
|
||||
|
||||
def test_contains_code_with_scope(self):
|
||||
doc = RawBSONDocument(encode({"value": Code("x=1", scope={})}))
|
||||
|
||||
self.assertEqual(decode(encode(doc)), {"value": Code("x=1", {})})
|
||||
self.assertEqual(doc["value"].scope, RawBSONDocument(encode({})))
|
||||
|
||||
def test_contains_dbref(self):
|
||||
doc = RawBSONDocument(encode({"value": DBRef("test", "id")}))
|
||||
raw = {"$ref": "test", "$id": "id"}
|
||||
raw_encoded = encode(decode(encode(raw)))
|
||||
|
||||
self.assertEqual(decode(encode(doc)), {"value": DBRef("test", "id")})
|
||||
self.assertEqual(doc["value"].raw, raw_encoded)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -34,53 +34,49 @@ from bson.binary import *
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.son import SON
|
||||
from pymongo.common import validate_uuid_representation
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
class BinaryData:
|
||||
# Generated by the Java driver
|
||||
from_java = (
|
||||
b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu"
|
||||
b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND"
|
||||
b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+"
|
||||
b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1"
|
||||
b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR"
|
||||
b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA"
|
||||
b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z"
|
||||
b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf"
|
||||
b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx"
|
||||
b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My"
|
||||
b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB"
|
||||
b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp"
|
||||
b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc"
|
||||
b"0MQAA"
|
||||
)
|
||||
java_data = base64.b64decode(from_java)
|
||||
|
||||
# Generated by the .net driver
|
||||
from_csharp = (
|
||||
b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl"
|
||||
b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2"
|
||||
b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V"
|
||||
b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl"
|
||||
b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A"
|
||||
b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z"
|
||||
b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU"
|
||||
b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn"
|
||||
b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA"
|
||||
b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT"
|
||||
b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP"
|
||||
b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00"
|
||||
b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA="
|
||||
)
|
||||
csharp_data = base64.b64decode(from_csharp)
|
||||
|
||||
|
||||
class TestBinary(unittest.TestCase):
|
||||
csharp_data: bytes
|
||||
java_data: bytes
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Generated by the Java driver
|
||||
from_java = (
|
||||
b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu"
|
||||
b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND"
|
||||
b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+"
|
||||
b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1"
|
||||
b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR"
|
||||
b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA"
|
||||
b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z"
|
||||
b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf"
|
||||
b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx"
|
||||
b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My"
|
||||
b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB"
|
||||
b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp"
|
||||
b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc"
|
||||
b"0MQAA"
|
||||
)
|
||||
cls.java_data = base64.b64decode(from_java)
|
||||
|
||||
# Generated by the .net driver
|
||||
from_csharp = (
|
||||
b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl"
|
||||
b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2"
|
||||
b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V"
|
||||
b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl"
|
||||
b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A"
|
||||
b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z"
|
||||
b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU"
|
||||
b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn"
|
||||
b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA"
|
||||
b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT"
|
||||
b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP"
|
||||
b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00"
|
||||
b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA="
|
||||
)
|
||||
cls.csharp_data = base64.b64decode(from_csharp)
|
||||
|
||||
def test_binary(self):
|
||||
a_string = "hello world"
|
||||
a_binary = Binary(b"hello world")
|
||||
@ -159,7 +155,7 @@ class TestBinary(unittest.TestCase):
|
||||
|
||||
def test_legacy_java_uuid(self):
|
||||
# Test decoding
|
||||
data = self.java_data
|
||||
data = BinaryData.java_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY))
|
||||
for d in docs:
|
||||
self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
@ -197,27 +193,8 @@ class TestBinary(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(data, encoded)
|
||||
|
||||
@client_context.require_connection
|
||||
def test_legacy_java_uuid_roundtrip(self):
|
||||
data = self.java_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
|
||||
|
||||
client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
|
||||
|
||||
coll.insert_many(docs)
|
||||
self.assertEqual(5, coll.count_documents({}))
|
||||
for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
|
||||
def test_legacy_csharp_uuid(self):
|
||||
data = self.csharp_data
|
||||
data = BinaryData.csharp_data
|
||||
|
||||
# Test decoding
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY))
|
||||
@ -257,59 +234,6 @@ class TestBinary(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(data, encoded)
|
||||
|
||||
@client_context.require_connection
|
||||
def test_legacy_csharp_uuid_roundtrip(self):
|
||||
data = self.csharp_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
|
||||
|
||||
client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
|
||||
|
||||
coll.insert_many(docs)
|
||||
self.assertEqual(5, coll.count_documents({}))
|
||||
for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
|
||||
def test_uri_to_uuid(self):
|
||||
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
|
||||
client = MongoClient(uri, connect=False)
|
||||
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
|
||||
|
||||
@client_context.require_connection
|
||||
def test_uuid_queries(self):
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.test
|
||||
coll.drop()
|
||||
|
||||
uu = uuid.uuid4()
|
||||
coll.insert_one({"uuid": Binary(uu.bytes, 3)})
|
||||
self.assertEqual(1, coll.count_documents({}))
|
||||
|
||||
# Test regular UUID queries (using subtype 4).
|
||||
coll = db.get_collection(
|
||||
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
)
|
||||
self.assertEqual(0, coll.count_documents({"uuid": uu}))
|
||||
coll.insert_one({"uuid": uu})
|
||||
self.assertEqual(2, coll.count_documents({}))
|
||||
docs = list(coll.find({"uuid": uu}))
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(uu, docs[0]["uuid"])
|
||||
|
||||
# Test both.
|
||||
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
|
||||
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
|
||||
self.assertEqual(2, coll.count_documents(predicate))
|
||||
docs = list(coll.find(predicate))
|
||||
self.assertEqual(2, len(docs))
|
||||
coll.drop()
|
||||
|
||||
def test_pickle(self):
|
||||
b1 = Binary(b"123", 2)
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import _thread as thread
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
@ -31,12 +32,14 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Iterable, Type, no_type_check
|
||||
import uuid
|
||||
from typing import Any, Iterable, Type, no_type_check
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation
|
||||
from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -56,6 +59,7 @@ from test import (
|
||||
unittest,
|
||||
)
|
||||
from test.pymongo_mocks import MockClient
|
||||
from test.test_binary import BinaryData
|
||||
from test.utils import (
|
||||
NTHREADS,
|
||||
CMAPListener,
|
||||
@ -83,7 +87,7 @@ from bson.son import SON
|
||||
from bson.tz_util import utc
|
||||
from pymongo import event_loggers, message, monitoring
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT
|
||||
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c
|
||||
from pymongo.compression_support import _have_snappy, _have_zstd
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.errors import (
|
||||
@ -335,7 +339,10 @@ class ClientUnitTest(UnitTest):
|
||||
|
||||
def test_metadata(self):
|
||||
metadata = copy.deepcopy(_METADATA)
|
||||
metadata["driver"]["name"] = "PyMongo"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo"
|
||||
metadata["application"] = {"name": "foobar"}
|
||||
client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false")
|
||||
options = client.options
|
||||
@ -358,7 +365,10 @@ class ClientUnitTest(UnitTest):
|
||||
with self.assertRaises(TypeError):
|
||||
self.simple_client(driver=("Foo", "1", "a"))
|
||||
# Test appending to driver info.
|
||||
metadata["driver"]["name"] = "PyMongo|FooDriver"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c|FooDriver"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo|FooDriver"
|
||||
metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"])
|
||||
client = self.simple_client(
|
||||
"foo",
|
||||
@ -1885,7 +1895,10 @@ class TestClient(IntegrationTest):
|
||||
def _test_handshake(self, env_vars, expected_env):
|
||||
with patch.dict("os.environ", env_vars):
|
||||
metadata = copy.deepcopy(_METADATA)
|
||||
metadata["driver"]["name"] = "PyMongo"
|
||||
if has_c():
|
||||
metadata["driver"]["name"] = "PyMongo|c"
|
||||
else:
|
||||
metadata["driver"]["name"] = "PyMongo"
|
||||
if expected_env is not None:
|
||||
metadata["env"] = expected_env
|
||||
|
||||
@ -1978,6 +1991,75 @@ class TestClient(IntegrationTest):
|
||||
def test_dict_hints_create_index(self):
|
||||
self.db.t.create_index({"x": pymongo.ASCENDING})
|
||||
|
||||
def test_legacy_java_uuid_roundtrip(self):
|
||||
data = BinaryData.java_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
|
||||
|
||||
client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
|
||||
|
||||
coll.insert_many(docs)
|
||||
self.assertEqual(5, coll.count_documents({}))
|
||||
for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
client_context.client.pymongo_test.drop_collection("java_uuid")
|
||||
|
||||
def test_legacy_csharp_uuid_roundtrip(self):
|
||||
data = BinaryData.csharp_data
|
||||
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
|
||||
|
||||
client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
|
||||
|
||||
coll.insert_many(docs)
|
||||
self.assertEqual(5, coll.count_documents({}))
|
||||
for d in coll.find():
|
||||
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
|
||||
|
||||
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
for d in coll.find():
|
||||
self.assertNotEqual(d["newguid"], d["newguidstring"])
|
||||
client_context.client.pymongo_test.drop_collection("csharp_uuid")
|
||||
|
||||
def test_uri_to_uuid(self):
|
||||
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
|
||||
client = self.single_client(uri, connect=False)
|
||||
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
|
||||
|
||||
def test_uuid_queries(self):
|
||||
db = client_context.client.pymongo_test
|
||||
coll = db.test
|
||||
coll.drop()
|
||||
|
||||
uu = uuid.uuid4()
|
||||
coll.insert_one({"uuid": Binary(uu.bytes, 3)})
|
||||
self.assertEqual(1, coll.count_documents({}))
|
||||
|
||||
# Test regular UUID queries (using subtype 4).
|
||||
coll = db.get_collection(
|
||||
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
||||
)
|
||||
self.assertEqual(0, coll.count_documents({"uuid": uu}))
|
||||
coll.insert_one({"uuid": uu})
|
||||
self.assertEqual(2, coll.count_documents({}))
|
||||
docs = coll.find({"uuid": uu}).to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(uu, docs[0]["uuid"])
|
||||
|
||||
# Test both.
|
||||
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
|
||||
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
|
||||
self.assertEqual(2, coll.count_documents(predicate))
|
||||
docs = coll.find(predicate).to_list()
|
||||
self.assertEqual(2, len(docs))
|
||||
coll.drop()
|
||||
|
||||
|
||||
class TestExhaustCursor(IntegrationTest):
|
||||
"""Test that clients properly handle errors from exhaust cursors."""
|
||||
@ -2307,7 +2389,9 @@ class TestMongoClientFailover(MockClientTest):
|
||||
|
||||
# But it can reconnect.
|
||||
c.revive_host("a:1")
|
||||
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
(c._get_topology()).select_servers(
|
||||
writable_server_selector, _Op.TEST, server_selection_timeout=10
|
||||
)
|
||||
self.assertEqual(c.address, ("a", 1))
|
||||
|
||||
def _test_network_error(self, operation_callback):
|
||||
|
||||
@ -30,6 +30,7 @@ from test.utils import (
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pymongo
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
|
||||
from pymongo.errors import (
|
||||
ClientBulkWriteException,
|
||||
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(IntegrationTest):
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
client.admin.command("ping") # Init the client first.
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
client.bulk_write(models=models)
|
||||
self.assertIsInstance(context.exception.error, NetworkTimeout)
|
||||
|
||||
@ -1405,7 +1405,7 @@ class TestCursor(IntegrationTest):
|
||||
def test_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
@ -1447,7 +1447,7 @@ class TestCursor(IntegrationTest):
|
||||
def test_command_cursor_to_list_csot_applied(self):
|
||||
client = self.single_client(timeoutMS=500)
|
||||
# Initialize the client with a larger timeout to help make test less flakey
|
||||
with pymongo.timeout(2):
|
||||
with pymongo.timeout(10):
|
||||
client.admin.command("ping")
|
||||
coll = client.pymongo.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
|
||||
@ -702,7 +702,7 @@ class TestDatabase(IntegrationTest):
|
||||
"write_concern": WriteConcern(w=1),
|
||||
"read_concern": ReadConcern(level="local"),
|
||||
}
|
||||
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
|
||||
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
|
||||
for opt in newopts:
|
||||
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))
|
||||
|
||||
|
||||
@ -19,8 +19,7 @@ import uuid
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.test_client import IntegrationTest
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
|
||||
from bson import Code, DBRef, decode, encode
|
||||
from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation
|
||||
@ -29,6 +28,8 @@ from bson.errors import InvalidBSON
|
||||
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
|
||||
from bson.son import SON
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestRawBSONDocument(IntegrationTest):
|
||||
# {'_id': ObjectId('556df68b6e32ab21a95e0785'),
|
||||
|
||||
@ -19,6 +19,7 @@ import os
|
||||
import threading
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
get_pool,
|
||||
@ -27,6 +28,7 @@ from test.utils import (
|
||||
from test.utils_selection_tests import create_topology
|
||||
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.monitoring import ConnectionReadyEvent
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
|
||||
@ -131,19 +133,20 @@ class TestProse(IntegrationTest):
|
||||
@client_context.require_multiple_mongoses
|
||||
def test_load_balancing(self):
|
||||
listener = OvertCommandListener()
|
||||
cmap_listener = CMAPListener()
|
||||
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
|
||||
# varying RTTs.
|
||||
client = self.rs_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="loadBalancingTest",
|
||||
event_listeners=[listener],
|
||||
event_listeners=[listener, cmap_listener],
|
||||
localThresholdMS=30000,
|
||||
minPoolSize=10,
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
|
||||
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
|
||||
# Delay find commands on
|
||||
# Wait for both pools to be populated.
|
||||
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
|
||||
# Delay find commands on only one mongos.
|
||||
delay_finds = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 10000},
|
||||
@ -161,7 +164,7 @@ class TestProse(IntegrationTest):
|
||||
freqs = self.frequencies(client, listener)
|
||||
self.assertLessEqual(freqs[delayed_server], 0.25)
|
||||
listener.reset()
|
||||
freqs = self.frequencies(client, listener, n_finds=100)
|
||||
freqs = self.frequencies(client, listener, n_finds=150)
|
||||
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
|
||||
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
try:
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from bson import ObjectId
|
||||
@ -49,16 +49,13 @@ try:
|
||||
year: int
|
||||
|
||||
class ImplicitMovie(TypedDict):
|
||||
_id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues]
|
||||
_id: NotRequired[ObjectId]
|
||||
name: str
|
||||
year: int
|
||||
|
||||
except ImportError:
|
||||
Movie = dict # type:ignore[misc,assignment]
|
||||
ImplicitMovie = dict # type: ignore[assignment,misc]
|
||||
MovieWithId = dict # type: ignore[assignment,misc]
|
||||
TypedDict = None
|
||||
NotRequired = None # type: ignore[assignment]
|
||||
else:
|
||||
Movie = dict
|
||||
ImplicitMovie = dict
|
||||
NotRequired = None
|
||||
|
||||
|
||||
try:
|
||||
@ -234,6 +231,19 @@ class TestPymongo(IntegrationTest):
|
||||
execute_transaction, read_preference=ReadPreference.PRIMARY
|
||||
)
|
||||
|
||||
def test_with_options(self) -> None:
|
||||
coll: Collection[Dict[str, Any]] = self.coll
|
||||
coll.drop()
|
||||
doc = {"name": "foo", "year": 1982, "other": 1}
|
||||
coll.insert_one(doc)
|
||||
|
||||
coll2 = coll.with_options(codec_options=CodecOptions(document_class=Movie))
|
||||
retrieved = coll2.find_one()
|
||||
assert retrieved is not None
|
||||
assert retrieved["name"] == "foo"
|
||||
# We expect a type error here.
|
||||
assert retrieved["other"] == 1 # type:ignore[typeddict-item]
|
||||
|
||||
|
||||
class TestDecode(unittest.TestCase):
|
||||
def test_bson_decode(self) -> None:
|
||||
@ -426,7 +436,7 @@ class TestDocumentType(PyMongoTestCase):
|
||||
)
|
||||
coll.bulk_write(
|
||||
[
|
||||
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971})
|
||||
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
|
||||
] # No error because it is in-line.
|
||||
)
|
||||
|
||||
@ -443,7 +453,7 @@ class TestDocumentType(PyMongoTestCase):
|
||||
)
|
||||
coll.bulk_write(
|
||||
[
|
||||
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971})
|
||||
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
|
||||
] # No error because it is in-line.
|
||||
)
|
||||
|
||||
@ -566,7 +576,7 @@ class TestCodecOptionsDocumentType(unittest.TestCase):
|
||||
def test_typeddict_document_type(self) -> None:
|
||||
options: CodecOptions[Movie] = CodecOptions()
|
||||
# Suppress: Cannot instantiate type "Type[Movie]".
|
||||
obj = options.document_class(name="a", year=1) # type: ignore[misc]
|
||||
obj = options.document_class(name="a", year=1)
|
||||
assert obj["year"] == 1
|
||||
assert obj["name"] == "a"
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ import re
|
||||
from os import listdir
|
||||
from pathlib import Path
|
||||
|
||||
from unasync import Rule, unasync_files # type: ignore[import]
|
||||
from unasync import Rule, unasync_files # type: ignore[import-not-found]
|
||||
|
||||
replacements = {
|
||||
"AsyncCollection": "Collection",
|
||||
@ -101,6 +101,7 @@ replacements = {
|
||||
"default_async": "default",
|
||||
"aclose": "close",
|
||||
"PyMongo|async": "PyMongo",
|
||||
"PyMongo|c|async": "PyMongo|c",
|
||||
"AsyncTestGridFile": "TestGridFile",
|
||||
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
|
||||
}
|
||||
@ -144,7 +145,17 @@ gridfs_files = [
|
||||
_gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file()
|
||||
]
|
||||
|
||||
test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()]
|
||||
|
||||
def async_only_test(f: str) -> bool:
|
||||
"""Return True for async tests that should not be converted to sync."""
|
||||
return f in ["test_locks.py"]
|
||||
|
||||
|
||||
test_files = [
|
||||
_test_base + f
|
||||
for f in listdir(_test_base)
|
||||
if (Path(_test_base) / f).is_file() and not async_only_test(f)
|
||||
]
|
||||
|
||||
sync_files = [
|
||||
_pymongo_dest_base + f
|
||||
@ -171,16 +182,17 @@ converted_tests = [
|
||||
"test_change_stream.py",
|
||||
"test_client.py",
|
||||
"test_client_bulk_write.py",
|
||||
"test_client_context.py",
|
||||
"test_collection.py",
|
||||
"test_cursor.py",
|
||||
"test_database.py",
|
||||
"test_encryption.py",
|
||||
"test_grid_file.py",
|
||||
"test_logger.py",
|
||||
"test_monitoring.py",
|
||||
"test_raw_bson.py",
|
||||
"test_session.py",
|
||||
"test_transactions.py",
|
||||
"test_client_context.py",
|
||||
"test_monitoring.py",
|
||||
]
|
||||
|
||||
sync_test_files = [
|
||||
@ -240,7 +252,7 @@ def translate_locks(lines: list[str]) -> list[str]:
|
||||
lock_lines = [line for line in lines if "_Lock(" in line]
|
||||
cond_lines = [line for line in lines if "_Condition(" in line]
|
||||
for line in lock_lines:
|
||||
res = re.search(r"_Lock\(([^()]*\(\))\)", line)
|
||||
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
|
||||
if res:
|
||||
old = res[0]
|
||||
index = lines.index(line)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user