diff --git a/bson/__init__.py b/bson/__init__.py index e8ac7c444..e866a99c8 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1324,7 +1324,7 @@ def decode_iter( elements = data[position : position + obj_size] position += obj_size - yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var] + yield _bson_to_dict(elements, opts) # type:ignore[misc] @overload @@ -1370,7 +1370,7 @@ def decode_file_iter( raise InvalidBSON("cut off in middle of objsize") obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4 elements = size_data + file_obj.read(max(0, obj_size)) - yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc] + yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc] def is_valid(bson: bytes) -> bool: diff --git a/bson/decimal128.py b/bson/decimal128.py index 8581d5a3c..016afb5eb 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -223,7 +223,7 @@ class Decimal128: "from list or tuple. Must have exactly 2 " "elements." ) - self.__high, self.__low = value # type: ignore + self.__high, self.__low = value else: raise TypeError(f"Cannot convert {value!r} to Decimal128") diff --git a/bson/json_util.py b/bson/json_util.py index 4269ba985..6f34e4103 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -324,7 +324,7 @@ class JSONOptions(_BASE_CLASS): "JSONOptions.datetime_representation must be one of LEGACY, " "NUMBERLONG, or ISO8601 from DatetimeRepresentation." ) - self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type] + self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): raise ValueError( "JSONOptions.json_mode must be one of LEGACY, RELAXED, " diff --git a/bson/son.py b/bson/son.py index cf6271723..24275fce1 100644 --- a/bson/son.py +++ b/bson/son.py @@ -68,7 +68,7 @@ class SON(Dict[_Key, _Value]): self.update(kwargs) def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]: - instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var] + instance = super().__new__(cls, *args, **kwargs) instance.__keys = [] return instance diff --git a/hatch.toml b/hatch.toml index 8b1cf93e3..d5293a1d7 100644 --- a/hatch.toml +++ b/hatch.toml @@ -13,8 +13,9 @@ features = ["docs","test"] test = "sphinx-build -E -b doctest doc ./doc/_build/doctest" [envs.typing] -features = ["encryption", "ocsp", "zstd", "aws"] -dependencies = ["mypy==1.2.0","pyright==1.1.290", "certifi", "typing_extensions"] +pre-install-commands = [ + "pip install -q -r requirements/typing.txt", +] [envs.typing.scripts] check-mypy = [ "mypy --install-types --non-interactive bson gridfs tools pymongo", diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 8116788bc..6416f939e 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -88,7 +88,7 @@ TEXT = "text" from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple -from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c from pymongo.cursor import CursorType from pymongo.operations import ( DeleteMany, @@ -116,16 +116,6 @@ version = __version__ """Current version of PyMongo.""" -def has_c() -> bool: - """Is the C extension installed?""" - try: - from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401 - - return True - except ImportError: - return False - - def timeout(seconds: Optional[float]) -> ContextManager[None]: """**(Provisional)** Apply the given timeout for a block of operations. diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 94328f981..06c6b68ac 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -75,14 +75,13 @@ class _TimeoutContext(AbstractContextManager): self._timeout = timeout self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None - def __enter__(self) -> _TimeoutContext: + def __enter__(self) -> None: timeout_token = TIMEOUT.set(self._timeout) prev_deadline = DEADLINE.get() next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) rtt_token = RTT.set(0.0) self._tokens = (timeout_token, deadline_token, rtt_token) - return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._tokens: diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 1ec74aad0..5abc41a7e 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -35,6 +35,7 @@ from typing import ( TypeVar, Union, cast, + overload, ) from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions @@ -332,13 +333,33 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ return self._database + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncCollection[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncCollection[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> AsyncCollection[_DocumentType]: + ) -> AsyncCollection[_DocumentType] | AsyncCollection[_DocumentTypeArg]: """Get a clone of this collection changing the specified settings. >>> coll1.read_preference diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 06c0eca2c..98a0a6ff3 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -146,13 +146,33 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): """The name of this :class:`AsyncDatabase`.""" return self._name + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncDatabase[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> AsyncDatabase[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> AsyncDatabase[_DocumentType]: + ) -> AsyncDatabase[_DocumentType] | AsyncDatabase[_DocumentTypeArg]: """Get a clone of this database changing the specified settings. >>> db1.read_preference diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a65704242..a9f02d650 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -913,7 +913,7 @@ async def _configured_socket( and not options.tls_allow_invalid_hostnames ): try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] except _CertificateError: ssl_sock.close() raise @@ -992,7 +992,8 @@ class Pool: # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _ALock(_create_lock()) + _lock = _create_lock() + self.lock = _ALock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1018,7 +1019,7 @@ class Pool: # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self.size_cond = _ACondition(threading.Condition(_lock)) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1026,7 +1027,7 @@ class Pool: # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 4e778cbc1..82af4257b 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -170,8 +170,9 @@ class Topology: self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _ALock(_create_lock()) - self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _ALock(_lock) + self._condition = _ACondition(self._settings.condition_class(_lock)) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/pymongo/common.py b/pymongo/common.py index a073eba57..126d0ee46 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -850,7 +850,7 @@ def get_validated_options( return x def get_setter_key(x: str) -> str: - return options.cased_key(x) # type: ignore[attr-defined] + return options.cased_key(x) else: validated_options = {} @@ -1060,3 +1060,13 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]): def cased_key(self, key: str) -> Any: return self.__casedkeys[key.lower()] + + +def has_c() -> bool: + """Is the C extension installed?""" + try: + from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401 + + return True + except ImportError: + return False diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 7123b90df..c71e4bddc 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -26,7 +26,7 @@ _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) def _have_snappy() -> bool: try: - import snappy # type:ignore[import] # noqa: F401 + import snappy # type:ignore[import-not-found] # noqa: F401 return True except ImportError: diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index df1302650..ee749e7ac 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -21,7 +21,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, Mapping, Optional try: - import pymongocrypt # type:ignore[import] # noqa: F401 + import pymongocrypt # type:ignore[import-untyped] # noqa: F401 # Check for pymongocrypt>=1.10. from pymongocrypt import synchronous as _ # noqa: F401 diff --git a/pymongo/lock.py b/pymongo/lock.py index b05f6acff..0cbfb4a57 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -14,17 +14,20 @@ from __future__ import annotations import asyncio +import collections import os import threading import time import weakref -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") # References to instances of _create_lock _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() +_T = TypeVar("_T") + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -43,7 +46,14 @@ def _release_locks() -> None: lock.release() +# Needed only for synchro.py compat. +def _Lock(lock: threading.Lock) -> threading.Lock: + return lock + + class _ALock: + __slots__ = ("_lock",) + def __init__(self, lock: threading.Lock) -> None: self._lock = lock @@ -81,9 +91,18 @@ class _ALock: self.release() +def _safe_set_result(fut: asyncio.Future) -> None: + # Ensure the future hasn't been cancelled before calling set_result. + if not fut.done(): + fut.set_result(False) + + class _ACondition: + __slots__ = ("_condition", "_waiters") + def __init__(self, condition: threading.Condition) -> None: self._condition = condition + self._waiters: collections.deque = collections.deque() async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: if timeout > 0: @@ -99,30 +118,116 @@ class _ACondition: await asyncio.sleep(0) async def wait(self, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait(0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False + """Wait until notified. - async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait_for(predicate, 0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._waiters.append((loop, fut)) + self.release() + try: + try: + try: + await asyncio.wait_for(fut, timeout) + return True + except asyncio.TimeoutError: + return False # Return false on timeout for sync pool compat. + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except asyncio.exceptions.CancelledError as e: + err = e + + self._waiters.remove((loop, fut)) + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self.notify(1) + raise + + async def wait_for(self, predicate: Callable[[], _T]) -> _T: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result def notify(self, n: int = 1) -> None: - self._condition.notify(n) + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + idx = 0 + to_remove = [] + for loop, fut in self._waiters: + if idx >= n: + break + + if fut.done(): + continue + + try: + loop.call_soon_threadsafe(_safe_set_result, fut) + except RuntimeError: + # Loop was closed, ignore. + to_remove.append((loop, fut)) + continue + + idx += 1 + + for waiter in to_remove: + self._waiters.remove(waiter) def notify_all(self) -> None: - self._condition.notify_all() + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + def locked(self) -> bool: + """Only needed for tests in test_locks.""" + return self._condition._lock.locked() # type: ignore[attr-defined] def release(self) -> None: self._condition.release() diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py index 6ec97d7d1..61486c91c 100644 --- a/pymongo/pool_options.py +++ b/pymongo/pool_options.py @@ -33,6 +33,7 @@ from pymongo.common import ( MAX_POOL_SIZE, MIN_POOL_SIZE, WAIT_QUEUE_TIMEOUT, + has_c, ) if TYPE_CHECKING: @@ -363,6 +364,11 @@ class PoolOptions: # }, # 'platform': 'CPython 3.8.0|MyPlatform' # } + if has_c(): + self.__metadata["driver"]["name"] = "{}|{}".format( + self.__metadata["driver"]["name"], + "c", + ) if not is_sync: self.__metadata["driver"]["name"] = "{}|{}".format( self.__metadata["driver"]["name"], diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 7a41aef31..15a1913ea 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -34,6 +34,7 @@ from typing import ( TypeVar, Union, cast, + overload, ) from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions @@ -333,13 +334,33 @@ class Collection(common.BaseObject, Generic[_DocumentType]): """ return self._database + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Collection[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Collection[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> Collection[_DocumentType]: + ) -> Collection[_DocumentType] | Collection[_DocumentTypeArg]: """Get a clone of this collection changing the specified settings. >>> coll1.read_preference diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index c57a59e09..a0bef5534 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -146,13 +146,33 @@ class Database(common.BaseObject, Generic[_DocumentType]): """The name of this :class:`Database`.""" return self._name + @overload + def with_options( + self, + codec_options: None = None, + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Database[_DocumentType]: + ... + + @overload + def with_options( + self, + codec_options: bson.CodecOptions[_DocumentTypeArg], + read_preference: Optional[_ServerMode] = ..., + write_concern: Optional[WriteConcern] = ..., + read_concern: Optional[ReadConcern] = ..., + ) -> Database[_DocumentTypeArg]: + ... + def with_options( self, codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - ) -> Database[_DocumentType]: + ) -> Database[_DocumentType] | Database[_DocumentTypeArg]: """Get a clone of this database changing the specified settings. >>> db1.read_preference diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 94a1d1043..eb007a347 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -62,7 +62,7 @@ from pymongo.errors import ( # type:ignore[attr-defined] _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -909,7 +909,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket. and not options.tls_allow_invalid_hostnames ): try: - ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined] except _CertificateError: ssl_sock.close() raise @@ -988,7 +988,8 @@ class Pool: # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _create_lock() + _lock = _create_lock() + self.lock = _Lock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1014,7 +1015,7 @@ class Pool: # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self.size_cond = threading.Condition(_lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1022,7 +1023,7 @@ class Pool: # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self._max_connecting_cond = threading.Condition(_lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e8070e30a..a350c1702 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,7 @@ from pymongo.errors import ( WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,8 +170,9 @@ class Topology: self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _create_lock() - self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _Lock(_lock) + self._condition = self._settings.condition_class(_lock) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/requirements/typing.txt b/requirements/typing.txt new file mode 100644 index 000000000..1669e6bbc --- /dev/null +++ b/requirements/typing.txt @@ -0,0 +1,7 @@ +mypy==1.11.2 +pyright==1.1.382.post1 +typing_extensions +-r ./encryption.txt +-r ./ocsp.txt +-r ./zstd.txt +-r ./aws.txt diff --git a/test/__init__.py b/test/__init__.py index 1a17ff14c..af12bc032 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -313,7 +313,7 @@ class ClientContext: params = self.cmd_line["parsed"].get("setParameter", {}) if params.get("enableTestCommands") == "1": self.test_commands_enabled = True - self.has_ipv6 = self._server_started_with_ipv6() + self.has_ipv6 = self._server_started_with_ipv6() self.is_mongos = (self.hello).get("msg") == "isdbgrid" if self.is_mongos: diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0d9433158..2a44785b2 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -313,7 +313,7 @@ class AsyncClientContext: params = self.cmd_line["parsed"].get("setParameter", {}) if params.get("enableTestCommands") == "1": self.test_commands_enabled = True - self.has_ipv6 = await self._server_started_with_ipv6() + self.has_ipv6 = await self._server_started_with_ipv6() self.is_mongos = (await self.hello).get("msg") == "isdbgrid" if self.is_mongos: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f610f3277..5c0633179 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -17,6 +17,7 @@ from __future__ import annotations import _thread as thread import asyncio +import base64 import contextlib import copy import datetime @@ -31,13 +32,15 @@ import subprocess import sys import threading import time -from typing import Iterable, Type, no_type_check +import uuid +from typing import Any, Iterable, Type, no_type_check from unittest import mock from unittest.mock import patch import pytest import pytest_asyncio +from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation from pymongo.operations import _Op sys.path[0:0] = [""] @@ -57,6 +60,7 @@ from test.asynchronous import ( unittest, ) from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, @@ -95,7 +99,7 @@ from pymongo.asynchronous.pool import ( from pymongo.asynchronous.settings import TOPOLOGY_TYPE from pymongo.asynchronous.topology import _ErrorContext from pymongo.client_options import ClientOptions -from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( @@ -343,7 +347,10 @@ class AsyncClientUnitTest(AsyncUnitTest): async def test_metadata(self): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo|async" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async" + else: + metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options @@ -366,7 +373,10 @@ class AsyncClientUnitTest(AsyncUnitTest): with self.assertRaises(TypeError): self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. - metadata["driver"]["name"] = "PyMongo|async|FooDriver" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async|FooDriver" + else: + metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) client = self.simple_client( "foo", @@ -1927,7 +1937,10 @@ class TestClient(AsyncIntegrationTest): async def _test_handshake(self, env_vars, expected_env): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo|async" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|async" + else: + metadata["driver"]["name"] = "PyMongo|async" if expected_env is not None: metadata["env"] = expected_env @@ -2020,6 +2033,75 @@ class TestClient(AsyncIntegrationTest): async def test_dict_hints_create_index(self): await self.db.t.create_index({"x": pymongo.ASCENDING}) + async def test_legacy_java_uuid_roundtrip(self): + data = BinaryData.java_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) + + await async_client_context.client.pymongo_test.drop_collection("java_uuid") + db = async_client_context.client.pymongo_test + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) + + await coll.insert_many(docs) + self.assertEqual(5, await coll.count_documents({})) + async for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + async for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + await async_client_context.client.pymongo_test.drop_collection("java_uuid") + + async def test_legacy_csharp_uuid_roundtrip(self): + data = BinaryData.csharp_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) + + await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") + db = async_client_context.client.pymongo_test + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) + + await coll.insert_many(docs) + self.assertEqual(5, await coll.count_documents({})) + async for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + async for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") + + async def test_uri_to_uuid(self): + uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" + client = await self.async_single_client(uri, connect=False) + self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + + async def test_uuid_queries(self): + db = async_client_context.client.pymongo_test + coll = db.test + await coll.drop() + + uu = uuid.uuid4() + await coll.insert_one({"uuid": Binary(uu.bytes, 3)}) + self.assertEqual(1, await coll.count_documents({})) + + # Test regular UUID queries (using subtype 4). + coll = db.get_collection( + "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + ) + self.assertEqual(0, await coll.count_documents({"uuid": uu})) + await coll.insert_one({"uuid": uu}) + self.assertEqual(2, await coll.count_documents({})) + docs = await coll.find({"uuid": uu}).to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(uu, docs[0]["uuid"]) + + # Test both. + uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) + predicate = {"uuid": {"$in": [uu, uu_legacy]}} + self.assertEqual(2, await coll.count_documents(predicate)) + docs = await coll.find(predicate).to_list() + self.assertEqual(2, len(docs)) + await coll.drop() + class TestExhaustCursor(AsyncIntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" @@ -2351,7 +2433,9 @@ class TestMongoClientFailover(AsyncMockClientTest): # But it can reconnect. c.revive_host("a:1") - await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + await (await c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(await c.address, ("a", 1)) async def _test_network_error(self, operation_callback): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 3a1729945..80cfd30bd 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -30,6 +30,7 @@ from test.utils import ( ) from unittest.mock import patch +import pymongo from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( @@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): timeoutMS=2000, w="majority", ) - await client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + await client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 33eaacee9..e79ad0064 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1414,7 +1414,7 @@ class TestCursor(AsyncIntegrationTest): async def test_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) @@ -1456,7 +1456,7 @@ class TestCursor(AsyncIntegrationTest): async def test_command_cursor_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index c5d62323d..61369c854 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -711,7 +711,7 @@ class TestDatabase(AsyncIntegrationTest): "write_concern": WriteConcern(w=1), "read_concern": ReadConcern(level="local"), } - db2 = db1.with_options(**newopts) # type: ignore[arg-type] + db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload] for opt in newopts: self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py new file mode 100644 index 000000000..e0e7f2fc8 --- /dev/null +++ b/test/asynchronous/test_locks.py @@ -0,0 +1,513 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lock.py""" +from __future__ import annotations + +import asyncio +import sys +import threading +import unittest + +sys.path[0:0] = [""] + +from pymongo.lock import _ACondition + + +# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# Includes tests for: +# - https://github.com/python/cpython/issues/111693 +# - https://github.com/python/cpython/issues/112202 +class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _ACondition(threading.Condition(threading.Lock())) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = _ACondition(threading.Condition(threading.Lock())) + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = _ACondition(threading.Condition(threading.Lock())) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _ACondition(threading.Condition(threading.Lock())) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _ACondition(threading.Condition(threading.Lock())) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + +class TestCondition(unittest.IsolatedAsyncioTestCase): + async def test_multiple_loops_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + def tmain(cond): + async def atmain(cond): + await asyncio.sleep(1) + async with cond: + cond.notify(1) + + asyncio.run(atmain(cond)) + + t = threading.Thread(target=tmain, args=(cond,)) + t.start() + + async with cond: + self.assertTrue(await cond.wait(30)) + t.join() + + async def test_multiple_loops_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + results = [] + + def tmain(cond, results): + async def atmain(cond, results): + await asyncio.sleep(1) + async with cond: + res = await cond.wait(30) + results.append(res) + + asyncio.run(atmain(cond, results)) + + nthreads = 5 + threads = [] + for _ in range(nthreads): + threads.append(threading.Thread(target=tmain, args=(cond, results))) + for t in threads: + t.start() + + await asyncio.sleep(2) + async with cond: + cond.notify_all() + + for t in threads: + t.join() + + self.assertEqual(results, [True] * nthreads) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_raw_bson.py b/test/asynchronous/test_raw_bson.py new file mode 100644 index 000000000..70832ea66 --- /dev/null +++ b/test/asynchronous/test_raw_bson.py @@ -0,0 +1,219 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +import sys +import uuid + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest + +from bson import Code, DBRef, decode, encode +from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation +from bson.codec_options import CodecOptions +from bson.errors import InvalidBSON +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument +from bson.son import SON + +_IS_SYNC = False + + +class TestRawBSONDocument(AsyncIntegrationTest): + # {'_id': ObjectId('556df68b6e32ab21a95e0785'), + # 'name': 'Sherlock', + # 'addresses': [{'street': 'Baker Street'}]} + bson_string = ( + b"Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t" + b"\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e" + b"\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00" + ) + document = RawBSONDocument(bson_string) + + async def asyncTearDown(self): + if async_client_context.connected: + await self.client.pymongo_test.test_raw.drop() + + def test_decode(self): + self.assertEqual("Sherlock", self.document["name"]) + first_address = self.document["addresses"][0] + self.assertIsInstance(first_address, RawBSONDocument) + self.assertEqual("Baker Street", first_address["street"]) + + def test_raw(self): + self.assertEqual(self.bson_string, self.document.raw) + + def test_empty_doc(self): + doc = RawBSONDocument(encode({})) + with self.assertRaises(KeyError): + doc["does-not-exist"] + + def test_invalid_bson_sequence(self): + bson_byte_sequence = encode({"a": 1}) + encode({}) + with self.assertRaisesRegex(InvalidBSON, "invalid object length"): + RawBSONDocument(bson_byte_sequence) + + def test_invalid_bson_eoo(self): + invalid_bson_eoo = encode({"a": 1})[:-1] + b"\x01" + with self.assertRaisesRegex(InvalidBSON, "bad eoo"): + RawBSONDocument(invalid_bson_eoo) + + @async_client_context.require_connection + async def test_round_trip(self): + db = self.client.get_database( + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ) + await db.test_raw.insert_one(self.document) + result = await db.test_raw.find_one(self.document["_id"]) + assert result is not None + self.assertIsInstance(result, RawBSONDocument) + self.assertEqual(dict(self.document.items()), dict(result.items())) + + @async_client_context.require_connection + async def test_round_trip_raw_uuid(self): + coll = self.client.get_database("pymongo_test").test_raw + uid = uuid.uuid4() + doc = {"_id": 1, "bin4": Binary(uid.bytes, 4), "bin3": Binary(uid.bytes, 3)} + raw = RawBSONDocument(encode(doc)) + await coll.insert_one(raw) + self.assertEqual(await coll.find_one(), doc) + uuid_coll = coll.with_options( + codec_options=coll.codec_options.with_options( + uuid_representation=UuidRepresentation.STANDARD + ) + ) + self.assertEqual( + await uuid_coll.find_one(), {"_id": 1, "bin4": uid, "bin3": Binary(uid.bytes, 3)} + ) + + # Test that the raw bytes haven't changed. + raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS) + self.assertEqual(await raw_coll.find_one(), raw) + + def test_with_codec_options(self): + # {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + # '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} + # encoded with JAVA_LEGACY uuid representation. + bson_string = ( + b"-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02" + b"\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM" + b"\x01\x00\x00\x00" + ) + document = RawBSONDocument( + bson_string, + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) + + self.assertEqual(uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), document["_id"]) + + @async_client_context.require_connection + async def test_round_trip_codec_options(self): + doc = { + "date": datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + "_id": uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), + } + db = self.client.pymongo_test + coll = db.get_collection( + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ) + await coll.insert_one(doc) + raw_java_legacy = CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ) + coll = db.get_collection("test_raw", codec_options=raw_java_legacy) + self.assertEqual( + RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), await coll.find_one() + ) + + @async_client_context.require_connection + async def test_raw_bson_document_embedded(self): + doc = {"embedded": self.document} + db = self.client.pymongo_test + await db.test_raw.insert_one(doc) + result = await db.test_raw.find_one() + assert result is not None + self.assertEqual(decode(self.document.raw), result["embedded"]) + + # Make sure that CodecOptions are preserved. + # {'embedded': [ + # {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + # '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} + # ]} + # encoded with JAVA_LEGACY uuid representation. + bson_string = ( + b"D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00" + b"\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00" + b"\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00" + b"\x00" + ) + rbd = RawBSONDocument( + bson_string, + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) + + await db.test_raw.drop() + await db.test_raw.insert_one(rbd) + result = await db.get_collection( + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ).find_one() + assert result is not None + self.assertEqual(rbd["embedded"][0]["_id"], result["embedded"][0]["_id"]) + + @async_client_context.require_connection + async def test_write_response_raw_bson(self): + coll = self.client.get_database( + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ).test_raw + + # No Exceptions raised while handling write response. + await coll.insert_one(self.document) + await coll.delete_one(self.document) + await coll.insert_many([self.document]) + await coll.delete_many(self.document) + await coll.update_one(self.document, {"$set": {"a": "b"}}, upsert=True) + await coll.update_many(self.document, {"$set": {"b": "c"}}) + + def test_preserve_key_ordering(self): + keyvaluepairs = [ + ("a", 1), + ("b", 2), + ("c", 3), + ] + rawdoc = RawBSONDocument(encode(SON(keyvaluepairs))) + + for rkey, elt in zip(rawdoc, keyvaluepairs): + self.assertEqual(rkey, elt[0]) + + def test_contains_code_with_scope(self): + doc = RawBSONDocument(encode({"value": Code("x=1", scope={})})) + + self.assertEqual(decode(encode(doc)), {"value": Code("x=1", {})}) + self.assertEqual(doc["value"].scope, RawBSONDocument(encode({}))) + + def test_contains_dbref(self): + doc = RawBSONDocument(encode({"value": DBRef("test", "id")})) + raw = {"$ref": "test", "$id": "id"} + raw_encoded = encode(decode(encode(raw))) + + self.assertEqual(decode(encode(doc)), {"value": DBRef("test", "id")}) + self.assertEqual(doc["value"].raw, raw_encoded) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_binary.py b/test/test_binary.py index 93f6d0831..567c5ae92 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -34,53 +34,49 @@ from bson.binary import * from bson.codec_options import CodecOptions from bson.son import SON from pymongo.common import validate_uuid_representation -from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +class BinaryData: + # Generated by the Java driver + from_java = ( + b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu" + b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND" + b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+" + b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1" + b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR" + b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA" + b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z" + b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf" + b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx" + b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My" + b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB" + b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp" + b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc" + b"0MQAA" + ) + java_data = base64.b64decode(from_java) + + # Generated by the .net driver + from_csharp = ( + b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl" + b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2" + b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V" + b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl" + b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A" + b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z" + b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU" + b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn" + b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA" + b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT" + b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP" + b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00" + b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=" + ) + csharp_data = base64.b64decode(from_csharp) + + class TestBinary(unittest.TestCase): - csharp_data: bytes - java_data: bytes - - @classmethod - def setUpClass(cls): - # Generated by the Java driver - from_java = ( - b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu" - b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND" - b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+" - b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1" - b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR" - b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA" - b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z" - b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf" - b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx" - b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My" - b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB" - b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp" - b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc" - b"0MQAA" - ) - cls.java_data = base64.b64decode(from_java) - - # Generated by the .net driver - from_csharp = ( - b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl" - b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2" - b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V" - b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl" - b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A" - b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z" - b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU" - b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn" - b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA" - b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT" - b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP" - b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00" - b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=" - ) - cls.csharp_data = base64.b64decode(from_csharp) - def test_binary(self): a_string = "hello world" a_binary = Binary(b"hello world") @@ -159,7 +155,7 @@ class TestBinary(unittest.TestCase): def test_legacy_java_uuid(self): # Test decoding - data = self.java_data + data = BinaryData.java_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY)) for d in docs: self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) @@ -197,27 +193,8 @@ class TestBinary(unittest.TestCase): ) self.assertEqual(data, encoded) - @client_context.require_connection - def test_legacy_java_uuid_roundtrip(self): - data = self.java_data - docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) - - client_context.client.pymongo_test.drop_collection("java_uuid") - db = client_context.client.pymongo_test - coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) - - coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) - for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - - coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("java_uuid") - def test_legacy_csharp_uuid(self): - data = self.csharp_data + data = BinaryData.csharp_data # Test decoding docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY)) @@ -257,59 +234,6 @@ class TestBinary(unittest.TestCase): ) self.assertEqual(data, encoded) - @client_context.require_connection - def test_legacy_csharp_uuid_roundtrip(self): - data = self.csharp_data - docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) - - client_context.client.pymongo_test.drop_collection("csharp_uuid") - db = client_context.client.pymongo_test - coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) - - coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) - for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - - coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("csharp_uuid") - - def test_uri_to_uuid(self): - uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" - client = MongoClient(uri, connect=False) - self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) - - @client_context.require_connection - def test_uuid_queries(self): - db = client_context.client.pymongo_test - coll = db.test - coll.drop() - - uu = uuid.uuid4() - coll.insert_one({"uuid": Binary(uu.bytes, 3)}) - self.assertEqual(1, coll.count_documents({})) - - # Test regular UUID queries (using subtype 4). - coll = db.get_collection( - "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - ) - self.assertEqual(0, coll.count_documents({"uuid": uu})) - coll.insert_one({"uuid": uu}) - self.assertEqual(2, coll.count_documents({})) - docs = list(coll.find({"uuid": uu})) - self.assertEqual(1, len(docs)) - self.assertEqual(uu, docs[0]["uuid"]) - - # Test both. - uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) - predicate = {"uuid": {"$in": [uu, uu_legacy]}} - self.assertEqual(2, coll.count_documents(predicate)) - docs = list(coll.find(predicate)) - self.assertEqual(2, len(docs)) - coll.drop() - def test_pickle(self): b1 = Binary(b"123", 2) diff --git a/test/test_client.py b/test/test_client.py index bc45325f0..c88a8fd9b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -17,6 +17,7 @@ from __future__ import annotations import _thread as thread import asyncio +import base64 import contextlib import copy import datetime @@ -31,12 +32,14 @@ import subprocess import sys import threading import time -from typing import Iterable, Type, no_type_check +import uuid +from typing import Any, Iterable, Type, no_type_check from unittest import mock from unittest.mock import patch import pytest +from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation from pymongo.operations import _Op sys.path[0:0] = [""] @@ -56,6 +59,7 @@ from test import ( unittest, ) from test.pymongo_mocks import MockClient +from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, @@ -83,7 +87,7 @@ from bson.son import SON from bson.tz_util import utc from pymongo import event_loggers, message, monitoring from pymongo.client_options import ClientOptions -from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( @@ -335,7 +339,10 @@ class ClientUnitTest(UnitTest): def test_metadata(self): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c" + else: + metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options @@ -358,7 +365,10 @@ class ClientUnitTest(UnitTest): with self.assertRaises(TypeError): self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. - metadata["driver"]["name"] = "PyMongo|FooDriver" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c|FooDriver" + else: + metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) client = self.simple_client( "foo", @@ -1885,7 +1895,10 @@ class TestClient(IntegrationTest): def _test_handshake(self, env_vars, expected_env): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo" + if has_c(): + metadata["driver"]["name"] = "PyMongo|c" + else: + metadata["driver"]["name"] = "PyMongo" if expected_env is not None: metadata["env"] = expected_env @@ -1978,6 +1991,75 @@ class TestClient(IntegrationTest): def test_dict_hints_create_index(self): self.db.t.create_index({"x": pymongo.ASCENDING}) + def test_legacy_java_uuid_roundtrip(self): + data = BinaryData.java_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) + + client_context.client.pymongo_test.drop_collection("java_uuid") + db = client_context.client.pymongo_test + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) + + coll.insert_many(docs) + self.assertEqual(5, coll.count_documents({})) + for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("java_uuid") + + def test_legacy_csharp_uuid_roundtrip(self): + data = BinaryData.csharp_data + docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) + + client_context.client.pymongo_test.drop_collection("csharp_uuid") + db = client_context.client.pymongo_test + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) + + coll.insert_many(docs) + self.assertEqual(5, coll.count_documents({})) + for d in coll.find(): + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + for d in coll.find(): + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("csharp_uuid") + + def test_uri_to_uuid(self): + uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" + client = self.single_client(uri, connect=False) + self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + + def test_uuid_queries(self): + db = client_context.client.pymongo_test + coll = db.test + coll.drop() + + uu = uuid.uuid4() + coll.insert_one({"uuid": Binary(uu.bytes, 3)}) + self.assertEqual(1, coll.count_documents({})) + + # Test regular UUID queries (using subtype 4). + coll = db.get_collection( + "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + ) + self.assertEqual(0, coll.count_documents({"uuid": uu})) + coll.insert_one({"uuid": uu}) + self.assertEqual(2, coll.count_documents({})) + docs = coll.find({"uuid": uu}).to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(uu, docs[0]["uuid"]) + + # Test both. + uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) + predicate = {"uuid": {"$in": [uu, uu_legacy]}} + self.assertEqual(2, coll.count_documents(predicate)) + docs = coll.find(predicate).to_list() + self.assertEqual(2, len(docs)) + coll.drop() + class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" @@ -2307,7 +2389,9 @@ class TestMongoClientFailover(MockClientTest): # But it can reconnect. c.revive_host("a:1") - (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(c.address, ("a", 1)) def _test_network_error(self, operation_callback): diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ebbdc74c1..d1aff03fc 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -30,6 +30,7 @@ from test.utils import ( ) from unittest.mock import patch +import pymongo from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( ClientBulkWriteException, @@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(IntegrationTest): timeoutMS=2000, w="majority", ) - client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/test_cursor.py b/test/test_cursor.py index d99732aec..7c073bf35 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1405,7 +1405,7 @@ class TestCursor(IntegrationTest): def test_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) @@ -1447,7 +1447,7 @@ class TestCursor(IntegrationTest): def test_command_cursor_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) diff --git a/test/test_database.py b/test/test_database.py index fe07f343c..4973ed013 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -702,7 +702,7 @@ class TestDatabase(IntegrationTest): "write_concern": WriteConcern(w=1), "read_concern": ReadConcern(level="local"), } - db2 = db1.with_options(**newopts) # type: ignore[arg-type] + db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload] for opt in newopts: self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 11bc80dd9..4d9a3ceb0 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -19,8 +19,7 @@ import uuid sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import IntegrationTest, client_context, unittest from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation @@ -29,6 +28,8 @@ from bson.errors import InvalidBSON from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from bson.son import SON +_IS_SYNC = True + class TestRawBSONDocument(IntegrationTest): # {'_id': ObjectId('556df68b6e32ab21a95e0785'), diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 8e030f61e..7cab42cca 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -19,6 +19,7 @@ import os import threading from test import IntegrationTest, client_context, unittest from test.utils import ( + CMAPListener, OvertCommandListener, SpecTestCreator, get_pool, @@ -27,6 +28,7 @@ from test.utils import ( from test.utils_selection_tests import create_topology from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference @@ -131,19 +133,20 @@ class TestProse(IntegrationTest): @client_context.require_multiple_mongoses def test_load_balancing(self): listener = OvertCommandListener() + cmap_listener = CMAPListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", - event_listeners=[listener], + event_listeners=[listener, cmap_listener], localThresholdMS=30000, minPoolSize=10, ) - self.addCleanup(client.close) wait_until(lambda: len(client.nodes) == 2, "discover both nodes") - wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections") - # Delay find commands on + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. delay_finds = { "configureFailPoint": "failCommand", "mode": {"times": 10000}, @@ -161,7 +164,7 @@ class TestProse(IntegrationTest): freqs = self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() - freqs = self.frequencies(client, listener, n_finds=100) + freqs = self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) diff --git a/test/test_typing.py b/test/test_typing.py index 6cfe40537..441707616 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -34,7 +34,7 @@ from typing import ( cast, ) -try: +if TYPE_CHECKING: from typing_extensions import NotRequired, TypedDict from bson import ObjectId @@ -49,16 +49,13 @@ try: year: int class ImplicitMovie(TypedDict): - _id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues] + _id: NotRequired[ObjectId] name: str year: int - -except ImportError: - Movie = dict # type:ignore[misc,assignment] - ImplicitMovie = dict # type: ignore[assignment,misc] - MovieWithId = dict # type: ignore[assignment,misc] - TypedDict = None - NotRequired = None # type: ignore[assignment] +else: + Movie = dict + ImplicitMovie = dict + NotRequired = None try: @@ -234,6 +231,19 @@ class TestPymongo(IntegrationTest): execute_transaction, read_preference=ReadPreference.PRIMARY ) + def test_with_options(self) -> None: + coll: Collection[Dict[str, Any]] = self.coll + coll.drop() + doc = {"name": "foo", "year": 1982, "other": 1} + coll.insert_one(doc) + + coll2 = coll.with_options(codec_options=CodecOptions(document_class=Movie)) + retrieved = coll2.find_one() + assert retrieved is not None + assert retrieved["name"] == "foo" + # We expect a type error here. + assert retrieved["other"] == 1 # type:ignore[typeddict-item] + class TestDecode(unittest.TestCase): def test_bson_decode(self) -> None: @@ -426,7 +436,7 @@ class TestDocumentType(PyMongoTestCase): ) coll.bulk_write( [ - InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore ] # No error because it is in-line. ) @@ -443,7 +453,7 @@ class TestDocumentType(PyMongoTestCase): ) coll.bulk_write( [ - ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore ] # No error because it is in-line. ) @@ -566,7 +576,7 @@ class TestCodecOptionsDocumentType(unittest.TestCase): def test_typeddict_document_type(self) -> None: options: CodecOptions[Movie] = CodecOptions() # Suppress: Cannot instantiate type "Type[Movie]". - obj = options.document_class(name="a", year=1) # type: ignore[misc] + obj = options.document_class(name="a", year=1) assert obj["year"] == 1 assert obj["name"] == "a" diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e..e0c194f96 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -23,7 +23,7 @@ import re from os import listdir from pathlib import Path -from unasync import Rule, unasync_files # type: ignore[import] +from unasync import Rule, unasync_files # type: ignore[import-not-found] replacements = { "AsyncCollection": "Collection", @@ -101,6 +101,7 @@ replacements = { "default_async": "default", "aclose": "close", "PyMongo|async": "PyMongo", + "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", } @@ -144,7 +145,17 @@ gridfs_files = [ _gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file() ] -test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()] + +def async_only_test(f: str) -> bool: + """Return True for async tests that should not be converted to sync.""" + return f in ["test_locks.py"] + + +test_files = [ + _test_base + f + for f in listdir(_test_base) + if (Path(_test_base) / f).is_file() and not async_only_test(f) +] sync_files = [ _pymongo_dest_base + f @@ -171,16 +182,17 @@ converted_tests = [ "test_change_stream.py", "test_client.py", "test_client_bulk_write.py", + "test_client_context.py", "test_collection.py", "test_cursor.py", "test_database.py", "test_encryption.py", "test_grid_file.py", "test_logger.py", + "test_monitoring.py", + "test_raw_bson.py", "test_session.py", "test_transactions.py", - "test_client_context.py", - "test_monitoring.py", ] sync_test_files = [ @@ -240,7 +252,7 @@ def translate_locks(lines: list[str]) -> list[str]: lock_lines = [line for line in lines if "_Lock(" in line] cond_lines = [line for line in lines if "_Condition(" in line] for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\(\))\)", line) + res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) if res: old = res[0] index = lines.index(line)