Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-02-05 21:37:03 -06:00
commit 0609631a0c
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
107 changed files with 6665 additions and 546 deletions

View File

@ -1386,7 +1386,7 @@ def is_valid(bson: bytes) -> bool:
:param bson: the data to be validated
"""
if not isinstance(bson, bytes):
raise TypeError("BSON data must be an instance of a subclass of bytes")
raise TypeError(f"BSON data must be an instance of a subclass of bytes, not {type(bson)}")
try:
_bson_to_dict(bson, DEFAULT_CODEC_OPTIONS)

View File

@ -290,7 +290,7 @@ class Binary(bytes):
subtype: int = BINARY_SUBTYPE,
) -> Binary:
if not isinstance(subtype, int):
raise TypeError("subtype must be an instance of int")
raise TypeError(f"subtype must be an instance of int, not {type(subtype)}")
if subtype >= 256 or subtype < 0:
raise ValueError("subtype must be contained in [0, 256)")
# Support any type that implements the buffer protocol.
@ -321,7 +321,7 @@ class Binary(bytes):
.. versionadded:: 3.11
"""
if not isinstance(uuid, UUID):
raise TypeError("uuid must be an instance of uuid.UUID")
raise TypeError(f"uuid must be an instance of uuid.UUID, not {type(uuid)}")
if uuid_representation not in ALL_UUID_REPRESENTATIONS:
raise ValueError(
@ -470,7 +470,7 @@ class Binary(bytes):
"""
if self.subtype != VECTOR_SUBTYPE:
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.")
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
position = 0
dtype, padding = struct.unpack_from("<sB", self, position)

View File

@ -56,7 +56,7 @@ class Code(str):
**kwargs: Any,
) -> Code:
if not isinstance(code, str):
raise TypeError("code must be an instance of str")
raise TypeError(f"code must be an instance of str, not {type(code)}")
self = str.__new__(cls, code)
@ -67,7 +67,7 @@ class Code(str):
if scope is not None:
if not isinstance(scope, _Mapping):
raise TypeError("scope must be an instance of dict")
raise TypeError(f"scope must be an instance of dict, not {type(scope)}")
if self.__scope is not None:
self.__scope.update(scope) # type: ignore
else:

View File

@ -401,17 +401,23 @@ else:
"uuid_representation must be a value from bson.binary.UuidRepresentation"
)
if not isinstance(unicode_decode_error_handler, str):
raise ValueError("unicode_decode_error_handler must be a string")
raise ValueError(
f"unicode_decode_error_handler must be a string, not {type(unicode_decode_error_handler)}"
)
if tzinfo is not None:
if not isinstance(tzinfo, datetime.tzinfo):
raise TypeError("tzinfo must be an instance of datetime.tzinfo")
raise TypeError(
f"tzinfo must be an instance of datetime.tzinfo, not {type(tzinfo)}"
)
if not tz_aware:
raise ValueError("cannot specify tzinfo without also setting tz_aware=True")
type_registry = type_registry or TypeRegistry()
if not isinstance(type_registry, TypeRegistry):
raise TypeError("type_registry must be an instance of TypeRegistry")
raise TypeError(
f"type_registry must be an instance of TypeRegistry, not {type(type_registry)}"
)
return tuple.__new__(
cls,

View File

@ -56,9 +56,9 @@ class DBRef:
.. seealso:: The MongoDB documentation on `dbrefs <https://dochub.mongodb.org/core/dbrefs>`_.
"""
if not isinstance(collection, str):
raise TypeError("collection must be an instance of str")
raise TypeError(f"collection must be an instance of str, not {type(collection)}")
if database is not None and not isinstance(database, str):
raise TypeError("database must be an instance of str")
raise TypeError(f"database must be an instance of str, not {type(database)}")
self.__collection = collection
self.__id = id

View File

@ -277,7 +277,7 @@ class Decimal128:
point in Binary Integer Decimal (BID) format).
"""
if not isinstance(value, bytes):
raise TypeError("value must be an instance of bytes")
raise TypeError(f"value must be an instance of bytes, not {type(value)}")
if len(value) != 16:
raise ValueError("value must be exactly 16 bytes")
return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) # type: ignore

View File

@ -58,9 +58,9 @@ class Timestamp:
time = time - offset
time = int(calendar.timegm(time.timetuple()))
if not isinstance(time, int):
raise TypeError("time must be an instance of int")
raise TypeError(f"time must be an instance of int, not {type(time)}")
if not isinstance(inc, int):
raise TypeError("inc must be an instance of int")
raise TypeError(f"inc must be an instance of int, not {type(inc)}")
if not 0 <= time < UPPERBOUND:
raise ValueError("time must be contained in [0, 2**32)")
if not 0 <= inc < UPPERBOUND:

View File

@ -67,6 +67,14 @@ uMongo
mongomock. The source `is available on GitHub
<https://github.com/Scille/umongo>`_
Django MongoDB Backend
`Django MongoDB Backend <https://django-mongodb-backend.readthedocs.io>`_ is a
database backend library specifically made for Django. The integration takes
advantage of MongoDB's unique document model capabilities, which align
naturally with Django's philosophy of simplified data modeling and
reduced development complexity. The source is available
`on GitHub <https://github.com/mongodb-labs/django-mongodb-backend>`_.
No longer maintained
""""""""""""""""""""

View File

@ -100,7 +100,7 @@ class AsyncGridFS:
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(database, AsyncDatabase):
raise TypeError("database must be an instance of Database")
raise TypeError(f"database must be an instance of Database, not {type(database)}")
database = _clear_entity_type_registry(database)
@ -503,7 +503,7 @@ class AsyncGridFSBucket:
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(db, AsyncDatabase):
raise TypeError("database must be an instance of AsyncDatabase")
raise TypeError(f"database must be an instance of AsyncDatabase, not {type(db)}")
db = _clear_entity_type_registry(db)
@ -1082,7 +1082,9 @@ class AsyncGridIn:
:attr:`~pymongo.collection.AsyncCollection.write_concern`
"""
if not isinstance(root_collection, AsyncCollection):
raise TypeError("root_collection must be an instance of AsyncCollection")
raise TypeError(
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
)
if not root_collection.write_concern.acknowledged:
raise ConfigurationError("root_collection must use acknowledged write_concern")
@ -1436,7 +1438,9 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
from the server. Metadata is fetched when first needed.
"""
if not isinstance(root_collection, AsyncCollection):
raise TypeError("root_collection must be an instance of AsyncCollection")
raise TypeError(
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
)
_disallow_transactions(session)
root_collection = _clear_entity_type_registry(root_collection)

View File

@ -100,7 +100,7 @@ class GridFS:
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(database, Database):
raise TypeError("database must be an instance of Database")
raise TypeError(f"database must be an instance of Database, not {type(database)}")
database = _clear_entity_type_registry(database)
@ -501,7 +501,7 @@ class GridFSBucket:
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
"""
if not isinstance(db, Database):
raise TypeError("database must be an instance of Database")
raise TypeError(f"database must be an instance of Database, not {type(db)}")
db = _clear_entity_type_registry(db)
@ -1076,7 +1076,9 @@ class GridIn:
:attr:`~pymongo.collection.Collection.write_concern`
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of Collection")
raise TypeError(
f"root_collection must be an instance of Collection, not {type(root_collection)}"
)
if not root_collection.write_concern.acknowledged:
raise ConfigurationError("root_collection must use acknowledged write_concern")
@ -1426,7 +1428,9 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
from the server. Metadata is fetched when first needed.
"""
if not isinstance(root_collection, Collection):
raise TypeError("root_collection must be an instance of Collection")
raise TypeError(
f"root_collection must be an instance of Collection, not {type(root_collection)}"
)
_disallow_transactions(session)
root_collection = _clear_entity_type_registry(root_collection)

View File

@ -160,7 +160,7 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]:
.. versionadded:: 4.2
"""
if not isinstance(seconds, (int, float, type(None))):
raise TypeError("timeout must be None, an int, or a float")
raise TypeError(f"timeout must be None, an int, or a float, not {type(seconds)}")
if seconds and seconds < 0:
raise ValueError("timeout cannot be negative")
if seconds is not None:

View File

@ -160,7 +160,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
self._locked = False
self._wake_up_first()
else:
raise RuntimeError("Lock is not acquired.")
raise RuntimeError("Lock is not acquired")
def _wake_up_first(self) -> None:
"""Ensure that the first waiter will wake up."""

View File

@ -46,7 +46,7 @@ def _get_azure_response(
try:
data = json.loads(body)
except Exception:
raise ValueError("Azure IMDS response must be in JSON format.") from None
raise ValueError("Azure IMDS response must be in JSON format") from None
for key in ["access_token", "expires_in"]:
if not data.get(key):

View File

@ -161,7 +161,7 @@ def _password_digest(username: str, password: str) -> str:
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
raise TypeError(f"username must be an instance of str, not {type(username)}")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"

View File

@ -213,7 +213,9 @@ class _OIDCAuthenticator:
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
raise ValueError(
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
)
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1

View File

@ -310,7 +310,9 @@ class TransactionOptions:
)
if max_commit_time_ms is not None:
if not isinstance(max_commit_time_ms, int):
raise TypeError("max_commit_time_ms must be an integer or None")
raise TypeError(
f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}"
)
@property
def read_concern(self) -> Optional[ReadConcern]:
@ -902,7 +904,9 @@ class AsyncClientSession:
another `AsyncClientSession` instance.
"""
if not isinstance(cluster_time, _Mapping):
raise TypeError("cluster_time must be a subclass of collections.Mapping")
raise TypeError(
f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}"
)
if not isinstance(cluster_time.get("clusterTime"), Timestamp):
raise ValueError("Invalid cluster_time")
self._advance_cluster_time(cluster_time)
@ -923,7 +927,9 @@ class AsyncClientSession:
another `AsyncClientSession` instance.
"""
if not isinstance(operation_time, Timestamp):
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
raise TypeError(
f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}"
)
self._advance_operation_time(operation_time)
def _process_response(self, reply: Mapping[str, Any]) -> None:

View File

@ -228,7 +228,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
read_concern or database.read_concern,
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
raise TypeError(f"name must be an instance of str, not {type(name)}")
from pymongo.asynchronous.database import AsyncDatabase
if not isinstance(database, AsyncDatabase):
@ -2475,7 +2475,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
name = helpers_shared._gen_index_name(index_or_name)
if not isinstance(name, str):
raise TypeError("index_or_name must be an instance of str or list")
raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}")
cmd = {"dropIndexes": self._name, "index": name}
cmd.update(kwargs)
@ -3078,7 +3078,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
if not isinstance(new_name, str):
raise TypeError("new_name must be an instance of str")
raise TypeError(f"new_name must be an instance of str, not {type(new_name)}")
if not new_name or ".." in new_name:
raise InvalidName("collection names cannot be empty")
@ -3148,7 +3148,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
if not isinstance(key, str):
raise TypeError("key must be an instance of str")
raise TypeError(f"key must be an instance of str, not {type(key)}")
cmd = {"distinct": self._name, "key": key}
if filter is not None:
if "query" in kwargs:
@ -3196,7 +3196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
common.validate_is_mapping("filter", filter)
if not isinstance(return_document, bool):
raise ValueError(
"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER"
f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}"
)
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}

View File

@ -94,7 +94,9 @@ class AsyncCommandCursor(Generic[_DocumentType]):
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
raise TypeError(
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
)
def __del__(self) -> None:
self._die_no_lock()
@ -115,7 +117,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")

View File

@ -146,9 +146,9 @@ class AsyncCursor(Generic[_DocumentType]):
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
if not isinstance(skip, int):
raise TypeError("skip must be an instance of int")
raise TypeError(f"skip must be an instance of int, not {type(skip)}")
if not isinstance(limit, int):
raise TypeError("limit must be an instance of int")
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
warnings.warn(
@ -171,7 +171,7 @@ class AsyncCursor(Generic[_DocumentType]):
validate_boolean("allow_partial_results", allow_partial_results)
validate_boolean("oplog_replay", oplog_replay)
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
# Only set if allow_disk_use is provided by the user, else None.
@ -388,7 +388,7 @@ class AsyncCursor(Generic[_DocumentType]):
cursor.add_option(2)
"""
if not isinstance(mask, int):
raise TypeError("mask must be an int")
raise TypeError(f"mask must be an int, not {type(mask)}")
self._check_okay_to_chain()
if mask & _QUERY_OPTIONS["exhaust"]:
@ -408,7 +408,7 @@ class AsyncCursor(Generic[_DocumentType]):
cursor.remove_option(2)
"""
if not isinstance(mask, int):
raise TypeError("mask must be an int")
raise TypeError(f"mask must be an int, not {type(mask)}")
self._check_okay_to_chain()
if mask & _QUERY_OPTIONS["exhaust"]:
@ -432,7 +432,7 @@ class AsyncCursor(Generic[_DocumentType]):
.. versionadded:: 3.11
"""
if not isinstance(allow_disk_use, bool):
raise TypeError("allow_disk_use must be a bool")
raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}")
self._check_okay_to_chain()
self._allow_disk_use = allow_disk_use
@ -451,7 +451,7 @@ class AsyncCursor(Generic[_DocumentType]):
.. seealso:: The MongoDB documentation on `limit <https://dochub.mongodb.org/core/limit>`_.
"""
if not isinstance(limit, int):
raise TypeError("limit must be an integer")
raise TypeError(f"limit must be an integer, not {type(limit)}")
if self._exhaust:
raise InvalidOperation("Can't use limit and exhaust together.")
self._check_okay_to_chain()
@ -479,7 +479,7 @@ class AsyncCursor(Generic[_DocumentType]):
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self._check_okay_to_chain()
@ -499,7 +499,7 @@ class AsyncCursor(Generic[_DocumentType]):
:param skip: the number of results to skip
"""
if not isinstance(skip, int):
raise TypeError("skip must be an integer")
raise TypeError(f"skip must be an integer, not {type(skip)}")
if skip < 0:
raise ValueError("skip must be >= 0")
self._check_okay_to_chain()
@ -520,7 +520,7 @@ class AsyncCursor(Generic[_DocumentType]):
:param max_time_ms: the time limit after which the operation is aborted
"""
if not isinstance(max_time_ms, int) and max_time_ms is not None:
raise TypeError("max_time_ms must be an integer or None")
raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}")
self._check_okay_to_chain()
self._max_time_ms = max_time_ms
@ -543,7 +543,9 @@ class AsyncCursor(Generic[_DocumentType]):
.. versionadded:: 3.2
"""
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
raise TypeError(
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
)
self._check_okay_to_chain()
# Ignore max_await_time_ms if not tailable or await_data is False.
@ -679,7 +681,7 @@ class AsyncCursor(Generic[_DocumentType]):
.. versionadded:: 2.7
"""
if not isinstance(spec, (list, tuple)):
raise TypeError("spec must be an instance of list or tuple")
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
self._check_okay_to_chain()
self._max = dict(spec)
@ -701,7 +703,7 @@ class AsyncCursor(Generic[_DocumentType]):
.. versionadded:: 2.7
"""
if not isinstance(spec, (list, tuple)):
raise TypeError("spec must be an instance of list or tuple")
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
self._check_okay_to_chain()
self._min = dict(spec)

View File

@ -122,7 +122,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
from pymongo.asynchronous.mongo_client import AsyncMongoClient
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
raise TypeError(f"name must be an instance of str, not {type(name)}")
if not isinstance(client, AsyncMongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.
@ -1310,7 +1310,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_collection must be an instance of str")
raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}")
encrypted_fields = await self._get_encrypted_fields(
{"encryptedFields": encrypted_fields},
name,
@ -1374,7 +1374,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_collection must be an instance of str or AsyncCollection")
raise TypeError(
f"name_or_collection must be an instance of str or AsyncCollection, not {type(name)}"
)
cmd = {"validate": name, "scandata": scandata, "full": full}
if comment is not None:
cmd["comment"] = comment

View File

@ -322,7 +322,9 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS)
data_key_id = raw_doc.get("_id")
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
raise TypeError("data_key _id must be Binary with a UUID subtype")
raise TypeError(
f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}"
)
assert self.key_vault_coll is not None
await self.key_vault_coll.insert_one(raw_doc)
@ -644,7 +646,9 @@ class AsyncClientEncryption(Generic[_DocumentType]):
)
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
raise TypeError(
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
)
if not isinstance(key_vault_client, AsyncMongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.

View File

@ -750,7 +750,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
if port is None:
port = self.PORT
if not isinstance(port, int):
raise TypeError("port must be an instance of int")
raise TypeError(f"port must be an instance of int, not {type(port)}")
# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
@ -1565,6 +1565,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
self._closed = True
if not _IS_SYNC:
await asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)
if not _IS_SYNC:
# Add support for contextlib.aclosing.
@ -1971,7 +1977,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
The cursor is closed synchronously on the current thread.
"""
if not isinstance(cursor_id, int):
raise TypeError("cursor_id must be an instance of int")
raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}")
try:
if conn_mgr:
@ -2093,7 +2099,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
raise ValueError("'session' argument must be an AsyncClientSession or None.")
raise ValueError(
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
)
# Don't call end_session.
yield session
return
@ -2247,7 +2255,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_database must be an instance of str or a AsyncDatabase")
raise TypeError(
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
)
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
await self[name]._command(

View File

@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
async def join(self, timeout: Optional[int] = None) -> None:
async def join(self) -> None:
"""Wait for the monitor to stop."""
await self._executor.join(timeout)
await self._executor.join()
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -189,6 +189,11 @@ class Monitor(MonitorBase):
self._rtt_monitor.gc_safe_close()
self.cancel_check()
async def join(self) -> None:
await asyncio.gather(
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
) # type: ignore[func-returns-value]
async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import asyncio
import logging
import os
import queue
@ -29,7 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.errors import (
@ -207,6 +208,9 @@ class Topology:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []
async def open(self) -> None:
"""Start monitoring, or restart after a fork.
@ -241,6 +245,8 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
@ -283,6 +289,10 @@ class Topology:
else:
server_timeout = server_selection_timeout
# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
await self.cleanup_monitors()
async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
@ -520,6 +530,8 @@ class Topology:
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
@ -695,6 +707,8 @@ class Topology:
old_td = self._description
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Mark all servers Unknown.
self._description = self._description.reset()
@ -705,6 +719,8 @@ class Topology:
# Stop SRV polling thread.
if self._srv_monitor:
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
self._opened = False
self._closed = True
@ -944,6 +960,8 @@ class Topology:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)
def _create_pool_for_server(self, address: _Address) -> Pool:
@ -1031,6 +1049,15 @@ class Topology:
else:
return ",".join(str(server.error) for server in servers if server.error)
async def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
def __repr__(self) -> str:
msg = ""
if not self._opened:

View File

@ -107,7 +107,7 @@ def _build_credentials_tuple(
) -> MongoCredential:
"""Build and return a mechanism specific credentials tuple."""
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
raise ConfigurationError(f"{mech} requires a username.")
raise ConfigurationError(f"{mech} requires a username")
if mech == "GSSAPI":
if source is not None and source != "$external":
raise ValueError("authentication source must be $external or None for GSSAPI")
@ -219,7 +219,7 @@ def _build_credentials_tuple(
else:
source_database = source or database or "admin"
if passwd is None:
raise ConfigurationError("A password is required.")
raise ConfigurationError("A password is required")
return MongoCredential(mech, source_database, user, passwd, None, _Cache())

View File

@ -223,4 +223,4 @@ def validate_collation_or_none(
return value.document
if isinstance(value, dict):
return value
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
raise TypeError("collation must be a dict, an instance of collation.Collation, or None")

View File

@ -202,7 +202,7 @@ def validate_integer(option: str, value: Any) -> int:
return int(value)
except ValueError:
raise ValueError(f"The value of {option} must be an integer") from None
raise TypeError(f"Wrong type for {option}, value must be an integer")
raise TypeError(f"Wrong type for {option}, value must be an integer, not {type(value)}")
def validate_positive_integer(option: str, value: Any) -> int:
@ -250,7 +250,7 @@ def validate_string(option: str, value: Any) -> str:
"""Validates that 'value' is an instance of `str`."""
if isinstance(value, str):
return value
raise TypeError(f"Wrong type for {option}, value must be an instance of str")
raise TypeError(f"Wrong type for {option}, value must be an instance of str, not {type(value)}")
def validate_string_or_none(option: str, value: Any) -> Optional[str]:
@ -269,7 +269,9 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
return int(value)
except ValueError:
return value
raise TypeError(f"Wrong type for {option}, value must be an integer or a string")
raise TypeError(
f"Wrong type for {option}, value must be an integer or a string, not {type(value)}"
)
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
@ -282,7 +284,9 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in
except ValueError:
return value
return validate_non_negative_integer(option, val)
raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string")
raise TypeError(
f"Wrong type for {option}, value must be an non negative integer or a string, not {type(value)}"
)
def validate_positive_float(option: str, value: Any) -> float:
@ -365,7 +369,7 @@ def validate_max_staleness(option: str, value: Any) -> int:
def validate_read_preference(dummy: Any, value: Any) -> _ServerMode:
"""Validate a read preference."""
if not isinstance(value, _ServerMode):
raise TypeError(f"{value!r} is not a read preference.")
raise TypeError(f"{value!r} is not a read preference")
return value
@ -441,7 +445,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
props: dict[str, Any] = {}
if not isinstance(value, str):
if not isinstance(value, dict):
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
raise ValueError(
f"Auth mechanism properties must be given as a string or a dictionary, not {type(value)}"
)
for key, value in value.items(): # noqa: B020
if isinstance(value, str):
props[key] = value
@ -453,7 +459,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
from pymongo.auth_oidc_shared import OIDCCallback
if not isinstance(value, OIDCCallback):
raise ValueError("callback must be an OIDCCallback object")
raise ValueError(f"callback must be an OIDCCallback object, not {type(value)}")
props[key] = value
else:
raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}")
@ -476,7 +482,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
raise ValueError(
f"{key} is not a supported auth "
"mechanism property. Must be one of "
f"{tuple(_MECHANISM_PROPS)}."
f"{tuple(_MECHANISM_PROPS)}"
)
if key == "CANONICALIZE_HOST_NAME":
@ -520,7 +526,7 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
def validate_list(option: str, value: Any) -> list:
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError(f"{option} must be a list")
raise TypeError(f"{option} must be a list, not {type(value)}")
return value
@ -587,7 +593,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
if value is None:
return value
if not isinstance(value, ServerApi):
raise TypeError(f"{option} must be an instance of ServerApi")
raise TypeError(f"{option} must be an instance of ServerApi, not {type(value)}")
return value
@ -596,7 +602,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
if value is None:
return value
if not callable(value):
raise ValueError(f"{option} must be a callable")
raise ValueError(f"{option} must be a callable, not {type(value)}")
return value
@ -651,7 +657,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
from pymongo.encryption_options import AutoEncryptionOpts
if not isinstance(value, AutoEncryptionOpts):
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts, not {type(value)}")
return value
@ -668,7 +674,9 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo
elif isinstance(value, int):
return DatetimeConversion(value)
raise TypeError(f"{option} must be a str or int representing DatetimeConversion")
raise TypeError(
f"{option} must be a str or int representing DatetimeConversion, not {type(value)}"
)
def validate_server_monitoring_mode(option: str, value: str) -> str:
@ -928,12 +936,14 @@ class BaseObject:
if not isinstance(write_concern, WriteConcern):
raise TypeError(
"write_concern must be an instance of pymongo.write_concern.WriteConcern"
f"write_concern must be an instance of pymongo.write_concern.WriteConcern, not {type(write_concern)}"
)
self._write_concern = write_concern
if not isinstance(read_concern, ReadConcern):
raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern")
raise TypeError(
f"read_concern must be an instance of pymongo.read_concern.ReadConcern, not {type(read_concern)}"
)
self._read_concern = read_concern
@property

View File

@ -91,7 +91,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int:
try:
level = int(value)
except Exception:
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
raise TypeError(f"{option} must be an integer, not {value!r}") from None
if level < -1 or level > 9:
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
return level

View File

@ -39,7 +39,7 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])):
for key, value in self._asdict().items():
if value is not None and not isinstance(value, str):
raise TypeError(
f"Wrong type for DriverInfo {key} option, value must be an instance of str"
f"Wrong type for DriverInfo {key} option, value must be an instance of str, not {type(value)}"
)
return self

View File

@ -225,7 +225,9 @@ class AutoEncryptionOpts:
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
self._mongocryptd_spawn_args = mongocryptd_spawn_args
if not isinstance(self._mongocryptd_spawn_args, list):
raise TypeError("mongocryptd_spawn_args must be a list")
raise TypeError(
f"mongocryptd_spawn_args must be a list, not {type(self._mongocryptd_spawn_args)}"
)
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.

View File

@ -122,7 +122,7 @@ def _index_list(
"""
if direction is not None:
if not isinstance(key_or_list, str):
raise TypeError("Expected a string and a direction")
raise TypeError(f"Expected a string and a direction, not {type(key_or_list)}")
return [(key_or_list, direction)]
else:
if isinstance(key_or_list, str):
@ -132,7 +132,9 @@ def _index_list(
elif isinstance(key_or_list, abc.Mapping):
return list(key_or_list.items())
elif not isinstance(key_or_list, (list, tuple)):
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
raise TypeError(
f"if no direction is specified, key_or_list must be an instance of list, not {type(key_or_list)}"
)
values: list[tuple[str, int]] = []
for item in key_or_list:
if isinstance(item, str):
@ -172,11 +174,12 @@ def _index_document(index_list: _IndexList) -> dict[str, Any]:
def _validate_index_key_pair(key: Any, value: Any) -> None:
if not isinstance(key, str):
raise TypeError("first item in each key pair must be an instance of str")
raise TypeError(f"first item in each key pair must be an instance of str, not {type(key)}")
if not isinstance(value, (str, int, abc.Mapping)):
raise TypeError(
"second item in each key pair must be 1, -1, "
"'2d', or another valid MongoDB index specifier."
f", not {type(value)}"
)

View File

@ -472,14 +472,15 @@ def _validate_event_listeners(
) -> Sequence[_EventListeners]:
"""Validate event listeners"""
if not isinstance(listeners, abc.Sequence):
raise TypeError(f"{option} must be a list or tuple")
raise TypeError(f"{option} must be a list or tuple, not {type(listeners)}")
for listener in listeners:
if not isinstance(listener, _EventListener):
raise TypeError(
f"Listeners for {option} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener."
"ConnectionPoolListener,"
f"not {type(listener)}"
)
return listeners
@ -496,7 +497,8 @@ def register(listener: _EventListener) -> None:
f"Listeners for {listener} must be either a "
"CommandListener, ServerHeartbeatListener, "
"ServerListener, TopologyListener, or "
"ConnectionPoolListener."
"ConnectionPoolListener,"
f"not {type(listener)}"
)
if isinstance(listener, CommandListener):
_LISTENERS.command_listeners.append(listener)

View File

@ -75,6 +75,8 @@ class AsyncPeriodicExecutor:
callback; see monitor.py.
"""
self._stopped = True
if self._task is not None:
self._task.cancel()
async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:

View File

@ -38,7 +38,7 @@ class ReadConcern:
if level is None or isinstance(level, str):
self.__level = level
else:
raise TypeError("level must be a string or None.")
raise TypeError(f"level must be a string or None, not {type(level)}")
@property
def level(self) -> Optional[str]:

View File

@ -115,4 +115,4 @@ else:
def get_ssl_context(*dummy): # type: ignore
"""No ssl module, raise ConfigurationError."""
raise ConfigurationError("The ssl module is not available.")
raise ConfigurationError("The ssl module is not available")

View File

@ -158,7 +158,7 @@ def _password_digest(username: str, password: str) -> str:
if len(password) == 0:
raise ValueError("password can't be empty")
if not isinstance(username, str):
raise TypeError("username must be an instance of str")
raise TypeError(f"username must be an instance of str, not {type(username)}")
md5hash = hashlib.md5() # noqa: S324
data = f"{username}:mongo:{password}"

View File

@ -213,7 +213,9 @@ class _OIDCAuthenticator:
)
resp = cb.fetch(context)
if not isinstance(resp, OIDCCallbackResult):
raise ValueError("Callback result must be of type OIDCCallbackResult")
raise ValueError(
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
)
self.refresh_token = resp.refresh_token
self.access_token = resp.access_token
self.token_gen_id += 1

View File

@ -309,7 +309,9 @@ class TransactionOptions:
)
if max_commit_time_ms is not None:
if not isinstance(max_commit_time_ms, int):
raise TypeError("max_commit_time_ms must be an integer or None")
raise TypeError(
f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}"
)
@property
def read_concern(self) -> Optional[ReadConcern]:
@ -897,7 +899,9 @@ class ClientSession:
another `ClientSession` instance.
"""
if not isinstance(cluster_time, _Mapping):
raise TypeError("cluster_time must be a subclass of collections.Mapping")
raise TypeError(
f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}"
)
if not isinstance(cluster_time.get("clusterTime"), Timestamp):
raise ValueError("Invalid cluster_time")
self._advance_cluster_time(cluster_time)
@ -918,7 +922,9 @@ class ClientSession:
another `ClientSession` instance.
"""
if not isinstance(operation_time, Timestamp):
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
raise TypeError(
f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}"
)
self._advance_operation_time(operation_time)
def _process_response(self, reply: Mapping[str, Any]) -> None:

View File

@ -231,7 +231,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
read_concern or database.read_concern,
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
raise TypeError(f"name must be an instance of str, not {type(name)}")
from pymongo.synchronous.database import Database
if not isinstance(database, Database):
@ -2472,7 +2472,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
name = helpers_shared._gen_index_name(index_or_name)
if not isinstance(name, str):
raise TypeError("index_or_name must be an instance of str or list")
raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}")
cmd = {"dropIndexes": self._name, "index": name}
cmd.update(kwargs)
@ -3071,7 +3071,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""
if not isinstance(new_name, str):
raise TypeError("new_name must be an instance of str")
raise TypeError(f"new_name must be an instance of str, not {type(new_name)}")
if not new_name or ".." in new_name:
raise InvalidName("collection names cannot be empty")
@ -3141,7 +3141,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
"""
if not isinstance(key, str):
raise TypeError("key must be an instance of str")
raise TypeError(f"key must be an instance of str, not {type(key)}")
cmd = {"distinct": self._name, "key": key}
if filter is not None:
if "query" in kwargs:
@ -3189,7 +3189,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
common.validate_is_mapping("filter", filter)
if not isinstance(return_document, bool):
raise ValueError(
"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER"
f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}"
)
collation = validate_collation_or_none(kwargs.pop("collation", None))
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}

View File

@ -94,7 +94,9 @@ class CommandCursor(Generic[_DocumentType]):
self.batch_size(batch_size)
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
raise TypeError(
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
)
def __del__(self) -> None:
self._die_no_lock()
@ -115,7 +117,7 @@ class CommandCursor(Generic[_DocumentType]):
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")

View File

@ -146,9 +146,9 @@ class Cursor(Generic[_DocumentType]):
spec: Mapping[str, Any] = filter or {}
validate_is_mapping("filter", spec)
if not isinstance(skip, int):
raise TypeError("skip must be an instance of int")
raise TypeError(f"skip must be an instance of int, not {type(skip)}")
if not isinstance(limit, int):
raise TypeError("limit must be an instance of int")
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
validate_boolean("no_cursor_timeout", no_cursor_timeout)
if no_cursor_timeout and not self._explicit_session:
warnings.warn(
@ -171,7 +171,7 @@ class Cursor(Generic[_DocumentType]):
validate_boolean("allow_partial_results", allow_partial_results)
validate_boolean("oplog_replay", oplog_replay)
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
# Only set if allow_disk_use is provided by the user, else None.
@ -388,7 +388,7 @@ class Cursor(Generic[_DocumentType]):
cursor.add_option(2)
"""
if not isinstance(mask, int):
raise TypeError("mask must be an int")
raise TypeError(f"mask must be an int, not {type(mask)}")
self._check_okay_to_chain()
if mask & _QUERY_OPTIONS["exhaust"]:
@ -408,7 +408,7 @@ class Cursor(Generic[_DocumentType]):
cursor.remove_option(2)
"""
if not isinstance(mask, int):
raise TypeError("mask must be an int")
raise TypeError(f"mask must be an int, not {type(mask)}")
self._check_okay_to_chain()
if mask & _QUERY_OPTIONS["exhaust"]:
@ -432,7 +432,7 @@ class Cursor(Generic[_DocumentType]):
.. versionadded:: 3.11
"""
if not isinstance(allow_disk_use, bool):
raise TypeError("allow_disk_use must be a bool")
raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}")
self._check_okay_to_chain()
self._allow_disk_use = allow_disk_use
@ -451,7 +451,7 @@ class Cursor(Generic[_DocumentType]):
.. seealso:: The MongoDB documentation on `limit <https://dochub.mongodb.org/core/limit>`_.
"""
if not isinstance(limit, int):
raise TypeError("limit must be an integer")
raise TypeError(f"limit must be an integer, not {type(limit)}")
if self._exhaust:
raise InvalidOperation("Can't use limit and exhaust together.")
self._check_okay_to_chain()
@ -479,7 +479,7 @@ class Cursor(Generic[_DocumentType]):
:param batch_size: The size of each batch of results requested.
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self._check_okay_to_chain()
@ -499,7 +499,7 @@ class Cursor(Generic[_DocumentType]):
:param skip: the number of results to skip
"""
if not isinstance(skip, int):
raise TypeError("skip must be an integer")
raise TypeError(f"skip must be an integer, not {type(skip)}")
if skip < 0:
raise ValueError("skip must be >= 0")
self._check_okay_to_chain()
@ -520,7 +520,7 @@ class Cursor(Generic[_DocumentType]):
:param max_time_ms: the time limit after which the operation is aborted
"""
if not isinstance(max_time_ms, int) and max_time_ms is not None:
raise TypeError("max_time_ms must be an integer or None")
raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}")
self._check_okay_to_chain()
self._max_time_ms = max_time_ms
@ -543,7 +543,9 @@ class Cursor(Generic[_DocumentType]):
.. versionadded:: 3.2
"""
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
raise TypeError("max_await_time_ms must be an integer or None")
raise TypeError(
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
)
self._check_okay_to_chain()
# Ignore max_await_time_ms if not tailable or await_data is False.
@ -677,7 +679,7 @@ class Cursor(Generic[_DocumentType]):
.. versionadded:: 2.7
"""
if not isinstance(spec, (list, tuple)):
raise TypeError("spec must be an instance of list or tuple")
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
self._check_okay_to_chain()
self._max = dict(spec)
@ -699,7 +701,7 @@ class Cursor(Generic[_DocumentType]):
.. versionadded:: 2.7
"""
if not isinstance(spec, (list, tuple)):
raise TypeError("spec must be an instance of list or tuple")
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
self._check_okay_to_chain()
self._min = dict(spec)

View File

@ -122,7 +122,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
from pymongo.synchronous.mongo_client import MongoClient
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
raise TypeError(f"name must be an instance of str, not {type(name)}")
if not isinstance(client, MongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.
@ -1303,7 +1303,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_collection must be an instance of str")
raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}")
encrypted_fields = self._get_encrypted_fields(
{"encryptedFields": encrypted_fields},
name,
@ -1367,7 +1367,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_collection must be an instance of str or Collection")
raise TypeError(
f"name_or_collection must be an instance of str or Collection, not {type(name)}"
)
cmd = {"validate": name, "scandata": scandata, "full": full}
if comment is not None:
cmd["comment"] = comment

View File

@ -320,7 +320,9 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS)
data_key_id = raw_doc.get("_id")
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
raise TypeError("data_key _id must be Binary with a UUID subtype")
raise TypeError(
f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}"
)
assert self.key_vault_coll is not None
self.key_vault_coll.insert_one(raw_doc)
@ -642,7 +644,9 @@ class ClientEncryption(Generic[_DocumentType]):
)
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
raise TypeError(
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
)
if not isinstance(key_vault_client, MongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.

View File

@ -748,7 +748,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if port is None:
port = self.PORT
if not isinstance(port, int):
raise TypeError("port must be an instance of int")
raise TypeError(f"port must be an instance of int, not {type(port)}")
# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
@ -1559,6 +1559,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
self._closed = True
if not _IS_SYNC:
asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)
if not _IS_SYNC:
# Add support for contextlib.closing.
@ -1965,7 +1971,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
The cursor is closed synchronously on the current thread.
"""
if not isinstance(cursor_id, int):
raise TypeError("cursor_id must be an instance of int")
raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}")
try:
if conn_mgr:
@ -2087,7 +2093,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
raise ValueError("'session' argument must be a ClientSession or None.")
raise ValueError(
f"'session' argument must be a ClientSession or None, not {type(session)}"
)
# Don't call end_session.
yield session
return
@ -2235,7 +2243,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
name = name.name
if not isinstance(name, str):
raise TypeError("name_or_database must be an instance of str or a Database")
raise TypeError(
f"name_or_database must be an instance of str or a Database, not {type(name)}"
)
with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
self[name]._command(

View File

@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
def join(self) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
self._executor.join()
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -189,6 +189,9 @@ class Monitor(MonitorBase):
self._rtt_monitor.gc_safe_close()
self.cancel_check()
def join(self) -> None:
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]
def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import asyncio
import logging
import os
import queue
@ -61,7 +62,7 @@ from pymongo.server_selectors import (
writable_server_selector,
)
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
@ -207,6 +208,9 @@ class Topology:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []
def open(self) -> None:
"""Start monitoring, or restart after a fork.
@ -241,6 +245,8 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
@ -283,6 +289,10 @@ class Topology:
else:
server_timeout = server_selection_timeout
# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
self.cleanup_monitors()
with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
@ -520,6 +530,8 @@ class Topology:
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
@ -693,6 +705,8 @@ class Topology:
old_td = self._description
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Mark all servers Unknown.
self._description = self._description.reset()
@ -703,6 +717,8 @@ class Topology:
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
self._opened = False
self._closed = True
@ -942,6 +958,8 @@ class Topology:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)
def _create_pool_for_server(self, address: _Address) -> Pool:
@ -1029,6 +1047,15 @@ class Topology:
else:
return ",".join(str(server.error) for server in servers if server.error)
def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
def __repr__(self) -> str:
msg = ""
if not self._opened:

View File

@ -91,7 +91,7 @@ def parse_userinfo(userinfo: str) -> tuple[str, str]:
user, _, passwd = userinfo.partition(":")
# No password is expected with GSSAPI authentication.
if not user:
raise InvalidURI("The empty string is not valid username.")
raise InvalidURI("The empty string is not valid username")
return unquote_plus(user), unquote_plus(passwd)
@ -347,7 +347,7 @@ def split_options(
semi_idx = opts.find(";")
try:
if and_idx >= 0 and semi_idx >= 0:
raise InvalidURI("Can not mix '&' and ';' for option separators.")
raise InvalidURI("Can not mix '&' and ';' for option separators")
elif and_idx >= 0:
options = _parse_options(opts, "&")
elif semi_idx >= 0:
@ -357,7 +357,7 @@ def split_options(
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
raise InvalidURI("MongoDB URI options are key=value pairs") from None
options = _handle_security_options(options)
@ -389,7 +389,7 @@ def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[
nodes = []
for entity in hosts.split(","):
if not entity:
raise ConfigurationError("Empty host (or extra comma in host list).")
raise ConfigurationError("Empty host (or extra comma in host list)")
port = default_port
# Unix socket entities don't have ports
if entity.endswith(".sock"):
@ -502,7 +502,7 @@ def parse_uri(
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP.")
raise InvalidURI("Must provide at least one hostname or IP")
user = None
passwd = None

View File

@ -74,7 +74,7 @@ class WriteConcern:
if wtimeout is not None:
if not isinstance(wtimeout, int):
raise TypeError("wtimeout must be an integer")
raise TypeError(f"wtimeout must be an integer, not {type(wtimeout)}")
if wtimeout < 0:
raise ValueError("wtimeout cannot be less than 0")
self.__document["wtimeout"] = wtimeout
@ -98,7 +98,7 @@ class WriteConcern:
raise ValueError("w cannot be less than 0")
self.__acknowledged = w > 0
elif not isinstance(w, str):
raise TypeError("w must be an integer or string")
raise TypeError(f"w must be an integer or string, not {type(w)}")
self.__document["w"] = w
self.__server_default = not self.__document

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
@ -30,6 +31,33 @@ import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test.helpers import (
COMPRESSORS,
IS_SRV,
@ -52,31 +80,7 @@ from test.helpers import (
sanitize_cmd,
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
@ -589,7 +593,7 @@ class ClientContext:
if self.has_secondaries:
return True
if self.is_mongos:
shard = self.client.config.shards.find_one()["host"] # type:ignore[index]
shard = (self.client.config.shards.find_one())["host"] # type:ignore[index]
num_members = shard.count(",") + 1
return num_members > 1
return False
@ -863,18 +867,66 @@ class ClientContext:
# Reusable client context
client_context = ClientContext()
# Global event loop for async tests.
LOOP = None
def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP
class PyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by TestCase.
def setUp(self):
pass
def tearDown(self):
pass
def addCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)
def _callSetUp(self):
self.setUp()
self._callAsync(self.setUp)
def _callTestMethod(self, method):
self._callMaybeAsync(method)
def _callTearDown(self):
self._callAsync(self.tearDown)
self.tearDown()
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))
def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
@client_context.require_connection
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest):
def setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
@ -30,6 +31,33 @@ import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
sys.path[0:0] = [""]
from test.helpers import (
COMPRESSORS,
IS_SRV,
@ -52,31 +80,7 @@ from test.helpers import (
sanitize_cmd,
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
_IS_SYNC = False
@ -588,10 +592,10 @@ class AsyncClientContext:
@property
async def supports_secondary_read_pref(self):
if self.has_secondaries:
if await self.has_secondaries:
return True
if self.is_mongos:
shard = await self.client.config.shards.find_one()["host"] # type:ignore[index]
shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index]
num_members = shard.count(",") + 1
return num_members > 1
return False
@ -865,18 +869,66 @@ class AsyncClientContext:
# Reusable client context
async_client_context = AsyncClientContext()
async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
# Global event loop for async tests.
LOOP = None
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP
class AsyncPyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by IsolatedAsyncioTestCase.
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)
def _callSetUp(self):
self.setUp()
self._callAsync(self.asyncSetUp)
def _callTestMethod(self, method):
self._callMaybeAsync(method)
def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))
def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1124,15 +1176,15 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
async def disable_replication(self, client):
"""Disable replication on all secondaries."""
for h, p in client.secondaries:
for h, p in await client.secondaries:
secondary = await self.async_single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
async def enable_replication(self, client):
"""Enable replication on all secondaries."""
for h, p in client.secondaries:
for h, p in await client.secondaries:
secondary = await self.async_single_client(h, p)
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
class AsyncUnitTest(AsyncPyMongoTestCase):
@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest):
async def async_setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
await async_client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")

View File

@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import asyncio
import base64
import gc
import multiprocessing
@ -30,6 +31,8 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo._asyncio_task import create_task
try:
import ipaddress
@ -369,3 +372,38 @@ class SystemCertsPatcher:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs
if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object
class ConcurrentRunner(PARENT):
def __init__(self, **kwargs):
if _IS_SYNC:
super().__init__(**kwargs)
self.name = kwargs.get("name", "ConcurrentRunner")
self.stopped = False
self.task = None
self.target = kwargs.get("target", None)
self.args = kwargs.get("args", [])
if not _IS_SYNC:
async def start(self):
self.task = create_task(self.run(), name=self.name)
async def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)
def is_alive(self):
return not self.stopped
async def run(self):
try:
await self.target(*self.args)
finally:
self.stopped = True

View File

@ -0,0 +1,221 @@
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the SRV support tests."""
from __future__ import annotations
import glob
import json
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
AsyncPyMongoTestCase,
async_client_context,
unittest,
)
from test.utils import async_wait_until
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import parse_uri, split_hosts
_IS_SYNC = False
class TestDNSRepl(AsyncPyMongoTestCase):
if _IS_SYNC:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set"
)
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set"
)
load_balanced = False
@async_client_context.require_replica_set
def asyncSetUp(self):
pass
class TestDNSLoadBalanced(AsyncPyMongoTestCase):
if _IS_SYNC:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced"
)
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced"
)
load_balanced = True
@async_client_context.require_load_balancer
def asyncSetUp(self):
pass
class TestDNSSharded(AsyncPyMongoTestCase):
if _IS_SYNC:
TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded")
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded"
)
load_balanced = False
@async_client_context.require_mongos
def asyncSetUp(self):
pass
def create_test(test_case):
async def run_test(self):
uri = test_case["uri"]
seeds = test_case.get("seeds")
num_seeds = test_case.get("numSeeds", len(seeds or []))
hosts = test_case.get("hosts")
num_hosts = test_case.get("numHosts", len(hosts or []))
options = test_case.get("options", {})
if "ssl" in options:
options["tls"] = options.pop("ssl")
parsed_options = test_case.get("parsed_options")
# See DRIVERS-1324, unless tls is explicitly set to False we need TLS.
needs_tls = not (options and (options.get("ssl") is False or options.get("tls") is False))
if needs_tls and not async_client_context.tls:
self.skipTest("this test requires a TLS cluster")
if not needs_tls and async_client_context.tls:
self.skipTest("this test requires a non-TLS cluster")
if seeds:
seeds = split_hosts(",".join(seeds))
if hosts:
hosts = frozenset(split_hosts(",".join(hosts)))
if seeds or num_seeds:
result = parse_uri(uri, validate=True)
if seeds is not None:
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
if num_seeds is not None:
self.assertEqual(len(result["nodelist"]), num_seeds)
if options:
opts = result["options"]
if "readpreferencetags" in opts:
rpts = validate_read_preference_tags(
"readPreferenceTags", opts.pop("readpreferencetags")
)
opts["readPreferenceTags"] = rpts
self.assertEqual(result["options"], options)
if parsed_options:
for opt, expected in parsed_options.items():
if opt == "user":
self.assertEqual(result["username"], expected)
elif opt == "password":
self.assertEqual(result["password"], expected)
elif opt == "auth_database" or opt == "db":
self.assertEqual(result["database"], expected)
hostname = next(iter(async_client_context.client.nodes))[0]
# The replica set members must be configured as 'localhost'.
if hostname == "localhost":
copts = async_client_context.default_client_options.copy()
# Remove tls since SRV parsing should add it automatically.
copts.pop("tls", None)
if async_client_context.tls:
# Our test certs don't support the SRV hosts used in these
# tests.
copts["tlsAllowInvalidHostnames"] = True
client = self.simple_client(uri, **copts)
if client._options.connect:
await client.aconnect()
if num_seeds is not None:
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
if hosts is not None:
await async_wait_until(
lambda: hosts == client.nodes, "match test hosts to client nodes"
)
if num_hosts is not None:
await async_wait_until(
lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts"
)
if test_case.get("ping", True):
await client.admin.command("ping")
# XXX: we should block until SRV poller runs at least once
# and re-run these assertions.
else:
try:
parse_uri(uri)
except (ConfigurationError, ValueError):
pass
else:
self.fail("failed to raise an exception")
return run_test
def create_tests(cls):
for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")):
test_suffix, _ = os.path.splitext(os.path.basename(filename))
with open(filename) as dns_test_file:
test_method = create_test(json.load(dns_test_file))
setattr(cls, "test_" + test_suffix, test_method)
create_tests(TestDNSRepl)
create_tests(TestDNSLoadBalanced)
create_tests(TestDNSSharded)
class TestParsingErrors(AsyncPyMongoTestCase):
async def test_invalid_host(self):
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb is not",
self.simple_client,
"mongodb+srv://mongodb",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb.com is not",
self.simple_client,
"mongodb+srv://mongodb.com",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://127.0.0.1",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://[::1]",
)
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
async def test_connect_case_insensitive(self):
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,39 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the AsyncGridFS unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,97 @@
# Copyright 2016-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the monitoring of the server heartbeats."""
from __future__ import annotations
import sys
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest
from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until
from pymongo.asynchronous.monitor import Monitor
from pymongo.errors import ConnectionFailure
from pymongo.hello import Hello, HelloCompat
_IS_SYNC = False
class TestHeartbeatMonitoring(AsyncIntegrationTest):
async def create_mock_monitor(self, responses, uri, expected_results):
listener = HeartbeatEventListener()
with client_knobs(
heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1
):
class MockMonitor(Monitor):
async def _check_with_socket(self, *args, **kwargs):
if isinstance(responses[1], Exception):
raise responses[1]
return Hello(responses[1]), 99
_ = await self.async_single_client(
h=uri,
event_listeners=(listener,),
_monitor_class=MockMonitor,
_pool_class=AsyncMockPool,
connect=True,
)
expected_len = len(expected_results)
# Wait for *at least* expected_len number of results. The
# monitor thread may run multiple times during the execution
# of this test.
await async_wait_until(
lambda: len(listener.events) >= expected_len, "publish all events"
)
# zip gives us len(expected_results) pairs.
for expected, actual in zip(expected_results, listener.events):
self.assertEqual(expected, actual.__class__.__name__)
self.assertEqual(actual.connection_id, responses[0])
if expected != "ServerHeartbeatStartedEvent":
if isinstance(actual.reply, Hello):
self.assertEqual(actual.duration, 99)
self.assertEqual(actual.reply._doc, responses[1])
else:
self.assertEqual(actual.reply, responses[1])
async def test_standalone(self):
responses = (
("a", 27017),
{HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1},
)
uri = "mongodb://a:27017"
expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"]
await self.create_mock_monitor(responses, uri, expected_results)
async def test_standalone_error(self):
responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE"))
uri = "mongodb://a:27017"
# _check_with_socket failing results in a second attempt.
expected_results = [
"ServerHeartbeatStartedEvent",
"ServerHeartbeatFailedEvent",
"ServerHeartbeatStartedEvent",
"ServerHeartbeatFailedEvent",
]
await self.create_mock_monitor(responses, uri, expected_results)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,383 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the auth spec tests."""
from __future__ import annotations
import asyncio
import os
import pathlib
import sys
import time
import uuid
from typing import Any, Mapping
import pytest
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest
from test.asynchronous.unified_format import generate_test_classes
from test.utils import AllowListEventListener, OvertCommandListener
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
pytestmark = pytest.mark.index_management
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management")
_NAME = "test-search-index"
class TestCreateSearchIndex(AsyncIntegrationTest):
async def test_inputs(self):
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
raise unittest.SkipTest("Skipping index management tests")
listener = AllowListEventListener("createSearchIndexes")
client = self.simple_client(event_listeners=[listener])
coll = client.test.test
await coll.drop()
definition = dict(mappings=dict(dynamic=True))
model_kwarg_list: list[Mapping[str, Any]] = [
dict(definition=definition, name=None),
dict(definition=definition, name="test"),
]
for model_kwargs in model_kwarg_list:
model = SearchIndexModel(**model_kwargs)
with self.assertRaises(OperationFailure):
await coll.create_search_index(model)
with self.assertRaises(OperationFailure):
await coll.create_search_index(model_kwargs)
listener.reset()
with self.assertRaises(OperationFailure):
await coll.create_search_index({"definition": definition, "arbitraryOption": 1})
self.assertEqual(
{"definition": definition, "arbitraryOption": 1},
listener.events[0].command["indexes"][0],
)
listener.reset()
with self.assertRaises(OperationFailure):
await coll.create_search_index({"definition": definition, "type": "search"})
self.assertEqual(
{"definition": definition, "type": "search"}, listener.events[0].command["indexes"][0]
)
class SearchIndexIntegrationBase(AsyncPyMongoTestCase):
db_name = "test_search_index_base"
@classmethod
def setUpClass(cls) -> None:
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
raise unittest.SkipTest("Skipping index management tests")
cls.url = os.environ.get("MONGODB_URI")
cls.username = os.environ["DB_USER"]
cls.password = os.environ["DB_PASSWORD"]
cls.listener = OvertCommandListener()
async def asyncSetUp(self) -> None:
self.client = self.simple_client(
self.url,
username=self.username,
password=self.password,
event_listeners=[self.listener],
)
await self.client.drop_database(_NAME)
self.db = self.client[self.db_name]
async def asyncTearDown(self):
await self.client.drop_database(_NAME)
async def wait_for_ready(self, coll, name=_NAME, predicate=None):
"""Wait for a search index to be ready."""
indices: list[Mapping[str, Any]] = []
if predicate is None:
predicate = lambda index: index.get("queryable") is True
while True:
indices = await (await coll.list_search_indexes(name)).to_list()
if len(indices) and predicate(indices[0]):
return indices[0]
await asyncio.sleep(5)
class TestSearchIndexIntegration(SearchIndexIntegrationBase):
db_name = "test_search_index"
async def test_comment_field(self):
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Create a new search index on ``coll0`` that implicitly passes its type.
search_definition = {"mappings": {"dynamic": False}}
self.listener.reset()
implicit_search_resp = await coll0.create_search_index(
model={"name": _NAME + "-implicit", "definition": search_definition}, comment="foo"
)
event = self.listener.events[0]
self.assertEqual(event.command["comment"], "foo")
# Get the index definition.
self.listener.reset()
await (await coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next()
event = self.listener.events[0]
self.assertEqual(event.command["comment"], "foo")
class TestSearchIndexProse(SearchIndexIntegrationBase):
db_name = "test_search_index_prose"
async def test_case_1(self):
"""Driver can successfully create and list search indexes."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
# Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. Use the following definition:
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
await coll0.insert_one({})
resp = await coll0.create_search_index(model)
# Assert that the command returns the name of the index: ``"test-search-index"``.
self.assertEqual(resp, _NAME)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``:
# An index with the ``name`` of ``test-search-index`` is present and the index has a field ``queryable`` with a value of ``true``.
index = await self.wait_for_ready(coll0)
# . Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }``
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], model["definition"])
async def test_case_2(self):
"""Driver can successfully create multiple indexes in batch."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Create two new search indexes on ``coll0`` with the ``createSearchIndexes`` helper.
name1 = "test-search-index-1"
name2 = "test-search-index-2"
definition = {"mappings": {"dynamic": False}}
index_definitions: list[dict[str, Any]] = [
{"name": name1, "definition": definition},
{"name": name2, "definition": definition},
]
await coll0.create_search_indexes(
[SearchIndexModel(i["definition"], i["name"]) for i in index_definitions]
)
# .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``.
indices = await (await coll0.list_search_indexes()).to_list()
names = [i["name"] for i in indices]
self.assertIn(name1, names)
self.assertIn(name2, names)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied.
# An index with the ``name`` of ``test-search-index-1`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index1``.
# An index with the ``name`` of ``test-search-index-2`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index2``.
index1 = await self.wait_for_ready(coll0, name1)
index2 = await self.wait_for_ready(coll0, name2)
# Assert that ``index1`` and ``index2`` have the property ``latestDefinition`` whose value is ``{ "mappings" : { "dynamic" : false } }``
for index in [index1, index2]:
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], definition)
async def test_case_3(self):
"""Driver can successfully drop search indexes."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Create a new search index on ``coll0``.
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
resp = await coll0.create_search_index(model)
# Assert that the command returns the name of the index: ``"test-search-index"``.
self.assertEqual(resp, "test-search-index")
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
# An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``.
await self.wait_for_ready(coll0)
# Run a ``dropSearchIndex`` on ``coll0``, using ``test-search-index`` for the name.
await coll0.drop_search_index(_NAME)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array.
t0 = time.time()
while True:
indices = await (await coll0.list_search_indexes()).to_list()
if indices:
break
if (time.time() - t0) / 60 > 5:
raise TimeoutError("Timed out waiting for index deletion")
await asyncio.sleep(5)
async def test_case_4(self):
"""Driver can update a search index."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Create a new search index on ``coll0``.
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
resp = await coll0.create_search_index(model)
# Assert that the command returns the name of the index: ``"test-search-index"``.
self.assertEqual(resp, _NAME)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
# An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``.
await self.wait_for_ready(coll0)
# Run a ``updateSearchIndex`` on ``coll0``.
# Assert that the command does not error and the server responds with a success.
model2: dict[str, Any] = {"name": _NAME, "definition": {"mappings": {"dynamic": True}}}
await coll0.update_search_index(_NAME, model2["definition"])
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
# An index with the ``name`` of ``test-search-index`` is present. This index is referred to as ``index``.
# The index has a field ``queryable`` with a value of ``true`` and has a field ``status`` with the value of ``READY``.
predicate = lambda index: index.get("queryable") is True and index.get("status") == "READY"
await self.wait_for_ready(coll0, predicate=predicate)
# Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``.
index = (await (await coll0.list_search_indexes(_NAME)).to_list())[0]
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], model2["definition"])
async def test_case_5(self):
"""``dropSearchIndex`` suppresses namespace not found errors."""
# Create a driver-side collection object for a randomly generated collection name. Do not create this collection on the server.
coll0 = self.db[f"col{uuid.uuid4()}"]
# Run a ``dropSearchIndex`` command and assert that no error is thrown.
await coll0.drop_search_index("foo")
async def test_case_6(self):
"""Driver can successfully create and list search indexes with non-default readConcern and writeConcern."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Apply a write concern ``WriteConcern(w=1)`` and a read concern with ``ReadConcern(level="majority")`` to ``coll0``.
coll0 = coll0.with_options(
write_concern=WriteConcern(w="1"), read_concern=ReadConcern(level="majority")
)
# Create a new search index on ``coll0`` with the ``createSearchIndex`` helper.
name = "test-search-index-case6"
model = {"name": name, "definition": {"mappings": {"dynamic": False}}}
resp = await coll0.create_search_index(model)
# Assert that the command returns the name of the index: ``"test-search-index-case6"``.
self.assertEqual(resp, name)
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``:
# - An index with the ``name`` of ``test-search-index-case6`` is present and the index has a field ``queryable`` with a value of ``true``.
index = await self.wait_for_ready(coll0, name)
# Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }``
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], model["definition"])
async def test_case_7(self):
"""Driver handles index types."""
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
coll0 = self.db[f"col{uuid.uuid4()}"]
await coll0.insert_one({})
# Use these search and vector search definitions for indexes.
search_definition = {"mappings": {"dynamic": False}}
vector_search_definition = {
"fields": [
{
"type": "vector",
"path": "plot_embedding",
"numDimensions": 1536,
"similarity": "euclidean",
},
]
}
# Create a new search index on ``coll0`` that implicitly passes its type.
implicit_search_resp = await coll0.create_search_index(
model={"name": _NAME + "-implicit", "definition": search_definition}
)
# Get the index definition.
resp = await (await coll0.list_search_indexes(name=implicit_search_resp)).next()
# Assert that the index model contains the correct index type: ``"search"``.
self.assertEqual(resp["type"], "search")
# Create a new search index on ``coll0`` that explicitly passes its type.
explicit_search_resp = await coll0.create_search_index(
model={"name": _NAME + "-explicit", "type": "search", "definition": search_definition}
)
# Get the index definition.
resp = await (await coll0.list_search_indexes(name=explicit_search_resp)).next()
# Assert that the index model contains the correct index type: ``"search"``.
self.assertEqual(resp["type"], "search")
# Create a new vector search index on ``coll0`` that explicitly passes its type.
explicit_vector_resp = await coll0.create_search_index(
model={
"name": _NAME + "-vector",
"type": "vectorSearch",
"definition": vector_search_definition,
}
)
# Get the index definition.
resp = await (await coll0.list_search_indexes(name=explicit_vector_resp)).next()
# Assert that the index model contains the correct index type: ``"vectorSearch"``.
self.assertEqual(resp["type"], "vectorSearch")
# Catch the error raised when trying to create a vector search index without specifying the type
with self.assertRaises(OperationFailure) as e:
await coll0.create_search_index(
model={"name": _NAME + "-error", "definition": vector_search_definition}
)
self.assertIn("Attribute mappings missing.", e.exception.details["errmsg"])
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from test.asynchronous import AsyncIntegrationTest
from typing import Any, List, MutableMapping
from bson import Binary, Code, DBRef, ObjectId, json_util
from bson.binary import USER_DEFINED_SUBTYPE
_IS_SYNC = False
class TestJsonUtilRoundtrip(AsyncIntegrationTest):
async def test_cursor(self):
db = self.db
await db.drop_collection("test")
docs: List[MutableMapping[str, Any]] = [
{"foo": [1, 2]},
{"bar": {"hello": "world"}},
{"code": Code("function x() { return 1; }")},
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
]
await db.test.insert_many(docs)
reloaded_docs = json_util.loads(json_util.dumps(await (db.test.find()).to_list()))
for doc in docs:
self.assertTrue(doc in reloaded_docs)

View File

@ -0,0 +1,149 @@
# Copyright 2016 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test maxStalenessSeconds support."""
from __future__ import annotations
import asyncio
import os
import sys
import time
import warnings
from pathlib import Path
from pymongo import AsyncMongoClient
from pymongo.operations import _Op
sys.path[0:0] = [""]
from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest
from test.utils_selection_tests import create_selection_tests
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import writable_server_selector
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness")
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
pass
class TestMaxStaleness(AsyncPyMongoTestCase):
async def test_max_staleness(self):
client = self.simple_client()
self.assertEqual(-1, client.read_preference.max_staleness)
client = self.simple_client("mongodb://a/?readPreference=secondary")
self.assertEqual(-1, client.read_preference.max_staleness)
# These tests are specified in max-staleness-tests.rst.
with self.assertRaises(ConfigurationError):
# Default read pref "primary" can't be used with max staleness.
self.simple_client("mongodb://a/?maxStalenessSeconds=120")
with self.assertRaises(ConfigurationError):
# Read pref "primary" can't be used with max staleness.
self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = self.simple_client(
"mongodb://host/?readPreference=secondary&maxStalenessSeconds=120"
)
self.assertEqual(120, client.read_preference.max_staleness)
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
self.assertEqual(1, client.read_preference.max_staleness)
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
self.assertEqual(-1, client.read_preference.max_staleness)
client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest")
self.assertEqual(-1, client.read_preference.max_staleness)
with self.assertRaises(TypeError):
# Prohibit None.
self.simple_client(maxStalenessSeconds=None, readPreference="nearest")
async def test_max_staleness_float(self):
with self.assertRaises(TypeError) as ctx:
await self.async_rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
self.assertIn("must be an integer", str(ctx.exception))
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter("always")
client = self.simple_client(
"mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest"
)
# Option was ignored.
self.assertEqual(-1, client.read_preference.max_staleness)
self.assertIn("must be an integer", str(ctx[0]))
async def test_max_staleness_zero(self):
# Zero is too small.
with self.assertRaises(ValueError) as ctx:
await self.async_rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
self.assertIn("must be a positive integer", str(ctx.exception))
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter("always")
client = self.simple_client(
"mongodb://host/?maxStalenessSeconds=0&readPreference=nearest"
)
# Option was ignored.
self.assertEqual(-1, client.read_preference.max_staleness)
self.assertIn("must be a positive integer", str(ctx[0]))
@async_client_context.require_replica_set
async def test_last_write_date(self):
# From max-staleness-tests.rst, "Parse lastWriteDate".
client = await self.async_rs_or_single_client(heartbeatFrequencyMS=500)
await client.pymongo_test.test.insert_one({})
# Wait for the server description to be updated.
await asyncio.sleep(1)
server = await client._topology.select_server(writable_server_selector, _Op.TEST)
first = server.description.last_write_date
self.assertTrue(first)
# The first last_write_date may correspond to a internal server write,
# sleep so that the next write does not occur within the same second.
await asyncio.sleep(1)
await client.pymongo_test.test.insert_one({})
# Wait for the server description to be updated.
await asyncio.sleep(1)
server = await client._topology.select_server(writable_server_selector, _Op.TEST)
second = server.description.last_write_date
assert first is not None
assert second is not None
self.assertGreater(second, first)
self.assertLess(second, first + 10)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,115 @@
# Copyright 2022-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test client side encryption with on demand credentials."""
from __future__ import annotations
import os
import sys
import unittest
import pytest
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context
from bson.codec_options import CodecOptions
from pymongo.asynchronous.encryption import (
_HAVE_PYMONGOCRYPT,
AsyncClientEncryption,
EncryptionError,
)
_IS_SYNC = False
pytestmark = pytest.mark.csfle
class TestonDemandGCPCredentials(AsyncIntegrationTest):
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@async_client_context.require_version_min(4, 2, -1)
async def asyncSetUp(self):
await super().asyncSetUp()
self.master_key = {
"projectId": "devprod-drivers",
"location": "global",
"keyRing": "key-ring-csfle",
"keyName": "key-name-csfle",
}
@unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto")
async def test_01_failure(self):
if os.environ["SUCCESS"].lower() == "true":
self.skipTest("Expecting success")
self.client_encryption = AsyncClientEncryption(
kms_providers={"gcp": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=CodecOptions(),
)
with self.assertRaises(EncryptionError):
await self.client_encryption.create_data_key("gcp", self.master_key)
@unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto")
async def test_02_success(self):
if os.environ["SUCCESS"].lower() == "false":
self.skipTest("Expecting failure")
self.client_encryption = AsyncClientEncryption(
kms_providers={"gcp": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=CodecOptions(),
)
await self.client_encryption.create_data_key("gcp", self.master_key)
class TestonDemandAzureCredentials(AsyncIntegrationTest):
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@async_client_context.require_version_min(4, 2, -1)
async def asyncSetUp(self):
await super().asyncSetUp()
self.master_key = {
"keyVaultEndpoint": os.environ["KEY_VAULT_ENDPOINT"],
"keyName": os.environ["KEY_NAME"],
}
@unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto")
async def test_01_failure(self):
if os.environ["SUCCESS"].lower() == "true":
self.skipTest("Expecting success")
self.client_encryption = AsyncClientEncryption(
kms_providers={"azure": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=CodecOptions(),
)
with self.assertRaises(EncryptionError):
await self.client_encryption.create_data_key("azure", self.master_key)
@unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto")
async def test_02_success(self):
if os.environ["SUCCESS"].lower() == "false":
self.skipTest("Expecting failure")
self.client_encryption = AsyncClientEncryption(
kms_providers={"azure": {}},
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=CodecOptions(),
)
await self.client_encryption.create_data_key("azure", self.master_key)
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@ -0,0 +1,122 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the read_concern module."""
from __future__ import annotations
import sys
import unittest
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context
from test.utils import OvertCommandListener
from bson.son import SON
from pymongo.errors import OperationFailure
from pymongo.read_concern import ReadConcern
_IS_SYNC = False
class TestReadConcern(AsyncIntegrationTest):
listener: OvertCommandListener
@async_client_context.require_connection
async def asyncSetUp(self):
await super().asyncSetUp()
self.listener = OvertCommandListener()
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
self.db = self.client.pymongo_test
await async_client_context.client.pymongo_test.create_collection("coll")
async def asyncTearDown(self):
await async_client_context.client.pymongo_test.drop_collection("coll")
def test_read_concern(self):
rc = ReadConcern()
self.assertIsNone(rc.level)
self.assertTrue(rc.ok_for_legacy)
rc = ReadConcern("majority")
self.assertEqual("majority", rc.level)
self.assertFalse(rc.ok_for_legacy)
rc = ReadConcern("local")
self.assertEqual("local", rc.level)
self.assertTrue(rc.ok_for_legacy)
self.assertRaises(TypeError, ReadConcern, 42)
async def test_read_concern_uri(self):
uri = f"mongodb://{await async_client_context.pair}/?readConcernLevel=majority"
client = await self.async_rs_or_single_client(uri, connect=False)
self.assertEqual(ReadConcern("majority"), client.read_concern)
async def test_invalid_read_concern(self):
coll = self.db.get_collection("coll", read_concern=ReadConcern("unknown"))
# We rely on the server to validate read concern.
with self.assertRaises(OperationFailure):
await coll.find_one()
async def test_find_command(self):
# readConcern not sent in command if not specified.
coll = self.db.coll
await coll.find({"field": "value"}).to_list()
self.assertNotIn("readConcern", self.listener.started_events[0].command)
self.listener.reset()
# Explicitly set readConcern to 'local'.
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
await coll.find({"field": "value"}).to_list()
self.assertEqualCommand(
SON(
[
("find", "coll"),
("filter", {"field": "value"}),
("readConcern", {"level": "local"}),
]
),
self.listener.started_events[0].command,
)
async def test_command_cursor(self):
# readConcern not sent in command if not specified.
coll = self.db.coll
await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list()
self.assertNotIn("readConcern", self.listener.started_events[0].command)
self.listener.reset()
# Explicitly set readConcern to 'local'.
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list()
self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"])
async def test_aggregate_out(self):
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
await (
await coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])
).to_list()
# Aggregate with $out supports readConcern MongoDB 4.2 onwards.
if async_client_context.version >= (4, 1):
self.assertIn("readConcern", self.listener.started_events[0].command)
else:
self.assertNotIn("readConcern", self.listener.started_events[0].command)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,730 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the replica_set_connection module."""
from __future__ import annotations
import contextlib
import copy
import pickle
import random
import sys
from typing import Any
from pymongo.operations import _Op
sys.path[0:0] = [""]
from test.asynchronous import (
AsyncIntegrationTest,
SkipTest,
async_client_context,
connected,
unittest,
)
from test.utils import (
OvertCommandListener,
async_wait_until,
one,
)
from test.version import Version
from bson.son import SON
from pymongo.asynchronous.helpers import anext
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.message import _maybe_add_read_preference
from pymongo.read_preferences import (
MovingAverage,
Nearest,
Primary,
PrimaryPreferred,
ReadPreference,
Secondary,
SecondaryPreferred,
)
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import Selection, readable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class TestSelections(AsyncIntegrationTest):
@async_client_context.require_connection
async def test_bool(self):
client = await self.async_single_client()
async def predicate():
return await client.address
await async_wait_until(predicate, "discover primary")
selection = Selection.from_topology_description(client._topology.description)
self.assertTrue(selection)
self.assertFalse(selection.with_server_descriptions([]))
class TestReadPreferenceObjects(unittest.TestCase):
prefs = [
Primary(),
PrimaryPreferred(),
Secondary(),
Nearest(tag_sets=[{"a": 1}, {"b": 2}]),
SecondaryPreferred(max_staleness=30),
]
def test_pickle(self):
for pref in self.prefs:
self.assertEqual(pref, pickle.loads(pickle.dumps(pref)))
def test_copy(self):
for pref in self.prefs:
self.assertEqual(pref, copy.copy(pref))
def test_deepcopy(self):
for pref in self.prefs:
self.assertEqual(pref, copy.deepcopy(pref))
class TestReadPreferencesBase(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
async def asyncSetUp(self):
await super().asyncSetUp()
# Insert some data so we can use cursors in read_from_which_host
await self.client.pymongo_test.test.drop()
await self.client.get_database(
"pymongo_test", write_concern=WriteConcern(w=async_client_context.w)
).test.insert_many([{"_id": i} for i in range(10)])
self.addAsyncCleanup(self.client.pymongo_test.test.drop)
async def read_from_which_host(self, client):
"""Do a find() on the client and return which host was used"""
cursor = client.pymongo_test.test.find()
await anext(cursor)
return cursor.address
async def read_from_which_kind(self, client):
"""Do a find() on the client and return 'primary' or 'secondary'
depending on which the client used.
"""
address = await self.read_from_which_host(client)
if address == await client.primary:
return "primary"
elif address in await client.secondaries:
return "secondary"
else:
self.fail(
f"Cursor used address {address}, expected either primary "
f"{client.primary} or secondaries {client.secondaries}"
)
async def assertReadsFrom(self, expected, **kwargs):
c = await self.async_rs_client(**kwargs)
async def predicate():
return len(c.nodes - await c.arbiters) == async_client_context.w
await async_wait_until(predicate, "discovered all nodes")
used = await self.read_from_which_kind(c)
self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}")
class TestSingleSecondaryOk(TestReadPreferencesBase):
async def test_reads_from_secondary(self):
host, port = next(iter(await self.client.secondaries))
# Direct connection to a secondary.
client = await self.async_single_client(host, port)
self.assertFalse(await client.is_primary)
# Regardless of read preference, we should be able to do
# "reads" with a direct connection to a secondary.
# See server-selection.rst#topology-type-single.
self.assertEqual(client.read_preference, ReadPreference.PRIMARY)
db = client.pymongo_test
coll = db.test
# Test find and find_one.
self.assertIsNotNone(await coll.find_one())
self.assertEqual(10, len(await coll.find().to_list()))
# Test some database helpers.
self.assertIsNotNone(await db.list_collection_names())
self.assertIsNotNone(await db.validate_collection("test"))
self.assertIsNotNone(await db.command("ping"))
# Test some collection helpers.
self.assertEqual(10, await coll.count_documents({}))
self.assertEqual(10, len(await coll.distinct("_id")))
self.assertIsNotNone(await coll.aggregate([]))
self.assertIsNotNone(await coll.index_information())
class TestReadPreferences(TestReadPreferencesBase):
async def test_mode_validation(self):
for mode in (
ReadPreference.PRIMARY,
ReadPreference.PRIMARY_PREFERRED,
ReadPreference.SECONDARY,
ReadPreference.SECONDARY_PREFERRED,
ReadPreference.NEAREST,
):
self.assertEqual(
mode, (await self.async_rs_client(read_preference=mode)).read_preference
)
with self.assertRaises(TypeError):
await self.async_rs_client(read_preference="foo")
async def test_tag_sets_validation(self):
S = Secondary(tag_sets=[{}])
self.assertEqual(
[{}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets
)
S = Secondary(tag_sets=[{"k": "v"}])
self.assertEqual(
[{"k": "v"}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets
)
S = Secondary(tag_sets=[{"k": "v"}, {}])
self.assertEqual(
[{"k": "v"}, {}],
(await self.async_rs_client(read_preference=S)).read_preference.tag_sets,
)
self.assertRaises(ValueError, Secondary, tag_sets=[])
# One dict not ok, must be a list of dicts
self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"})
self.assertRaises(TypeError, Secondary, tag_sets="foo")
self.assertRaises(TypeError, Secondary, tag_sets=["foo"])
async def test_threshold_validation(self):
self.assertEqual(
17,
(
await self.async_rs_client(localThresholdMS=17, connect=False)
).options.local_threshold_ms,
)
self.assertEqual(
42,
(
await self.async_rs_client(localThresholdMS=42, connect=False)
).options.local_threshold_ms,
)
self.assertEqual(
666,
(
await self.async_rs_client(localThresholdMS=666, connect=False)
).options.local_threshold_ms,
)
self.assertEqual(
0,
(
await self.async_rs_client(localThresholdMS=0, connect=False)
).options.local_threshold_ms,
)
with self.assertRaises(ValueError):
await self.async_rs_client(localthresholdms=-1)
async def test_zero_latency(self):
ping_times: set = set()
# Generate unique ping times.
while len(ping_times) < len(self.client.nodes):
ping_times.add(random.random())
for ping_time, host in zip(ping_times, self.client.nodes):
ServerDescription._host_to_round_trip_time[host] = ping_time
try:
client = await connected(
await self.async_rs_client(readPreference="nearest", localThresholdMS=0)
)
await async_wait_until(
lambda: client.nodes == self.client.nodes, "discovered all nodes"
)
host = await self.read_from_which_host(client)
for _ in range(5):
self.assertEqual(host, await self.read_from_which_host(client))
finally:
ServerDescription._host_to_round_trip_time.clear()
async def test_primary(self):
await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY)
async def test_primary_with_tags(self):
# Tags not allowed with PRIMARY
with self.assertRaises(ConfigurationError):
await self.async_rs_client(tag_sets=[{"dc": "ny"}])
async def test_primary_preferred(self):
await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
async def test_secondary(self):
await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY)
async def test_secondary_preferred(self):
await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED)
async def test_nearest(self):
# With high localThresholdMS, expect to read from any
# member
c = await self.async_rs_client(
read_preference=ReadPreference.NEAREST, localThresholdMS=10000
) # 10 seconds
data_members = {await self.client.primary} | await self.client.secondaries
# This is a probabilistic test; track which members we've read from so
# far, and keep reading until we've used all the members or give up.
# Chance of using only 2 of 3 members 10k times if there's no bug =
# 3 * (2/3)**10000, very low.
used: set = set()
i = 0
while data_members.difference(used) and i < 10000:
address = await self.read_from_which_host(c)
used.add(address)
i += 1
not_used = data_members.difference(used)
latencies = ", ".join(
"%s: %sms" % (server.description.address, server.description.round_trip_time)
for server in await (await c._get_topology()).select_servers(
readable_server_selector, _Op.TEST
)
)
self.assertFalse(
not_used,
"Expected to use primary and all secondaries for mode NEAREST,"
f" but didn't use {not_used}\nlatencies: {latencies}",
)
class ReadPrefTester(AsyncMongoClient):
def __init__(self, *args, **kwargs):
self.has_read_from = set()
client_options = async_client_context.client_options
client_options.update(kwargs)
super().__init__(*args, **client_options)
async def _conn_for_reads(self, read_preference, session, operation):
context = await super()._conn_for_reads(read_preference, session, operation)
return context
@contextlib.asynccontextmanager
async def _conn_from_server(self, read_preference, server, session):
context = super()._conn_from_server(read_preference, server, session)
async with context as (conn, read_preference):
await self.record_a_read(conn.address)
yield conn, read_preference
async def record_a_read(self, address):
server = await (await self._get_topology()).select_server_by_address(address, _Op.TEST, 0)
self.has_read_from.add(server)
_PREF_MAP = [
(Primary, SERVER_TYPE.RSPrimary),
(PrimaryPreferred, SERVER_TYPE.RSPrimary),
(Secondary, SERVER_TYPE.RSSecondary),
(SecondaryPreferred, SERVER_TYPE.RSSecondary),
(Nearest, "any"),
]
class TestCommandAndReadPreference(AsyncIntegrationTest):
c: ReadPrefTester
client_version: Version
@async_client_context.require_secondaries_count(1)
async def asyncSetUp(self):
await super().asyncSetUp()
self.c = ReadPrefTester(
# Ignore round trip times, to test ReadPreference modes only.
localThresholdMS=1000 * 1000,
)
self.client_version = await Version.async_from_client(self.c)
# mapReduce fails if the collection does not exist.
coll = self.c.pymongo_test.get_collection(
"test", write_concern=WriteConcern(w=async_client_context.w)
)
await coll.insert_one({})
async def asyncTearDown(self):
await self.c.drop_database("pymongo_test")
await self.c.close()
async def executed_on_which_server(self, client, fn, *args, **kwargs):
"""Execute fn(*args, **kwargs) and return the Server instance used."""
client.has_read_from.clear()
await fn(*args, **kwargs)
self.assertEqual(1, len(client.has_read_from))
return one(client.has_read_from)
async def assertExecutedOn(self, server_type, client, fn, *args, **kwargs):
server = await self.executed_on_which_server(client, fn, *args, **kwargs)
self.assertEqual(
SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type]
)
async def _test_fn(self, server_type, fn):
for _ in range(10):
if server_type == "any":
used = set()
for _ in range(1000):
server = await self.executed_on_which_server(self.c, fn)
used.add(server.description.address)
if len(used) == len(await self.c.secondaries) + 1:
# Success
break
assert await self.c.primary is not None
unused = (await self.c.secondaries).union({await self.c.primary}).difference(used)
if unused:
self.fail("Some members not used for NEAREST: %s" % (unused))
else:
await self.assertExecutedOn(server_type, self.c, fn)
async def _test_primary_helper(self, func):
# Helpers that ignore read preference.
await self._test_fn(SERVER_TYPE.RSPrimary, func)
async def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs):
for mode, server_type in _PREF_MAP:
new_coll = coll.with_options(read_preference=mode())
async def func():
return await getattr(new_coll, meth)(*args, **kwargs)
if secondary_ok:
await self._test_fn(server_type, func)
else:
await self._test_fn(SERVER_TYPE.RSPrimary, func)
async def test_command(self):
# Test that the generic command helper obeys the read preference
# passed to it.
for mode, server_type in _PREF_MAP:
async def func():
return await self.c.pymongo_test.command("dbStats", read_preference=mode())
await self._test_fn(server_type, func)
async def test_create_collection(self):
# create_collection runs listCollections on the primary to check if
# the collection already exists.
async def func():
return await self.c.pymongo_test.create_collection(
"some_collection%s" % random.randint(0, sys.maxsize)
)
await self._test_primary_helper(func)
async def test_count_documents(self):
await self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {})
async def test_estimated_document_count(self):
await self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count")
async def test_distinct(self):
await self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a")
async def test_aggregate(self):
await self._test_coll_helper(
True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}]
)
async def test_aggregate_write(self):
# 5.0 servers support $out on secondaries.
secondary_ok = async_client_context.version.at_least(5, 0)
await self._test_coll_helper(
secondary_ok,
self.c.pymongo_test.test,
"aggregate",
[{"$project": {"_id": 1}}, {"$out": "agg_write_test"}],
)
class TestMovingAverage(unittest.TestCase):
def test_moving_average(self):
avg = MovingAverage()
self.assertIsNone(avg.get())
avg.add_sample(10)
self.assertAlmostEqual(10, avg.get()) # type: ignore
avg.add_sample(20)
self.assertAlmostEqual(12, avg.get()) # type: ignore
avg.add_sample(30)
self.assertAlmostEqual(15.6, avg.get()) # type: ignore
class TestMongosAndReadPreference(AsyncIntegrationTest):
def test_read_preference_document(self):
pref = Primary()
self.assertEqual(pref.document, {"mode": "primary"})
pref = PrimaryPreferred()
self.assertEqual(pref.document, {"mode": "primaryPreferred"})
pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}])
self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]})
pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30)
self.assertEqual(
pref.document,
{"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30},
)
pref = Secondary()
self.assertEqual(pref.document, {"mode": "secondary"})
pref = Secondary(tag_sets=[{"dc": "sf"}])
self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]})
pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30)
self.assertEqual(
pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}
)
pref = SecondaryPreferred()
self.assertEqual(pref.document, {"mode": "secondaryPreferred"})
pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}])
self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]})
pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30)
self.assertEqual(
pref.document,
{"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30},
)
pref = Nearest()
self.assertEqual(pref.document, {"mode": "nearest"})
pref = Nearest(tag_sets=[{"dc": "sf"}])
self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]})
pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30)
self.assertEqual(
pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}
)
with self.assertRaises(TypeError):
# Float is prohibited.
Nearest(max_staleness=1.5) # type: ignore
with self.assertRaises(ValueError):
Nearest(max_staleness=0)
with self.assertRaises(ValueError):
Nearest(max_staleness=-2)
def test_read_preference_document_hedge(self):
cases = {
"primaryPreferred": PrimaryPreferred,
"secondary": Secondary,
"secondaryPreferred": SecondaryPreferred,
"nearest": Nearest,
}
for mode, cls in cases.items():
with self.assertRaises(TypeError):
cls(hedge=[]) # type: ignore
pref = cls(hedge={})
self.assertEqual(pref.document, {"mode": mode})
out = _maybe_add_read_preference({}, pref)
if cls == SecondaryPreferred:
# SecondaryPreferred without hedge doesn't add $readPreference.
self.assertEqual(out, {})
else:
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge: dict[str, Any] = {"enabled": True}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False, "extra": "option"}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
async def test_send_hedge(self):
cases = {
"primaryPreferred": PrimaryPreferred,
"secondaryPreferred": SecondaryPreferred,
"nearest": Nearest,
}
if await async_client_context.supports_secondary_read_pref:
cases["secondary"] = Secondary
listener = OvertCommandListener()
client = await self.async_rs_client(event_listeners=[listener])
await client.admin.command("ping")
for _mode, cls in cases.items():
pref = cls(hedge={"enabled": True})
coll = client.test.get_collection("test", read_preference=pref)
listener.reset()
await coll.find_one()
started = listener.started_events
self.assertEqual(len(started), 1, started)
cmd = started[0].command
if async_client_context.is_rs or async_client_context.is_mongos:
self.assertIn("$readPreference", cmd)
self.assertEqual(cmd["$readPreference"], pref.document)
else:
self.assertNotIn("$readPreference", cmd)
def test_maybe_add_read_preference(self):
# Primary doesn't add $readPreference
out = _maybe_add_read_preference({}, Primary())
self.assertEqual(out, {})
pref = PrimaryPreferred()
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}])
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = Secondary()
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = Secondary(tag_sets=[{"dc": "nyc"}])
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
# SecondaryPreferred without tag_sets or max_staleness doesn't add
# $readPreference
pref = SecondaryPreferred()
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, {})
pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}])
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = SecondaryPreferred(max_staleness=120)
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = Nearest()
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = Nearest(tag_sets=[{"dc": "nyc"}])
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))])
pref = Nearest()
out = _maybe_add_read_preference(criteria, pref)
self.assertEqual(
out,
SON(
[
("$query", {}),
("$orderby", SON([("_id", 1)])),
("$readPreference", pref.document),
]
),
)
pref = Nearest(tag_sets=[{"dc": "nyc"}])
out = _maybe_add_read_preference(criteria, pref)
self.assertEqual(
out,
SON(
[
("$query", {}),
("$orderby", SON([("_id", 1)])),
("$readPreference", pref.document),
]
),
)
@async_client_context.require_mongos
async def test_mongos(self):
res = await async_client_context.client.config.shards.find_one()
assert res is not None
shard = res["host"]
num_members = shard.count(",") + 1
if num_members == 1:
raise SkipTest("Need a replica set shard to test.")
coll = async_client_context.client.pymongo_test.get_collection(
"test", write_concern=WriteConcern(w=num_members)
)
await coll.drop()
res = await coll.insert_many([{} for _ in range(5)])
first_id = res.inserted_ids[0]
last_id = res.inserted_ids[-1]
# Note - this isn't a perfect test since there's no way to
# tell what shard member a query ran on.
for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()):
qcoll = coll.with_options(read_preference=pref)
results = await qcoll.find().sort([("_id", 1)]).to_list()
self.assertEqual(first_id, results[0]["_id"])
self.assertEqual(last_id, results[-1]["_id"])
results = await qcoll.find().sort([("_id", -1)]).to_list()
self.assertEqual(first_id, results[-1]["_id"])
self.assertEqual(last_id, results[0]["_id"])
@async_client_context.require_mongos
async def test_mongos_max_staleness(self):
# Sanity check that we're sending maxStalenessSeconds
coll = async_client_context.client.pymongo_test.get_collection(
"test", read_preference=SecondaryPreferred(max_staleness=120)
)
# No error
await coll.find_one()
coll = async_client_context.client.pymongo_test.get_collection(
"test", read_preference=SecondaryPreferred(max_staleness=10)
)
try:
await coll.find_one()
except OperationFailure as exc:
self.assertEqual(160, exc.code)
else:
self.fail("mongos accepted invalid staleness")
coll = (
await self.async_single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=120
)
).pymongo_test.test
# No error
await coll.find_one()
coll = (
await self.async_single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=10
)
).pymongo_test.test
try:
await coll.find_one()
except OperationFailure as exc:
self.assertEqual(160, exc.code)
else:
self.fail("mongos accepted invalid staleness")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,344 @@
# Copyright 2018-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the read and write concern tests."""
from __future__ import annotations
import json
import os
import sys
import warnings
from pathlib import Path
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.unified_format import generate_test_classes
from test.utils import OvertCommandListener
from pymongo import DESCENDING
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import (
BulkWriteError,
ConfigurationError,
WriteConcernError,
WriteError,
WTimeoutError,
)
from pymongo.operations import IndexModel, InsertOne
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern")
class TestReadWriteConcernSpec(AsyncIntegrationTest):
async def test_omit_default_read_write_concern(self):
listener = OvertCommandListener()
# Client with default readConcern and writeConcern
client = await self.async_rs_or_single_client(event_listeners=[listener])
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
await collection.insert_many([{} for _ in range(10)])
self.addAsyncCleanup(collection.drop)
self.addAsyncCleanup(client.pymongo_test.collection2.drop)
# Commands MUST NOT send the default read/write concern to the server.
async def rename_and_drop():
# Ensure collection exists.
await collection.insert_one({})
await collection.rename("collection2")
await client.pymongo_test.collection2.drop()
async def insert_command_default_write_concern():
await collection.database.command(
"insert", "collection", documents=[{}], write_concern=WriteConcern()
)
async def aggregate_op():
await (await collection.aggregate([])).to_list()
ops = [
("aggregate", aggregate_op),
("find", lambda: collection.find().to_list()),
("insert_one", lambda: collection.insert_one({})),
("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})),
("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})),
("delete_one", lambda: collection.delete_one({})),
("delete_many", lambda: collection.delete_many({})),
("bulk_write", lambda: collection.bulk_write([InsertOne({})])),
("rename_and_drop", rename_and_drop),
("command", insert_command_default_write_concern),
]
for name, f in ops:
listener.reset()
await f()
self.assertGreaterEqual(len(listener.started_events), 1)
for _i, event in enumerate(listener.started_events):
self.assertNotIn(
"readConcern",
event.command,
f"{name} sent default readConcern with {event.command_name}",
)
self.assertNotIn(
"writeConcern",
event.command,
f"{name} sent default writeConcern with {event.command_name}",
)
async def assertWriteOpsRaise(self, write_concern, expected_exception):
wc = write_concern.document
# Set socket timeout to avoid indefinite stalls
client = await self.async_rs_or_single_client(
w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000
)
db = client.get_database("pymongo_test")
coll = db.test
async def insert_command():
await coll.database.command(
"insert",
"new_collection",
documents=[{}],
writeConcern=write_concern.document,
parse_write_concern_error=True,
)
ops = [
("insert_one", lambda: coll.insert_one({})),
("insert_many", lambda: coll.insert_many([{}, {}])),
("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})),
("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})),
("delete_one", lambda: coll.delete_one({})),
("delete_many", lambda: coll.delete_many({})),
("bulk_write", lambda: coll.bulk_write([InsertOne({})])),
("command", insert_command),
("aggregate", lambda: coll.aggregate([{"$out": "out"}])),
# SERVER-46668 Delete all the documents in the collection to
# workaround a hang in createIndexes.
("delete_many", lambda: coll.delete_many({})),
("create_index", lambda: coll.create_index([("a", DESCENDING)])),
("create_indexes", lambda: coll.create_indexes([IndexModel("b")])),
("drop_index", lambda: coll.drop_index([("a", DESCENDING)])),
("create", lambda: db.create_collection("new")),
("rename", lambda: coll.rename("new")),
("drop", lambda: db.new.drop()),
]
# SERVER-47194: dropDatabase does not respect wtimeout in 3.6.
if async_client_context.version[:2] != (3, 6):
ops.append(("drop_database", lambda: client.drop_database(db)))
for name, f in ops:
# Ensure insert_many and bulk_write still raise BulkWriteError.
if name in ("insert_many", "bulk_write"):
expected = BulkWriteError
else:
expected = expected_exception
with self.assertRaises(expected, msg=name) as cm:
await f()
if expected == BulkWriteError:
bulk_result = cm.exception.details
assert bulk_result is not None
wc_errors = bulk_result["writeConcernErrors"]
self.assertTrue(wc_errors)
@async_client_context.require_replica_set
async def test_raise_write_concern_error(self):
self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test")
assert async_client_context.w is not None
await self.assertWriteOpsRaise(
WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError
)
@async_client_context.require_secondaries_count(1)
@async_client_context.require_test_commands
async def test_raise_wtimeout(self):
self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test")
self.addAsyncCleanup(self.enable_replication, async_client_context.client)
# Disable replication to guarantee a wtimeout error.
await self.disable_replication(async_client_context.client)
await self.assertWriteOpsRaise(
WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError
)
@async_client_context.require_failCommand_fail_point
async def test_error_includes_errInfo(self):
expected_wce = {
"code": 100,
"codeName": "UnsatisfiableWriteConcern",
"errmsg": "Not enough data-bearing nodes",
"errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}},
}
cause_wce = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {"failCommands": ["insert"], "writeConcernError": expected_wce},
}
async with self.fail_point(cause_wce):
# Write concern error on insert includes errInfo.
with self.assertRaises(WriteConcernError) as ctx:
await self.db.test.insert_one({})
self.assertEqual(ctx.exception.details, expected_wce)
# Test bulk_write as well.
with self.assertRaises(BulkWriteError) as ctx:
await self.db.test.bulk_write([InsertOne({})])
expected_details = {
"writeErrors": [],
"writeConcernErrors": [expected_wce],
"nInserted": 1,
"nUpserted": 0,
"nMatched": 0,
"nModified": 0,
"nRemoved": 0,
"upserted": [],
}
self.assertEqual(ctx.exception.details, expected_details)
@async_client_context.require_version_min(4, 9)
async def test_write_error_details_exposes_errinfo(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener])
db = client.errinfotest
self.addAsyncCleanup(client.drop_database, "errinfotest")
validator = {"x": {"$type": "string"}}
await db.create_collection("test", validator=validator)
with self.assertRaises(WriteError) as ctx:
await db.test.insert_one({"x": 1})
self.assertEqual(ctx.exception.code, 121)
self.assertIsNotNone(ctx.exception.details)
assert ctx.exception.details is not None
self.assertIsNotNone(ctx.exception.details.get("errInfo"))
for event in listener.succeeded_events:
if event.command_name == "insert":
self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details)
break
else:
self.fail("Couldn't find insert event.")
def normalize_write_concern(concern):
result = {}
for key in concern:
if key.lower() == "wtimeoutms":
result["wtimeout"] = concern[key]
elif key == "journal":
result["j"] = concern[key]
else:
result[key] = concern[key]
return result
def create_connection_string_test(test_case):
def run_test(self):
uri = test_case["uri"]
valid = test_case["valid"]
warning = test_case["warning"]
if not valid:
if warning is False:
self.assertRaises(
(ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False
)
else:
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False)
else:
client = AsyncMongoClient(uri, connect=False)
if "writeConcern" in test_case:
document = client.write_concern.document
self.assertEqual(document, normalize_write_concern(test_case["writeConcern"]))
if "readConcern" in test_case:
document = client.read_concern.document
self.assertEqual(document, test_case["readConcern"])
return run_test
def create_document_test(test_case):
def run_test(self):
valid = test_case["valid"]
if "writeConcern" in test_case:
normalized = normalize_write_concern(test_case["writeConcern"])
if not valid:
self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized)
else:
write_concern = WriteConcern(**normalized)
self.assertEqual(write_concern.document, test_case["writeConcernDocument"])
self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"])
self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"])
if "readConcern" in test_case:
# Any string for 'level' is equally valid
read_concern = ReadConcern(**test_case["readConcern"])
self.assertEqual(read_concern.document, test_case["readConcernDocument"])
self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"])
return run_test
def create_tests():
for dirpath, _, filenames in os.walk(TEST_PATH):
dirname = os.path.split(dirpath)[-1]
if dirname == "operation":
# This directory is tested by TestOperations.
continue
elif dirname == "connection-string":
create_test = create_connection_string_test
else:
create_test = create_document_test
for filename in filenames:
with open(os.path.join(dirpath, filename)) as test_stream:
test_cases = json.load(test_stream)["tests"]
fname = os.path.splitext(filename)[0]
for test_case in test_cases:
new_test = create_test(test_case)
test_name = "test_{}_{}_{}".format(
dirname.replace("-", "_"),
fname.replace("-", "_"),
str(test_case["description"].lower().replace(" ", "_")),
)
new_test.__name__ = test_name
setattr(TestReadWriteConcernSpec, new_test.__name__, new_test)
create_tests()
# Generate unified tests.
# PyMongo does not support MapReduce.
globals().update(
generate_test_classes(
os.path.join(TEST_PATH, "operation"),
module=__name__,
expected_failures=["MapReduce .*"],
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,46 @@
# Copyright 2022-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Retryable Reads unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified")
# Generate unified tests.
# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects.
globals().update(
generate_test_classes(
TEST_PATH,
module=__name__,
expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"],
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,39 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Retryable Writes unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,41 @@
# Copyright 2024-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run Command unified tests."""
from __future__ import annotations
import os
import unittest
from pathlib import Path
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command")
globals().update(
generate_test_classes(
os.path.join(TEST_PATH, "unified"),
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,45 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the server selection logging unified format spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging")
globals().update(
generate_test_classes(
TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,77 @@
# Copyright 2015 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the topology module."""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous import AsyncPyMongoTestCase
from pymongo.read_preferences import MovingAverage
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt")
class TestAllScenarios(AsyncPyMongoTestCase):
pass
def create_test(scenario_def):
def run_scenario(self):
moving_average = MovingAverage()
if scenario_def["avg_rtt_ms"] != "NULL":
moving_average.add_sample(scenario_def["avg_rtt_ms"])
if scenario_def["new_rtt_ms"] != "NULL":
moving_average.add_sample(scenario_def["new_rtt_ms"])
self.assertAlmostEqual(moving_average.get(), scenario_def["new_avg_rtt"])
return run_scenario
def create_tests():
for dirpath, _, filenames in os.walk(TEST_PATH):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = json.load(scenario_stream)
# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)
create_tests()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,40 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Sessions unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,361 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the SRV support tests."""
from __future__ import annotations
import asyncio
import sys
import time
from typing import Any
sys.path[0:0] = [""]
from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest
from test.utils import FunctionCallRecorder, async_wait_until
import pymongo
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.srv_resolver import _have_dnspython
_IS_SYNC = False
WAIT_TIME = 0.1
class SrvPollingKnobs:
def __init__(
self,
ttl_time=None,
min_srv_rescan_interval=None,
nodelist_callback=None,
count_resolver_calls=False,
):
self.ttl_time = ttl_time
self.min_srv_rescan_interval = min_srv_rescan_interval
self.nodelist_callback = nodelist_callback
self.count_resolver_calls = count_resolver_calls
self.old_min_srv_rescan_interval = None
self.old_dns_resolver_response = None
def enable(self):
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
if self.min_srv_rescan_interval is not None:
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
def mock_get_hosts_and_min_ttl(resolver, *args):
assert self.old_dns_resolver_response is not None
nodes, ttl = self.old_dns_resolver_response(resolver)
if self.nodelist_callback is not None:
nodes = self.nodelist_callback()
if self.ttl_time is not None:
ttl = self.ttl_time
return nodes, ttl
patch_func: Any
if self.count_resolver_calls:
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
self.old_dns_resolver_response
)
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
class TestSrvPolling(AsyncPyMongoTestCase):
BASE_SRV_RESPONSE = [
("localhost.test.build.10gen.cc", 27017),
("localhost.test.build.10gen.cc", 27018),
]
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
async def asyncSetUp(self):
# Patch timeouts to ensure short rescan SRV interval.
self.client_knobs = client_knobs(
heartbeat_frequency=WAIT_TIME,
min_heartbeat_interval=WAIT_TIME,
events_queue_frequency=WAIT_TIME,
)
self.client_knobs.enable()
async def asyncTearDown(self):
self.client_knobs.disable()
def get_nodelist(self, client):
return client._topology.description.server_descriptions().keys()
async def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
"""Check if the client._topology eventually sees all nodes in the
expected_nodelist.
"""
def predicate():
nodelist = self.get_nodelist(client)
if set(expected_nodelist) == set(nodelist):
return True
return False
await async_wait_until(predicate, "see expected nodelist", timeout=timeout)
async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
"""Check if the client._topology ever deviates from seeing all nodes
in the expected_nodelist. Consistency is checked after sleeping for
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
least once.
"""
def predicate():
if set(expected_nodelist) == set(self.get_nodelist(client)):
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
return False
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
nodelist = self.get_nodelist(client)
if set(expected_nodelist) != set(nodelist):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
1,
"resolver was never called",
)
return True
async def run_scenario(self, dns_response, expect_change):
self.assertEqual(_have_dnspython(), True)
if callable(dns_response):
dns_resolver_response = dns_response
else:
def dns_resolver_response():
return dns_response
if expect_change:
assertion_method = self.assert_nodelist_change
count_resolver_calls = False
expected_response = dns_response
else:
assertion_method = self.assert_nodelist_nochange
count_resolver_calls = True
expected_response = self.BASE_SRV_RESPONSE
# Patch timeouts to ensure short test running times.
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING)
await client.aconnect()
await self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
# Patch list of hosts returned by DNS query.
with SrvPollingKnobs(
nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls
):
await assertion_method(expected_response, client)
async def test_addition(self):
response = self.BASE_SRV_RESPONSE[:]
response.append(("localhost.test.build.10gen.cc", 27019))
await self.run_scenario(response, True)
async def test_removal(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(("localhost.test.build.10gen.cc", 27018))
await self.run_scenario(response, True)
async def test_replace_one(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(("localhost.test.build.10gen.cc", 27018))
response.append(("localhost.test.build.10gen.cc", 27019))
await self.run_scenario(response, True)
async def test_replace_both_with_one(self):
response = [("localhost.test.build.10gen.cc", 27019)]
await self.run_scenario(response, True)
async def test_replace_both_with_two(self):
response = [
("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27020),
]
await self.run_scenario(response, True)
async def test_dns_failures(self):
from dns import exception
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
def response_callback(*args):
raise exc("DNS Failure!")
await self.run_scenario(response_callback, False)
async def test_dns_record_lookup_empty(self):
response: list = []
await self.run_scenario(response, False)
async def _test_recover_from_initial(self, initial_callback):
# Construct a valid final response callback distinct from base.
response_final = self.BASE_SRV_RESPONSE[:]
response_final.pop()
def final_callback():
return response_final
with SrvPollingKnobs(
ttl_time=WAIT_TIME,
min_srv_rescan_interval=WAIT_TIME,
nodelist_callback=initial_callback,
count_resolver_calls=True,
):
# Client uses unpatched method to get initial nodelist
client = self.simple_client(self.CONNECTION_STRING)
await client.aconnect()
# Invalid DNS resolver response should not change nodelist.
await self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
with SrvPollingKnobs(
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback
):
# Nodelist should reflect new valid DNS resolver response.
await self.assert_nodelist_change(response_final, client)
async def test_recover_from_initially_empty_seedlist(self):
def empty_seedlist():
return []
await self._test_recover_from_initial(empty_seedlist)
async def test_recover_from_initially_erroring_seedlist(self):
def erroring_seedlist():
raise ConfigurationError
await self._test_recover_from_initial(erroring_seedlist)
async def test_10_all_dns_selected(self):
response = [
("localhost.test.build.10gen.cc", 27017),
("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27020),
]
def nodelist_callback():
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
await client.aconnect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
await self.assert_nodelist_change(response, client)
async def test_11_all_dns_selected(self):
response = [
("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27020),
]
def nodelist_callback():
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
await client.aconnect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
await self.assert_nodelist_change(response, client)
async def test_12_new_dns_randomly_selected(self):
response = [
("localhost.test.build.10gen.cc", 27020),
("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27017),
]
def nodelist_callback():
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
await client.aconnect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
await asyncio.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
final_topology = set(client.topology_description.server_descriptions())
self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology)
self.assertEqual(len(final_topology), 2)
async def test_does_not_flipflop(self):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
await client.aconnect()
old = set(client.topology_description.server_descriptions())
await asyncio.sleep(4 * WAIT_TIME)
new = set(client.topology_description.server_descriptions())
self.assertSetEqual(old, new)
async def test_srv_service_name(self):
# Construct a valid final response callback distinct from base.
response = [
("localhost.test.build.10gen.cc.", 27019),
("localhost.test.build.10gen.cc.", 27020),
]
def nodelist_callback():
return response
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
)
await client.aconnect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
await self.assert_nodelist_change(response, client)
async def test_srv_waits_to_poll(self):
modified = [("localhost.test.build.10gen.cc", 27019)]
def resolver_response():
return modified
with SrvPollingKnobs(
ttl_time=WAIT_TIME,
min_srv_rescan_interval=WAIT_TIME,
nodelist_callback=resolver_response,
):
client = self.simple_client(self.CONNECTION_STRING)
await client.aconnect()
with self.assertRaises(AssertionError):
await self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2)
def test_import_dns_resolver(self):
# Regression test for PYTHON-4407
import dns.resolver
self.assertTrue(hasattr(dns.resolver, "resolve"))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,662 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SSL support."""
from __future__ import annotations
import os
import pathlib
import socket
import sys
sys.path[0:0] = [""]
from test.asynchronous import (
HAVE_IPADDRESS,
AsyncIntegrationTest,
AsyncPyMongoTestCase,
SkipTest,
async_client_context,
connected,
remove_all_users,
unittest,
)
from test.utils import (
EventListener,
OvertCommandListener,
cat_files,
ignore_deprecations,
)
from urllib.parse import quote_plus
from pymongo import AsyncMongoClient, ssl_support
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
from pymongo.write_concern import WriteConcern
_HAVE_PYOPENSSL = False
try:
# All of these must be available to use PyOpenSSL
import OpenSSL
import requests
import service_identity
# Ensure service_identity>=18.1 is installed
from service_identity.pyopenssl import verify_ip_address
from pymongo.ocsp_support import _load_trusted_ca_certs
_HAVE_PYOPENSSL = True
except ImportError:
_load_trusted_ca_certs = None # type: ignore
if HAVE_SSL:
import ssl
_IS_SYNC = False
if _IS_SYNC:
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates")
else:
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates")
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem")
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
CA_BUNDLE_PEM = os.path.join(CERT_PATH, "trusted-ca.pem")
CRL_PEM = os.path.join(CERT_PATH, "crl.pem")
MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client"
# To fully test this start a mongod instance (built with SSL support) like so:
# mongod --dbpath /path/to/data/directory --sslOnNormalPorts \
# --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \
# --sslCAFile /path/to/pymongo/test/certificates/ca.pem \
# --sslWeakCertificateValidation
# Also, make sure you have 'server' as an alias for localhost in /etc/hosts
#
# Note: For all replica set tests to pass, the replica set configuration must
# use 'localhost' for the hostname of all hosts.
class TestClientSSL(AsyncPyMongoTestCase):
@unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.")
def test_no_ssl_module(self):
# Explicit
self.assertRaises(ConfigurationError, self.simple_client, ssl=True)
# Implied
self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM)
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
@ignore_deprecations
def test_config_ssl(self):
# Tests various ssl configurations
self.assertRaises(ValueError, self.simple_client, ssl="foo")
self.assertRaises(
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
)
self.assertRaises(TypeError, self.simple_client, ssl=0)
self.assertRaises(TypeError, self.simple_client, ssl=5.5)
self.assertRaises(TypeError, self.simple_client, ssl=[])
self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile")
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True)
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[])
# Test invalid combinations
self.assertRaises(
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
)
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM)
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM)
self.assertRaises(
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False
)
self.assertRaises(
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False
)
self.assertRaises(
ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False
)
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_use_pyopenssl_when_available(self):
self.assertTrue(_ssl.IS_PYOPENSSL)
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
def test_load_trusted_ca_certs(self):
trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM)
self.assertEqual(2, len(trusted_ca_certs))
class TestSSL(AsyncIntegrationTest):
saved_port: int
async def assertClientWorks(self, client):
coll = client.pymongo_test.ssl_test.with_options(
write_concern=WriteConcern(w=async_client_context.w)
)
await coll.drop()
await coll.insert_one({"ssl": True})
self.assertTrue((await coll.find_one())["ssl"])
await coll.drop()
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
async def asyncSetUp(self):
await super().asyncSetUp()
# MongoClient should connect to the primary by default.
self.saved_port = AsyncMongoClient.PORT
AsyncMongoClient.PORT = await async_client_context.port
async def asyncTearDown(self):
AsyncMongoClient.PORT = self.saved_port
@async_client_context.require_tls
async def test_simple_ssl(self):
# Expects the server to be running with ssl and with
# no --sslPEMKeyFile or with --sslWeakCertificateValidation
await self.assertClientWorks(self.client)
@async_client_context.require_tlsCertificateKeyFile
@ignore_deprecations
async def test_tlsCertificateKeyFilePassword(self):
# Expects the server to be running with server.pem and ca.pem
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
tlsCertificateKeyFilePassword="qwerty",
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=1000,
)
else:
await connected(
self.simple_client(
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
tlsCertificateKeyFilePassword="qwerty",
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=5000,
**self.credentials, # type: ignore[arg-type]
)
)
uri_fmt = (
"mongodb://localhost/?ssl=true"
"&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty"
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000"
)
await connected(
self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_no_auth
@ignore_deprecations
async def test_cert_ssl_implicitly_set(self):
# Expects the server to be running with server.pem and ca.pem
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
#
# test that setting tlsCertificateKeyFile causes ssl to be set to True
client = self.simple_client(
await async_client_context.host,
await async_client_context.port,
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
response = await client.admin.command(HelloCompat.LEGACY_CMD)
if "setName" in response:
client = self.simple_client(
await async_client_context.pair,
replicaSet=response["setName"],
w=len(response["hosts"]),
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
await self.assertClientWorks(client)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_no_auth
@ignore_deprecations
async def test_cert_ssl_validation(self):
# Expects the server to be running with server.pem and ca.pem
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
#
client = self.simple_client(
"localhost",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
)
response = await client.admin.command(HelloCompat.LEGACY_CMD)
if "setName" in response:
if response["primary"].split(":")[0] != "localhost":
raise SkipTest(
"No hosts in the replicaset for 'localhost'. "
"Cannot validate hostname in the certificate"
)
client = self.simple_client(
"localhost",
replicaSet=response["setName"],
w=len(response["hosts"]),
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
)
await self.assertClientWorks(client)
if HAVE_IPADDRESS:
client = self.simple_client(
"127.0.0.1",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
)
await self.assertClientWorks(client)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_no_auth
@ignore_deprecations
async def test_cert_ssl_uri_support(self):
# Expects the server to be running with server.pem and ca.pem
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
#
uri_fmt = (
"mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates"
"=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false"
)
client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
await self.assertClientWorks(client)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_server_resolvable
@ignore_deprecations
async def test_cert_ssl_validation_hostname_matching(self):
# Expects the server to be running with server.pem and ca.pem
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
ctx = get_ssl_context(None, None, None, None, True, True, False)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, True, False, False)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, True, False)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, False, False)
self.assertTrue(ctx.check_hostname)
response = await self.client.admin.command(HelloCompat.LEGACY_CMD)
with self.assertRaises(ConnectionFailure):
await connected(
self.simple_client(
"server",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=500,
**self.credentials, # type: ignore[arg-type]
)
)
await connected(
self.simple_client(
"server",
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=500,
**self.credentials, # type: ignore[arg-type]
)
)
if "setName" in response:
with self.assertRaises(ConnectionFailure):
await connected(
self.simple_client(
"server",
replicaSet=response["setName"],
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=500,
**self.credentials, # type: ignore[arg-type]
)
)
await connected(
self.simple_client(
"server",
replicaSet=response["setName"],
ssl=True,
tlsCertificateKeyFile=CLIENT_PEM,
tlsAllowInvalidCertificates=False,
tlsCAFile=CA_PEM,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=500,
**self.credentials, # type: ignore[arg-type]
)
)
@async_client_context.require_tlsCertificateKeyFile
@ignore_deprecations
async def test_tlsCRLFile_support(self):
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
tlsCRLFile=CRL_PEM,
serverSelectionTimeoutMS=1000,
)
else:
await connected(
self.simple_client(
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
serverSelectionTimeoutMS=1000,
**self.credentials, # type: ignore[arg-type]
)
)
with self.assertRaises(ConnectionFailure):
await connected(
self.simple_client(
"localhost",
ssl=True,
tlsCAFile=CA_PEM,
tlsCRLFile=CRL_PEM,
serverSelectionTimeoutMS=1000,
**self.credentials, # type: ignore[arg-type]
)
)
uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
await connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
uri_fmt = (
"mongodb://localhost/?ssl=true&tlsCRLFile=%s"
"&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
)
with self.assertRaises(ConnectionFailure):
await connected(
self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_server_resolvable
@ignore_deprecations
async def test_validation_with_system_ca_certs(self):
# Expects the server to be running with server.pem and ca.pem.
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
# --sslWeakCertificateValidation
#
self.patch_system_certs(CA_PEM)
with self.assertRaises(ConnectionFailure):
# Server cert is verified but hostname matching fails
await connected(
self.simple_client(
"server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
) # type: ignore[arg-type]
)
# Server cert is verified. Disable hostname matching.
await connected(
self.simple_client(
"server",
ssl=True,
tlsAllowInvalidHostnames=True,
serverSelectionTimeoutMS=1000,
**self.credentials, # type: ignore[arg-type]
)
)
# Server cert and hostname are verified.
await connected(
self.simple_client(
"localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
) # type: ignore[arg-type]
)
# Server cert and hostname are verified.
await connected(
self.simple_client(
"mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000",
**self.credentials, # type: ignore[arg-type]
)
)
def test_system_certs_config_error(self):
ctx = get_ssl_context(None, None, None, None, True, True, False)
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
ctx, "load_default_certs"
):
raise SkipTest("Can't test when system CA certificates are loadable.")
have_certifi = ssl_support.HAVE_CERTIFI
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
# Force the test regardless of environment.
ssl_support.HAVE_CERTIFI = False
ssl_support.HAVE_WINCERTSTORE = False
try:
with self.assertRaises(ConfigurationError):
self.simple_client("mongodb://localhost/?ssl=true")
finally:
ssl_support.HAVE_CERTIFI = have_certifi
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
def test_certifi_support(self):
if hasattr(ssl, "SSLContext"):
# SSLSocket doesn't provide ca_certs attribute on pythons
# with SSLContext and SSLContext provides no information
# about ca_certs.
raise SkipTest("Can't test when SSLContext available.")
if not ssl_support.HAVE_CERTIFI:
raise SkipTest("Need certifi to test certifi support.")
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
# Force the test on Windows, regardless of environment.
ssl_support.HAVE_WINCERTSTORE = False
try:
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
finally:
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
def test_wincertstore(self):
if sys.platform != "win32":
raise SkipTest("Only valid on Windows.")
if hasattr(ssl, "SSLContext"):
# SSLSocket doesn't provide ca_certs attribute on pythons
# with SSLContext and SSLContext provides no information
# about ca_certs.
raise SkipTest("Can't test when SSLContext available.")
if not ssl_support.HAVE_WINCERTSTORE:
raise SkipTest("Need wincertstore to test wincertstore.")
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
@async_client_context.require_auth
@async_client_context.require_tlsCertificateKeyFile
@ignore_deprecations
async def test_mongodb_x509_auth(self):
host, port = await async_client_context.host, await async_client_context.port
self.addAsyncCleanup(remove_all_users, async_client_context.client["$external"])
# Give x509 user all necessary privileges.
await async_client_context.create_user(
"$external",
MONGODB_X509_USERNAME,
roles=[
{"role": "readWriteAnyDatabase", "db": "admin"},
{"role": "userAdminAnyDatabase", "db": "admin"},
],
)
noauth = self.simple_client(
await async_client_context.pair,
ssl=True,
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
with self.assertRaises(OperationFailure):
await noauth.pymongo_test.test.find_one()
listener = EventListener()
auth = self.simple_client(
await async_client_context.pair,
authMechanism="MONGODB-X509",
ssl=True,
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
event_listeners=[listener],
)
# No error
await auth.pymongo_test.test.find_one()
names = listener.started_command_names()
if async_client_context.version.at_least(4, 4, -1):
# Speculative auth skips the authenticate command.
self.assertEqual(names, ["find"])
else:
self.assertEqual(names, ["authenticate", "find"])
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
quote_plus(MONGODB_X509_USERNAME),
host,
port,
)
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
# No error
await client.pymongo_test.test.find_one()
uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port)
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
# No error
await client.pymongo_test.test.find_one()
# Auth should fail if username and certificate do not match
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
quote_plus("not the username"),
host,
port,
)
bad_client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
with self.assertRaises(OperationFailure):
await bad_client.pymongo_test.test.find_one()
bad_client = self.simple_client(
await async_client_context.pair,
username="not the username",
authMechanism="MONGODB-X509",
ssl=True,
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
with self.assertRaises(OperationFailure):
await bad_client.pymongo_test.test.find_one()
# Invalid certificate (using CA certificate as client certificate)
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
quote_plus(MONGODB_X509_USERNAME),
host,
port,
)
try:
await connected(
self.simple_client(
uri,
ssl=True,
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CA_PEM,
serverSelectionTimeoutMS=1000,
)
)
except (ConnectionFailure, ConfigurationError):
pass
else:
self.fail("Invalid certificate accepted.")
@async_client_context.require_tlsCertificateKeyFile
@ignore_deprecations
async def test_connect_with_ca_bundle(self):
def remove(path):
try:
os.remove(path)
except OSError:
pass
temp_ca_bundle = os.path.join(CERT_PATH, "trusted-ca-bundle.pem")
self.addCleanup(remove, temp_ca_bundle)
# Add the CA cert file to the bundle.
cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM)
async with self.simple_client(
"localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle
) as client:
self.assertTrue(await client.admin.command("ping"))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,228 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the database module."""
from __future__ import annotations
import sys
import time
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import (
HeartbeatEventListener,
ServerEventListener,
async_wait_until,
)
from pymongo import monitoring
from pymongo.hello import HelloCompat
_IS_SYNC = False
class TestStreamingProtocol(AsyncIntegrationTest):
@async_client_context.require_failCommand_appName
async def test_failCommand_streaming(self):
listener = ServerEventListener()
hb_listener = HeartbeatEventListener()
client = await self.async_rs_or_single_client(
event_listeners=[listener, hb_listener],
heartbeatFrequencyMS=500,
appName="failingHeartbeatTest",
)
# Force a connection.
await client.admin.command("ping")
address = await client.address
listener.reset()
fail_hello = {
"configureFailPoint": "failCommand",
"mode": {"times": 4},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"closeConnection": False,
"errorCode": 10107,
"appName": "failingHeartbeatTest",
},
}
async with self.fail_point(fail_hello):
def _marked_unknown(event):
return (
event.server_address == address
and not event.new_description.is_server_type_known
)
def _discovered_node(event):
return (
event.server_address == address
and not event.previous_description.is_server_type_known
and event.new_description.is_server_type_known
)
def marked_unknown():
return len(listener.matching(_marked_unknown)) >= 1
def rediscovered():
return len(listener.matching(_discovered_node)) >= 1
# Topology events are not published synchronously
await async_wait_until(marked_unknown, "mark node unknown")
await async_wait_until(rediscovered, "rediscover node")
# Server should be selectable.
await client.admin.command("ping")
@async_client_context.require_failCommand_appName
async def test_streaming_rtt(self):
listener = ServerEventListener()
hb_listener = HeartbeatEventListener()
# On Windows, RTT can actually be 0.0 because time.time() only has
# 1-15 millisecond resolution. We need to delay the initial hello
# to ensure that RTT is never zero.
name = "streamingRttTest"
delay_hello: dict = {
"configureFailPoint": "failCommand",
"mode": {"times": 1000},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"blockConnection": True,
"blockTimeMS": 20,
# This can be uncommented after SERVER-49220 is fixed.
# 'appName': name,
},
}
async with self.fail_point(delay_hello):
client = await self.async_rs_or_single_client(
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
)
# Force a connection.
await client.admin.command("ping")
address = await client.address
delay_hello["data"]["blockTimeMS"] = 500
delay_hello["data"]["appName"] = name
async with self.fail_point(delay_hello):
def rtt_exceeds_250_ms():
# XXX: Add a public TopologyDescription getter to MongoClient?
topology = client._topology
sd = topology.description.server_descriptions()[address]
assert sd.round_trip_time is not None
return sd.round_trip_time > 0.250
await async_wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT")
# Server should be selectable.
await client.admin.command("ping")
def changed_event(event):
return event.server_address == address and isinstance(
event, monitoring.ServerDescriptionChangedEvent
)
# There should only be one event published, for the initial discovery.
events = listener.matching(changed_event)
self.assertEqual(1, len(events))
self.assertGreater(events[0].new_description.round_trip_time, 0)
@async_client_context.require_failCommand_appName
async def test_monitor_waits_after_server_check_error(self):
# This test implements:
# https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-discovery-and-monitoring-tests.md#monitors-sleep-at-least-minheartbeatfreqencyms-between-checks
fail_hello = {
"mode": {"times": 5},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 1234,
"appName": "SDAMMinHeartbeatFrequencyTest",
},
}
async with self.fail_point(fail_hello):
start = time.time()
client = await self.async_single_client(
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
)
# Force a connection.
await client.admin.command("ping")
duration = time.time() - start
# Explanation of the expected events:
# 0ms: run configureFailPoint
# 1ms: create MongoClient
# 2ms: failed monitor handshake, 1
# 502ms: failed monitor handshake, 2
# 1002ms: failed monitor handshake, 3
# 1502ms: failed monitor handshake, 4
# 2002ms: failed monitor handshake, 5
# 2502ms: monitor handshake succeeds
# 2503ms: run awaitable hello
# 2504ms: application handshake succeeds
# 2505ms: ping command succeeds
self.assertGreaterEqual(duration, 2)
self.assertLessEqual(duration, 3.5)
@async_client_context.require_failCommand_appName
async def test_heartbeat_awaited_flag(self):
hb_listener = HeartbeatEventListener()
client = await self.async_single_client(
event_listeners=[hb_listener],
heartbeatFrequencyMS=500,
appName="heartbeatEventAwaitedFlag",
)
# Force a connection.
await client.admin.command("ping")
def hb_succeeded(event):
return isinstance(event, monitoring.ServerHeartbeatSucceededEvent)
def hb_failed(event):
return isinstance(event, monitoring.ServerHeartbeatFailedEvent)
fail_heartbeat = {
"mode": {"times": 2},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"closeConnection": True,
"appName": "heartbeatEventAwaitedFlag",
},
}
async with self.fail_point(fail_heartbeat):
await async_wait_until(
lambda: hb_listener.matching(hb_failed), "published failed event"
)
# Reconnect.
await client.admin.command("ping")
hb_succeeded_events = hb_listener.matching(hb_succeeded)
hb_failed_events = hb_listener.matching(hb_failed)
self.assertFalse(hb_succeeded_events[0].awaited)
self.assertTrue(hb_failed_events[0].awaited)
# Depending on thread scheduling, the failed heartbeat could occur on
# the second or third check.
events = [type(e) for e in hb_listener.events[:4]]
if events == [
monitoring.ServerHeartbeatStartedEvent,
monitoring.ServerHeartbeatSucceededEvent,
monitoring.ServerHeartbeatStartedEvent,
monitoring.ServerHeartbeatFailedEvent,
]:
self.assertFalse(hb_succeeded_events[1].awaited)
else:
self.assertTrue(hb_succeeded_events[1].awaited)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,56 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the Transactions unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import client_context, unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
@client_context.require_no_mmap
def setUpModule():
pass
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
# Location of JSON test specifications for transactions-convenient-api.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified"
)
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,99 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Any
sys.path[0:0] = [""]
from test import UnitTest, unittest
from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes
from bson import ObjectId
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format")
globals().update(
generate_test_classes(
os.path.join(TEST_PATH, "valid-pass"),
module=__name__,
class_name_prefix="UnifiedTestFormat",
expected_failures=[
"Client side error in command starting transaction", # PYTHON-1894
],
RUN_ON_SERVERLESS=False,
)
)
globals().update(
generate_test_classes(
os.path.join(TEST_PATH, "valid-fail"),
module=__name__,
class_name_prefix="UnifiedTestFormat",
bypass_test_generation_errors=True,
expected_failures=[
".*", # All tests expected to fail
],
RUN_ON_SERVERLESS=False,
)
)
class TestMatchEvaluatorUtil(UnitTest):
def setUp(self):
self.match_evaluator = MatchEvaluatorUtil(self)
def test_unsetOrMatches(self):
spec: dict[str, Any] = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}
for actual in [{}, {"y": 2}, None]:
self.match_evaluator.match_result(spec, actual)
spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}}
for actual in [{}, {"x": {}}, {"x": {"y": 2}}]:
self.match_evaluator.match_result(spec, actual)
spec = {"y": {"$$unsetOrMatches": {"$$exists": True}}}
self.match_evaluator.match_result(spec, {})
self.match_evaluator.match_result(spec, {"y": 2})
self.match_evaluator.match_result(spec, {"x": 1})
self.match_evaluator.match_result(spec, {"y": {}})
def test_type(self):
self.match_evaluator.match_result(
{
"operationType": "insert",
"ns": {"db": "change-stream-tests", "coll": "test"},
"fullDocument": {"_id": {"$$type": "objectId"}, "x": 1},
},
{
"operationType": "insert",
"fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1},
"ns": {"db": "change-stream-tests", "coll": "test"},
},
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,86 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import sys
from pathlib import Path
from test.asynchronous.unified_format import generate_test_classes
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import OvertCommandListener
from pymongo.server_api import ServerApi
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
class TestServerApiIntegration(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
def assertServerApi(self, event):
self.assertIn("apiVersion", event.command)
self.assertEqual(event.command["apiVersion"], "1")
def assertServerApiInAllCommands(self, events):
for event in events:
self.assertServerApi(event)
@async_client_context.require_version_min(4, 7)
async def test_command_options(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
server_api=ServerApi("1"), event_listeners=[listener]
)
coll = client.test.test
await coll.insert_many([{} for _ in range(100)])
self.addAsyncCleanup(coll.delete_many, {})
await coll.find(batch_size=25).to_list()
await client.admin.command("ping")
self.assertServerApiInAllCommands(listener.started_events)
@async_client_context.require_version_min(4, 7)
@async_client_context.require_transactions
async def test_command_options_txn(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
server_api=ServerApi("1"), event_listeners=[listener]
)
coll = client.test.test
await coll.insert_many([{} for _ in range(100)])
self.addAsyncCleanup(coll.delete_many, {})
listener.reset()
async with client.start_session() as s, await s.start_transaction():
await coll.insert_many([{} for _ in range(100)], session=s)
await coll.find(batch_size=25, session=s).to_list()
await client.test.command("find", "test", session=s)
self.assertServerApiInAllCommands(listener.started_events)
if __name__ == "__main__":
unittest.main()

View File

@ -35,6 +35,7 @@ from test.asynchronous import (
client_knobs,
unittest,
)
from test.asynchronous.utils_spec_runner import SpecRunnerTask
from test.unified_format_shared import (
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
@ -58,7 +59,6 @@ from test.utils import (
snake_to_camel,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional
@ -382,8 +382,8 @@ class EntityMapUtil:
return
elif entity_type == "thread":
name = spec["id"]
thread = SpecRunnerThread(name)
thread.start()
thread = SpecRunnerTask(name)
await thread.start()
self[name] = thread
return
@ -711,7 +711,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
return await target.command(**kwargs)
async def _databaseOperation_runCursorCommand(self, target, **kwargs):
return list(await self._databaseOperation_createCommandCursor(target, **kwargs))
return await (await self._databaseOperation_createCommandCursor(target, **kwargs)).to_list()
async def _databaseOperation_createCommandCursor(self, target, **kwargs):
self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase)
@ -1177,16 +1177,16 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
wait_until(primary_changed, "change primary", timeout=timeout)
def _testOperation_runOnThread(self, spec):
async def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
def _testOperation_waitForThread(self, spec):
async def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.stop()
thread.join(10)
await thread.stop()
await thread.join(10)
if thread.exc:
raise thread.exc
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))

View File

@ -18,11 +18,11 @@ from __future__ import annotations
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.asynchronous.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
CompareType,
@ -47,6 +47,7 @@ from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
@ -55,38 +56,36 @@ from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class SpecRunnerThread(threading.Thread):
class SpecRunnerTask(ConcurrentRunner):
def __init__(self, name):
super().__init__()
self.name = name
super().__init__(name=name)
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _async_create_condition(_async_create_lock())
self.ops = []
self.stopped = False
def schedule(self, work):
async def schedule(self, work):
self.ops.append(work)
with self.cond:
async with self.cond:
self.cond.notify()
def stop(self):
async def stop(self):
self.stopped = True
with self.cond:
async with self.cond:
self.cond.notify()
def run(self):
async def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
async with self.cond:
await _async_cond_wait(self.cond, 10)
if self.ops:
try:
work = self.ops.pop(0)
work()
await work()
except Exception as exc:
self.exc = exc
self.stop()
await self.stop()
class AsyncSpecTestCreator:

View File

@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import asyncio
import base64
import gc
import multiprocessing
@ -30,6 +31,8 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo._asyncio_task import create_task
try:
import ipaddress
@ -369,3 +372,38 @@ class SystemCertsPatcher:
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs
if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object
class ConcurrentRunner(PARENT):
def __init__(self, **kwargs):
if _IS_SYNC:
super().__init__(**kwargs)
self.name = kwargs.get("name", "ConcurrentRunner")
self.stopped = False
self.task = None
self.target = kwargs.get("target", None)
self.args = kwargs.get("args", [])
if not _IS_SYNC:
def start(self):
self.task = create_task(self.run(), name=self.name)
def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)
def is_alive(self):
return not self.stopped
def run(self):
try:
self.target(*self.args)
finally:
self.stopped = True

View File

@ -2399,7 +2399,7 @@ class TestMongoClientFailover(MockClientTest):
# MongoClient discovers it's alone. The first attempt raises either
# ServerSelectionTimeoutError or AutoReconnect (from
# AsyncMockPool.get_socket).
# MockPool.get_socket).
with self.assertRaises(AutoReconnect):
c.db.collection.find_one()

View File

@ -18,22 +18,35 @@ from __future__ import annotations
import glob
import json
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test import IntegrationTest, PyMongoTestCase, client_context, unittest
from test import (
IntegrationTest,
PyMongoTestCase,
client_context,
unittest,
)
from test.utils import wait_until
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import parse_uri, split_hosts
_IS_SYNC = True
class TestDNSRepl(PyMongoTestCase):
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set"
)
if _IS_SYNC:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set"
)
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set"
)
load_balanced = False
@client_context.require_replica_set
@ -42,9 +55,14 @@ class TestDNSRepl(PyMongoTestCase):
class TestDNSLoadBalanced(PyMongoTestCase):
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced"
)
if _IS_SYNC:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced"
)
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced"
)
load_balanced = True
@client_context.require_load_balancer
@ -53,7 +71,12 @@ class TestDNSLoadBalanced(PyMongoTestCase):
class TestDNSSharded(PyMongoTestCase):
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded")
if _IS_SYNC:
TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded")
else:
TEST_PATH = os.path.join(
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded"
)
load_balanced = False
@client_context.require_mongos
@ -119,7 +142,9 @@ def create_test(test_case):
# tests.
copts["tlsAllowInvalidHostnames"] = True
client = PyMongoTestCase.unmanaged_simple_client(uri, **copts)
client = self.simple_client(uri, **copts)
if client._options.connect:
client._connect()
if num_seeds is not None:
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
if hosts is not None:
@ -132,7 +157,6 @@ def create_test(test_case):
client.admin.command("ping")
# XXX: we should block until SRV poller runs at least once
# and re-run these assertions.
client.close()
else:
try:
parse_uri(uri)
@ -188,7 +212,6 @@ class TestParsingErrors(PyMongoTestCase):
class TestCaseInsensitive(IntegrationTest):
def test_connect_case_insensitive(self):
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
self.addCleanup(client.close)
self.assertGreater(len(client.topology_description.server_descriptions()), 1)

View File

@ -15,9 +15,13 @@
"""MongoDB documentation examples in Python."""
from __future__ import annotations
import asyncio
import datetime
import functools
import sys
import threading
import time
from test.helpers import ConcurrentRunner
sys.path[0:0] = [""]
@ -29,8 +33,11 @@ from pymongo.errors import ConnectionFailure, OperationFailure
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import ServerApi
from pymongo.synchronous.helpers import next
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class TestSampleShellCommands(IntegrationTest):
def setUp(self):
@ -62,7 +69,7 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"item": "canvas"})
# End Example 2
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 3
db.inventory.insert_many(
@ -137,31 +144,31 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({})
# End Example 7
self.assertEqual(len(list(cursor)), 5)
self.assertEqual(len(cursor.to_list()), 5)
# Start Example 9
cursor = db.inventory.find({"status": "D"})
# End Example 9
self.assertEqual(len(list(cursor)), 2)
self.assertEqual(len(cursor.to_list()), 2)
# Start Example 10
cursor = db.inventory.find({"status": {"$in": ["A", "D"]}})
# End Example 10
self.assertEqual(len(list(cursor)), 5)
self.assertEqual(len(cursor.to_list()), 5)
# Start Example 11
cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}})
# End Example 11
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 12
cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]})
# End Example 12
self.assertEqual(len(list(cursor)), 3)
self.assertEqual(len(cursor.to_list()), 3)
# Start Example 13
cursor = db.inventory.find(
@ -169,7 +176,7 @@ class TestSampleShellCommands(IntegrationTest):
)
# End Example 13
self.assertEqual(len(list(cursor)), 2)
self.assertEqual(len(cursor.to_list()), 2)
def test_query_embedded_documents(self):
db = self.db
@ -219,31 +226,31 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])})
# End Example 15
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 16
cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])})
# End Example 16
self.assertEqual(len(list(cursor)), 0)
self.assertEqual(len(cursor.to_list()), 0)
# Start Example 17
cursor = db.inventory.find({"size.uom": "in"})
# End Example 17
self.assertEqual(len(list(cursor)), 2)
self.assertEqual(len(cursor.to_list()), 2)
# Start Example 18
cursor = db.inventory.find({"size.h": {"$lt": 15}})
# End Example 18
self.assertEqual(len(list(cursor)), 4)
self.assertEqual(len(cursor.to_list()), 4)
# Start Example 19
cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"})
# End Example 19
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
def test_query_arrays(self):
db = self.db
@ -269,49 +276,49 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"tags": ["red", "blank"]})
# End Example 21
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 22
cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}})
# End Example 22
self.assertEqual(len(list(cursor)), 4)
self.assertEqual(len(cursor.to_list()), 4)
# Start Example 23
cursor = db.inventory.find({"tags": "red"})
# End Example 23
self.assertEqual(len(list(cursor)), 4)
self.assertEqual(len(cursor.to_list()), 4)
# Start Example 24
cursor = db.inventory.find({"dim_cm": {"$gt": 25}})
# End Example 24
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 25
cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}})
# End Example 25
self.assertEqual(len(list(cursor)), 4)
self.assertEqual(len(cursor.to_list()), 4)
# Start Example 26
cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}})
# End Example 26
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 27
cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}})
# End Example 27
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 28
cursor = db.inventory.find({"tags": {"$size": 3}})
# End Example 28
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
def test_query_array_of_documents(self):
db = self.db
@ -360,49 +367,49 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])})
# End Example 30
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 31
cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])})
# End Example 31
self.assertEqual(len(list(cursor)), 0)
self.assertEqual(len(cursor.to_list()), 0)
# Start Example 32
cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}})
# End Example 32
self.assertEqual(len(list(cursor)), 3)
self.assertEqual(len(cursor.to_list()), 3)
# Start Example 33
cursor = db.inventory.find({"instock.qty": {"$lte": 20}})
# End Example 33
self.assertEqual(len(list(cursor)), 5)
self.assertEqual(len(cursor.to_list()), 5)
# Start Example 34
cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}})
# End Example 34
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 35
cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}})
# End Example 35
self.assertEqual(len(list(cursor)), 3)
self.assertEqual(len(cursor.to_list()), 3)
# Start Example 36
cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}})
# End Example 36
self.assertEqual(len(list(cursor)), 4)
self.assertEqual(len(cursor.to_list()), 4)
# Start Example 37
cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"})
# End Example 37
self.assertEqual(len(list(cursor)), 2)
self.assertEqual(len(cursor.to_list()), 2)
def test_query_null(self):
db = self.db
@ -415,19 +422,19 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"item": None})
# End Example 39
self.assertEqual(len(list(cursor)), 2)
self.assertEqual(len(cursor.to_list()), 2)
# Start Example 40
cursor = db.inventory.find({"item": {"$type": 10}})
# End Example 40
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
# Start Example 41
cursor = db.inventory.find({"item": {"$exists": False}})
# End Example 41
self.assertEqual(len(list(cursor)), 1)
self.assertEqual(len(cursor.to_list()), 1)
def test_projection(self):
db = self.db
@ -473,7 +480,7 @@ class TestSampleShellCommands(IntegrationTest):
cursor = db.inventory.find({"status": "A"})
# End Example 43
self.assertEqual(len(list(cursor)), 3)
self.assertEqual(len(cursor.to_list()), 3)
# Start Example 44
cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1})
@ -746,8 +753,9 @@ class TestSampleShellCommands(IntegrationTest):
while not done:
db.inventory.insert_one({"username": "alice"})
db.inventory.delete_one({"username": "alice"})
time.sleep(0.005)
t = threading.Thread(target=insert_docs)
t = ConcurrentRunner(target=insert_docs)
t.start()
try:
@ -1347,20 +1355,37 @@ class TestSnapshotQueryExamples(IntegrationTest):
db.drop_collection("dogs")
db.cats.insert_one({"name": "Whiskers", "color": "white", "age": 10, "adoptable": True})
db.dogs.insert_one({"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True})
wait_until(lambda: self.check_for_snapshot(db.cats), "success")
wait_until(lambda: self.check_for_snapshot(db.dogs), "success")
def predicate_one():
return self.check_for_snapshot(db.cats)
def predicate_two():
return self.check_for_snapshot(db.dogs)
wait_until(predicate_two, "success")
wait_until(predicate_one, "success")
# Start Snapshot Query Example 1
db = client.pets
with client.start_session(snapshot=True) as s:
adoptablePetsCount = db.cats.aggregate(
[{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], session=s
).next()["adoptableCatsCount"]
adoptablePetsCount = (
(
db.cats.aggregate(
[{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}],
session=s,
)
).next()
)["adoptableCatsCount"]
adoptablePetsCount += db.dogs.aggregate(
[{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], session=s
).next()["adoptableDogsCount"]
adoptablePetsCount += (
(
db.dogs.aggregate(
[{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}],
session=s,
)
).next()
)["adoptableDogsCount"]
print(adoptablePetsCount)
@ -1371,33 +1396,41 @@ class TestSnapshotQueryExamples(IntegrationTest):
saleDate = datetime.datetime.now()
db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate})
wait_until(lambda: self.check_for_snapshot(db.sales), "success")
def predicate_three():
return self.check_for_snapshot(db.sales)
wait_until(predicate_three, "success")
# Start Snapshot Query Example 2
db = client.retail
with client.start_session(snapshot=True) as s:
db.sales.aggregate(
[
{
"$match": {
"$expr": {
"$gt": [
"$saleDate",
{
"$dateSubtract": {
"startDate": "$$NOW",
"unit": "day",
"amount": 1,
}
},
]
}
}
},
{"$count": "totalDailySales"},
],
session=s,
).next()["totalDailySales"]
_ = (
(
db.sales.aggregate(
[
{
"$match": {
"$expr": {
"$gt": [
"$saleDate",
{
"$dateSubtract": {
"startDate": "$$NOW",
"unit": "day",
"amount": 1,
}
},
]
}
}
},
{"$count": "totalDailySales"},
],
session=s,
)
).next()
)["totalDailySales"]
# End Snapshot Query Example 2

View File

@ -17,14 +17,20 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))

View File

@ -26,6 +26,8 @@ from pymongo.errors import ConnectionFailure
from pymongo.hello import Hello, HelloCompat
from pymongo.synchronous.monitor import Monitor
_IS_SYNC = True
class TestHeartbeatMonitoring(IntegrationTest):
def create_mock_monitor(self, responses, uri, expected_results):
@ -40,8 +42,12 @@ class TestHeartbeatMonitoring(IntegrationTest):
raise responses[1]
return Hello(responses[1]), 99
m = self.single_client(
h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool
_ = self.single_client(
h=uri,
event_listeners=(listener,),
_monitor_class=MockMonitor,
_pool_class=MockPool,
connect=True,
)
expected_len = len(expected_results)
@ -50,20 +56,16 @@ class TestHeartbeatMonitoring(IntegrationTest):
# of this test.
wait_until(lambda: len(listener.events) >= expected_len, "publish all events")
try:
# zip gives us len(expected_results) pairs.
for expected, actual in zip(expected_results, listener.events):
self.assertEqual(expected, actual.__class__.__name__)
self.assertEqual(actual.connection_id, responses[0])
if expected != "ServerHeartbeatStartedEvent":
if isinstance(actual.reply, Hello):
self.assertEqual(actual.duration, 99)
self.assertEqual(actual.reply._doc, responses[1])
else:
self.assertEqual(actual.reply, responses[1])
finally:
m.close()
# zip gives us len(expected_results) pairs.
for expected, actual in zip(expected_results, listener.events):
self.assertEqual(expected, actual.__class__.__name__)
self.assertEqual(actual.connection_id, responses[0])
if expected != "ServerHeartbeatStartedEvent":
if isinstance(actual.reply, Hello):
self.assertEqual(actual.duration, 99)
self.assertEqual(actual.reply._doc, responses[1])
else:
self.assertEqual(actual.reply, responses[1])
def test_standalone(self):
responses = (

View File

@ -15,7 +15,9 @@
"""Run the auth spec tests."""
from __future__ import annotations
import asyncio
import os
import pathlib
import sys
import time
import uuid
@ -27,16 +29,22 @@ sys.path[0:0] = [""]
from test import IntegrationTest, PyMongoTestCase, unittest
from test.unified_format import generate_test_classes
from test.utils import AllowListEventListener, EventListener, OvertCommandListener
from test.utils import AllowListEventListener, OvertCommandListener
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
pytestmark = pytest.mark.index_management
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management")
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management")
_NAME = "test-search-index"
@ -82,23 +90,25 @@ class SearchIndexIntegrationBase(PyMongoTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
raise unittest.SkipTest("Skipping index management tests")
url = os.environ.get("MONGODB_URI")
username = os.environ["DB_USER"]
password = os.environ["DB_PASSWORD"]
cls.listener = listener = OvertCommandListener()
cls.client = cls.unmanaged_simple_client(
url, username=username, password=password, event_listeners=[listener]
)
cls.client.drop_database(_NAME)
cls.db = cls.client[cls.db_name]
cls.url = os.environ.get("MONGODB_URI")
cls.username = os.environ["DB_USER"]
cls.password = os.environ["DB_PASSWORD"]
cls.listener = OvertCommandListener()
@classmethod
def tearDownClass(cls):
cls.client.drop_database(_NAME)
cls.client.close()
def setUp(self) -> None:
self.client = self.simple_client(
self.url,
username=self.username,
password=self.password,
event_listeners=[self.listener],
)
self.client.drop_database(_NAME)
self.db = self.client[self.db_name]
def tearDown(self):
self.client.drop_database(_NAME)
def wait_for_ready(self, coll, name=_NAME, predicate=None):
"""Wait for a search index to be ready."""
@ -107,10 +117,9 @@ class SearchIndexIntegrationBase(PyMongoTestCase):
predicate = lambda index: index.get("queryable") is True
while True:
indices = list(coll.list_search_indexes(name))
indices = (coll.list_search_indexes(name)).to_list()
if len(indices) and predicate(indices[0]):
return indices[0]
break
time.sleep(5)
@ -133,7 +142,7 @@ class TestSearchIndexIntegration(SearchIndexIntegrationBase):
# Get the index definition.
self.listener.reset()
coll0.list_search_indexes(name=implicit_search_resp, comment="foo").next()
(coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next()
event = self.listener.events[0]
self.assertEqual(event.command["comment"], "foo")
@ -183,7 +192,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
)
# .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``.
indices = list(coll0.list_search_indexes())
indices = (coll0.list_search_indexes()).to_list()
names = [i["name"] for i in indices]
self.assertIn(name1, names)
self.assertIn(name2, names)
@ -223,7 +232,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array.
t0 = time.time()
while True:
indices = list(coll0.list_search_indexes())
indices = (coll0.list_search_indexes()).to_list()
if indices:
break
if (time.time() - t0) / 60 > 5:
@ -259,7 +268,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
self.wait_for_ready(coll0, predicate=predicate)
# Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``.
index = list(coll0.list_search_indexes(_NAME))[0]
index = ((coll0.list_search_indexes(_NAME)).to_list())[0]
self.assertIn("latestDefinition", index)
self.assertEqual(index["latestDefinition"], model2["definition"])
@ -324,7 +333,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
)
# Get the index definition.
resp = coll0.list_search_indexes(name=implicit_search_resp).next()
resp = (coll0.list_search_indexes(name=implicit_search_resp)).next()
# Assert that the index model contains the correct index type: ``"search"``.
self.assertEqual(resp["type"], "search")
@ -335,7 +344,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
)
# Get the index definition.
resp = coll0.list_search_indexes(name=explicit_search_resp).next()
resp = (coll0.list_search_indexes(name=explicit_search_resp)).next()
# Assert that the index model contains the correct index type: ``"search"``.
self.assertEqual(resp["type"], "search")
@ -350,7 +359,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
)
# Get the index definition.
resp = coll0.list_search_indexes(name=explicit_vector_resp).next()
resp = (coll0.list_search_indexes(name=explicit_vector_resp)).next()
# Assert that the index model contains the correct index type: ``"vectorSearch"``.
self.assertEqual(resp["type"], "vectorSearch")

View File

@ -21,13 +21,13 @@ import re
import sys
import uuid
from collections import OrderedDict
from typing import Any, List, MutableMapping, Tuple, Type
from typing import Any, Tuple, Type
from bson.codec_options import CodecOptions, DatetimeConversion
sys.path[0:0] = [""]
from test import IntegrationTest, unittest
from test import unittest
from bson import EPOCH_AWARE, EPOCH_NAIVE, SON, DatetimeMS, json_util
from bson.binary import (
@ -636,24 +636,5 @@ class TestJsonUtil(unittest.TestCase):
self.assertEqual(json_util.dumps(MyBinary(b"bin", USER_DEFINED_SUBTYPE)), expected_json)
class TestJsonUtilRoundtrip(IntegrationTest):
def test_cursor(self):
db = self.db
db.drop_collection("test")
docs: List[MutableMapping[str, Any]] = [
{"foo": [1, 2]},
{"bar": {"hello": "world"}},
{"code": Code("function x() { return 1; }")},
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
]
db.test.insert_many(docs)
reloaded_docs = json_util.loads(json_util.dumps(db.test.find()))
for doc in docs:
self.assertTrue(doc in reloaded_docs)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from test import IntegrationTest
from typing import Any, List, MutableMapping
from bson import Binary, Code, DBRef, ObjectId, json_util
from bson.binary import USER_DEFINED_SUBTYPE
_IS_SYNC = True
class TestJsonUtilRoundtrip(IntegrationTest):
def test_cursor(self):
db = self.db
db.drop_collection("test")
docs: List[MutableMapping[str, Any]] = [
{"foo": [1, 2]},
{"bar": {"hello": "world"}},
{"code": Code("function x() { return 1; }")},
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
]
db.test.insert_many(docs)
reloaded_docs = json_util.loads(json_util.dumps((db.test.find()).to_list()))
for doc in docs:
self.assertTrue(doc in reloaded_docs)

View File

@ -15,10 +15,12 @@
"""Test maxStalenessSeconds support."""
from __future__ import annotations
import asyncio
import os
import sys
import time
import warnings
from pathlib import Path
from pymongo import MongoClient
from pymongo.operations import _Op
@ -31,11 +33,16 @@ from test.utils_selection_tests import create_selection_tests
from pymongo.errors import ConfigurationError
from pymongo.server_selectors import writable_server_selector
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness")
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
pass

View File

@ -26,18 +26,20 @@ sys.path[0:0] = [""]
from test import IntegrationTest, client_context
from bson.codec_options import CodecOptions
from pymongo.synchronous.encryption import _HAVE_PYMONGOCRYPT, ClientEncryption, EncryptionError
from pymongo.synchronous.encryption import (
_HAVE_PYMONGOCRYPT,
ClientEncryption,
EncryptionError,
)
_IS_SYNC = True
pytestmark = pytest.mark.csfle
class TestonDemandGCPCredentials(IntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@client_context.require_version_min(4, 2, -1)
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
self.master_key = {
@ -74,12 +76,8 @@ class TestonDemandGCPCredentials(IntegrationTest):
class TestonDemandAzureCredentials(IntegrationTest):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
@client_context.require_version_min(4, 2, -1)
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
self.master_key = {

View File

@ -27,6 +27,8 @@ from bson.son import SON
from pymongo.errors import OperationFailure
from pymongo.read_concern import ReadConcern
_IS_SYNC = True
class TestReadConcern(IntegrationTest):
listener: OvertCommandListener
@ -71,14 +73,14 @@ class TestReadConcern(IntegrationTest):
def test_find_command(self):
# readConcern not sent in command if not specified.
coll = self.db.coll
tuple(coll.find({"field": "value"}))
coll.find({"field": "value"}).to_list()
self.assertNotIn("readConcern", self.listener.started_events[0].command)
self.listener.reset()
# Explicitly set readConcern to 'local'.
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
tuple(coll.find({"field": "value"}))
coll.find({"field": "value"}).to_list()
self.assertEqualCommand(
SON(
[
@ -93,19 +95,19 @@ class TestReadConcern(IntegrationTest):
def test_command_cursor(self):
# readConcern not sent in command if not specified.
coll = self.db.coll
tuple(coll.aggregate([{"$match": {"field": "value"}}]))
(coll.aggregate([{"$match": {"field": "value"}}])).to_list()
self.assertNotIn("readConcern", self.listener.started_events[0].command)
self.listener.reset()
# Explicitly set readConcern to 'local'.
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
tuple(coll.aggregate([{"$match": {"field": "value"}}]))
(coll.aggregate([{"$match": {"field": "value"}}])).to_list()
self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"])
def test_aggregate_out(self):
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
tuple(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}]))
(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])).to_list()
# Aggregate with $out supports readConcern MongoDB 4.2 onwards.
if client_context.version >= (4, 1):

View File

@ -26,7 +26,13 @@ from pymongo.operations import _Op
sys.path[0:0] = [""]
from test import IntegrationTest, SkipTest, client_context, connected, unittest
from test import (
IntegrationTest,
SkipTest,
client_context,
connected,
unittest,
)
from test.utils import (
OvertCommandListener,
one,
@ -49,16 +55,22 @@ from pymongo.read_preferences import (
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import Selection, readable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous.helpers import next
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class TestSelections(IntegrationTest):
@client_context.require_connection
def test_bool(self):
client = self.single_client()
wait_until(lambda: client.address, "discover primary")
def predicate():
return client.address
wait_until(predicate, "discover primary")
selection = Selection.from_topology_description(client._topology.description)
self.assertTrue(selection)
@ -88,11 +100,7 @@ class TestReadPreferenceObjects(unittest.TestCase):
class TestReadPreferencesBase(IntegrationTest):
@classmethod
@client_context.require_secondaries_count(1)
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
# Insert some data so we can use cursors in read_from_which_host
@ -123,11 +131,14 @@ class TestReadPreferencesBase(IntegrationTest):
f"Cursor used address {address}, expected either primary "
f"{client.primary} or secondaries {client.secondaries}"
)
return None
def assertReadsFrom(self, expected, **kwargs):
c = self.rs_client(**kwargs)
wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes")
def predicate():
return len(c.nodes - c.arbiters) == client_context.w
wait_until(predicate, "discovered all nodes")
used = self.read_from_which_kind(c)
self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}")
@ -150,7 +161,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase):
# Test find and find_one.
self.assertIsNotNone(coll.find_one())
self.assertEqual(10, len(list(coll.find())))
self.assertEqual(10, len(coll.find().to_list()))
# Test some database helpers.
self.assertIsNotNone(db.list_collection_names())
@ -173,20 +184,22 @@ class TestReadPreferences(TestReadPreferencesBase):
ReadPreference.SECONDARY_PREFERRED,
ReadPreference.NEAREST,
):
self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference)
self.assertEqual(mode, (self.rs_client(read_preference=mode)).read_preference)
self.assertRaises(TypeError, self.rs_client, read_preference="foo")
with self.assertRaises(TypeError):
self.rs_client(read_preference="foo")
def test_tag_sets_validation(self):
S = Secondary(tag_sets=[{}])
self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets)
self.assertEqual([{}], (self.rs_client(read_preference=S)).read_preference.tag_sets)
S = Secondary(tag_sets=[{"k": "v"}])
self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets)
self.assertEqual([{"k": "v"}], (self.rs_client(read_preference=S)).read_preference.tag_sets)
S = Secondary(tag_sets=[{"k": "v"}, {}])
self.assertEqual(
[{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets
[{"k": "v"}, {}],
(self.rs_client(read_preference=S)).read_preference.tag_sets,
)
self.assertRaises(ValueError, Secondary, tag_sets=[])
@ -200,22 +213,27 @@ class TestReadPreferences(TestReadPreferencesBase):
def test_threshold_validation(self):
self.assertEqual(
17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
17,
(self.rs_client(localThresholdMS=17, connect=False)).options.local_threshold_ms,
)
self.assertEqual(
42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
42,
(self.rs_client(localThresholdMS=42, connect=False)).options.local_threshold_ms,
)
self.assertEqual(
666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
666,
(self.rs_client(localThresholdMS=666, connect=False)).options.local_threshold_ms,
)
self.assertEqual(
0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms
0,
(self.rs_client(localThresholdMS=0, connect=False)).options.local_threshold_ms,
)
self.assertRaises(ValueError, self.rs_client, localthresholdms=-1)
with self.assertRaises(ValueError):
self.rs_client(localthresholdms=-1)
def test_zero_latency(self):
ping_times: set = set()
@ -238,7 +256,8 @@ class TestReadPreferences(TestReadPreferencesBase):
def test_primary_with_tags(self):
# Tags not allowed with PRIMARY
self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}])
with self.assertRaises(ConfigurationError):
self.rs_client(tag_sets=[{"dc": "ny"}])
def test_primary_preferred(self):
self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
@ -272,7 +291,7 @@ class TestReadPreferences(TestReadPreferencesBase):
not_used = data_members.difference(used)
latencies = ", ".join(
"%s: %sms" % (server.description.address, server.description.round_trip_time)
for server in c._get_topology().select_servers(readable_server_selector, _Op.TEST)
for server in (c._get_topology()).select_servers(readable_server_selector, _Op.TEST)
)
self.assertFalse(
@ -289,12 +308,9 @@ class ReadPrefTester(MongoClient):
client_options.update(kwargs)
super().__init__(*args, **client_options)
@contextlib.contextmanager
def _conn_for_reads(self, read_preference, session, operation):
context = super()._conn_for_reads(read_preference, session, operation)
with context as (conn, read_preference):
self.record_a_read(conn.address)
yield conn, read_preference
return context
@contextlib.contextmanager
def _conn_from_server(self, read_preference, server, session):
@ -304,7 +320,7 @@ class ReadPrefTester(MongoClient):
yield conn, read_preference
def record_a_read(self, address):
server = self._get_topology().select_server_by_address(address, _Op.TEST, 0)
server = (self._get_topology()).select_server_by_address(address, _Op.TEST, 0)
self.has_read_from.add(server)
@ -321,25 +337,23 @@ class TestCommandAndReadPreference(IntegrationTest):
c: ReadPrefTester
client_version: Version
@classmethod
@client_context.require_secondaries_count(1)
def setUpClass(cls):
super().setUpClass()
cls.c = ReadPrefTester(
def setUp(self):
super().setUp()
self.c = ReadPrefTester(
# Ignore round trip times, to test ReadPreference modes only.
localThresholdMS=1000 * 1000,
)
cls.client_version = Version.from_client(cls.c)
self.client_version = Version.from_client(self.c)
# mapReduce fails if the collection does not exist.
coll = cls.c.pymongo_test.get_collection(
coll = self.c.pymongo_test.get_collection(
"test", write_concern=WriteConcern(w=client_context.w)
)
coll.insert_one({})
@classmethod
def tearDownClass(cls):
cls.c.drop_database("pymongo_test")
cls.c.close()
def tearDown(self):
self.c.drop_database("pymongo_test")
self.c.close()
def executed_on_which_server(self, client, fn, *args, **kwargs):
"""Execute fn(*args, **kwargs) and return the Server instance used."""
@ -366,7 +380,7 @@ class TestCommandAndReadPreference(IntegrationTest):
break
assert self.c.primary is not None
unused = self.c.secondaries.union({self.c.primary}).difference(used)
unused = (self.c.secondaries).union({self.c.primary}).difference(used)
if unused:
self.fail("Some members not used for NEAREST: %s" % (unused))
else:
@ -401,11 +415,12 @@ class TestCommandAndReadPreference(IntegrationTest):
def test_create_collection(self):
# create_collection runs listCollections on the primary to check if
# the collection already exists.
self._test_primary_helper(
lambda: self.c.pymongo_test.create_collection(
def func():
return self.c.pymongo_test.create_collection(
"some_collection%s" % random.randint(0, sys.maxsize)
)
)
self._test_primary_helper(func)
def test_count_documents(self):
self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {})
@ -545,7 +560,6 @@ class TestMongosAndReadPreference(IntegrationTest):
cases["secondary"] = Secondary
listener = OvertCommandListener()
client = self.rs_client(event_listeners=[listener])
self.addCleanup(client.close)
client.admin.command("ping")
for _mode, cls in cases.items():
pref = cls(hedge={"enabled": True})
@ -645,10 +659,10 @@ class TestMongosAndReadPreference(IntegrationTest):
# tell what shard member a query ran on.
for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()):
qcoll = coll.with_options(read_preference=pref)
results = list(qcoll.find().sort([("_id", 1)]))
results = qcoll.find().sort([("_id", 1)]).to_list()
self.assertEqual(first_id, results[0]["_id"])
self.assertEqual(last_id, results[-1]["_id"])
results = list(qcoll.find().sort([("_id", -1)]))
results = qcoll.find().sort([("_id", -1)]).to_list()
self.assertEqual(first_id, results[-1]["_id"])
self.assertEqual(last_id, results[0]["_id"])
@ -671,14 +685,14 @@ class TestMongosAndReadPreference(IntegrationTest):
else:
self.fail("mongos accepted invalid staleness")
coll = self.single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=120
coll = (
self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=120)
).pymongo_test.test
# No error
coll.find_one()
coll = self.single_client(
readPreference="secondaryPreferred", maxStalenessSeconds=10
coll = (
self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=10)
).pymongo_test.test
try:
coll.find_one()

View File

@ -19,6 +19,7 @@ import json
import os
import sys
import warnings
from pathlib import Path
sys.path[0:0] = [""]
@ -39,7 +40,13 @@ from pymongo.read_concern import ReadConcern
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern")
_IS_SYNC = True
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern")
class TestReadWriteConcernSpec(IntegrationTest):
@ -47,7 +54,6 @@ class TestReadWriteConcernSpec(IntegrationTest):
listener = OvertCommandListener()
# Client with default readConcern and writeConcern
client = self.rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
collection.insert_many([{} for _ in range(10)])
@ -66,9 +72,12 @@ class TestReadWriteConcernSpec(IntegrationTest):
"insert", "collection", documents=[{}], write_concern=WriteConcern()
)
def aggregate_op():
(collection.aggregate([])).to_list()
ops = [
("aggregate", lambda: list(collection.aggregate([]))),
("find", lambda: list(collection.find())),
("aggregate", aggregate_op),
("find", lambda: collection.find().to_list()),
("insert_one", lambda: collection.insert_one({})),
("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})),
("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})),
@ -207,7 +216,6 @@ class TestReadWriteConcernSpec(IntegrationTest):
def test_write_error_details_exposes_errinfo(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(event_listeners=[listener])
self.addCleanup(client.close)
db = client.errinfotest
self.addCleanup(client.drop_database, "errinfotest")
validator = {"x": {"$type": "string"}}
@ -286,7 +294,7 @@ def create_document_test(test_case):
def create_tests():
for dirpath, _, filenames in os.walk(_TEST_PATH):
for dirpath, _, filenames in os.walk(TEST_PATH):
dirname = os.path.split(dirpath)[-1]
if dirname == "operation":
@ -321,7 +329,7 @@ create_tests()
# PyMongo does not support MapReduce.
globals().update(
generate_test_classes(
os.path.join(_TEST_PATH, "operation"),
os.path.join(TEST_PATH, "operation"),
module=__name__,
expected_failures=["MapReduce .*"],
)

View File

@ -15,6 +15,7 @@
"""Test the Retryable Reads unified spec tests."""
from __future__ import annotations
import os
import sys
from pathlib import Path
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = Path(__file__).parent / "retryable_reads/unified"
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified")
# Generate unified tests.
# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects.

View File

@ -17,14 +17,20 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "unified")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))

View File

@ -1,15 +1,37 @@
# Copyright 2024-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run Command unified tests."""
from __future__ import annotations
import os
import unittest
from pathlib import Path
from test.unified_format import generate_test_classes
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "run_command")
_IS_SYNC = True
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command")
globals().update(
generate_test_classes(
os.path.join(_TEST_PATH, "unified"),
os.path.join(TEST_PATH, "unified"),
module=__name__,
)
)

View File

@ -17,19 +17,25 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection_logging")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging")
globals().update(
generate_test_classes(
_TEST_PATH,
TEST_PATH,
module=__name__,
)
)

View File

@ -18,18 +18,24 @@ from __future__ import annotations
import json
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test import PyMongoTestCase, unittest
from pymongo.read_preferences import MovingAverage
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt")
class TestAllScenarios(unittest.TestCase):
class TestAllScenarios(PyMongoTestCase):
pass
@ -49,7 +55,7 @@ def create_test(scenario_def):
def create_tests():
for dirpath, _, filenames in os.walk(_TEST_PATH):
for dirpath, _, filenames in os.walk(TEST_PATH):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:

View File

@ -17,14 +17,21 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))

View File

@ -15,8 +15,9 @@
"""Run the SRV support tests."""
from __future__ import annotations
import asyncio
import sys
from time import sleep
import time
from typing import Any
sys.path[0:0] = [""]
@ -28,7 +29,8 @@ import pymongo
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.srv_resolver import _have_dnspython
from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
WAIT_TIME = 0.1
@ -168,6 +170,7 @@ class TestSrvPolling(PyMongoTestCase):
# Patch timeouts to ensure short test running times.
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING)
client._connect()
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
# Patch list of hosts returned by DNS query.
with SrvPollingKnobs(
@ -232,6 +235,7 @@ class TestSrvPolling(PyMongoTestCase):
):
# Client uses unpatched method to get initial nodelist
client = self.simple_client(self.CONNECTION_STRING)
client._connect()
# Invalid DNS resolver response should not change nodelist.
self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
@ -265,6 +269,7 @@ class TestSrvPolling(PyMongoTestCase):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
client._connect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
@ -279,6 +284,7 @@ class TestSrvPolling(PyMongoTestCase):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
client._connect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
@ -294,8 +300,9 @@ class TestSrvPolling(PyMongoTestCase):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
client._connect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
time.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
final_topology = set(client.topology_description.server_descriptions())
self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology)
self.assertEqual(len(final_topology), 2)
@ -303,8 +310,9 @@ class TestSrvPolling(PyMongoTestCase):
def test_does_not_flipflop(self):
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
client._connect()
old = set(client.topology_description.server_descriptions())
sleep(4 * WAIT_TIME)
time.sleep(4 * WAIT_TIME)
new = set(client.topology_description.server_descriptions())
self.assertSetEqual(old, new)
@ -322,6 +330,7 @@ class TestSrvPolling(PyMongoTestCase):
client = self.simple_client(
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
)
client._connect()
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
self.assert_nodelist_change(response, client)
@ -337,9 +346,9 @@ class TestSrvPolling(PyMongoTestCase):
nodelist_callback=resolver_response,
):
client = self.simple_client(self.CONNECTION_STRING)
self.assertRaises(
AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2
)
client._connect()
with self.assertRaises(AssertionError):
self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2)
def test_import_dns_resolver(self):
# Regression test for PYTHON-4407

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import os
import pathlib
import socket
import sys
@ -65,7 +66,13 @@ except ImportError:
if HAVE_SSL:
import ssl
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
_IS_SYNC = True
if _IS_SYNC:
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates")
else:
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates")
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem")
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
@ -144,21 +151,18 @@ class TestSSL(IntegrationTest):
)
coll.drop()
coll.insert_one({"ssl": True})
self.assertTrue(coll.find_one()["ssl"])
self.assertTrue((coll.find_one())["ssl"])
coll.drop()
@classmethod
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
def setUpClass(cls):
super().setUpClass()
def setUp(self):
super().setUp()
# MongoClient should connect to the primary by default.
cls.saved_port = MongoClient.PORT
self.saved_port = MongoClient.PORT
MongoClient.PORT = client_context.port
@classmethod
def tearDownClass(cls):
MongoClient.PORT = cls.saved_port
super().tearDownClass()
def tearDown(self):
MongoClient.PORT = self.saved_port
@client_context.require_tls
def test_simple_ssl(self):
@ -548,7 +552,6 @@ class TestSSL(IntegrationTest):
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
self.addCleanup(noauth.close)
with self.assertRaises(OperationFailure):
noauth.pymongo_test.test.find_one()
@ -562,7 +565,6 @@ class TestSSL(IntegrationTest):
tlsCertificateKeyFile=CLIENT_PEM,
event_listeners=[listener],
)
self.addCleanup(auth.close)
# No error
auth.pymongo_test.test.find_one()
@ -581,7 +583,6 @@ class TestSSL(IntegrationTest):
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(client.close)
# No error
client.pymongo_test.test.find_one()
@ -589,7 +590,6 @@ class TestSSL(IntegrationTest):
client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(client.close)
# No error
client.pymongo_test.test.find_one()
# Auth should fail if username and certificate do not match
@ -602,7 +602,6 @@ class TestSSL(IntegrationTest):
bad_client = self.simple_client(
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
)
self.addCleanup(bad_client.close)
with self.assertRaises(OperationFailure):
bad_client.pymongo_test.test.find_one()
@ -615,7 +614,6 @@ class TestSSL(IntegrationTest):
tlsAllowInvalidCertificates=True,
tlsCertificateKeyFile=CLIENT_PEM,
)
self.addCleanup(bad_client.close)
with self.assertRaises(OperationFailure):
bad_client.pymongo_test.test.find_one()

View File

@ -30,6 +30,8 @@ from test.utils import (
from pymongo import monitoring
from pymongo.hello import HelloCompat
_IS_SYNC = True
class TestStreamingProtocol(IntegrationTest):
@client_context.require_failCommand_appName
@ -41,7 +43,6 @@ class TestStreamingProtocol(IntegrationTest):
heartbeatFrequencyMS=500,
appName="failingHeartbeatTest",
)
self.addCleanup(client.close)
# Force a connection.
client.admin.command("ping")
address = client.address
@ -78,7 +79,7 @@ class TestStreamingProtocol(IntegrationTest):
def rediscovered():
return len(listener.matching(_discovered_node)) >= 1
# Topology events are published asynchronously
# Topology events are not published synchronously
wait_until(marked_unknown, "mark node unknown")
wait_until(rediscovered, "rediscover node")
@ -108,7 +109,6 @@ class TestStreamingProtocol(IntegrationTest):
client = self.rs_or_single_client(
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
)
self.addCleanup(client.close)
# Force a connection.
client.admin.command("ping")
address = client.address
@ -156,7 +156,6 @@ class TestStreamingProtocol(IntegrationTest):
client = self.single_client(
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
)
self.addCleanup(client.close)
# Force a connection.
client.admin.command("ping")
duration = time.time() - start
@ -183,7 +182,6 @@ class TestStreamingProtocol(IntegrationTest):
heartbeatFrequencyMS=500,
appName="heartbeatEventAwaitedFlag",
)
self.addCleanup(client.close)
# Force a connection.
client.admin.command("ping")

View File

@ -17,12 +17,15 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
sys.path[0:0] = [""]
from test import client_context, unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
@client_context.require_no_mmap
def setUpModule():
@ -30,15 +33,21 @@ def setUpModule():
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "unified")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
# Location of JSON test specifications for transactions-convenient-api.
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api", "unified"
)
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified"
)
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))

Some files were not shown because too many files have changed in this diff Show More