Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-09-30 20:12:16 -05:00
commit 1418c90be2
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
38 changed files with 1284 additions and 224 deletions

View File

@ -1324,7 +1324,7 @@ def decode_iter(
elements = data[position : position + obj_size]
position += obj_size
yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var]
yield _bson_to_dict(elements, opts) # type:ignore[misc]
@overload
@ -1370,7 +1370,7 @@ def decode_file_iter(
raise InvalidBSON("cut off in middle of objsize")
obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4
elements = size_data + file_obj.read(max(0, obj_size))
yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc]
yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc]
def is_valid(bson: bytes) -> bool:

View File

@ -223,7 +223,7 @@ class Decimal128:
"from list or tuple. Must have exactly 2 "
"elements."
)
self.__high, self.__low = value # type: ignore
self.__high, self.__low = value
else:
raise TypeError(f"Cannot convert {value!r} to Decimal128")

View File

@ -324,7 +324,7 @@ class JSONOptions(_BASE_CLASS):
"JSONOptions.datetime_representation must be one of LEGACY, "
"NUMBERLONG, or ISO8601 from DatetimeRepresentation."
)
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type]
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs))
if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL):
raise ValueError(
"JSONOptions.json_mode must be one of LEGACY, RELAXED, "

View File

@ -68,7 +68,7 @@ class SON(Dict[_Key, _Value]):
self.update(kwargs)
def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]:
instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var]
instance = super().__new__(cls, *args, **kwargs)
instance.__keys = []
return instance

View File

@ -13,8 +13,9 @@ features = ["docs","test"]
test = "sphinx-build -E -b doctest doc ./doc/_build/doctest"
[envs.typing]
features = ["encryption", "ocsp", "zstd", "aws"]
dependencies = ["mypy==1.2.0","pyright==1.1.290", "certifi", "typing_extensions"]
pre-install-commands = [
"pip install -q -r requirements/typing.txt",
]
[envs.typing.scripts]
check-mypy = [
"mypy --install-types --non-interactive bson gridfs tools pymongo",

View File

@ -88,7 +88,7 @@ TEXT = "text"
from pymongo import _csot
from pymongo._version import __version__, get_version_string, version_tuple
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION
from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION, has_c
from pymongo.cursor import CursorType
from pymongo.operations import (
DeleteMany,
@ -116,16 +116,6 @@ version = __version__
"""Current version of PyMongo."""
def has_c() -> bool:
"""Is the C extension installed?"""
try:
from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401
return True
except ImportError:
return False
def timeout(seconds: Optional[float]) -> ContextManager[None]:
"""**(Provisional)** Apply the given timeout for a block of operations.

View File

@ -75,14 +75,13 @@ class _TimeoutContext(AbstractContextManager):
self._timeout = timeout
self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None
def __enter__(self) -> _TimeoutContext:
def __enter__(self) -> None:
timeout_token = TIMEOUT.set(self._timeout)
prev_deadline = DEADLINE.get()
next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf")
deadline_token = DEADLINE.set(min(prev_deadline, next_deadline))
rtt_token = RTT.set(0.0)
self._tokens = (timeout_token, deadline_token, rtt_token)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._tokens:

View File

@ -35,6 +35,7 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
@ -332,13 +333,33 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
return self._database
@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncCollection[_DocumentType]:
...
@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncCollection[_DocumentTypeArg]:
...
def with_options(
self,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> AsyncCollection[_DocumentType]:
) -> AsyncCollection[_DocumentType] | AsyncCollection[_DocumentTypeArg]:
"""Get a clone of this collection changing the specified settings.
>>> coll1.read_preference

View File

@ -146,13 +146,33 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
"""The name of this :class:`AsyncDatabase`."""
return self._name
@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncDatabase[_DocumentType]:
...
@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncDatabase[_DocumentTypeArg]:
...
def with_options(
self,
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> AsyncDatabase[_DocumentType]:
) -> AsyncDatabase[_DocumentType] | AsyncDatabase[_DocumentTypeArg]:
"""Get a clone of this database changing the specified settings.
>>> db1.read_preference

View File

@ -913,7 +913,7 @@ async def _configured_socket(
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
@ -992,7 +992,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _ALock(_create_lock())
_lock = _create_lock()
self.lock = _ALock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1018,7 +1019,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self.size_cond = _ACondition(threading.Condition(_lock))
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1026,7 +1027,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id

View File

@ -170,8 +170,9 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _ALock(_create_lock())
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None

View File

@ -850,7 +850,7 @@ def get_validated_options(
return x
def get_setter_key(x: str) -> str:
return options.cased_key(x) # type: ignore[attr-defined]
return options.cased_key(x)
else:
validated_options = {}
@ -1060,3 +1060,13 @@ class _CaseInsensitiveDictionary(MutableMapping[str, Any]):
def cased_key(self, key: str) -> Any:
return self.__casedkeys[key.lower()]
def has_c() -> bool:
"""Is the C extension installed?"""
try:
from pymongo import _cmessage # type: ignore[attr-defined] # noqa: F401
return True
except ImportError:
return False

View File

@ -26,7 +26,7 @@ _NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
def _have_snappy() -> bool:
try:
import snappy # type:ignore[import] # noqa: F401
import snappy # type:ignore[import-not-found] # noqa: F401
return True
except ImportError:

View File

@ -21,7 +21,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
try:
import pymongocrypt # type:ignore[import] # noqa: F401
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
# Check for pymongocrypt>=1.10.
from pymongocrypt import synchronous as _ # noqa: F401

View File

@ -14,17 +14,20 @@
from __future__ import annotations
import asyncio
import collections
import os
import threading
import time
import weakref
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypeVar
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
# References to instances of _create_lock
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()
_T = TypeVar("_T")
def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
@ -43,7 +46,14 @@ def _release_locks() -> None:
lock.release()
# Needed only for synchro.py compat.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock
class _ALock:
__slots__ = ("_lock",)
def __init__(self, lock: threading.Lock) -> None:
self._lock = lock
@ -81,9 +91,18 @@ class _ALock:
self.release()
def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)
class _ACondition:
__slots__ = ("_condition", "_waiters")
def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()
async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
@ -99,30 +118,116 @@ class _ACondition:
await asyncio.sleep(0)
async def wait(self, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait(0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
"""Wait until notified.
async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait_for(predicate, 0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.
This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.
This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
break
except asyncio.exceptions.CancelledError as e:
err = e
self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise
async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result
def notify(self, n: int = 1) -> None:
self._condition.notify(n)
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.
This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.
Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break
if fut.done():
continue
try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue
idx += 1
for waiter in to_remove:
self._waiters.remove(waiter)
def notify_all(self) -> None:
self._condition.notify_all()
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))
def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]
def release(self) -> None:
self._condition.release()

View File

@ -33,6 +33,7 @@ from pymongo.common import (
MAX_POOL_SIZE,
MIN_POOL_SIZE,
WAIT_QUEUE_TIMEOUT,
has_c,
)
if TYPE_CHECKING:
@ -363,6 +364,11 @@ class PoolOptions:
# },
# 'platform': 'CPython 3.8.0|MyPlatform'
# }
if has_c():
self.__metadata["driver"]["name"] = "{}|{}".format(
self.__metadata["driver"]["name"],
"c",
)
if not is_sync:
self.__metadata["driver"]["name"] = "{}|{}".format(
self.__metadata["driver"]["name"],

View File

@ -34,6 +34,7 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
@ -333,13 +334,33 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""
return self._database
@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Collection[_DocumentType]:
...
@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Collection[_DocumentTypeArg]:
...
def with_options(
self,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> Collection[_DocumentType]:
) -> Collection[_DocumentType] | Collection[_DocumentTypeArg]:
"""Get a clone of this collection changing the specified settings.
>>> coll1.read_preference

View File

@ -146,13 +146,33 @@ class Database(common.BaseObject, Generic[_DocumentType]):
"""The name of this :class:`Database`."""
return self._name
@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Database[_DocumentType]:
...
@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Database[_DocumentTypeArg]:
...
def with_options(
self,
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> Database[_DocumentType]:
) -> Database[_DocumentType] | Database[_DocumentTypeArg]:
"""Get a clone of this database changing the specified settings.
>>> db1.read_preference

View File

@ -62,7 +62,7 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
@ -909,7 +909,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
@ -988,7 +988,8 @@ class Pool:
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _create_lock()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
@ -1014,7 +1015,7 @@ class Pool:
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self.size_cond = threading.Condition(_lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
@ -1022,7 +1023,7 @@ class Pool:
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id

View File

@ -39,7 +39,7 @@ from pymongo.errors import (
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
@ -170,8 +170,9 @@ class Topology:
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None

7
requirements/typing.txt Normal file
View File

@ -0,0 +1,7 @@
mypy==1.11.2
pyright==1.1.382.post1
typing_extensions
-r ./encryption.txt
-r ./ocsp.txt
-r ./zstd.txt
-r ./aws.txt

View File

@ -313,7 +313,7 @@ class ClientContext:
params = self.cmd_line["parsed"].get("setParameter", {})
if params.get("enableTestCommands") == "1":
self.test_commands_enabled = True
self.has_ipv6 = self._server_started_with_ipv6()
self.has_ipv6 = self._server_started_with_ipv6()
self.is_mongos = (self.hello).get("msg") == "isdbgrid"
if self.is_mongos:

View File

@ -313,7 +313,7 @@ class AsyncClientContext:
params = self.cmd_line["parsed"].get("setParameter", {})
if params.get("enableTestCommands") == "1":
self.test_commands_enabled = True
self.has_ipv6 = await self._server_started_with_ipv6()
self.has_ipv6 = await self._server_started_with_ipv6()
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
if self.is_mongos:

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import _thread as thread
import asyncio
import base64
import contextlib
import copy
import datetime
@ -31,13 +32,15 @@ import subprocess
import sys
import threading
import time
from typing import Iterable, Type, no_type_check
import uuid
from typing import Any, Iterable, Type, no_type_check
from unittest import mock
from unittest.mock import patch
import pytest
import pytest_asyncio
from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation
from pymongo.operations import _Op
sys.path[0:0] = [""]
@ -57,6 +60,7 @@ from test.asynchronous import (
unittest,
)
from test.asynchronous.pymongo_mocks import AsyncMockClient
from test.test_binary import BinaryData
from test.utils import (
NTHREADS,
CMAPListener,
@ -95,7 +99,7 @@ from pymongo.asynchronous.pool import (
from pymongo.asynchronous.settings import TOPOLOGY_TYPE
from pymongo.asynchronous.topology import _ErrorContext
from pymongo.client_options import ClientOptions
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c
from pymongo.compression_support import _have_snappy, _have_zstd
from pymongo.driver_info import DriverInfo
from pymongo.errors import (
@ -343,7 +347,10 @@ class AsyncClientUnitTest(AsyncUnitTest):
async def test_metadata(self):
metadata = copy.deepcopy(_METADATA)
metadata["driver"]["name"] = "PyMongo|async"
if has_c():
metadata["driver"]["name"] = "PyMongo|c|async"
else:
metadata["driver"]["name"] = "PyMongo|async"
metadata["application"] = {"name": "foobar"}
client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false")
options = client.options
@ -366,7 +373,10 @@ class AsyncClientUnitTest(AsyncUnitTest):
with self.assertRaises(TypeError):
self.simple_client(driver=("Foo", "1", "a"))
# Test appending to driver info.
metadata["driver"]["name"] = "PyMongo|async|FooDriver"
if has_c():
metadata["driver"]["name"] = "PyMongo|c|async|FooDriver"
else:
metadata["driver"]["name"] = "PyMongo|async|FooDriver"
metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"])
client = self.simple_client(
"foo",
@ -1927,7 +1937,10 @@ class TestClient(AsyncIntegrationTest):
async def _test_handshake(self, env_vars, expected_env):
with patch.dict("os.environ", env_vars):
metadata = copy.deepcopy(_METADATA)
metadata["driver"]["name"] = "PyMongo|async"
if has_c():
metadata["driver"]["name"] = "PyMongo|c|async"
else:
metadata["driver"]["name"] = "PyMongo|async"
if expected_env is not None:
metadata["env"] = expected_env
@ -2020,6 +2033,75 @@ class TestClient(AsyncIntegrationTest):
async def test_dict_hints_create_index(self):
await self.db.t.create_index({"x": pymongo.ASCENDING})
async def test_legacy_java_uuid_roundtrip(self):
data = BinaryData.java_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
await async_client_context.client.pymongo_test.drop_collection("java_uuid")
db = async_client_context.client.pymongo_test
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
await coll.insert_many(docs)
self.assertEqual(5, await coll.count_documents({}))
async for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
async for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
await async_client_context.client.pymongo_test.drop_collection("java_uuid")
async def test_legacy_csharp_uuid_roundtrip(self):
data = BinaryData.csharp_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
await async_client_context.client.pymongo_test.drop_collection("csharp_uuid")
db = async_client_context.client.pymongo_test
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
await coll.insert_many(docs)
self.assertEqual(5, await coll.count_documents({}))
async for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
async for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
await async_client_context.client.pymongo_test.drop_collection("csharp_uuid")
async def test_uri_to_uuid(self):
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
client = await self.async_single_client(uri, connect=False)
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
async def test_uuid_queries(self):
db = async_client_context.client.pymongo_test
coll = db.test
await coll.drop()
uu = uuid.uuid4()
await coll.insert_one({"uuid": Binary(uu.bytes, 3)})
self.assertEqual(1, await coll.count_documents({}))
# Test regular UUID queries (using subtype 4).
coll = db.get_collection(
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
)
self.assertEqual(0, await coll.count_documents({"uuid": uu}))
await coll.insert_one({"uuid": uu})
self.assertEqual(2, await coll.count_documents({}))
docs = await coll.find({"uuid": uu}).to_list()
self.assertEqual(1, len(docs))
self.assertEqual(uu, docs[0]["uuid"])
# Test both.
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
self.assertEqual(2, await coll.count_documents(predicate))
docs = await coll.find(predicate).to_list()
self.assertEqual(2, len(docs))
await coll.drop()
class TestExhaustCursor(AsyncIntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
@ -2351,7 +2433,9 @@ class TestMongoClientFailover(AsyncMockClientTest):
# But it can reconnect.
c.revive_host("a:1")
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
await (await c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(await c.address, ("a", 1))
async def _test_network_error(self, operation_callback):

View File

@ -30,6 +30,7 @@ from test.utils import (
)
from unittest.mock import patch
import pymongo
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
timeoutMS=2000,
w="majority",
)
await client.admin.command("ping") # Init the client first.
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(10):
await client.admin.command("ping")
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, NetworkTimeout)

View File

@ -1414,7 +1414,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
@ -1456,7 +1456,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_command_cursor_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])

View File

@ -711,7 +711,7 @@ class TestDatabase(AsyncIntegrationTest):
"write_concern": WriteConcern(w=1),
"read_concern": ReadConcern(level="local"),
}
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
for opt in newopts:
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))

View File

@ -0,0 +1,513 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lock.py"""
from __future__ import annotations
import asyncio
import sys
import threading
import unittest
sys.path[0:0] = [""]
from pymongo.lock import _ACondition
# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py
# Includes tests for:
# - https://github.com/python/cpython/issues/111693
# - https://github.com/python/cpython/issues/112202
class TestConditionStdlib(unittest.IsolatedAsyncioTestCase):
async def test_wait(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
await cond.acquire()
if await cond.wait():
result.append(1)
return True
async def c2(result):
await cond.acquire()
if await cond.wait():
result.append(2)
return True
async def c3(result):
await cond.acquire()
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertFalse(cond.locked())
self.assertTrue(await cond.acquire())
cond.notify()
await asyncio.sleep(0)
self.assertEqual([], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.notify(2)
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2], result)
self.assertTrue(cond.locked())
cond.release()
await asyncio.sleep(0)
self.assertEqual([1, 2, 3], result)
self.assertTrue(cond.locked())
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_wait_cancel(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
wait = asyncio.create_task(cond.wait())
asyncio.get_running_loop().call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
self.assertFalse(cond._waiters)
self.assertTrue(cond.locked())
async def test_wait_cancel_contested(self):
cond = _ACondition(threading.Condition(threading.Lock()))
await cond.acquire()
self.assertTrue(cond.locked())
wait_task = asyncio.create_task(cond.wait())
await asyncio.sleep(0)
self.assertFalse(cond.locked())
# Notify, but contest the lock before cancelling
await cond.acquire()
self.assertTrue(cond.locked())
cond.notify()
asyncio.get_running_loop().call_soon(wait_task.cancel)
asyncio.get_running_loop().call_soon(cond.release)
try:
await wait_task
except asyncio.CancelledError:
# Should not happen, since no cancellation points
pass
self.assertTrue(cond.locked())
async def test_wait_cancel_after_notify(self):
# See bpo-32841
waited = False
cond = _ACondition(threading.Condition(threading.Lock()))
async def wait_on_cond():
nonlocal waited
async with cond:
waited = True # Make sure this area was reached
await cond.wait()
waiter = asyncio.create_task(wait_on_cond())
await asyncio.sleep(0) # Start waiting
await cond.acquire()
cond.notify()
await asyncio.sleep(0) # Get to acquire()
waiter.cancel()
await asyncio.sleep(0) # Activate cancellation
cond.release()
await asyncio.sleep(0) # Cancellation should occur
self.assertTrue(waiter.cancelled())
self.assertTrue(waited)
async def test_wait_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
with self.assertRaises(RuntimeError):
await cond.wait()
async def test_wait_for(self):
cond = _ACondition(threading.Condition(threading.Lock()))
presult = False
def predicate():
return presult
result = []
async def c1(result):
await cond.acquire()
if await cond.wait_for(predicate):
result.append(1)
cond.release()
return True
t = asyncio.create_task(c1(result))
await asyncio.sleep(0)
self.assertEqual([], result)
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([], result)
presult = True
await cond.acquire()
cond.notify()
cond.release()
await asyncio.sleep(0)
self.assertEqual([1], result)
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_wait_for_unacquired(self):
cond = _ACondition(threading.Condition(threading.Lock()))
# predicate can return true immediately
res = await cond.wait_for(lambda: [1, 2, 3])
self.assertEqual([1, 2, 3], res)
with self.assertRaises(RuntimeError):
await cond.wait_for(lambda: False)
async def test_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
async def c3(result):
async with cond:
if await cond.wait():
result.append(3)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify(1)
await asyncio.sleep(1)
self.assertEqual([1], result)
async with cond:
cond.notify(1)
cond.notify(2048)
await asyncio.sleep(1)
self.assertEqual([1, 2, 3], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
async def test_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
result = []
async def c1(result):
async with cond:
if await cond.wait():
result.append(1)
return True
async def c2(result):
async with cond:
if await cond.wait():
result.append(2)
return True
t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
await asyncio.sleep(0)
self.assertEqual([], result)
async with cond:
cond.notify_all()
await asyncio.sleep(1)
self.assertEqual([1, 2], result)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
async def test_context_manager(self):
cond = _ACondition(threading.Condition(threading.Lock()))
self.assertFalse(cond.locked())
async with cond:
self.assertTrue(cond.locked())
self.assertFalse(cond.locked())
async def test_timeout_in_block(self):
condition = _ACondition(threading.Condition(threading.Lock()))
async with condition:
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(condition.wait(), timeout=0.5)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_wakeup(self):
# Test that a cancelled error, received when awaiting wakeup,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(
sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11"
)
async def test_cancelled_error_re_aquire(self):
# Test that a cancelled error, received when re-aquiring lock,
# will be re-raised un-modified.
wake = False
raised = None
cond = _ACondition(threading.Condition(threading.Lock()))
async def func():
nonlocal raised
async with cond:
with self.assertRaises(asyncio.CancelledError) as err:
await cond.wait_for(lambda: wake)
raised = err.exception
raise raised
task = asyncio.create_task(func())
await asyncio.sleep(0)
# Task is waiting on the condition
await cond.acquire()
wake = True
cond.notify()
await asyncio.sleep(0)
# Task is now trying to re-acquire the lock, cancel it there.
task.cancel(msg="foo") # type: ignore[call-arg]
cond.release()
with self.assertRaises(asyncio.CancelledError) as err:
await task
self.assertEqual(err.exception.args, ("foo",))
# We should have got the _same_ exception instance as the one
# originally raised.
self.assertIs(err.exception, raised)
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is awaiting initial
# wakeup on the wakeup queue.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# Cancel it while it is awaiting to be run.
# This cancellation could come from the outside
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
@unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11")
async def test_cancelled_wakeup_relock(self):
# Test that a task cancelled at the "same" time as it is woken
# up as part of a Condition.notify() does not result in a lost wakeup.
# This test simulates a cancel while the target task is acquiring the lock
# again.
condition = _ACondition(threading.Condition(threading.Lock()))
state = 0
async def consumer():
nonlocal state
async with condition:
while True:
await condition.wait_for(lambda: state != 0)
if state < 0:
return
state -= 1
# create two consumers
c = [asyncio.create_task(consumer()) for _ in range(2)]
# wait for them to settle
await asyncio.sleep(0.1)
async with condition:
# produce one item and wake up one
state += 1
condition.notify(1)
# now we sleep for a bit. This allows the target task to wake up and
# settle on re-aquiring the lock
await asyncio.sleep(0)
# Cancel it while awaiting the lock
# This cancel could come the outside.
c[0].cancel()
# now wait for the item to be consumed
# if it doesn't means that our "notify" didn"t take hold.
# because it raced with a cancel()
try:
async with asyncio.timeout(1):
await condition.wait_for(lambda: state == 0)
except TimeoutError:
pass
self.assertEqual(state, 0)
# clean up
state = -1
condition.notify_all()
await c[1]
class TestCondition(unittest.IsolatedAsyncioTestCase):
async def test_multiple_loops_notify(self):
cond = _ACondition(threading.Condition(threading.Lock()))
def tmain(cond):
async def atmain(cond):
await asyncio.sleep(1)
async with cond:
cond.notify(1)
asyncio.run(atmain(cond))
t = threading.Thread(target=tmain, args=(cond,))
t.start()
async with cond:
self.assertTrue(await cond.wait(30))
t.join()
async def test_multiple_loops_notify_all(self):
cond = _ACondition(threading.Condition(threading.Lock()))
results = []
def tmain(cond, results):
async def atmain(cond, results):
await asyncio.sleep(1)
async with cond:
res = await cond.wait(30)
results.append(res)
asyncio.run(atmain(cond, results))
nthreads = 5
threads = []
for _ in range(nthreads):
threads.append(threading.Thread(target=tmain, args=(cond, results)))
for t in threads:
t.start()
await asyncio.sleep(2)
async with cond:
cond.notify_all()
for t in threads:
t.join()
self.assertEqual(results, [True] * nthreads)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,219 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import sys
import uuid
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from bson import Code, DBRef, decode, encode
from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.errors import InvalidBSON
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
from bson.son import SON
_IS_SYNC = False
class TestRawBSONDocument(AsyncIntegrationTest):
# {'_id': ObjectId('556df68b6e32ab21a95e0785'),
# 'name': 'Sherlock',
# 'addresses': [{'street': 'Baker Street'}]}
bson_string = (
b"Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t"
b"\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e"
b"\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00"
)
document = RawBSONDocument(bson_string)
async def asyncTearDown(self):
if async_client_context.connected:
await self.client.pymongo_test.test_raw.drop()
def test_decode(self):
self.assertEqual("Sherlock", self.document["name"])
first_address = self.document["addresses"][0]
self.assertIsInstance(first_address, RawBSONDocument)
self.assertEqual("Baker Street", first_address["street"])
def test_raw(self):
self.assertEqual(self.bson_string, self.document.raw)
def test_empty_doc(self):
doc = RawBSONDocument(encode({}))
with self.assertRaises(KeyError):
doc["does-not-exist"]
def test_invalid_bson_sequence(self):
bson_byte_sequence = encode({"a": 1}) + encode({})
with self.assertRaisesRegex(InvalidBSON, "invalid object length"):
RawBSONDocument(bson_byte_sequence)
def test_invalid_bson_eoo(self):
invalid_bson_eoo = encode({"a": 1})[:-1] + b"\x01"
with self.assertRaisesRegex(InvalidBSON, "bad eoo"):
RawBSONDocument(invalid_bson_eoo)
@async_client_context.require_connection
async def test_round_trip(self):
db = self.client.get_database(
"pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument)
)
await db.test_raw.insert_one(self.document)
result = await db.test_raw.find_one(self.document["_id"])
assert result is not None
self.assertIsInstance(result, RawBSONDocument)
self.assertEqual(dict(self.document.items()), dict(result.items()))
@async_client_context.require_connection
async def test_round_trip_raw_uuid(self):
coll = self.client.get_database("pymongo_test").test_raw
uid = uuid.uuid4()
doc = {"_id": 1, "bin4": Binary(uid.bytes, 4), "bin3": Binary(uid.bytes, 3)}
raw = RawBSONDocument(encode(doc))
await coll.insert_one(raw)
self.assertEqual(await coll.find_one(), doc)
uuid_coll = coll.with_options(
codec_options=coll.codec_options.with_options(
uuid_representation=UuidRepresentation.STANDARD
)
)
self.assertEqual(
await uuid_coll.find_one(), {"_id": 1, "bin4": uid, "bin3": Binary(uid.bytes, 3)}
)
# Test that the raw bytes haven't changed.
raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS)
self.assertEqual(await raw_coll.find_one(), raw)
def test_with_codec_options(self):
# {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
# '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')}
# encoded with JAVA_LEGACY uuid representation.
bson_string = (
b"-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02"
b"\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM"
b"\x01\x00\x00\x00"
)
document = RawBSONDocument(
bson_string,
codec_options=CodecOptions(
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
),
)
self.assertEqual(uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), document["_id"])
@async_client_context.require_connection
async def test_round_trip_codec_options(self):
doc = {
"date": datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
"_id": uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"),
}
db = self.client.pymongo_test
coll = db.get_collection(
"test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)
)
await coll.insert_one(doc)
raw_java_legacy = CodecOptions(
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
)
coll = db.get_collection("test_raw", codec_options=raw_java_legacy)
self.assertEqual(
RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), await coll.find_one()
)
@async_client_context.require_connection
async def test_raw_bson_document_embedded(self):
doc = {"embedded": self.document}
db = self.client.pymongo_test
await db.test_raw.insert_one(doc)
result = await db.test_raw.find_one()
assert result is not None
self.assertEqual(decode(self.document.raw), result["embedded"])
# Make sure that CodecOptions are preserved.
# {'embedded': [
# {'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000),
# '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')}
# ]}
# encoded with JAVA_LEGACY uuid representation.
bson_string = (
b"D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00"
b"\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00"
b"\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00"
b"\x00"
)
rbd = RawBSONDocument(
bson_string,
codec_options=CodecOptions(
uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument
),
)
await db.test_raw.drop()
await db.test_raw.insert_one(rbd)
result = await db.get_collection(
"test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)
).find_one()
assert result is not None
self.assertEqual(rbd["embedded"][0]["_id"], result["embedded"][0]["_id"])
@async_client_context.require_connection
async def test_write_response_raw_bson(self):
coll = self.client.get_database(
"pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument)
).test_raw
# No Exceptions raised while handling write response.
await coll.insert_one(self.document)
await coll.delete_one(self.document)
await coll.insert_many([self.document])
await coll.delete_many(self.document)
await coll.update_one(self.document, {"$set": {"a": "b"}}, upsert=True)
await coll.update_many(self.document, {"$set": {"b": "c"}})
def test_preserve_key_ordering(self):
keyvaluepairs = [
("a", 1),
("b", 2),
("c", 3),
]
rawdoc = RawBSONDocument(encode(SON(keyvaluepairs)))
for rkey, elt in zip(rawdoc, keyvaluepairs):
self.assertEqual(rkey, elt[0])
def test_contains_code_with_scope(self):
doc = RawBSONDocument(encode({"value": Code("x=1", scope={})}))
self.assertEqual(decode(encode(doc)), {"value": Code("x=1", {})})
self.assertEqual(doc["value"].scope, RawBSONDocument(encode({})))
def test_contains_dbref(self):
doc = RawBSONDocument(encode({"value": DBRef("test", "id")}))
raw = {"$ref": "test", "$id": "id"}
raw_encoded = encode(decode(encode(raw)))
self.assertEqual(decode(encode(doc)), {"value": DBRef("test", "id")})
self.assertEqual(doc["value"].raw, raw_encoded)
if __name__ == "__main__":
unittest.main()

View File

@ -34,53 +34,49 @@ from bson.binary import *
from bson.codec_options import CodecOptions
from bson.son import SON
from pymongo.common import validate_uuid_representation
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
class BinaryData:
# Generated by the Java driver
from_java = (
b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu"
b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND"
b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+"
b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1"
b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR"
b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA"
b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z"
b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf"
b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx"
b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My"
b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB"
b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp"
b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc"
b"0MQAA"
)
java_data = base64.b64decode(from_java)
# Generated by the .net driver
from_csharp = (
b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl"
b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2"
b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V"
b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl"
b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A"
b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z"
b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU"
b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn"
b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA"
b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT"
b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP"
b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00"
b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA="
)
csharp_data = base64.b64decode(from_csharp)
class TestBinary(unittest.TestCase):
csharp_data: bytes
java_data: bytes
@classmethod
def setUpClass(cls):
# Generated by the Java driver
from_java = (
b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu"
b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND"
b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+"
b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1"
b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR"
b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA"
b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z"
b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf"
b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx"
b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My"
b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB"
b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp"
b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc"
b"0MQAA"
)
cls.java_data = base64.b64decode(from_java)
# Generated by the .net driver
from_csharp = (
b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl"
b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2"
b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V"
b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl"
b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A"
b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z"
b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU"
b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn"
b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA"
b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT"
b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP"
b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00"
b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA="
)
cls.csharp_data = base64.b64decode(from_csharp)
def test_binary(self):
a_string = "hello world"
a_binary = Binary(b"hello world")
@ -159,7 +155,7 @@ class TestBinary(unittest.TestCase):
def test_legacy_java_uuid(self):
# Test decoding
data = self.java_data
data = BinaryData.java_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY))
for d in docs:
self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
@ -197,27 +193,8 @@ class TestBinary(unittest.TestCase):
)
self.assertEqual(data, encoded)
@client_context.require_connection
def test_legacy_java_uuid_roundtrip(self):
data = self.java_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
client_context.client.pymongo_test.drop_collection("java_uuid")
db = client_context.client.pymongo_test
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
coll.insert_many(docs)
self.assertEqual(5, coll.count_documents({}))
for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
client_context.client.pymongo_test.drop_collection("java_uuid")
def test_legacy_csharp_uuid(self):
data = self.csharp_data
data = BinaryData.csharp_data
# Test decoding
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, PYTHON_LEGACY))
@ -257,59 +234,6 @@ class TestBinary(unittest.TestCase):
)
self.assertEqual(data, encoded)
@client_context.require_connection
def test_legacy_csharp_uuid_roundtrip(self):
data = self.csharp_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
client_context.client.pymongo_test.drop_collection("csharp_uuid")
db = client_context.client.pymongo_test
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
coll.insert_many(docs)
self.assertEqual(5, coll.count_documents({}))
for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
client_context.client.pymongo_test.drop_collection("csharp_uuid")
def test_uri_to_uuid(self):
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
client = MongoClient(uri, connect=False)
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
@client_context.require_connection
def test_uuid_queries(self):
db = client_context.client.pymongo_test
coll = db.test
coll.drop()
uu = uuid.uuid4()
coll.insert_one({"uuid": Binary(uu.bytes, 3)})
self.assertEqual(1, coll.count_documents({}))
# Test regular UUID queries (using subtype 4).
coll = db.get_collection(
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
)
self.assertEqual(0, coll.count_documents({"uuid": uu}))
coll.insert_one({"uuid": uu})
self.assertEqual(2, coll.count_documents({}))
docs = list(coll.find({"uuid": uu}))
self.assertEqual(1, len(docs))
self.assertEqual(uu, docs[0]["uuid"])
# Test both.
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
self.assertEqual(2, coll.count_documents(predicate))
docs = list(coll.find(predicate))
self.assertEqual(2, len(docs))
coll.drop()
def test_pickle(self):
b1 = Binary(b"123", 2)

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import _thread as thread
import asyncio
import base64
import contextlib
import copy
import datetime
@ -31,12 +32,14 @@ import subprocess
import sys
import threading
import time
from typing import Iterable, Type, no_type_check
import uuid
from typing import Any, Iterable, Type, no_type_check
from unittest import mock
from unittest.mock import patch
import pytest
from bson.binary import CSHARP_LEGACY, JAVA_LEGACY, PYTHON_LEGACY, Binary, UuidRepresentation
from pymongo.operations import _Op
sys.path[0:0] = [""]
@ -56,6 +59,7 @@ from test import (
unittest,
)
from test.pymongo_mocks import MockClient
from test.test_binary import BinaryData
from test.utils import (
NTHREADS,
CMAPListener,
@ -83,7 +87,7 @@ from bson.son import SON
from bson.tz_util import utc
from pymongo import event_loggers, message, monitoring
from pymongo.client_options import ClientOptions
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT, has_c
from pymongo.compression_support import _have_snappy, _have_zstd
from pymongo.driver_info import DriverInfo
from pymongo.errors import (
@ -335,7 +339,10 @@ class ClientUnitTest(UnitTest):
def test_metadata(self):
metadata = copy.deepcopy(_METADATA)
metadata["driver"]["name"] = "PyMongo"
if has_c():
metadata["driver"]["name"] = "PyMongo|c"
else:
metadata["driver"]["name"] = "PyMongo"
metadata["application"] = {"name": "foobar"}
client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false")
options = client.options
@ -358,7 +365,10 @@ class ClientUnitTest(UnitTest):
with self.assertRaises(TypeError):
self.simple_client(driver=("Foo", "1", "a"))
# Test appending to driver info.
metadata["driver"]["name"] = "PyMongo|FooDriver"
if has_c():
metadata["driver"]["name"] = "PyMongo|c|FooDriver"
else:
metadata["driver"]["name"] = "PyMongo|FooDriver"
metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"])
client = self.simple_client(
"foo",
@ -1885,7 +1895,10 @@ class TestClient(IntegrationTest):
def _test_handshake(self, env_vars, expected_env):
with patch.dict("os.environ", env_vars):
metadata = copy.deepcopy(_METADATA)
metadata["driver"]["name"] = "PyMongo"
if has_c():
metadata["driver"]["name"] = "PyMongo|c"
else:
metadata["driver"]["name"] = "PyMongo"
if expected_env is not None:
metadata["env"] = expected_env
@ -1978,6 +1991,75 @@ class TestClient(IntegrationTest):
def test_dict_hints_create_index(self):
self.db.t.create_index({"x": pymongo.ASCENDING})
def test_legacy_java_uuid_roundtrip(self):
data = BinaryData.java_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY))
client_context.client.pymongo_test.drop_collection("java_uuid")
db = client_context.client.pymongo_test
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY))
coll.insert_many(docs)
self.assertEqual(5, coll.count_documents({}))
for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
client_context.client.pymongo_test.drop_collection("java_uuid")
def test_legacy_csharp_uuid_roundtrip(self):
data = BinaryData.csharp_data
docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY))
client_context.client.pymongo_test.drop_collection("csharp_uuid")
db = client_context.client.pymongo_test
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY))
coll.insert_many(docs)
self.assertEqual(5, coll.count_documents({}))
for d in coll.find():
self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"]))
coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
for d in coll.find():
self.assertNotEqual(d["newguid"], d["newguidstring"])
client_context.client.pymongo_test.drop_collection("csharp_uuid")
def test_uri_to_uuid(self):
uri = "mongodb://foo/?uuidrepresentation=csharpLegacy"
client = self.single_client(uri, connect=False)
self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY)
def test_uuid_queries(self):
db = client_context.client.pymongo_test
coll = db.test
coll.drop()
uu = uuid.uuid4()
coll.insert_one({"uuid": Binary(uu.bytes, 3)})
self.assertEqual(1, coll.count_documents({}))
# Test regular UUID queries (using subtype 4).
coll = db.get_collection(
"test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
)
self.assertEqual(0, coll.count_documents({"uuid": uu}))
coll.insert_one({"uuid": uu})
self.assertEqual(2, coll.count_documents({}))
docs = coll.find({"uuid": uu}).to_list()
self.assertEqual(1, len(docs))
self.assertEqual(uu, docs[0]["uuid"])
# Test both.
uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY)
predicate = {"uuid": {"$in": [uu, uu_legacy]}}
self.assertEqual(2, coll.count_documents(predicate))
docs = coll.find(predicate).to_list()
self.assertEqual(2, len(docs))
coll.drop()
class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
@ -2307,7 +2389,9 @@ class TestMongoClientFailover(MockClientTest):
# But it can reconnect.
c.revive_host("a:1")
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
(c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(c.address, ("a", 1))
def _test_network_error(self, operation_callback):

View File

@ -30,6 +30,7 @@ from test.utils import (
)
from unittest.mock import patch
import pymongo
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
ClientBulkWriteException,
@ -597,7 +598,9 @@ class TestClientBulkWriteCSOT(IntegrationTest):
timeoutMS=2000,
w="majority",
)
client.admin.command("ping") # Init the client first.
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(10):
client.admin.command("ping")
with self.assertRaises(ClientBulkWriteException) as context:
client.bulk_write(models=models)
self.assertIsInstance(context.exception.error, NetworkTimeout)

View File

@ -1405,7 +1405,7 @@ class TestCursor(IntegrationTest):
def test_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])
@ -1447,7 +1447,7 @@ class TestCursor(IntegrationTest):
def test_command_cursor_to_list_csot_applied(self):
client = self.single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
client.admin.command("ping")
coll = client.pymongo.test
coll.insert_many([{} for _ in range(5)])

View File

@ -702,7 +702,7 @@ class TestDatabase(IntegrationTest):
"write_concern": WriteConcern(w=1),
"read_concern": ReadConcern(level="local"),
}
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
for opt in newopts:
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))

View File

@ -19,8 +19,7 @@ import uuid
sys.path[0:0] = [""]
from test import client_context, unittest
from test.test_client import IntegrationTest
from test import IntegrationTest, client_context, unittest
from bson import Code, DBRef, decode, encode
from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation
@ -29,6 +28,8 @@ from bson.errors import InvalidBSON
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument
from bson.son import SON
_IS_SYNC = True
class TestRawBSONDocument(IntegrationTest):
# {'_id': ObjectId('556df68b6e32ab21a95e0785'),

View File

@ -19,6 +19,7 @@ import os
import threading
from test import IntegrationTest, client_context, unittest
from test.utils import (
CMAPListener,
OvertCommandListener,
SpecTestCreator,
get_pool,
@ -27,6 +28,7 @@ from test.utils import (
from test.utils_selection_tests import create_topology
from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference
@ -131,19 +133,20 @@ class TestProse(IntegrationTest):
@client_context.require_multiple_mongoses
def test_load_balancing(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = self.rs_client(
client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener],
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
self.addCleanup(client.close)
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections")
# Delay find commands on
# Wait for both pools to be populated.
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
# Delay find commands on only one mongos.
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"times": 10000},
@ -161,7 +164,7 @@ class TestProse(IntegrationTest):
freqs = self.frequencies(client, listener)
self.assertLessEqual(freqs[delayed_server], 0.25)
listener.reset()
freqs = self.frequencies(client, listener, n_finds=100)
freqs = self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)

View File

@ -34,7 +34,7 @@ from typing import (
cast,
)
try:
if TYPE_CHECKING:
from typing_extensions import NotRequired, TypedDict
from bson import ObjectId
@ -49,16 +49,13 @@ try:
year: int
class ImplicitMovie(TypedDict):
_id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues]
_id: NotRequired[ObjectId]
name: str
year: int
except ImportError:
Movie = dict # type:ignore[misc,assignment]
ImplicitMovie = dict # type: ignore[assignment,misc]
MovieWithId = dict # type: ignore[assignment,misc]
TypedDict = None
NotRequired = None # type: ignore[assignment]
else:
Movie = dict
ImplicitMovie = dict
NotRequired = None
try:
@ -234,6 +231,19 @@ class TestPymongo(IntegrationTest):
execute_transaction, read_preference=ReadPreference.PRIMARY
)
def test_with_options(self) -> None:
coll: Collection[Dict[str, Any]] = self.coll
coll.drop()
doc = {"name": "foo", "year": 1982, "other": 1}
coll.insert_one(doc)
coll2 = coll.with_options(codec_options=CodecOptions(document_class=Movie))
retrieved = coll2.find_one()
assert retrieved is not None
assert retrieved["name"] == "foo"
# We expect a type error here.
assert retrieved["other"] == 1 # type:ignore[typeddict-item]
class TestDecode(unittest.TestCase):
def test_bson_decode(self) -> None:
@ -426,7 +436,7 @@ class TestDocumentType(PyMongoTestCase):
)
coll.bulk_write(
[
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971})
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
] # No error because it is in-line.
)
@ -443,7 +453,7 @@ class TestDocumentType(PyMongoTestCase):
)
coll.bulk_write(
[
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971})
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
] # No error because it is in-line.
)
@ -566,7 +576,7 @@ class TestCodecOptionsDocumentType(unittest.TestCase):
def test_typeddict_document_type(self) -> None:
options: CodecOptions[Movie] = CodecOptions()
# Suppress: Cannot instantiate type "Type[Movie]".
obj = options.document_class(name="a", year=1) # type: ignore[misc]
obj = options.document_class(name="a", year=1)
assert obj["year"] == 1
assert obj["name"] == "a"

View File

@ -23,7 +23,7 @@ import re
from os import listdir
from pathlib import Path
from unasync import Rule, unasync_files # type: ignore[import]
from unasync import Rule, unasync_files # type: ignore[import-not-found]
replacements = {
"AsyncCollection": "Collection",
@ -101,6 +101,7 @@ replacements = {
"default_async": "default",
"aclose": "close",
"PyMongo|async": "PyMongo",
"PyMongo|c|async": "PyMongo|c",
"AsyncTestGridFile": "TestGridFile",
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
}
@ -144,7 +145,17 @@ gridfs_files = [
_gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file()
]
test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()]
def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py"]
test_files = [
_test_base + f
for f in listdir(_test_base)
if (Path(_test_base) / f).is_file() and not async_only_test(f)
]
sync_files = [
_pymongo_dest_base + f
@ -171,16 +182,17 @@ converted_tests = [
"test_change_stream.py",
"test_client.py",
"test_client_bulk_write.py",
"test_client_context.py",
"test_collection.py",
"test_cursor.py",
"test_database.py",
"test_encryption.py",
"test_grid_file.py",
"test_logger.py",
"test_monitoring.py",
"test_raw_bson.py",
"test_session.py",
"test_transactions.py",
"test_client_context.py",
"test_monitoring.py",
]
sync_test_files = [
@ -240,7 +252,7 @@ def translate_locks(lines: list[str]) -> list[str]:
lock_lines = [line for line in lines if "_Lock(" in line]
cond_lines = [line for line in lines if "_Condition(" in line]
for line in lock_lines:
res = re.search(r"_Lock\(([^()]*\(\))\)", line)
res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line)
if res:
old = res[0]
index = lines.index(line)