Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
0609631a0c
@ -1386,7 +1386,7 @@ def is_valid(bson: bytes) -> bool:
|
||||
:param bson: the data to be validated
|
||||
"""
|
||||
if not isinstance(bson, bytes):
|
||||
raise TypeError("BSON data must be an instance of a subclass of bytes")
|
||||
raise TypeError(f"BSON data must be an instance of a subclass of bytes, not {type(bson)}")
|
||||
|
||||
try:
|
||||
_bson_to_dict(bson, DEFAULT_CODEC_OPTIONS)
|
||||
|
||||
@ -290,7 +290,7 @@ class Binary(bytes):
|
||||
subtype: int = BINARY_SUBTYPE,
|
||||
) -> Binary:
|
||||
if not isinstance(subtype, int):
|
||||
raise TypeError("subtype must be an instance of int")
|
||||
raise TypeError(f"subtype must be an instance of int, not {type(subtype)}")
|
||||
if subtype >= 256 or subtype < 0:
|
||||
raise ValueError("subtype must be contained in [0, 256)")
|
||||
# Support any type that implements the buffer protocol.
|
||||
@ -321,7 +321,7 @@ class Binary(bytes):
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
if not isinstance(uuid, UUID):
|
||||
raise TypeError("uuid must be an instance of uuid.UUID")
|
||||
raise TypeError(f"uuid must be an instance of uuid.UUID, not {type(uuid)}")
|
||||
|
||||
if uuid_representation not in ALL_UUID_REPRESENTATIONS:
|
||||
raise ValueError(
|
||||
@ -470,7 +470,7 @@ class Binary(bytes):
|
||||
"""
|
||||
|
||||
if self.subtype != VECTOR_SUBTYPE:
|
||||
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector.")
|
||||
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
|
||||
|
||||
position = 0
|
||||
dtype, padding = struct.unpack_from("<sB", self, position)
|
||||
|
||||
@ -56,7 +56,7 @@ class Code(str):
|
||||
**kwargs: Any,
|
||||
) -> Code:
|
||||
if not isinstance(code, str):
|
||||
raise TypeError("code must be an instance of str")
|
||||
raise TypeError(f"code must be an instance of str, not {type(code)}")
|
||||
|
||||
self = str.__new__(cls, code)
|
||||
|
||||
@ -67,7 +67,7 @@ class Code(str):
|
||||
|
||||
if scope is not None:
|
||||
if not isinstance(scope, _Mapping):
|
||||
raise TypeError("scope must be an instance of dict")
|
||||
raise TypeError(f"scope must be an instance of dict, not {type(scope)}")
|
||||
if self.__scope is not None:
|
||||
self.__scope.update(scope) # type: ignore
|
||||
else:
|
||||
|
||||
@ -401,17 +401,23 @@ else:
|
||||
"uuid_representation must be a value from bson.binary.UuidRepresentation"
|
||||
)
|
||||
if not isinstance(unicode_decode_error_handler, str):
|
||||
raise ValueError("unicode_decode_error_handler must be a string")
|
||||
raise ValueError(
|
||||
f"unicode_decode_error_handler must be a string, not {type(unicode_decode_error_handler)}"
|
||||
)
|
||||
if tzinfo is not None:
|
||||
if not isinstance(tzinfo, datetime.tzinfo):
|
||||
raise TypeError("tzinfo must be an instance of datetime.tzinfo")
|
||||
raise TypeError(
|
||||
f"tzinfo must be an instance of datetime.tzinfo, not {type(tzinfo)}"
|
||||
)
|
||||
if not tz_aware:
|
||||
raise ValueError("cannot specify tzinfo without also setting tz_aware=True")
|
||||
|
||||
type_registry = type_registry or TypeRegistry()
|
||||
|
||||
if not isinstance(type_registry, TypeRegistry):
|
||||
raise TypeError("type_registry must be an instance of TypeRegistry")
|
||||
raise TypeError(
|
||||
f"type_registry must be an instance of TypeRegistry, not {type(type_registry)}"
|
||||
)
|
||||
|
||||
return tuple.__new__(
|
||||
cls,
|
||||
|
||||
@ -56,9 +56,9 @@ class DBRef:
|
||||
.. seealso:: The MongoDB documentation on `dbrefs <https://dochub.mongodb.org/core/dbrefs>`_.
|
||||
"""
|
||||
if not isinstance(collection, str):
|
||||
raise TypeError("collection must be an instance of str")
|
||||
raise TypeError(f"collection must be an instance of str, not {type(collection)}")
|
||||
if database is not None and not isinstance(database, str):
|
||||
raise TypeError("database must be an instance of str")
|
||||
raise TypeError(f"database must be an instance of str, not {type(database)}")
|
||||
|
||||
self.__collection = collection
|
||||
self.__id = id
|
||||
|
||||
@ -277,7 +277,7 @@ class Decimal128:
|
||||
point in Binary Integer Decimal (BID) format).
|
||||
"""
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError("value must be an instance of bytes")
|
||||
raise TypeError(f"value must be an instance of bytes, not {type(value)}")
|
||||
if len(value) != 16:
|
||||
raise ValueError("value must be exactly 16 bytes")
|
||||
return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) # type: ignore
|
||||
|
||||
@ -58,9 +58,9 @@ class Timestamp:
|
||||
time = time - offset
|
||||
time = int(calendar.timegm(time.timetuple()))
|
||||
if not isinstance(time, int):
|
||||
raise TypeError("time must be an instance of int")
|
||||
raise TypeError(f"time must be an instance of int, not {type(time)}")
|
||||
if not isinstance(inc, int):
|
||||
raise TypeError("inc must be an instance of int")
|
||||
raise TypeError(f"inc must be an instance of int, not {type(inc)}")
|
||||
if not 0 <= time < UPPERBOUND:
|
||||
raise ValueError("time must be contained in [0, 2**32)")
|
||||
if not 0 <= inc < UPPERBOUND:
|
||||
|
||||
@ -67,6 +67,14 @@ uMongo
|
||||
mongomock. The source `is available on GitHub
|
||||
<https://github.com/Scille/umongo>`_
|
||||
|
||||
Django MongoDB Backend
|
||||
`Django MongoDB Backend <https://django-mongodb-backend.readthedocs.io>`_ is a
|
||||
database backend library specifically made for Django. The integration takes
|
||||
advantage of MongoDB's unique document model capabilities, which align
|
||||
naturally with Django's philosophy of simplified data modeling and
|
||||
reduced development complexity. The source is available
|
||||
`on GitHub <https://github.com/mongodb-labs/django-mongodb-backend>`_.
|
||||
|
||||
No longer maintained
|
||||
""""""""""""""""""""
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ class AsyncGridFS:
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(database, AsyncDatabase):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
raise TypeError(f"database must be an instance of Database, not {type(database)}")
|
||||
|
||||
database = _clear_entity_type_registry(database)
|
||||
|
||||
@ -503,7 +503,7 @@ class AsyncGridFSBucket:
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(db, AsyncDatabase):
|
||||
raise TypeError("database must be an instance of AsyncDatabase")
|
||||
raise TypeError(f"database must be an instance of AsyncDatabase, not {type(db)}")
|
||||
|
||||
db = _clear_entity_type_registry(db)
|
||||
|
||||
@ -1082,7 +1082,9 @@ class AsyncGridIn:
|
||||
:attr:`~pymongo.collection.AsyncCollection.write_concern`
|
||||
"""
|
||||
if not isinstance(root_collection, AsyncCollection):
|
||||
raise TypeError("root_collection must be an instance of AsyncCollection")
|
||||
raise TypeError(
|
||||
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
|
||||
)
|
||||
|
||||
if not root_collection.write_concern.acknowledged:
|
||||
raise ConfigurationError("root_collection must use acknowledged write_concern")
|
||||
@ -1436,7 +1438,9 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
from the server. Metadata is fetched when first needed.
|
||||
"""
|
||||
if not isinstance(root_collection, AsyncCollection):
|
||||
raise TypeError("root_collection must be an instance of AsyncCollection")
|
||||
raise TypeError(
|
||||
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
|
||||
)
|
||||
_disallow_transactions(session)
|
||||
|
||||
root_collection = _clear_entity_type_registry(root_collection)
|
||||
|
||||
@ -100,7 +100,7 @@ class GridFS:
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(database, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
raise TypeError(f"database must be an instance of Database, not {type(database)}")
|
||||
|
||||
database = _clear_entity_type_registry(database)
|
||||
|
||||
@ -501,7 +501,7 @@ class GridFSBucket:
|
||||
.. seealso:: The MongoDB documentation on `gridfs <https://dochub.mongodb.org/core/gridfs>`_.
|
||||
"""
|
||||
if not isinstance(db, Database):
|
||||
raise TypeError("database must be an instance of Database")
|
||||
raise TypeError(f"database must be an instance of Database, not {type(db)}")
|
||||
|
||||
db = _clear_entity_type_registry(db)
|
||||
|
||||
@ -1076,7 +1076,9 @@ class GridIn:
|
||||
:attr:`~pymongo.collection.Collection.write_concern`
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
raise TypeError(
|
||||
f"root_collection must be an instance of Collection, not {type(root_collection)}"
|
||||
)
|
||||
|
||||
if not root_collection.write_concern.acknowledged:
|
||||
raise ConfigurationError("root_collection must use acknowledged write_concern")
|
||||
@ -1426,7 +1428,9 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
from the server. Metadata is fetched when first needed.
|
||||
"""
|
||||
if not isinstance(root_collection, Collection):
|
||||
raise TypeError("root_collection must be an instance of Collection")
|
||||
raise TypeError(
|
||||
f"root_collection must be an instance of Collection, not {type(root_collection)}"
|
||||
)
|
||||
_disallow_transactions(session)
|
||||
|
||||
root_collection = _clear_entity_type_registry(root_collection)
|
||||
|
||||
@ -160,7 +160,7 @@ def timeout(seconds: Optional[float]) -> ContextManager[None]:
|
||||
.. versionadded:: 4.2
|
||||
"""
|
||||
if not isinstance(seconds, (int, float, type(None))):
|
||||
raise TypeError("timeout must be None, an int, or a float")
|
||||
raise TypeError(f"timeout must be None, an int, or a float, not {type(seconds)}")
|
||||
if seconds and seconds < 0:
|
||||
raise ValueError("timeout cannot be negative")
|
||||
if seconds is not None:
|
||||
|
||||
@ -160,7 +160,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
|
||||
self._locked = False
|
||||
self._wake_up_first()
|
||||
else:
|
||||
raise RuntimeError("Lock is not acquired.")
|
||||
raise RuntimeError("Lock is not acquired")
|
||||
|
||||
def _wake_up_first(self) -> None:
|
||||
"""Ensure that the first waiter will wake up."""
|
||||
|
||||
@ -46,7 +46,7 @@ def _get_azure_response(
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except Exception:
|
||||
raise ValueError("Azure IMDS response must be in JSON format.") from None
|
||||
raise ValueError("Azure IMDS response must be in JSON format") from None
|
||||
|
||||
for key in ["access_token", "expires_in"]:
|
||||
if not data.get(key):
|
||||
|
||||
@ -161,7 +161,7 @@ def _password_digest(username: str, password: str) -> str:
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
raise TypeError(f"username must be an instance of str, not {type(username)}")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
|
||||
@ -213,7 +213,9 @@ class _OIDCAuthenticator:
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
raise ValueError(
|
||||
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
|
||||
)
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
@ -310,7 +310,9 @@ class TransactionOptions:
|
||||
)
|
||||
if max_commit_time_ms is not None:
|
||||
if not isinstance(max_commit_time_ms, int):
|
||||
raise TypeError("max_commit_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def read_concern(self) -> Optional[ReadConcern]:
|
||||
@ -902,7 +904,9 @@ class AsyncClientSession:
|
||||
another `AsyncClientSession` instance.
|
||||
"""
|
||||
if not isinstance(cluster_time, _Mapping):
|
||||
raise TypeError("cluster_time must be a subclass of collections.Mapping")
|
||||
raise TypeError(
|
||||
f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}"
|
||||
)
|
||||
if not isinstance(cluster_time.get("clusterTime"), Timestamp):
|
||||
raise ValueError("Invalid cluster_time")
|
||||
self._advance_cluster_time(cluster_time)
|
||||
@ -923,7 +927,9 @@ class AsyncClientSession:
|
||||
another `AsyncClientSession` instance.
|
||||
"""
|
||||
if not isinstance(operation_time, Timestamp):
|
||||
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
|
||||
raise TypeError(
|
||||
f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}"
|
||||
)
|
||||
self._advance_operation_time(operation_time)
|
||||
|
||||
def _process_response(self, reply: Mapping[str, Any]) -> None:
|
||||
|
||||
@ -228,7 +228,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern or database.read_concern,
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
raise TypeError(f"name must be an instance of str, not {type(name)}")
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
if not isinstance(database, AsyncDatabase):
|
||||
@ -2475,7 +2475,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
name = helpers_shared._gen_index_name(index_or_name)
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("index_or_name must be an instance of str or list")
|
||||
raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}")
|
||||
|
||||
cmd = {"dropIndexes": self._name, "index": name}
|
||||
cmd.update(kwargs)
|
||||
@ -3078,7 +3078,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
"""
|
||||
if not isinstance(new_name, str):
|
||||
raise TypeError("new_name must be an instance of str")
|
||||
raise TypeError(f"new_name must be an instance of str, not {type(new_name)}")
|
||||
|
||||
if not new_name or ".." in new_name:
|
||||
raise InvalidName("collection names cannot be empty")
|
||||
@ -3148,7 +3148,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
"""
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("key must be an instance of str")
|
||||
raise TypeError(f"key must be an instance of str, not {type(key)}")
|
||||
cmd = {"distinct": self._name, "key": key}
|
||||
if filter is not None:
|
||||
if "query" in kwargs:
|
||||
@ -3196,7 +3196,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
common.validate_is_mapping("filter", filter)
|
||||
if not isinstance(return_document, bool):
|
||||
raise ValueError(
|
||||
"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER"
|
||||
f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}"
|
||||
)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}
|
||||
|
||||
@ -94,7 +94,9 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self._die_no_lock()
|
||||
@ -115,7 +117,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
|
||||
@ -146,9 +146,9 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
spec: Mapping[str, Any] = filter or {}
|
||||
validate_is_mapping("filter", spec)
|
||||
if not isinstance(skip, int):
|
||||
raise TypeError("skip must be an instance of int")
|
||||
raise TypeError(f"skip must be an instance of int, not {type(skip)}")
|
||||
if not isinstance(limit, int):
|
||||
raise TypeError("limit must be an instance of int")
|
||||
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
|
||||
validate_boolean("no_cursor_timeout", no_cursor_timeout)
|
||||
if no_cursor_timeout and not self._explicit_session:
|
||||
warnings.warn(
|
||||
@ -171,7 +171,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
validate_boolean("allow_partial_results", allow_partial_results)
|
||||
validate_boolean("oplog_replay", oplog_replay)
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
# Only set if allow_disk_use is provided by the user, else None.
|
||||
@ -388,7 +388,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
cursor.add_option(2)
|
||||
"""
|
||||
if not isinstance(mask, int):
|
||||
raise TypeError("mask must be an int")
|
||||
raise TypeError(f"mask must be an int, not {type(mask)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
if mask & _QUERY_OPTIONS["exhaust"]:
|
||||
@ -408,7 +408,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
cursor.remove_option(2)
|
||||
"""
|
||||
if not isinstance(mask, int):
|
||||
raise TypeError("mask must be an int")
|
||||
raise TypeError(f"mask must be an int, not {type(mask)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
if mask & _QUERY_OPTIONS["exhaust"]:
|
||||
@ -432,7 +432,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
if not isinstance(allow_disk_use, bool):
|
||||
raise TypeError("allow_disk_use must be a bool")
|
||||
raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
self._allow_disk_use = allow_disk_use
|
||||
@ -451,7 +451,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
.. seealso:: The MongoDB documentation on `limit <https://dochub.mongodb.org/core/limit>`_.
|
||||
"""
|
||||
if not isinstance(limit, int):
|
||||
raise TypeError("limit must be an integer")
|
||||
raise TypeError(f"limit must be an integer, not {type(limit)}")
|
||||
if self._exhaust:
|
||||
raise InvalidOperation("Can't use limit and exhaust together.")
|
||||
self._check_okay_to_chain()
|
||||
@ -479,7 +479,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
self._check_okay_to_chain()
|
||||
@ -499,7 +499,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
:param skip: the number of results to skip
|
||||
"""
|
||||
if not isinstance(skip, int):
|
||||
raise TypeError("skip must be an integer")
|
||||
raise TypeError(f"skip must be an integer, not {type(skip)}")
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be >= 0")
|
||||
self._check_okay_to_chain()
|
||||
@ -520,7 +520,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
:param max_time_ms: the time limit after which the operation is aborted
|
||||
"""
|
||||
if not isinstance(max_time_ms, int) and max_time_ms is not None:
|
||||
raise TypeError("max_time_ms must be an integer or None")
|
||||
raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
self._max_time_ms = max_time_ms
|
||||
@ -543,7 +543,9 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
|
||||
)
|
||||
self._check_okay_to_chain()
|
||||
|
||||
# Ignore max_await_time_ms if not tailable or await_data is False.
|
||||
@ -679,7 +681,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
if not isinstance(spec, (list, tuple)):
|
||||
raise TypeError("spec must be an instance of list or tuple")
|
||||
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
|
||||
|
||||
self._check_okay_to_chain()
|
||||
self._max = dict(spec)
|
||||
@ -701,7 +703,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
if not isinstance(spec, (list, tuple)):
|
||||
raise TypeError("spec must be an instance of list or tuple")
|
||||
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
|
||||
|
||||
self._check_okay_to_chain()
|
||||
self._min = dict(spec)
|
||||
|
||||
@ -122,7 +122,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
raise TypeError(f"name must be an instance of str, not {type(name)}")
|
||||
|
||||
if not isinstance(client, AsyncMongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
@ -1310,7 +1310,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_collection must be an instance of str")
|
||||
raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}")
|
||||
encrypted_fields = await self._get_encrypted_fields(
|
||||
{"encryptedFields": encrypted_fields},
|
||||
name,
|
||||
@ -1374,7 +1374,9 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_collection must be an instance of str or AsyncCollection")
|
||||
raise TypeError(
|
||||
f"name_or_collection must be an instance of str or AsyncCollection, not {type(name)}"
|
||||
)
|
||||
cmd = {"validate": name, "scandata": scandata, "full": full}
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
|
||||
@ -322,7 +322,9 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS)
|
||||
data_key_id = raw_doc.get("_id")
|
||||
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
|
||||
raise TypeError("data_key _id must be Binary with a UUID subtype")
|
||||
raise TypeError(
|
||||
f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}"
|
||||
)
|
||||
|
||||
assert self.key_vault_coll is not None
|
||||
await self.key_vault_coll.insert_one(raw_doc)
|
||||
@ -644,7 +646,9 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
|
||||
raise TypeError(
|
||||
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
|
||||
)
|
||||
|
||||
if not isinstance(key_vault_client, AsyncMongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
|
||||
@ -750,7 +750,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if port is None:
|
||||
port = self.PORT
|
||||
if not isinstance(port, int):
|
||||
raise TypeError("port must be an instance of int")
|
||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -1565,6 +1565,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
await self._encrypter.close()
|
||||
self._closed = True
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
|
||||
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.aclosing.
|
||||
@ -1971,7 +1977,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
The cursor is closed synchronously on the current thread.
|
||||
"""
|
||||
if not isinstance(cursor_id, int):
|
||||
raise TypeError("cursor_id must be an instance of int")
|
||||
raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}")
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
@ -2093,7 +2099,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.AsyncClientSession):
|
||||
raise ValueError("'session' argument must be an AsyncClientSession or None.")
|
||||
raise ValueError(
|
||||
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
|
||||
)
|
||||
# Don't call end_session.
|
||||
yield session
|
||||
return
|
||||
@ -2247,7 +2255,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_database must be an instance of str or a AsyncDatabase")
|
||||
raise TypeError(
|
||||
f"name_or_database must be an instance of str or a AsyncDatabase, not {type(name)}"
|
||||
)
|
||||
|
||||
async with await self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
|
||||
await self[name]._command(
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
async def join(self) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
await self._executor.join(timeout)
|
||||
await self._executor.join()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -189,6 +189,11 @@ class Monitor(MonitorBase):
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
async def join(self) -> None:
|
||||
await asyncio.gather(
|
||||
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
|
||||
) # type: ignore[func-returns-value]
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
@ -29,7 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.asynchronous.monitor import SrvMonitor
|
||||
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.errors import (
|
||||
@ -207,6 +208,9 @@ class Topology:
|
||||
if self._settings.fqdn is not None and not self._settings.load_balanced:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
# Stores all monitor tasks that need to be joined on close or server selection
|
||||
self._monitor_tasks: list[MonitorBase] = []
|
||||
|
||||
async def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
@ -241,6 +245,8 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
@ -283,6 +289,10 @@ class Topology:
|
||||
else:
|
||||
server_timeout = server_selection_timeout
|
||||
|
||||
# Cleanup any completed monitor tasks safely
|
||||
if not _IS_SYNC and self._monitor_tasks:
|
||||
await self.cleanup_monitors()
|
||||
|
||||
async with self._lock:
|
||||
server_descriptions = await self._select_servers_loop(
|
||||
selector, server_timeout, operation, operation_id, address
|
||||
@ -520,6 +530,8 @@ class Topology:
|
||||
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
|
||||
):
|
||||
await self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
@ -695,6 +707,8 @@ class Topology:
|
||||
old_td = self._description
|
||||
for server in self._servers.values():
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
@ -705,6 +719,8 @@ class Topology:
|
||||
# Stop SRV polling thread.
|
||||
if self._srv_monitor:
|
||||
await self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
self._opened = False
|
||||
self._closed = True
|
||||
@ -944,6 +960,8 @@ class Topology:
|
||||
for address, server in list(self._servers.items()):
|
||||
if not self._description.has_server(address):
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
self._servers.pop(address)
|
||||
|
||||
def _create_pool_for_server(self, address: _Address) -> Pool:
|
||||
@ -1031,6 +1049,15 @@ class Topology:
|
||||
else:
|
||||
return ",".join(str(server.error) for server in servers if server.error)
|
||||
|
||||
async def cleanup_monitors(self) -> None:
|
||||
tasks = []
|
||||
try:
|
||||
while self._monitor_tasks:
|
||||
tasks.append(self._monitor_tasks.pop())
|
||||
except IndexError:
|
||||
pass
|
||||
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ""
|
||||
if not self._opened:
|
||||
|
||||
@ -107,7 +107,7 @@ def _build_credentials_tuple(
|
||||
) -> MongoCredential:
|
||||
"""Build and return a mechanism specific credentials tuple."""
|
||||
if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None:
|
||||
raise ConfigurationError(f"{mech} requires a username.")
|
||||
raise ConfigurationError(f"{mech} requires a username")
|
||||
if mech == "GSSAPI":
|
||||
if source is not None and source != "$external":
|
||||
raise ValueError("authentication source must be $external or None for GSSAPI")
|
||||
@ -219,7 +219,7 @@ def _build_credentials_tuple(
|
||||
else:
|
||||
source_database = source or database or "admin"
|
||||
if passwd is None:
|
||||
raise ConfigurationError("A password is required.")
|
||||
raise ConfigurationError("A password is required")
|
||||
return MongoCredential(mech, source_database, user, passwd, None, _Cache())
|
||||
|
||||
|
||||
|
||||
@ -223,4 +223,4 @@ def validate_collation_or_none(
|
||||
return value.document
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None.")
|
||||
raise TypeError("collation must be a dict, an instance of collation.Collation, or None")
|
||||
|
||||
@ -202,7 +202,7 @@ def validate_integer(option: str, value: Any) -> int:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"The value of {option} must be an integer") from None
|
||||
raise TypeError(f"Wrong type for {option}, value must be an integer")
|
||||
raise TypeError(f"Wrong type for {option}, value must be an integer, not {type(value)}")
|
||||
|
||||
|
||||
def validate_positive_integer(option: str, value: Any) -> int:
|
||||
@ -250,7 +250,7 @@ def validate_string(option: str, value: Any) -> str:
|
||||
"""Validates that 'value' is an instance of `str`."""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
raise TypeError(f"Wrong type for {option}, value must be an instance of str")
|
||||
raise TypeError(f"Wrong type for {option}, value must be an instance of str, not {type(value)}")
|
||||
|
||||
|
||||
def validate_string_or_none(option: str, value: Any) -> Optional[str]:
|
||||
@ -269,7 +269,9 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
raise TypeError(f"Wrong type for {option}, value must be an integer or a string")
|
||||
raise TypeError(
|
||||
f"Wrong type for {option}, value must be an integer or a string, not {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]:
|
||||
@ -282,7 +284,9 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in
|
||||
except ValueError:
|
||||
return value
|
||||
return validate_non_negative_integer(option, val)
|
||||
raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string")
|
||||
raise TypeError(
|
||||
f"Wrong type for {option}, value must be an non negative integer or a string, not {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
def validate_positive_float(option: str, value: Any) -> float:
|
||||
@ -365,7 +369,7 @@ def validate_max_staleness(option: str, value: Any) -> int:
|
||||
def validate_read_preference(dummy: Any, value: Any) -> _ServerMode:
|
||||
"""Validate a read preference."""
|
||||
if not isinstance(value, _ServerMode):
|
||||
raise TypeError(f"{value!r} is not a read preference.")
|
||||
raise TypeError(f"{value!r} is not a read preference")
|
||||
return value
|
||||
|
||||
|
||||
@ -441,7 +445,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
props: dict[str, Any] = {}
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("Auth mechanism properties must be given as a string or a dictionary")
|
||||
raise ValueError(
|
||||
f"Auth mechanism properties must be given as a string or a dictionary, not {type(value)}"
|
||||
)
|
||||
for key, value in value.items(): # noqa: B020
|
||||
if isinstance(value, str):
|
||||
props[key] = value
|
||||
@ -453,7 +459,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
from pymongo.auth_oidc_shared import OIDCCallback
|
||||
|
||||
if not isinstance(value, OIDCCallback):
|
||||
raise ValueError("callback must be an OIDCCallback object")
|
||||
raise ValueError(f"callback must be an OIDCCallback object, not {type(value)}")
|
||||
props[key] = value
|
||||
else:
|
||||
raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}")
|
||||
@ -476,7 +482,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
raise ValueError(
|
||||
f"{key} is not a supported auth "
|
||||
"mechanism property. Must be one of "
|
||||
f"{tuple(_MECHANISM_PROPS)}."
|
||||
f"{tuple(_MECHANISM_PROPS)}"
|
||||
)
|
||||
|
||||
if key == "CANONICALIZE_HOST_NAME":
|
||||
@ -520,7 +526,7 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
|
||||
def validate_list(option: str, value: Any) -> list:
|
||||
"""Validates that 'value' is a list."""
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"{option} must be a list")
|
||||
raise TypeError(f"{option} must be a list, not {type(value)}")
|
||||
return value
|
||||
|
||||
|
||||
@ -587,7 +593,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, ServerApi):
|
||||
raise TypeError(f"{option} must be an instance of ServerApi")
|
||||
raise TypeError(f"{option} must be an instance of ServerApi, not {type(value)}")
|
||||
return value
|
||||
|
||||
|
||||
@ -596,7 +602,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
|
||||
if value is None:
|
||||
return value
|
||||
if not callable(value):
|
||||
raise ValueError(f"{option} must be a callable")
|
||||
raise ValueError(f"{option} must be a callable, not {type(value)}")
|
||||
return value
|
||||
|
||||
|
||||
@ -651,7 +657,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A
|
||||
from pymongo.encryption_options import AutoEncryptionOpts
|
||||
|
||||
if not isinstance(value, AutoEncryptionOpts):
|
||||
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts")
|
||||
raise TypeError(f"{option} must be an instance of AutoEncryptionOpts, not {type(value)}")
|
||||
|
||||
return value
|
||||
|
||||
@ -668,7 +674,9 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo
|
||||
elif isinstance(value, int):
|
||||
return DatetimeConversion(value)
|
||||
|
||||
raise TypeError(f"{option} must be a str or int representing DatetimeConversion")
|
||||
raise TypeError(
|
||||
f"{option} must be a str or int representing DatetimeConversion, not {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
def validate_server_monitoring_mode(option: str, value: str) -> str:
|
||||
@ -928,12 +936,14 @@ class BaseObject:
|
||||
|
||||
if not isinstance(write_concern, WriteConcern):
|
||||
raise TypeError(
|
||||
"write_concern must be an instance of pymongo.write_concern.WriteConcern"
|
||||
f"write_concern must be an instance of pymongo.write_concern.WriteConcern, not {type(write_concern)}"
|
||||
)
|
||||
self._write_concern = write_concern
|
||||
|
||||
if not isinstance(read_concern, ReadConcern):
|
||||
raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern")
|
||||
raise TypeError(
|
||||
f"read_concern must be an instance of pymongo.read_concern.ReadConcern, not {type(read_concern)}"
|
||||
)
|
||||
self._read_concern = read_concern
|
||||
|
||||
@property
|
||||
|
||||
@ -91,7 +91,7 @@ def validate_zlib_compression_level(option: str, value: Any) -> int:
|
||||
try:
|
||||
level = int(value)
|
||||
except Exception:
|
||||
raise TypeError(f"{option} must be an integer, not {value!r}.") from None
|
||||
raise TypeError(f"{option} must be an integer, not {value!r}") from None
|
||||
if level < -1 or level > 9:
|
||||
raise ValueError("%s must be between -1 and 9, not %d." % (option, level))
|
||||
return level
|
||||
|
||||
@ -39,7 +39,7 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])):
|
||||
for key, value in self._asdict().items():
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Wrong type for DriverInfo {key} option, value must be an instance of str"
|
||||
f"Wrong type for DriverInfo {key} option, value must be an instance of str, not {type(value)}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@ -225,7 +225,9 @@ class AutoEncryptionOpts:
|
||||
mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"]
|
||||
self._mongocryptd_spawn_args = mongocryptd_spawn_args
|
||||
if not isinstance(self._mongocryptd_spawn_args, list):
|
||||
raise TypeError("mongocryptd_spawn_args must be a list")
|
||||
raise TypeError(
|
||||
f"mongocryptd_spawn_args must be a list, not {type(self._mongocryptd_spawn_args)}"
|
||||
)
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
|
||||
@ -122,7 +122,7 @@ def _index_list(
|
||||
"""
|
||||
if direction is not None:
|
||||
if not isinstance(key_or_list, str):
|
||||
raise TypeError("Expected a string and a direction")
|
||||
raise TypeError(f"Expected a string and a direction, not {type(key_or_list)}")
|
||||
return [(key_or_list, direction)]
|
||||
else:
|
||||
if isinstance(key_or_list, str):
|
||||
@ -132,7 +132,9 @@ def _index_list(
|
||||
elif isinstance(key_or_list, abc.Mapping):
|
||||
return list(key_or_list.items())
|
||||
elif not isinstance(key_or_list, (list, tuple)):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
raise TypeError(
|
||||
f"if no direction is specified, key_or_list must be an instance of list, not {type(key_or_list)}"
|
||||
)
|
||||
values: list[tuple[str, int]] = []
|
||||
for item in key_or_list:
|
||||
if isinstance(item, str):
|
||||
@ -172,11 +174,12 @@ def _index_document(index_list: _IndexList) -> dict[str, Any]:
|
||||
|
||||
def _validate_index_key_pair(key: Any, value: Any) -> None:
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("first item in each key pair must be an instance of str")
|
||||
raise TypeError(f"first item in each key pair must be an instance of str, not {type(key)}")
|
||||
if not isinstance(value, (str, int, abc.Mapping)):
|
||||
raise TypeError(
|
||||
"second item in each key pair must be 1, -1, "
|
||||
"'2d', or another valid MongoDB index specifier."
|
||||
f", not {type(value)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -472,14 +472,15 @@ def _validate_event_listeners(
|
||||
) -> Sequence[_EventListeners]:
|
||||
"""Validate event listeners"""
|
||||
if not isinstance(listeners, abc.Sequence):
|
||||
raise TypeError(f"{option} must be a list or tuple")
|
||||
raise TypeError(f"{option} must be a list or tuple, not {type(listeners)}")
|
||||
for listener in listeners:
|
||||
if not isinstance(listener, _EventListener):
|
||||
raise TypeError(
|
||||
f"Listeners for {option} must be either a "
|
||||
"CommandListener, ServerHeartbeatListener, "
|
||||
"ServerListener, TopologyListener, or "
|
||||
"ConnectionPoolListener."
|
||||
"ConnectionPoolListener,"
|
||||
f"not {type(listener)}"
|
||||
)
|
||||
return listeners
|
||||
|
||||
@ -496,7 +497,8 @@ def register(listener: _EventListener) -> None:
|
||||
f"Listeners for {listener} must be either a "
|
||||
"CommandListener, ServerHeartbeatListener, "
|
||||
"ServerListener, TopologyListener, or "
|
||||
"ConnectionPoolListener."
|
||||
"ConnectionPoolListener,"
|
||||
f"not {type(listener)}"
|
||||
)
|
||||
if isinstance(listener, CommandListener):
|
||||
_LISTENERS.command_listeners.append(listener)
|
||||
|
||||
@ -75,6 +75,8 @@ class AsyncPeriodicExecutor:
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._task is not None:
|
||||
|
||||
@ -38,7 +38,7 @@ class ReadConcern:
|
||||
if level is None or isinstance(level, str):
|
||||
self.__level = level
|
||||
else:
|
||||
raise TypeError("level must be a string or None.")
|
||||
raise TypeError(f"level must be a string or None, not {type(level)}")
|
||||
|
||||
@property
|
||||
def level(self) -> Optional[str]:
|
||||
|
||||
@ -115,4 +115,4 @@ else:
|
||||
|
||||
def get_ssl_context(*dummy): # type: ignore
|
||||
"""No ssl module, raise ConfigurationError."""
|
||||
raise ConfigurationError("The ssl module is not available.")
|
||||
raise ConfigurationError("The ssl module is not available")
|
||||
|
||||
@ -158,7 +158,7 @@ def _password_digest(username: str, password: str) -> str:
|
||||
if len(password) == 0:
|
||||
raise ValueError("password can't be empty")
|
||||
if not isinstance(username, str):
|
||||
raise TypeError("username must be an instance of str")
|
||||
raise TypeError(f"username must be an instance of str, not {type(username)}")
|
||||
|
||||
md5hash = hashlib.md5() # noqa: S324
|
||||
data = f"{username}:mongo:{password}"
|
||||
|
||||
@ -213,7 +213,9 @@ class _OIDCAuthenticator:
|
||||
)
|
||||
resp = cb.fetch(context)
|
||||
if not isinstance(resp, OIDCCallbackResult):
|
||||
raise ValueError("Callback result must be of type OIDCCallbackResult")
|
||||
raise ValueError(
|
||||
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
|
||||
)
|
||||
self.refresh_token = resp.refresh_token
|
||||
self.access_token = resp.access_token
|
||||
self.token_gen_id += 1
|
||||
|
||||
@ -309,7 +309,9 @@ class TransactionOptions:
|
||||
)
|
||||
if max_commit_time_ms is not None:
|
||||
if not isinstance(max_commit_time_ms, int):
|
||||
raise TypeError("max_commit_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_commit_time_ms must be an integer or None, not {type(max_commit_time_ms)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def read_concern(self) -> Optional[ReadConcern]:
|
||||
@ -897,7 +899,9 @@ class ClientSession:
|
||||
another `ClientSession` instance.
|
||||
"""
|
||||
if not isinstance(cluster_time, _Mapping):
|
||||
raise TypeError("cluster_time must be a subclass of collections.Mapping")
|
||||
raise TypeError(
|
||||
f"cluster_time must be a subclass of collections.Mapping, not {type(cluster_time)}"
|
||||
)
|
||||
if not isinstance(cluster_time.get("clusterTime"), Timestamp):
|
||||
raise ValueError("Invalid cluster_time")
|
||||
self._advance_cluster_time(cluster_time)
|
||||
@ -918,7 +922,9 @@ class ClientSession:
|
||||
another `ClientSession` instance.
|
||||
"""
|
||||
if not isinstance(operation_time, Timestamp):
|
||||
raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp")
|
||||
raise TypeError(
|
||||
f"operation_time must be an instance of bson.timestamp.Timestamp, not {type(operation_time)}"
|
||||
)
|
||||
self._advance_operation_time(operation_time)
|
||||
|
||||
def _process_response(self, reply: Mapping[str, Any]) -> None:
|
||||
|
||||
@ -231,7 +231,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern or database.read_concern,
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
raise TypeError(f"name must be an instance of str, not {type(name)}")
|
||||
from pymongo.synchronous.database import Database
|
||||
|
||||
if not isinstance(database, Database):
|
||||
@ -2472,7 +2472,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
name = helpers_shared._gen_index_name(index_or_name)
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("index_or_name must be an instance of str or list")
|
||||
raise TypeError(f"index_or_name must be an instance of str or list, not {type(name)}")
|
||||
|
||||
cmd = {"dropIndexes": self._name, "index": name}
|
||||
cmd.update(kwargs)
|
||||
@ -3071,7 +3071,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
"""
|
||||
if not isinstance(new_name, str):
|
||||
raise TypeError("new_name must be an instance of str")
|
||||
raise TypeError(f"new_name must be an instance of str, not {type(new_name)}")
|
||||
|
||||
if not new_name or ".." in new_name:
|
||||
raise InvalidName("collection names cannot be empty")
|
||||
@ -3141,7 +3141,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
"""
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("key must be an instance of str")
|
||||
raise TypeError(f"key must be an instance of str, not {type(key)}")
|
||||
cmd = {"distinct": self._name, "key": key}
|
||||
if filter is not None:
|
||||
if "query" in kwargs:
|
||||
@ -3189,7 +3189,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
common.validate_is_mapping("filter", filter)
|
||||
if not isinstance(return_document, bool):
|
||||
raise ValueError(
|
||||
"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER"
|
||||
f"return_document must be ReturnDocument.BEFORE or ReturnDocument.AFTER, not {type(return_document)}"
|
||||
)
|
||||
collation = validate_collation_or_none(kwargs.pop("collation", None))
|
||||
cmd = {"findAndModify": self._name, "query": filter, "new": return_document}
|
||||
|
||||
@ -94,7 +94,9 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
self.batch_size(batch_size)
|
||||
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self._die_no_lock()
|
||||
@ -115,7 +117,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
|
||||
|
||||
@ -146,9 +146,9 @@ class Cursor(Generic[_DocumentType]):
|
||||
spec: Mapping[str, Any] = filter or {}
|
||||
validate_is_mapping("filter", spec)
|
||||
if not isinstance(skip, int):
|
||||
raise TypeError("skip must be an instance of int")
|
||||
raise TypeError(f"skip must be an instance of int, not {type(skip)}")
|
||||
if not isinstance(limit, int):
|
||||
raise TypeError("limit must be an instance of int")
|
||||
raise TypeError(f"limit must be an instance of int, not {type(limit)}")
|
||||
validate_boolean("no_cursor_timeout", no_cursor_timeout)
|
||||
if no_cursor_timeout and not self._explicit_session:
|
||||
warnings.warn(
|
||||
@ -171,7 +171,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
validate_boolean("allow_partial_results", allow_partial_results)
|
||||
validate_boolean("oplog_replay", oplog_replay)
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
# Only set if allow_disk_use is provided by the user, else None.
|
||||
@ -388,7 +388,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
cursor.add_option(2)
|
||||
"""
|
||||
if not isinstance(mask, int):
|
||||
raise TypeError("mask must be an int")
|
||||
raise TypeError(f"mask must be an int, not {type(mask)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
if mask & _QUERY_OPTIONS["exhaust"]:
|
||||
@ -408,7 +408,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
cursor.remove_option(2)
|
||||
"""
|
||||
if not isinstance(mask, int):
|
||||
raise TypeError("mask must be an int")
|
||||
raise TypeError(f"mask must be an int, not {type(mask)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
if mask & _QUERY_OPTIONS["exhaust"]:
|
||||
@ -432,7 +432,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 3.11
|
||||
"""
|
||||
if not isinstance(allow_disk_use, bool):
|
||||
raise TypeError("allow_disk_use must be a bool")
|
||||
raise TypeError(f"allow_disk_use must be a bool, not {type(allow_disk_use)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
self._allow_disk_use = allow_disk_use
|
||||
@ -451,7 +451,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
.. seealso:: The MongoDB documentation on `limit <https://dochub.mongodb.org/core/limit>`_.
|
||||
"""
|
||||
if not isinstance(limit, int):
|
||||
raise TypeError("limit must be an integer")
|
||||
raise TypeError(f"limit must be an integer, not {type(limit)}")
|
||||
if self._exhaust:
|
||||
raise InvalidOperation("Can't use limit and exhaust together.")
|
||||
self._check_okay_to_chain()
|
||||
@ -479,7 +479,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
:param batch_size: The size of each batch of results requested.
|
||||
"""
|
||||
if not isinstance(batch_size, int):
|
||||
raise TypeError("batch_size must be an integer")
|
||||
raise TypeError(f"batch_size must be an integer, not {type(batch_size)}")
|
||||
if batch_size < 0:
|
||||
raise ValueError("batch_size must be >= 0")
|
||||
self._check_okay_to_chain()
|
||||
@ -499,7 +499,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
:param skip: the number of results to skip
|
||||
"""
|
||||
if not isinstance(skip, int):
|
||||
raise TypeError("skip must be an integer")
|
||||
raise TypeError(f"skip must be an integer, not {type(skip)}")
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be >= 0")
|
||||
self._check_okay_to_chain()
|
||||
@ -520,7 +520,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
:param max_time_ms: the time limit after which the operation is aborted
|
||||
"""
|
||||
if not isinstance(max_time_ms, int) and max_time_ms is not None:
|
||||
raise TypeError("max_time_ms must be an integer or None")
|
||||
raise TypeError(f"max_time_ms must be an integer or None, not {type(max_time_ms)}")
|
||||
self._check_okay_to_chain()
|
||||
|
||||
self._max_time_ms = max_time_ms
|
||||
@ -543,7 +543,9 @@ class Cursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None:
|
||||
raise TypeError("max_await_time_ms must be an integer or None")
|
||||
raise TypeError(
|
||||
f"max_await_time_ms must be an integer or None, not {type(max_await_time_ms)}"
|
||||
)
|
||||
self._check_okay_to_chain()
|
||||
|
||||
# Ignore max_await_time_ms if not tailable or await_data is False.
|
||||
@ -677,7 +679,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
if not isinstance(spec, (list, tuple)):
|
||||
raise TypeError("spec must be an instance of list or tuple")
|
||||
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
|
||||
|
||||
self._check_okay_to_chain()
|
||||
self._max = dict(spec)
|
||||
@ -699,7 +701,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
if not isinstance(spec, (list, tuple)):
|
||||
raise TypeError("spec must be an instance of list or tuple")
|
||||
raise TypeError(f"spec must be an instance of list or tuple, not {type(spec)}")
|
||||
|
||||
self._check_okay_to_chain()
|
||||
self._min = dict(spec)
|
||||
|
||||
@ -122,7 +122,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
raise TypeError(f"name must be an instance of str, not {type(name)}")
|
||||
|
||||
if not isinstance(client, MongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
@ -1303,7 +1303,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_collection must be an instance of str")
|
||||
raise TypeError(f"name_or_collection must be an instance of str, not {type(name)}")
|
||||
encrypted_fields = self._get_encrypted_fields(
|
||||
{"encryptedFields": encrypted_fields},
|
||||
name,
|
||||
@ -1367,7 +1367,9 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_collection must be an instance of str or Collection")
|
||||
raise TypeError(
|
||||
f"name_or_collection must be an instance of str or Collection, not {type(name)}"
|
||||
)
|
||||
cmd = {"validate": name, "scandata": scandata, "full": full}
|
||||
if comment is not None:
|
||||
cmd["comment"] = comment
|
||||
|
||||
@ -320,7 +320,9 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS)
|
||||
data_key_id = raw_doc.get("_id")
|
||||
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
|
||||
raise TypeError("data_key _id must be Binary with a UUID subtype")
|
||||
raise TypeError(
|
||||
f"data_key _id must be Binary with a UUID subtype, not {type(data_key_id)}"
|
||||
)
|
||||
|
||||
assert self.key_vault_coll is not None
|
||||
self.key_vault_coll.insert_one(raw_doc)
|
||||
@ -642,7 +644,9 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
)
|
||||
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
|
||||
raise TypeError(
|
||||
f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}"
|
||||
)
|
||||
|
||||
if not isinstance(key_vault_client, MongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
|
||||
@ -748,7 +748,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if port is None:
|
||||
port = self.PORT
|
||||
if not isinstance(port, int):
|
||||
raise TypeError("port must be an instance of int")
|
||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -1559,6 +1559,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
self._encrypter.close()
|
||||
self._closed = True
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
|
||||
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.closing.
|
||||
@ -1965,7 +1971,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
The cursor is closed synchronously on the current thread.
|
||||
"""
|
||||
if not isinstance(cursor_id, int):
|
||||
raise TypeError("cursor_id must be an instance of int")
|
||||
raise TypeError(f"cursor_id must be an instance of int, not {type(cursor_id)}")
|
||||
|
||||
try:
|
||||
if conn_mgr:
|
||||
@ -2087,7 +2093,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.ClientSession):
|
||||
raise ValueError("'session' argument must be a ClientSession or None.")
|
||||
raise ValueError(
|
||||
f"'session' argument must be a ClientSession or None, not {type(session)}"
|
||||
)
|
||||
# Don't call end_session.
|
||||
yield session
|
||||
return
|
||||
@ -2235,7 +2243,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
name = name.name
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name_or_database must be an instance of str or a Database")
|
||||
raise TypeError(
|
||||
f"name_or_database must be an instance of str or a Database, not {type(name)}"
|
||||
)
|
||||
|
||||
with self._conn_for_writes(session, operation=_Op.DROP_DATABASE) as conn:
|
||||
self[name]._command(
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
def join(self) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
self._executor.join()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -189,6 +189,9 @@ class Monitor(MonitorBase):
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
def join(self) -> None:
|
||||
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
self._rtt_monitor.close()
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
@ -61,7 +62,7 @@ from pymongo.server_selectors import (
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.synchronous.monitor import SrvMonitor
|
||||
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
|
||||
from pymongo.synchronous.pool import Pool
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.topology_description import (
|
||||
@ -207,6 +208,9 @@ class Topology:
|
||||
if self._settings.fqdn is not None and not self._settings.load_balanced:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
# Stores all monitor tasks that need to be joined on close or server selection
|
||||
self._monitor_tasks: list[MonitorBase] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
@ -241,6 +245,8 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
@ -283,6 +289,10 @@ class Topology:
|
||||
else:
|
||||
server_timeout = server_selection_timeout
|
||||
|
||||
# Cleanup any completed monitor tasks safely
|
||||
if not _IS_SYNC and self._monitor_tasks:
|
||||
self.cleanup_monitors()
|
||||
|
||||
with self._lock:
|
||||
server_descriptions = self._select_servers_loop(
|
||||
selector, server_timeout, operation, operation_id, address
|
||||
@ -520,6 +530,8 @@ class Topology:
|
||||
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
|
||||
):
|
||||
self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
@ -693,6 +705,8 @@ class Topology:
|
||||
old_td = self._description
|
||||
for server in self._servers.values():
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
@ -703,6 +717,8 @@ class Topology:
|
||||
# Stop SRV polling thread.
|
||||
if self._srv_monitor:
|
||||
self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
self._opened = False
|
||||
self._closed = True
|
||||
@ -942,6 +958,8 @@ class Topology:
|
||||
for address, server in list(self._servers.items()):
|
||||
if not self._description.has_server(address):
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
self._servers.pop(address)
|
||||
|
||||
def _create_pool_for_server(self, address: _Address) -> Pool:
|
||||
@ -1029,6 +1047,15 @@ class Topology:
|
||||
else:
|
||||
return ",".join(str(server.error) for server in servers if server.error)
|
||||
|
||||
def cleanup_monitors(self) -> None:
|
||||
tasks = []
|
||||
try:
|
||||
while self._monitor_tasks:
|
||||
tasks.append(self._monitor_tasks.pop())
|
||||
except IndexError:
|
||||
pass
|
||||
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ""
|
||||
if not self._opened:
|
||||
|
||||
@ -91,7 +91,7 @@ def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||
user, _, passwd = userinfo.partition(":")
|
||||
# No password is expected with GSSAPI authentication.
|
||||
if not user:
|
||||
raise InvalidURI("The empty string is not valid username.")
|
||||
raise InvalidURI("The empty string is not valid username")
|
||||
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
@ -347,7 +347,7 @@ def split_options(
|
||||
semi_idx = opts.find(";")
|
||||
try:
|
||||
if and_idx >= 0 and semi_idx >= 0:
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators.")
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators")
|
||||
elif and_idx >= 0:
|
||||
options = _parse_options(opts, "&")
|
||||
elif semi_idx >= 0:
|
||||
@ -357,7 +357,7 @@ def split_options(
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs.") from None
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs") from None
|
||||
|
||||
options = _handle_security_options(options)
|
||||
|
||||
@ -389,7 +389,7 @@ def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[
|
||||
nodes = []
|
||||
for entity in hosts.split(","):
|
||||
if not entity:
|
||||
raise ConfigurationError("Empty host (or extra comma in host list).")
|
||||
raise ConfigurationError("Empty host (or extra comma in host list)")
|
||||
port = default_port
|
||||
# Unix socket entities don't have ports
|
||||
if entity.endswith(".sock"):
|
||||
@ -502,7 +502,7 @@ def parse_uri(
|
||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||
|
||||
if not scheme_free:
|
||||
raise InvalidURI("Must provide at least one hostname or IP.")
|
||||
raise InvalidURI("Must provide at least one hostname or IP")
|
||||
|
||||
user = None
|
||||
passwd = None
|
||||
|
||||
@ -74,7 +74,7 @@ class WriteConcern:
|
||||
|
||||
if wtimeout is not None:
|
||||
if not isinstance(wtimeout, int):
|
||||
raise TypeError("wtimeout must be an integer")
|
||||
raise TypeError(f"wtimeout must be an integer, not {type(wtimeout)}")
|
||||
if wtimeout < 0:
|
||||
raise ValueError("wtimeout cannot be less than 0")
|
||||
self.__document["wtimeout"] = wtimeout
|
||||
@ -98,7 +98,7 @@ class WriteConcern:
|
||||
raise ValueError("w cannot be less than 0")
|
||||
self.__acknowledged = w > 0
|
||||
elif not isinstance(w, str):
|
||||
raise TypeError("w must be an integer or string")
|
||||
raise TypeError(f"w must be an integer or string, not {type(w)}")
|
||||
self.__document["w"] = w
|
||||
|
||||
self.__server_default = not self.__document
|
||||
|
||||
123
test/__init__.py
123
test/__init__.py
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -30,6 +31,33 @@ import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.helpers import (
|
||||
COMPRESSORS,
|
||||
IS_SRV,
|
||||
@ -52,31 +80,7 @@ from test.helpers import (
|
||||
sanitize_cmd,
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -589,7 +593,7 @@ class ClientContext:
|
||||
if self.has_secondaries:
|
||||
return True
|
||||
if self.is_mongos:
|
||||
shard = self.client.config.shards.find_one()["host"] # type:ignore[index]
|
||||
shard = (self.client.config.shards.find_one())["host"] # type:ignore[index]
|
||||
num_members = shard.count(",") + 1
|
||||
return num_members > 1
|
||||
return False
|
||||
@ -863,18 +867,66 @@ class ClientContext:
|
||||
# Reusable client context
|
||||
client_context = ClientContext()
|
||||
|
||||
# Global event loop for async tests.
|
||||
LOOP = None
|
||||
|
||||
def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif client_context.client is not None:
|
||||
client_context.client.close()
|
||||
client_context.client = None
|
||||
client_context._init_client()
|
||||
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get the test suite's global event loop."""
|
||||
global LOOP
|
||||
if LOOP is None:
|
||||
try:
|
||||
LOOP = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# no running event loop, fallback to get_event_loop.
|
||||
try:
|
||||
# Ignore DeprecationWarning: There is no current event loop
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
LOOP = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
LOOP = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(LOOP)
|
||||
return LOOP
|
||||
|
||||
|
||||
class PyMongoTestCase(unittest.TestCase):
|
||||
if not _IS_SYNC:
|
||||
# An async TestCase that uses a single event loop for all tests.
|
||||
# Inspired by TestCase.
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def addCleanup(self, func, /, *args, **kwargs):
|
||||
self.addCleanup(*(func, *args), **kwargs)
|
||||
|
||||
def _callSetUp(self):
|
||||
self.setUp()
|
||||
self._callAsync(self.setUp)
|
||||
|
||||
def _callTestMethod(self, method):
|
||||
self._callMaybeAsync(method)
|
||||
|
||||
def _callTearDown(self):
|
||||
self._callAsync(self.tearDown)
|
||||
self.tearDown()
|
||||
|
||||
def _callCleanup(self, function, *args, **kwargs):
|
||||
self._callMaybeAsync(function, *args, **kwargs)
|
||||
|
||||
def _callAsync(self, func, /, *args, **kwargs):
|
||||
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
|
||||
def _callMaybeAsync(self, func, /, *args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
|
||||
@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
|
||||
|
||||
@client_context.require_connection
|
||||
def setUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
reset_client_context()
|
||||
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest):
|
||||
|
||||
|
||||
def setup():
|
||||
if not _IS_SYNC:
|
||||
# Set up the event loop.
|
||||
get_loop()
|
||||
client_context.init()
|
||||
warnings.resetwarnings()
|
||||
warnings.simplefilter("always")
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -30,6 +31,33 @@ import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.helpers import (
|
||||
COMPRESSORS,
|
||||
IS_SRV,
|
||||
@ -52,31 +80,7 @@ from test.helpers import (
|
||||
sanitize_cmd,
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -588,10 +592,10 @@ class AsyncClientContext:
|
||||
|
||||
@property
|
||||
async def supports_secondary_read_pref(self):
|
||||
if self.has_secondaries:
|
||||
if await self.has_secondaries:
|
||||
return True
|
||||
if self.is_mongos:
|
||||
shard = await self.client.config.shards.find_one()["host"] # type:ignore[index]
|
||||
shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index]
|
||||
num_members = shard.count(",") + 1
|
||||
return num_members > 1
|
||||
return False
|
||||
@ -865,18 +869,66 @@ class AsyncClientContext:
|
||||
# Reusable client context
|
||||
async_client_context = AsyncClientContext()
|
||||
|
||||
|
||||
async def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif async_client_context.client is not None:
|
||||
await async_client_context.client.close()
|
||||
async_client_context.client = None
|
||||
await async_client_context._init_client()
|
||||
# Global event loop for async tests.
|
||||
LOOP = None
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get the test suite's global event loop."""
|
||||
global LOOP
|
||||
if LOOP is None:
|
||||
try:
|
||||
LOOP = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# no running event loop, fallback to get_event_loop.
|
||||
try:
|
||||
# Ignore DeprecationWarning: There is no current event loop
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
LOOP = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
LOOP = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(LOOP)
|
||||
return LOOP
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.TestCase):
|
||||
if not _IS_SYNC:
|
||||
# An async TestCase that uses a single event loop for all tests.
|
||||
# Inspired by IsolatedAsyncioTestCase.
|
||||
async def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
pass
|
||||
|
||||
def addAsyncCleanup(self, func, /, *args, **kwargs):
|
||||
self.addCleanup(*(func, *args), **kwargs)
|
||||
|
||||
def _callSetUp(self):
|
||||
self.setUp()
|
||||
self._callAsync(self.asyncSetUp)
|
||||
|
||||
def _callTestMethod(self, method):
|
||||
self._callMaybeAsync(method)
|
||||
|
||||
def _callTearDown(self):
|
||||
self._callAsync(self.asyncTearDown)
|
||||
self.tearDown()
|
||||
|
||||
def _callCleanup(self, function, *args, **kwargs):
|
||||
self._callMaybeAsync(function, *args, **kwargs)
|
||||
|
||||
def _callAsync(self, func, /, *args, **kwargs):
|
||||
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
|
||||
def _callMaybeAsync(self, func, /, *args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
|
||||
@ -1124,15 +1176,15 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def disable_replication(self, client):
|
||||
"""Disable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
for h, p in await client.secondaries:
|
||||
secondary = await self.async_single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
|
||||
await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn")
|
||||
|
||||
async def enable_replication(self, client):
|
||||
"""Enable replication on all secondaries."""
|
||||
for h, p in client.secondaries:
|
||||
for h, p in await client.secondaries:
|
||||
secondary = await self.async_single_client(h, p)
|
||||
secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
|
||||
await secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off")
|
||||
|
||||
|
||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
await reset_client_context()
|
||||
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest):
|
||||
|
||||
|
||||
async def async_setup():
|
||||
if not _IS_SYNC:
|
||||
# Set up the event loop.
|
||||
get_loop()
|
||||
await async_client_context.init()
|
||||
warnings.resetwarnings()
|
||||
warnings.simplefilter("always")
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import multiprocessing
|
||||
@ -30,6 +31,8 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -369,3 +372,38 @@ class SystemCertsPatcher:
|
||||
os.environ.pop("SSL_CERT_FILE")
|
||||
else:
|
||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
PARENT = threading.Thread
|
||||
else:
|
||||
PARENT = object
|
||||
|
||||
|
||||
class ConcurrentRunner(PARENT):
|
||||
def __init__(self, **kwargs):
|
||||
if _IS_SYNC:
|
||||
super().__init__(**kwargs)
|
||||
self.name = kwargs.get("name", "ConcurrentRunner")
|
||||
self.stopped = False
|
||||
self.task = None
|
||||
self.target = kwargs.get("target", None)
|
||||
self.args = kwargs.get("args", [])
|
||||
|
||||
if not _IS_SYNC:
|
||||
|
||||
async def start(self):
|
||||
self.task = create_task(self.run(), name=self.name)
|
||||
|
||||
async def join(self, timeout: float | None = 0): # type: ignore[override]
|
||||
if self.task is not None:
|
||||
await asyncio.wait([self.task], timeout=timeout)
|
||||
|
||||
def is_alive(self):
|
||||
return not self.stopped
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
await self.target(*self.args)
|
||||
finally:
|
||||
self.stopped = True
|
||||
|
||||
221
test/asynchronous/test_dns.py
Normal file
221
test/asynchronous/test_dns.py
Normal file
@ -0,0 +1,221 @@
|
||||
# Copyright 2017 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the SRV support tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
async_client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import async_wait_until
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestDNSRepl(AsyncPyMongoTestCase):
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set"
|
||||
)
|
||||
load_balanced = False
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestDNSLoadBalanced(AsyncPyMongoTestCase):
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced"
|
||||
)
|
||||
load_balanced = True
|
||||
|
||||
@async_client_context.require_load_balancer
|
||||
def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestDNSSharded(AsyncPyMongoTestCase):
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded"
|
||||
)
|
||||
load_balanced = False
|
||||
|
||||
@async_client_context.require_mongos
|
||||
def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_test(test_case):
|
||||
async def run_test(self):
|
||||
uri = test_case["uri"]
|
||||
seeds = test_case.get("seeds")
|
||||
num_seeds = test_case.get("numSeeds", len(seeds or []))
|
||||
hosts = test_case.get("hosts")
|
||||
num_hosts = test_case.get("numHosts", len(hosts or []))
|
||||
|
||||
options = test_case.get("options", {})
|
||||
if "ssl" in options:
|
||||
options["tls"] = options.pop("ssl")
|
||||
parsed_options = test_case.get("parsed_options")
|
||||
# See DRIVERS-1324, unless tls is explicitly set to False we need TLS.
|
||||
needs_tls = not (options and (options.get("ssl") is False or options.get("tls") is False))
|
||||
if needs_tls and not async_client_context.tls:
|
||||
self.skipTest("this test requires a TLS cluster")
|
||||
if not needs_tls and async_client_context.tls:
|
||||
self.skipTest("this test requires a non-TLS cluster")
|
||||
|
||||
if seeds:
|
||||
seeds = split_hosts(",".join(seeds))
|
||||
if hosts:
|
||||
hosts = frozenset(split_hosts(",".join(hosts)))
|
||||
|
||||
if seeds or num_seeds:
|
||||
result = parse_uri(uri, validate=True)
|
||||
if seeds is not None:
|
||||
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
|
||||
if num_seeds is not None:
|
||||
self.assertEqual(len(result["nodelist"]), num_seeds)
|
||||
if options:
|
||||
opts = result["options"]
|
||||
if "readpreferencetags" in opts:
|
||||
rpts = validate_read_preference_tags(
|
||||
"readPreferenceTags", opts.pop("readpreferencetags")
|
||||
)
|
||||
opts["readPreferenceTags"] = rpts
|
||||
self.assertEqual(result["options"], options)
|
||||
if parsed_options:
|
||||
for opt, expected in parsed_options.items():
|
||||
if opt == "user":
|
||||
self.assertEqual(result["username"], expected)
|
||||
elif opt == "password":
|
||||
self.assertEqual(result["password"], expected)
|
||||
elif opt == "auth_database" or opt == "db":
|
||||
self.assertEqual(result["database"], expected)
|
||||
|
||||
hostname = next(iter(async_client_context.client.nodes))[0]
|
||||
# The replica set members must be configured as 'localhost'.
|
||||
if hostname == "localhost":
|
||||
copts = async_client_context.default_client_options.copy()
|
||||
# Remove tls since SRV parsing should add it automatically.
|
||||
copts.pop("tls", None)
|
||||
if async_client_context.tls:
|
||||
# Our test certs don't support the SRV hosts used in these
|
||||
# tests.
|
||||
copts["tlsAllowInvalidHostnames"] = True
|
||||
|
||||
client = self.simple_client(uri, **copts)
|
||||
if client._options.connect:
|
||||
await client.aconnect()
|
||||
if num_seeds is not None:
|
||||
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
|
||||
if hosts is not None:
|
||||
await async_wait_until(
|
||||
lambda: hosts == client.nodes, "match test hosts to client nodes"
|
||||
)
|
||||
if num_hosts is not None:
|
||||
await async_wait_until(
|
||||
lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts"
|
||||
)
|
||||
if test_case.get("ping", True):
|
||||
await client.admin.command("ping")
|
||||
# XXX: we should block until SRV poller runs at least once
|
||||
# and re-run these assertions.
|
||||
else:
|
||||
try:
|
||||
parse_uri(uri)
|
||||
except (ConfigurationError, ValueError):
|
||||
pass
|
||||
else:
|
||||
self.fail("failed to raise an exception")
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
def create_tests(cls):
|
||||
for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")):
|
||||
test_suffix, _ = os.path.splitext(os.path.basename(filename))
|
||||
with open(filename) as dns_test_file:
|
||||
test_method = create_test(json.load(dns_test_file))
|
||||
setattr(cls, "test_" + test_suffix, test_method)
|
||||
|
||||
|
||||
create_tests(TestDNSRepl)
|
||||
create_tests(TestDNSLoadBalanced)
|
||||
create_tests(TestDNSSharded)
|
||||
|
||||
|
||||
class TestParsingErrors(AsyncPyMongoTestCase):
|
||||
async def test_invalid_host(self):
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb.com is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb.com",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://127.0.0.1",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://[::1]",
|
||||
)
|
||||
|
||||
|
||||
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
|
||||
async def test_connect_case_insensitive(self):
|
||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
1461
test/asynchronous/test_examples.py
Normal file
1461
test/asynchronous/test_examples.py
Normal file
File diff suppressed because it is too large
Load Diff
39
test/asynchronous/test_gridfs_spec.py
Normal file
39
test/asynchronous/test_gridfs_spec.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the AsyncGridFS unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
97
test/asynchronous/test_heartbeat_monitoring.py
Normal file
97
test/asynchronous/test_heartbeat_monitoring.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright 2016-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the monitoring of the server heartbeats."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest
|
||||
from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until
|
||||
|
||||
from pymongo.asynchronous.monitor import Monitor
|
||||
from pymongo.errors import ConnectionFailure
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestHeartbeatMonitoring(AsyncIntegrationTest):
|
||||
async def create_mock_monitor(self, responses, uri, expected_results):
|
||||
listener = HeartbeatEventListener()
|
||||
with client_knobs(
|
||||
heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1
|
||||
):
|
||||
|
||||
class MockMonitor(Monitor):
|
||||
async def _check_with_socket(self, *args, **kwargs):
|
||||
if isinstance(responses[1], Exception):
|
||||
raise responses[1]
|
||||
return Hello(responses[1]), 99
|
||||
|
||||
_ = await self.async_single_client(
|
||||
h=uri,
|
||||
event_listeners=(listener,),
|
||||
_monitor_class=MockMonitor,
|
||||
_pool_class=AsyncMockPool,
|
||||
connect=True,
|
||||
)
|
||||
|
||||
expected_len = len(expected_results)
|
||||
# Wait for *at least* expected_len number of results. The
|
||||
# monitor thread may run multiple times during the execution
|
||||
# of this test.
|
||||
await async_wait_until(
|
||||
lambda: len(listener.events) >= expected_len, "publish all events"
|
||||
)
|
||||
|
||||
# zip gives us len(expected_results) pairs.
|
||||
for expected, actual in zip(expected_results, listener.events):
|
||||
self.assertEqual(expected, actual.__class__.__name__)
|
||||
self.assertEqual(actual.connection_id, responses[0])
|
||||
if expected != "ServerHeartbeatStartedEvent":
|
||||
if isinstance(actual.reply, Hello):
|
||||
self.assertEqual(actual.duration, 99)
|
||||
self.assertEqual(actual.reply._doc, responses[1])
|
||||
else:
|
||||
self.assertEqual(actual.reply, responses[1])
|
||||
|
||||
async def test_standalone(self):
|
||||
responses = (
|
||||
("a", 27017),
|
||||
{HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1},
|
||||
)
|
||||
uri = "mongodb://a:27017"
|
||||
expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"]
|
||||
|
||||
await self.create_mock_monitor(responses, uri, expected_results)
|
||||
|
||||
async def test_standalone_error(self):
|
||||
responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE"))
|
||||
uri = "mongodb://a:27017"
|
||||
# _check_with_socket failing results in a second attempt.
|
||||
expected_results = [
|
||||
"ServerHeartbeatStartedEvent",
|
||||
"ServerHeartbeatFailedEvent",
|
||||
"ServerHeartbeatStartedEvent",
|
||||
"ServerHeartbeatFailedEvent",
|
||||
]
|
||||
|
||||
await self.create_mock_monitor(responses, uri, expected_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
383
test/asynchronous/test_index_management.py
Normal file
383
test/asynchronous/test_index_management.py
Normal file
@ -0,0 +1,383 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the auth spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Mapping
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
from test.utils import AllowListEventListener, OvertCommandListener
|
||||
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
pytestmark = pytest.mark.index_management
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management")
|
||||
|
||||
_NAME = "test-search-index"
|
||||
|
||||
|
||||
class TestCreateSearchIndex(AsyncIntegrationTest):
|
||||
async def test_inputs(self):
|
||||
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
|
||||
raise unittest.SkipTest("Skipping index management tests")
|
||||
listener = AllowListEventListener("createSearchIndexes")
|
||||
client = self.simple_client(event_listeners=[listener])
|
||||
coll = client.test.test
|
||||
await coll.drop()
|
||||
definition = dict(mappings=dict(dynamic=True))
|
||||
model_kwarg_list: list[Mapping[str, Any]] = [
|
||||
dict(definition=definition, name=None),
|
||||
dict(definition=definition, name="test"),
|
||||
]
|
||||
for model_kwargs in model_kwarg_list:
|
||||
model = SearchIndexModel(**model_kwargs)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.create_search_index(model)
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.create_search_index(model_kwargs)
|
||||
|
||||
listener.reset()
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.create_search_index({"definition": definition, "arbitraryOption": 1})
|
||||
self.assertEqual(
|
||||
{"definition": definition, "arbitraryOption": 1},
|
||||
listener.events[0].command["indexes"][0],
|
||||
)
|
||||
|
||||
listener.reset()
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.create_search_index({"definition": definition, "type": "search"})
|
||||
self.assertEqual(
|
||||
{"definition": definition, "type": "search"}, listener.events[0].command["indexes"][0]
|
||||
)
|
||||
|
||||
|
||||
class SearchIndexIntegrationBase(AsyncPyMongoTestCase):
|
||||
db_name = "test_search_index_base"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
|
||||
raise unittest.SkipTest("Skipping index management tests")
|
||||
cls.url = os.environ.get("MONGODB_URI")
|
||||
cls.username = os.environ["DB_USER"]
|
||||
cls.password = os.environ["DB_PASSWORD"]
|
||||
cls.listener = OvertCommandListener()
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.client = self.simple_client(
|
||||
self.url,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
event_listeners=[self.listener],
|
||||
)
|
||||
await self.client.drop_database(_NAME)
|
||||
self.db = self.client[self.db_name]
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.client.drop_database(_NAME)
|
||||
|
||||
async def wait_for_ready(self, coll, name=_NAME, predicate=None):
|
||||
"""Wait for a search index to be ready."""
|
||||
indices: list[Mapping[str, Any]] = []
|
||||
if predicate is None:
|
||||
predicate = lambda index: index.get("queryable") is True
|
||||
|
||||
while True:
|
||||
indices = await (await coll.list_search_indexes(name)).to_list()
|
||||
if len(indices) and predicate(indices[0]):
|
||||
return indices[0]
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
class TestSearchIndexIntegration(SearchIndexIntegrationBase):
|
||||
db_name = "test_search_index"
|
||||
|
||||
async def test_comment_field(self):
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Create a new search index on ``coll0`` that implicitly passes its type.
|
||||
search_definition = {"mappings": {"dynamic": False}}
|
||||
self.listener.reset()
|
||||
implicit_search_resp = await coll0.create_search_index(
|
||||
model={"name": _NAME + "-implicit", "definition": search_definition}, comment="foo"
|
||||
)
|
||||
event = self.listener.events[0]
|
||||
self.assertEqual(event.command["comment"], "foo")
|
||||
|
||||
# Get the index definition.
|
||||
self.listener.reset()
|
||||
await (await coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next()
|
||||
event = self.listener.events[0]
|
||||
self.assertEqual(event.command["comment"], "foo")
|
||||
|
||||
|
||||
class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
db_name = "test_search_index_prose"
|
||||
|
||||
async def test_case_1(self):
|
||||
"""Driver can successfully create and list search indexes."""
|
||||
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
|
||||
# Create a new search index on ``coll0`` with the ``createSearchIndex`` helper. Use the following definition:
|
||||
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
|
||||
await coll0.insert_one({})
|
||||
resp = await coll0.create_search_index(model)
|
||||
|
||||
# Assert that the command returns the name of the index: ``"test-search-index"``.
|
||||
self.assertEqual(resp, _NAME)
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``:
|
||||
# An index with the ``name`` of ``test-search-index`` is present and the index has a field ``queryable`` with a value of ``true``.
|
||||
index = await self.wait_for_ready(coll0)
|
||||
|
||||
# . Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }``
|
||||
self.assertIn("latestDefinition", index)
|
||||
self.assertEqual(index["latestDefinition"], model["definition"])
|
||||
|
||||
async def test_case_2(self):
|
||||
"""Driver can successfully create multiple indexes in batch."""
|
||||
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Create two new search indexes on ``coll0`` with the ``createSearchIndexes`` helper.
|
||||
name1 = "test-search-index-1"
|
||||
name2 = "test-search-index-2"
|
||||
definition = {"mappings": {"dynamic": False}}
|
||||
index_definitions: list[dict[str, Any]] = [
|
||||
{"name": name1, "definition": definition},
|
||||
{"name": name2, "definition": definition},
|
||||
]
|
||||
await coll0.create_search_indexes(
|
||||
[SearchIndexModel(i["definition"], i["name"]) for i in index_definitions]
|
||||
)
|
||||
|
||||
# .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``.
|
||||
indices = await (await coll0.list_search_indexes()).to_list()
|
||||
names = [i["name"] for i in indices]
|
||||
self.assertIn(name1, names)
|
||||
self.assertIn(name2, names)
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied.
|
||||
# An index with the ``name`` of ``test-search-index-1`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index1``.
|
||||
# An index with the ``name`` of ``test-search-index-2`` is present and index has a field ``queryable`` with the value of ``true``. Store result in ``index2``.
|
||||
index1 = await self.wait_for_ready(coll0, name1)
|
||||
index2 = await self.wait_for_ready(coll0, name2)
|
||||
|
||||
# Assert that ``index1`` and ``index2`` have the property ``latestDefinition`` whose value is ``{ "mappings" : { "dynamic" : false } }``
|
||||
for index in [index1, index2]:
|
||||
self.assertIn("latestDefinition", index)
|
||||
self.assertEqual(index["latestDefinition"], definition)
|
||||
|
||||
async def test_case_3(self):
|
||||
"""Driver can successfully drop search indexes."""
|
||||
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Create a new search index on ``coll0``.
|
||||
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
|
||||
resp = await coll0.create_search_index(model)
|
||||
|
||||
# Assert that the command returns the name of the index: ``"test-search-index"``.
|
||||
self.assertEqual(resp, "test-search-index")
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
|
||||
# An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``.
|
||||
await self.wait_for_ready(coll0)
|
||||
|
||||
# Run a ``dropSearchIndex`` on ``coll0``, using ``test-search-index`` for the name.
|
||||
await coll0.drop_search_index(_NAME)
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array.
|
||||
t0 = time.time()
|
||||
while True:
|
||||
indices = await (await coll0.list_search_indexes()).to_list()
|
||||
if indices:
|
||||
break
|
||||
if (time.time() - t0) / 60 > 5:
|
||||
raise TimeoutError("Timed out waiting for index deletion")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def test_case_4(self):
|
||||
"""Driver can update a search index."""
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Create a new search index on ``coll0``.
|
||||
model = {"name": _NAME, "definition": {"mappings": {"dynamic": False}}}
|
||||
resp = await coll0.create_search_index(model)
|
||||
|
||||
# Assert that the command returns the name of the index: ``"test-search-index"``.
|
||||
self.assertEqual(resp, _NAME)
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
|
||||
# An index with the ``name`` of ``test-search-index`` is present and index has a field ``queryable`` with the value of ``true``.
|
||||
await self.wait_for_ready(coll0)
|
||||
|
||||
# Run a ``updateSearchIndex`` on ``coll0``.
|
||||
# Assert that the command does not error and the server responds with a success.
|
||||
model2: dict[str, Any] = {"name": _NAME, "definition": {"mappings": {"dynamic": True}}}
|
||||
await coll0.update_search_index(_NAME, model2["definition"])
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied:
|
||||
# An index with the ``name`` of ``test-search-index`` is present. This index is referred to as ``index``.
|
||||
# The index has a field ``queryable`` with a value of ``true`` and has a field ``status`` with the value of ``READY``.
|
||||
predicate = lambda index: index.get("queryable") is True and index.get("status") == "READY"
|
||||
await self.wait_for_ready(coll0, predicate=predicate)
|
||||
|
||||
# Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``.
|
||||
index = (await (await coll0.list_search_indexes(_NAME)).to_list())[0]
|
||||
self.assertIn("latestDefinition", index)
|
||||
self.assertEqual(index["latestDefinition"], model2["definition"])
|
||||
|
||||
async def test_case_5(self):
|
||||
"""``dropSearchIndex`` suppresses namespace not found errors."""
|
||||
# Create a driver-side collection object for a randomly generated collection name. Do not create this collection on the server.
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
|
||||
# Run a ``dropSearchIndex`` command and assert that no error is thrown.
|
||||
await coll0.drop_search_index("foo")
|
||||
|
||||
async def test_case_6(self):
|
||||
"""Driver can successfully create and list search indexes with non-default readConcern and writeConcern."""
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Apply a write concern ``WriteConcern(w=1)`` and a read concern with ``ReadConcern(level="majority")`` to ``coll0``.
|
||||
coll0 = coll0.with_options(
|
||||
write_concern=WriteConcern(w="1"), read_concern=ReadConcern(level="majority")
|
||||
)
|
||||
|
||||
# Create a new search index on ``coll0`` with the ``createSearchIndex`` helper.
|
||||
name = "test-search-index-case6"
|
||||
model = {"name": name, "definition": {"mappings": {"dynamic": False}}}
|
||||
resp = await coll0.create_search_index(model)
|
||||
|
||||
# Assert that the command returns the name of the index: ``"test-search-index-case6"``.
|
||||
self.assertEqual(resp, name)
|
||||
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until the following condition is satisfied and store the value in a variable ``index``:
|
||||
# - An index with the ``name`` of ``test-search-index-case6`` is present and the index has a field ``queryable`` with a value of ``true``.
|
||||
index = await self.wait_for_ready(coll0, name)
|
||||
|
||||
# Assert that ``index`` has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': false } }``
|
||||
self.assertIn("latestDefinition", index)
|
||||
self.assertEqual(index["latestDefinition"], model["definition"])
|
||||
|
||||
async def test_case_7(self):
|
||||
"""Driver handles index types."""
|
||||
|
||||
# Create a collection with the "create" command using a randomly generated name (referred to as ``coll0``).
|
||||
coll0 = self.db[f"col{uuid.uuid4()}"]
|
||||
await coll0.insert_one({})
|
||||
|
||||
# Use these search and vector search definitions for indexes.
|
||||
search_definition = {"mappings": {"dynamic": False}}
|
||||
vector_search_definition = {
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"path": "plot_embedding",
|
||||
"numDimensions": 1536,
|
||||
"similarity": "euclidean",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Create a new search index on ``coll0`` that implicitly passes its type.
|
||||
implicit_search_resp = await coll0.create_search_index(
|
||||
model={"name": _NAME + "-implicit", "definition": search_definition}
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = await (await coll0.list_search_indexes(name=implicit_search_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"search"``.
|
||||
self.assertEqual(resp["type"], "search")
|
||||
|
||||
# Create a new search index on ``coll0`` that explicitly passes its type.
|
||||
explicit_search_resp = await coll0.create_search_index(
|
||||
model={"name": _NAME + "-explicit", "type": "search", "definition": search_definition}
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = await (await coll0.list_search_indexes(name=explicit_search_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"search"``.
|
||||
self.assertEqual(resp["type"], "search")
|
||||
|
||||
# Create a new vector search index on ``coll0`` that explicitly passes its type.
|
||||
explicit_vector_resp = await coll0.create_search_index(
|
||||
model={
|
||||
"name": _NAME + "-vector",
|
||||
"type": "vectorSearch",
|
||||
"definition": vector_search_definition,
|
||||
}
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = await (await coll0.list_search_indexes(name=explicit_vector_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"vectorSearch"``.
|
||||
self.assertEqual(resp["type"], "vectorSearch")
|
||||
|
||||
# Catch the error raised when trying to create a vector search index without specifying the type
|
||||
with self.assertRaises(OperationFailure) as e:
|
||||
await coll0.create_search_index(
|
||||
model={"name": _NAME + "-error", "definition": vector_search_definition}
|
||||
)
|
||||
self.assertIn("Attribute mappings missing.", e.exception.details["errmsg"])
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
_TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
test/asynchronous/test_json_util_integration.py
Normal file
28
test/asynchronous/test_json_util_integration.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest
|
||||
from typing import Any, List, MutableMapping
|
||||
|
||||
from bson import Binary, Code, DBRef, ObjectId, json_util
|
||||
from bson.binary import USER_DEFINED_SUBTYPE
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestJsonUtilRoundtrip(AsyncIntegrationTest):
|
||||
async def test_cursor(self):
|
||||
db = self.db
|
||||
|
||||
await db.drop_collection("test")
|
||||
docs: List[MutableMapping[str, Any]] = [
|
||||
{"foo": [1, 2]},
|
||||
{"bar": {"hello": "world"}},
|
||||
{"code": Code("function x() { return 1; }")},
|
||||
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
|
||||
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
|
||||
]
|
||||
|
||||
await db.test.insert_many(docs)
|
||||
reloaded_docs = json_util.loads(json_util.dumps(await (db.test.find()).to_list()))
|
||||
for doc in docs:
|
||||
self.assertTrue(doc in reloaded_docs)
|
||||
149
test/asynchronous/test_max_staleness.py
Normal file
149
test/asynchronous/test_max_staleness.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Copyright 2016 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test maxStalenessSeconds support."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest
|
||||
from test.utils_selection_tests import create_selection_tests
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness")
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class TestMaxStaleness(AsyncPyMongoTestCase):
|
||||
async def test_max_staleness(self):
|
||||
client = self.simple_client()
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
# These tests are specified in max-staleness-tests.rst.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
# Default read pref "primary" can't be used with max staleness.
|
||||
self.simple_client("mongodb://a/?maxStalenessSeconds=120")
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
# Read pref "primary" can't be used with max staleness.
|
||||
self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120")
|
||||
|
||||
client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?readPreference=secondary&maxStalenessSeconds=120"
|
||||
)
|
||||
self.assertEqual(120, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1")
|
||||
self.assertEqual(1, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest")
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# Prohibit None.
|
||||
self.simple_client(maxStalenessSeconds=None, readPreference="nearest")
|
||||
|
||||
async def test_max_staleness_float(self):
|
||||
with self.assertRaises(TypeError) as ctx:
|
||||
await self.async_rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest")
|
||||
|
||||
self.assertIn("must be an integer", str(ctx.exception))
|
||||
|
||||
with warnings.catch_warnings(record=True) as ctx:
|
||||
warnings.simplefilter("always")
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest"
|
||||
)
|
||||
|
||||
# Option was ignored.
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
self.assertIn("must be an integer", str(ctx[0]))
|
||||
|
||||
async def test_max_staleness_zero(self):
|
||||
# Zero is too small.
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
await self.async_rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest")
|
||||
|
||||
self.assertIn("must be a positive integer", str(ctx.exception))
|
||||
|
||||
with warnings.catch_warnings(record=True) as ctx:
|
||||
warnings.simplefilter("always")
|
||||
client = self.simple_client(
|
||||
"mongodb://host/?maxStalenessSeconds=0&readPreference=nearest"
|
||||
)
|
||||
|
||||
# Option was ignored.
|
||||
self.assertEqual(-1, client.read_preference.max_staleness)
|
||||
self.assertIn("must be a positive integer", str(ctx[0]))
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
async def test_last_write_date(self):
|
||||
# From max-staleness-tests.rst, "Parse lastWriteDate".
|
||||
client = await self.async_rs_or_single_client(heartbeatFrequencyMS=500)
|
||||
await client.pymongo_test.test.insert_one({})
|
||||
# Wait for the server description to be updated.
|
||||
await asyncio.sleep(1)
|
||||
server = await client._topology.select_server(writable_server_selector, _Op.TEST)
|
||||
first = server.description.last_write_date
|
||||
self.assertTrue(first)
|
||||
# The first last_write_date may correspond to a internal server write,
|
||||
# sleep so that the next write does not occur within the same second.
|
||||
await asyncio.sleep(1)
|
||||
await client.pymongo_test.test.insert_one({})
|
||||
# Wait for the server description to be updated.
|
||||
await asyncio.sleep(1)
|
||||
server = await client._topology.select_server(writable_server_selector, _Op.TEST)
|
||||
second = server.description.last_write_date
|
||||
assert first is not None
|
||||
|
||||
assert second is not None
|
||||
self.assertGreater(second, first)
|
||||
self.assertLess(second, first + 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
115
test/asynchronous/test_on_demand_csfle.py
Normal file
115
test/asynchronous/test_on_demand_csfle.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2022-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test client side encryption with on demand credentials."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context
|
||||
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.asynchronous.encryption import (
|
||||
_HAVE_PYMONGOCRYPT,
|
||||
AsyncClientEncryption,
|
||||
EncryptionError,
|
||||
)
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
pytestmark = pytest.mark.csfle
|
||||
|
||||
|
||||
class TestonDemandGCPCredentials(AsyncIntegrationTest):
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@async_client_context.require_version_min(4, 2, -1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.master_key = {
|
||||
"projectId": "devprod-drivers",
|
||||
"location": "global",
|
||||
"keyRing": "key-ring-csfle",
|
||||
"keyName": "key-name-csfle",
|
||||
}
|
||||
|
||||
@unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto")
|
||||
async def test_01_failure(self):
|
||||
if os.environ["SUCCESS"].lower() == "true":
|
||||
self.skipTest("Expecting success")
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
kms_providers={"gcp": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
codec_options=CodecOptions(),
|
||||
)
|
||||
with self.assertRaises(EncryptionError):
|
||||
await self.client_encryption.create_data_key("gcp", self.master_key)
|
||||
|
||||
@unittest.skipIf(not os.getenv("TEST_FLE_GCP_AUTO"), "Not testing FLE GCP auto")
|
||||
async def test_02_success(self):
|
||||
if os.environ["SUCCESS"].lower() == "false":
|
||||
self.skipTest("Expecting failure")
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
kms_providers={"gcp": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
codec_options=CodecOptions(),
|
||||
)
|
||||
await self.client_encryption.create_data_key("gcp", self.master_key)
|
||||
|
||||
|
||||
class TestonDemandAzureCredentials(AsyncIntegrationTest):
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@async_client_context.require_version_min(4, 2, -1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.master_key = {
|
||||
"keyVaultEndpoint": os.environ["KEY_VAULT_ENDPOINT"],
|
||||
"keyName": os.environ["KEY_NAME"],
|
||||
}
|
||||
|
||||
@unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto")
|
||||
async def test_01_failure(self):
|
||||
if os.environ["SUCCESS"].lower() == "true":
|
||||
self.skipTest("Expecting success")
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
kms_providers={"azure": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
codec_options=CodecOptions(),
|
||||
)
|
||||
with self.assertRaises(EncryptionError):
|
||||
await self.client_encryption.create_data_key("azure", self.master_key)
|
||||
|
||||
@unittest.skipIf(not os.getenv("TEST_FLE_AZURE_AUTO"), "Not testing FLE Azure auto")
|
||||
async def test_02_success(self):
|
||||
if os.environ["SUCCESS"].lower() == "false":
|
||||
self.skipTest("Expecting failure")
|
||||
self.client_encryption = AsyncClientEncryption(
|
||||
kms_providers={"azure": {}},
|
||||
key_vault_namespace="keyvault.datakeys",
|
||||
key_vault_client=async_client_context.client,
|
||||
codec_options=CodecOptions(),
|
||||
)
|
||||
await self.client_encryption.create_data_key("azure", self.master_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
122
test/asynchronous/test_read_concern.py
Normal file
122
test/asynchronous/test_read_concern.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the read_concern module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context
|
||||
from test.utils import OvertCommandListener
|
||||
|
||||
from bson.son import SON
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestReadConcern(AsyncIntegrationTest):
|
||||
listener: OvertCommandListener
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.listener = OvertCommandListener()
|
||||
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
|
||||
self.db = self.client.pymongo_test
|
||||
await async_client_context.client.pymongo_test.create_collection("coll")
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await async_client_context.client.pymongo_test.drop_collection("coll")
|
||||
|
||||
def test_read_concern(self):
|
||||
rc = ReadConcern()
|
||||
self.assertIsNone(rc.level)
|
||||
self.assertTrue(rc.ok_for_legacy)
|
||||
|
||||
rc = ReadConcern("majority")
|
||||
self.assertEqual("majority", rc.level)
|
||||
self.assertFalse(rc.ok_for_legacy)
|
||||
|
||||
rc = ReadConcern("local")
|
||||
self.assertEqual("local", rc.level)
|
||||
self.assertTrue(rc.ok_for_legacy)
|
||||
|
||||
self.assertRaises(TypeError, ReadConcern, 42)
|
||||
|
||||
async def test_read_concern_uri(self):
|
||||
uri = f"mongodb://{await async_client_context.pair}/?readConcernLevel=majority"
|
||||
client = await self.async_rs_or_single_client(uri, connect=False)
|
||||
self.assertEqual(ReadConcern("majority"), client.read_concern)
|
||||
|
||||
async def test_invalid_read_concern(self):
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("unknown"))
|
||||
# We rely on the server to validate read concern.
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.find_one()
|
||||
|
||||
async def test_find_command(self):
|
||||
# readConcern not sent in command if not specified.
|
||||
coll = self.db.coll
|
||||
await coll.find({"field": "value"}).to_list()
|
||||
self.assertNotIn("readConcern", self.listener.started_events[0].command)
|
||||
|
||||
self.listener.reset()
|
||||
|
||||
# Explicitly set readConcern to 'local'.
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
await coll.find({"field": "value"}).to_list()
|
||||
self.assertEqualCommand(
|
||||
SON(
|
||||
[
|
||||
("find", "coll"),
|
||||
("filter", {"field": "value"}),
|
||||
("readConcern", {"level": "local"}),
|
||||
]
|
||||
),
|
||||
self.listener.started_events[0].command,
|
||||
)
|
||||
|
||||
async def test_command_cursor(self):
|
||||
# readConcern not sent in command if not specified.
|
||||
coll = self.db.coll
|
||||
await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list()
|
||||
self.assertNotIn("readConcern", self.listener.started_events[0].command)
|
||||
|
||||
self.listener.reset()
|
||||
|
||||
# Explicitly set readConcern to 'local'.
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
await (await coll.aggregate([{"$match": {"field": "value"}}])).to_list()
|
||||
self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"])
|
||||
|
||||
async def test_aggregate_out(self):
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
await (
|
||||
await coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])
|
||||
).to_list()
|
||||
|
||||
# Aggregate with $out supports readConcern MongoDB 4.2 onwards.
|
||||
if async_client_context.version >= (4, 1):
|
||||
self.assertIn("readConcern", self.listener.started_events[0].command)
|
||||
else:
|
||||
self.assertNotIn("readConcern", self.listener.started_events[0].command)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
730
test/asynchronous/test_read_preferences.py
Normal file
730
test/asynchronous/test_read_preferences.py
Normal file
@ -0,0 +1,730 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the replica_set_connection module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
AsyncIntegrationTest,
|
||||
SkipTest,
|
||||
async_client_context,
|
||||
connected,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
async_wait_until,
|
||||
one,
|
||||
)
|
||||
from test.version import Version
|
||||
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.message import _maybe_add_read_preference
|
||||
from pymongo.read_preferences import (
|
||||
MovingAverage,
|
||||
Nearest,
|
||||
Primary,
|
||||
PrimaryPreferred,
|
||||
ReadPreference,
|
||||
Secondary,
|
||||
SecondaryPreferred,
|
||||
)
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import Selection, readable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestSelections(AsyncIntegrationTest):
|
||||
@async_client_context.require_connection
|
||||
async def test_bool(self):
|
||||
client = await self.async_single_client()
|
||||
|
||||
async def predicate():
|
||||
return await client.address
|
||||
|
||||
await async_wait_until(predicate, "discover primary")
|
||||
selection = Selection.from_topology_description(client._topology.description)
|
||||
|
||||
self.assertTrue(selection)
|
||||
self.assertFalse(selection.with_server_descriptions([]))
|
||||
|
||||
|
||||
class TestReadPreferenceObjects(unittest.TestCase):
|
||||
prefs = [
|
||||
Primary(),
|
||||
PrimaryPreferred(),
|
||||
Secondary(),
|
||||
Nearest(tag_sets=[{"a": 1}, {"b": 2}]),
|
||||
SecondaryPreferred(max_staleness=30),
|
||||
]
|
||||
|
||||
def test_pickle(self):
|
||||
for pref in self.prefs:
|
||||
self.assertEqual(pref, pickle.loads(pickle.dumps(pref)))
|
||||
|
||||
def test_copy(self):
|
||||
for pref in self.prefs:
|
||||
self.assertEqual(pref, copy.copy(pref))
|
||||
|
||||
def test_deepcopy(self):
|
||||
for pref in self.prefs:
|
||||
self.assertEqual(pref, copy.deepcopy(pref))
|
||||
|
||||
|
||||
class TestReadPreferencesBase(AsyncIntegrationTest):
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# Insert some data so we can use cursors in read_from_which_host
|
||||
await self.client.pymongo_test.test.drop()
|
||||
await self.client.get_database(
|
||||
"pymongo_test", write_concern=WriteConcern(w=async_client_context.w)
|
||||
).test.insert_many([{"_id": i} for i in range(10)])
|
||||
|
||||
self.addAsyncCleanup(self.client.pymongo_test.test.drop)
|
||||
|
||||
async def read_from_which_host(self, client):
|
||||
"""Do a find() on the client and return which host was used"""
|
||||
cursor = client.pymongo_test.test.find()
|
||||
await anext(cursor)
|
||||
return cursor.address
|
||||
|
||||
async def read_from_which_kind(self, client):
|
||||
"""Do a find() on the client and return 'primary' or 'secondary'
|
||||
depending on which the client used.
|
||||
"""
|
||||
address = await self.read_from_which_host(client)
|
||||
if address == await client.primary:
|
||||
return "primary"
|
||||
elif address in await client.secondaries:
|
||||
return "secondary"
|
||||
else:
|
||||
self.fail(
|
||||
f"Cursor used address {address}, expected either primary "
|
||||
f"{client.primary} or secondaries {client.secondaries}"
|
||||
)
|
||||
|
||||
async def assertReadsFrom(self, expected, **kwargs):
|
||||
c = await self.async_rs_client(**kwargs)
|
||||
|
||||
async def predicate():
|
||||
return len(c.nodes - await c.arbiters) == async_client_context.w
|
||||
|
||||
await async_wait_until(predicate, "discovered all nodes")
|
||||
|
||||
used = await self.read_from_which_kind(c)
|
||||
self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}")
|
||||
|
||||
|
||||
class TestSingleSecondaryOk(TestReadPreferencesBase):
|
||||
async def test_reads_from_secondary(self):
|
||||
host, port = next(iter(await self.client.secondaries))
|
||||
# Direct connection to a secondary.
|
||||
client = await self.async_single_client(host, port)
|
||||
self.assertFalse(await client.is_primary)
|
||||
|
||||
# Regardless of read preference, we should be able to do
|
||||
# "reads" with a direct connection to a secondary.
|
||||
# See server-selection.rst#topology-type-single.
|
||||
self.assertEqual(client.read_preference, ReadPreference.PRIMARY)
|
||||
|
||||
db = client.pymongo_test
|
||||
coll = db.test
|
||||
|
||||
# Test find and find_one.
|
||||
self.assertIsNotNone(await coll.find_one())
|
||||
self.assertEqual(10, len(await coll.find().to_list()))
|
||||
|
||||
# Test some database helpers.
|
||||
self.assertIsNotNone(await db.list_collection_names())
|
||||
self.assertIsNotNone(await db.validate_collection("test"))
|
||||
self.assertIsNotNone(await db.command("ping"))
|
||||
|
||||
# Test some collection helpers.
|
||||
self.assertEqual(10, await coll.count_documents({}))
|
||||
self.assertEqual(10, len(await coll.distinct("_id")))
|
||||
self.assertIsNotNone(await coll.aggregate([]))
|
||||
self.assertIsNotNone(await coll.index_information())
|
||||
|
||||
|
||||
class TestReadPreferences(TestReadPreferencesBase):
|
||||
async def test_mode_validation(self):
|
||||
for mode in (
|
||||
ReadPreference.PRIMARY,
|
||||
ReadPreference.PRIMARY_PREFERRED,
|
||||
ReadPreference.SECONDARY,
|
||||
ReadPreference.SECONDARY_PREFERRED,
|
||||
ReadPreference.NEAREST,
|
||||
):
|
||||
self.assertEqual(
|
||||
mode, (await self.async_rs_client(read_preference=mode)).read_preference
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
await self.async_rs_client(read_preference="foo")
|
||||
|
||||
async def test_tag_sets_validation(self):
|
||||
S = Secondary(tag_sets=[{}])
|
||||
self.assertEqual(
|
||||
[{}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets
|
||||
)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}])
|
||||
self.assertEqual(
|
||||
[{"k": "v"}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets
|
||||
)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}, {}])
|
||||
self.assertEqual(
|
||||
[{"k": "v"}, {}],
|
||||
(await self.async_rs_client(read_preference=S)).read_preference.tag_sets,
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, Secondary, tag_sets=[])
|
||||
|
||||
# One dict not ok, must be a list of dicts
|
||||
self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"})
|
||||
|
||||
self.assertRaises(TypeError, Secondary, tag_sets="foo")
|
||||
|
||||
self.assertRaises(TypeError, Secondary, tag_sets=["foo"])
|
||||
|
||||
async def test_threshold_validation(self):
|
||||
self.assertEqual(
|
||||
17,
|
||||
(
|
||||
await self.async_rs_client(localThresholdMS=17, connect=False)
|
||||
).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
42,
|
||||
(
|
||||
await self.async_rs_client(localThresholdMS=42, connect=False)
|
||||
).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
666,
|
||||
(
|
||||
await self.async_rs_client(localThresholdMS=666, connect=False)
|
||||
).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
0,
|
||||
(
|
||||
await self.async_rs_client(localThresholdMS=0, connect=False)
|
||||
).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
await self.async_rs_client(localthresholdms=-1)
|
||||
|
||||
async def test_zero_latency(self):
|
||||
ping_times: set = set()
|
||||
# Generate unique ping times.
|
||||
while len(ping_times) < len(self.client.nodes):
|
||||
ping_times.add(random.random())
|
||||
for ping_time, host in zip(ping_times, self.client.nodes):
|
||||
ServerDescription._host_to_round_trip_time[host] = ping_time
|
||||
try:
|
||||
client = await connected(
|
||||
await self.async_rs_client(readPreference="nearest", localThresholdMS=0)
|
||||
)
|
||||
await async_wait_until(
|
||||
lambda: client.nodes == self.client.nodes, "discovered all nodes"
|
||||
)
|
||||
host = await self.read_from_which_host(client)
|
||||
for _ in range(5):
|
||||
self.assertEqual(host, await self.read_from_which_host(client))
|
||||
finally:
|
||||
ServerDescription._host_to_round_trip_time.clear()
|
||||
|
||||
async def test_primary(self):
|
||||
await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY)
|
||||
|
||||
async def test_primary_with_tags(self):
|
||||
# Tags not allowed with PRIMARY
|
||||
with self.assertRaises(ConfigurationError):
|
||||
await self.async_rs_client(tag_sets=[{"dc": "ny"}])
|
||||
|
||||
async def test_primary_preferred(self):
|
||||
await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
|
||||
|
||||
async def test_secondary(self):
|
||||
await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY)
|
||||
|
||||
async def test_secondary_preferred(self):
|
||||
await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED)
|
||||
|
||||
async def test_nearest(self):
|
||||
# With high localThresholdMS, expect to read from any
|
||||
# member
|
||||
c = await self.async_rs_client(
|
||||
read_preference=ReadPreference.NEAREST, localThresholdMS=10000
|
||||
) # 10 seconds
|
||||
|
||||
data_members = {await self.client.primary} | await self.client.secondaries
|
||||
|
||||
# This is a probabilistic test; track which members we've read from so
|
||||
# far, and keep reading until we've used all the members or give up.
|
||||
# Chance of using only 2 of 3 members 10k times if there's no bug =
|
||||
# 3 * (2/3)**10000, very low.
|
||||
used: set = set()
|
||||
i = 0
|
||||
while data_members.difference(used) and i < 10000:
|
||||
address = await self.read_from_which_host(c)
|
||||
used.add(address)
|
||||
i += 1
|
||||
|
||||
not_used = data_members.difference(used)
|
||||
latencies = ", ".join(
|
||||
"%s: %sms" % (server.description.address, server.description.round_trip_time)
|
||||
for server in await (await c._get_topology()).select_servers(
|
||||
readable_server_selector, _Op.TEST
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
not_used,
|
||||
"Expected to use primary and all secondaries for mode NEAREST,"
|
||||
f" but didn't use {not_used}\nlatencies: {latencies}",
|
||||
)
|
||||
|
||||
|
||||
class ReadPrefTester(AsyncMongoClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.has_read_from = set()
|
||||
client_options = async_client_context.client_options
|
||||
client_options.update(kwargs)
|
||||
super().__init__(*args, **client_options)
|
||||
|
||||
async def _conn_for_reads(self, read_preference, session, operation):
|
||||
context = await super()._conn_for_reads(read_preference, session, operation)
|
||||
return context
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _conn_from_server(self, read_preference, server, session):
|
||||
context = super()._conn_from_server(read_preference, server, session)
|
||||
async with context as (conn, read_preference):
|
||||
await self.record_a_read(conn.address)
|
||||
yield conn, read_preference
|
||||
|
||||
async def record_a_read(self, address):
|
||||
server = await (await self._get_topology()).select_server_by_address(address, _Op.TEST, 0)
|
||||
self.has_read_from.add(server)
|
||||
|
||||
|
||||
_PREF_MAP = [
|
||||
(Primary, SERVER_TYPE.RSPrimary),
|
||||
(PrimaryPreferred, SERVER_TYPE.RSPrimary),
|
||||
(Secondary, SERVER_TYPE.RSSecondary),
|
||||
(SecondaryPreferred, SERVER_TYPE.RSSecondary),
|
||||
(Nearest, "any"),
|
||||
]
|
||||
|
||||
|
||||
class TestCommandAndReadPreference(AsyncIntegrationTest):
|
||||
c: ReadPrefTester
|
||||
client_version: Version
|
||||
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.c = ReadPrefTester(
|
||||
# Ignore round trip times, to test ReadPreference modes only.
|
||||
localThresholdMS=1000 * 1000,
|
||||
)
|
||||
self.client_version = await Version.async_from_client(self.c)
|
||||
# mapReduce fails if the collection does not exist.
|
||||
coll = self.c.pymongo_test.get_collection(
|
||||
"test", write_concern=WriteConcern(w=async_client_context.w)
|
||||
)
|
||||
await coll.insert_one({})
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.c.drop_database("pymongo_test")
|
||||
await self.c.close()
|
||||
|
||||
async def executed_on_which_server(self, client, fn, *args, **kwargs):
|
||||
"""Execute fn(*args, **kwargs) and return the Server instance used."""
|
||||
client.has_read_from.clear()
|
||||
await fn(*args, **kwargs)
|
||||
self.assertEqual(1, len(client.has_read_from))
|
||||
return one(client.has_read_from)
|
||||
|
||||
async def assertExecutedOn(self, server_type, client, fn, *args, **kwargs):
|
||||
server = await self.executed_on_which_server(client, fn, *args, **kwargs)
|
||||
self.assertEqual(
|
||||
SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type]
|
||||
)
|
||||
|
||||
async def _test_fn(self, server_type, fn):
|
||||
for _ in range(10):
|
||||
if server_type == "any":
|
||||
used = set()
|
||||
for _ in range(1000):
|
||||
server = await self.executed_on_which_server(self.c, fn)
|
||||
used.add(server.description.address)
|
||||
if len(used) == len(await self.c.secondaries) + 1:
|
||||
# Success
|
||||
break
|
||||
|
||||
assert await self.c.primary is not None
|
||||
unused = (await self.c.secondaries).union({await self.c.primary}).difference(used)
|
||||
if unused:
|
||||
self.fail("Some members not used for NEAREST: %s" % (unused))
|
||||
else:
|
||||
await self.assertExecutedOn(server_type, self.c, fn)
|
||||
|
||||
async def _test_primary_helper(self, func):
|
||||
# Helpers that ignore read preference.
|
||||
await self._test_fn(SERVER_TYPE.RSPrimary, func)
|
||||
|
||||
async def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs):
|
||||
for mode, server_type in _PREF_MAP:
|
||||
new_coll = coll.with_options(read_preference=mode())
|
||||
|
||||
async def func():
|
||||
return await getattr(new_coll, meth)(*args, **kwargs)
|
||||
|
||||
if secondary_ok:
|
||||
await self._test_fn(server_type, func)
|
||||
else:
|
||||
await self._test_fn(SERVER_TYPE.RSPrimary, func)
|
||||
|
||||
async def test_command(self):
|
||||
# Test that the generic command helper obeys the read preference
|
||||
# passed to it.
|
||||
for mode, server_type in _PREF_MAP:
|
||||
|
||||
async def func():
|
||||
return await self.c.pymongo_test.command("dbStats", read_preference=mode())
|
||||
|
||||
await self._test_fn(server_type, func)
|
||||
|
||||
async def test_create_collection(self):
|
||||
# create_collection runs listCollections on the primary to check if
|
||||
# the collection already exists.
|
||||
async def func():
|
||||
return await self.c.pymongo_test.create_collection(
|
||||
"some_collection%s" % random.randint(0, sys.maxsize)
|
||||
)
|
||||
|
||||
await self._test_primary_helper(func)
|
||||
|
||||
async def test_count_documents(self):
|
||||
await self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {})
|
||||
|
||||
async def test_estimated_document_count(self):
|
||||
await self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count")
|
||||
|
||||
async def test_distinct(self):
|
||||
await self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a")
|
||||
|
||||
async def test_aggregate(self):
|
||||
await self._test_coll_helper(
|
||||
True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}]
|
||||
)
|
||||
|
||||
async def test_aggregate_write(self):
|
||||
# 5.0 servers support $out on secondaries.
|
||||
secondary_ok = async_client_context.version.at_least(5, 0)
|
||||
await self._test_coll_helper(
|
||||
secondary_ok,
|
||||
self.c.pymongo_test.test,
|
||||
"aggregate",
|
||||
[{"$project": {"_id": 1}}, {"$out": "agg_write_test"}],
|
||||
)
|
||||
|
||||
|
||||
class TestMovingAverage(unittest.TestCase):
|
||||
def test_moving_average(self):
|
||||
avg = MovingAverage()
|
||||
self.assertIsNone(avg.get())
|
||||
avg.add_sample(10)
|
||||
self.assertAlmostEqual(10, avg.get()) # type: ignore
|
||||
avg.add_sample(20)
|
||||
self.assertAlmostEqual(12, avg.get()) # type: ignore
|
||||
avg.add_sample(30)
|
||||
self.assertAlmostEqual(15.6, avg.get()) # type: ignore
|
||||
|
||||
|
||||
class TestMongosAndReadPreference(AsyncIntegrationTest):
|
||||
def test_read_preference_document(self):
|
||||
pref = Primary()
|
||||
self.assertEqual(pref.document, {"mode": "primary"})
|
||||
|
||||
pref = PrimaryPreferred()
|
||||
self.assertEqual(pref.document, {"mode": "primaryPreferred"})
|
||||
pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}])
|
||||
self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]})
|
||||
pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30)
|
||||
self.assertEqual(
|
||||
pref.document,
|
||||
{"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30},
|
||||
)
|
||||
|
||||
pref = Secondary()
|
||||
self.assertEqual(pref.document, {"mode": "secondary"})
|
||||
pref = Secondary(tag_sets=[{"dc": "sf"}])
|
||||
self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]})
|
||||
pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30)
|
||||
self.assertEqual(
|
||||
pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}
|
||||
)
|
||||
|
||||
pref = SecondaryPreferred()
|
||||
self.assertEqual(pref.document, {"mode": "secondaryPreferred"})
|
||||
pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}])
|
||||
self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]})
|
||||
pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30)
|
||||
self.assertEqual(
|
||||
pref.document,
|
||||
{"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30},
|
||||
)
|
||||
|
||||
pref = Nearest()
|
||||
self.assertEqual(pref.document, {"mode": "nearest"})
|
||||
pref = Nearest(tag_sets=[{"dc": "sf"}])
|
||||
self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]})
|
||||
pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30)
|
||||
self.assertEqual(
|
||||
pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# Float is prohibited.
|
||||
Nearest(max_staleness=1.5) # type: ignore
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Nearest(max_staleness=0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Nearest(max_staleness=-2)
|
||||
|
||||
def test_read_preference_document_hedge(self):
|
||||
cases = {
|
||||
"primaryPreferred": PrimaryPreferred,
|
||||
"secondary": Secondary,
|
||||
"secondaryPreferred": SecondaryPreferred,
|
||||
"nearest": Nearest,
|
||||
}
|
||||
for mode, cls in cases.items():
|
||||
with self.assertRaises(TypeError):
|
||||
cls(hedge=[]) # type: ignore
|
||||
|
||||
pref = cls(hedge={})
|
||||
self.assertEqual(pref.document, {"mode": mode})
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
if cls == SecondaryPreferred:
|
||||
# SecondaryPreferred without hedge doesn't add $readPreference.
|
||||
self.assertEqual(out, {})
|
||||
else:
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
hedge: dict[str, Any] = {"enabled": True}
|
||||
pref = cls(hedge=hedge)
|
||||
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
hedge = {"enabled": False}
|
||||
pref = cls(hedge=hedge)
|
||||
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
hedge = {"enabled": False, "extra": "option"}
|
||||
pref = cls(hedge=hedge)
|
||||
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
async def test_send_hedge(self):
|
||||
cases = {
|
||||
"primaryPreferred": PrimaryPreferred,
|
||||
"secondaryPreferred": SecondaryPreferred,
|
||||
"nearest": Nearest,
|
||||
}
|
||||
if await async_client_context.supports_secondary_read_pref:
|
||||
cases["secondary"] = Secondary
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_client(event_listeners=[listener])
|
||||
await client.admin.command("ping")
|
||||
for _mode, cls in cases.items():
|
||||
pref = cls(hedge={"enabled": True})
|
||||
coll = client.test.get_collection("test", read_preference=pref)
|
||||
listener.reset()
|
||||
await coll.find_one()
|
||||
started = listener.started_events
|
||||
self.assertEqual(len(started), 1, started)
|
||||
cmd = started[0].command
|
||||
if async_client_context.is_rs or async_client_context.is_mongos:
|
||||
self.assertIn("$readPreference", cmd)
|
||||
self.assertEqual(cmd["$readPreference"], pref.document)
|
||||
else:
|
||||
self.assertNotIn("$readPreference", cmd)
|
||||
|
||||
def test_maybe_add_read_preference(self):
|
||||
# Primary doesn't add $readPreference
|
||||
out = _maybe_add_read_preference({}, Primary())
|
||||
self.assertEqual(out, {})
|
||||
|
||||
pref = PrimaryPreferred()
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}])
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
pref = Secondary()
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
pref = Secondary(tag_sets=[{"dc": "nyc"}])
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
# SecondaryPreferred without tag_sets or max_staleness doesn't add
|
||||
# $readPreference
|
||||
pref = SecondaryPreferred()
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, {})
|
||||
pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}])
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
pref = SecondaryPreferred(max_staleness=120)
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
pref = Nearest()
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
pref = Nearest(tag_sets=[{"dc": "nyc"}])
|
||||
out = _maybe_add_read_preference({}, pref)
|
||||
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
|
||||
|
||||
criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))])
|
||||
pref = Nearest()
|
||||
out = _maybe_add_read_preference(criteria, pref)
|
||||
self.assertEqual(
|
||||
out,
|
||||
SON(
|
||||
[
|
||||
("$query", {}),
|
||||
("$orderby", SON([("_id", 1)])),
|
||||
("$readPreference", pref.document),
|
||||
]
|
||||
),
|
||||
)
|
||||
pref = Nearest(tag_sets=[{"dc": "nyc"}])
|
||||
out = _maybe_add_read_preference(criteria, pref)
|
||||
self.assertEqual(
|
||||
out,
|
||||
SON(
|
||||
[
|
||||
("$query", {}),
|
||||
("$orderby", SON([("_id", 1)])),
|
||||
("$readPreference", pref.document),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
@async_client_context.require_mongos
|
||||
async def test_mongos(self):
|
||||
res = await async_client_context.client.config.shards.find_one()
|
||||
assert res is not None
|
||||
shard = res["host"]
|
||||
num_members = shard.count(",") + 1
|
||||
if num_members == 1:
|
||||
raise SkipTest("Need a replica set shard to test.")
|
||||
coll = async_client_context.client.pymongo_test.get_collection(
|
||||
"test", write_concern=WriteConcern(w=num_members)
|
||||
)
|
||||
await coll.drop()
|
||||
res = await coll.insert_many([{} for _ in range(5)])
|
||||
first_id = res.inserted_ids[0]
|
||||
last_id = res.inserted_ids[-1]
|
||||
|
||||
# Note - this isn't a perfect test since there's no way to
|
||||
# tell what shard member a query ran on.
|
||||
for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()):
|
||||
qcoll = coll.with_options(read_preference=pref)
|
||||
results = await qcoll.find().sort([("_id", 1)]).to_list()
|
||||
self.assertEqual(first_id, results[0]["_id"])
|
||||
self.assertEqual(last_id, results[-1]["_id"])
|
||||
results = await qcoll.find().sort([("_id", -1)]).to_list()
|
||||
self.assertEqual(first_id, results[-1]["_id"])
|
||||
self.assertEqual(last_id, results[0]["_id"])
|
||||
|
||||
@async_client_context.require_mongos
|
||||
async def test_mongos_max_staleness(self):
|
||||
# Sanity check that we're sending maxStalenessSeconds
|
||||
coll = async_client_context.client.pymongo_test.get_collection(
|
||||
"test", read_preference=SecondaryPreferred(max_staleness=120)
|
||||
)
|
||||
# No error
|
||||
await coll.find_one()
|
||||
|
||||
coll = async_client_context.client.pymongo_test.get_collection(
|
||||
"test", read_preference=SecondaryPreferred(max_staleness=10)
|
||||
)
|
||||
try:
|
||||
await coll.find_one()
|
||||
except OperationFailure as exc:
|
||||
self.assertEqual(160, exc.code)
|
||||
else:
|
||||
self.fail("mongos accepted invalid staleness")
|
||||
|
||||
coll = (
|
||||
await self.async_single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=120
|
||||
)
|
||||
).pymongo_test.test
|
||||
# No error
|
||||
await coll.find_one()
|
||||
|
||||
coll = (
|
||||
await self.async_single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=10
|
||||
)
|
||||
).pymongo_test.test
|
||||
try:
|
||||
await coll.find_one()
|
||||
except OperationFailure as exc:
|
||||
self.assertEqual(160, exc.code)
|
||||
else:
|
||||
self.fail("mongos accepted invalid staleness")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
344
test/asynchronous/test_read_write_concern_spec.py
Normal file
344
test/asynchronous/test_read_write_concern_spec.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2018-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the read and write concern tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
from test.utils import OvertCommandListener
|
||||
|
||||
from pymongo import DESCENDING
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.errors import (
|
||||
BulkWriteError,
|
||||
ConfigurationError,
|
||||
WriteConcernError,
|
||||
WriteError,
|
||||
WTimeoutError,
|
||||
)
|
||||
from pymongo.operations import IndexModel, InsertOne
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern")
|
||||
|
||||
|
||||
class TestReadWriteConcernSpec(AsyncIntegrationTest):
|
||||
async def test_omit_default_read_write_concern(self):
|
||||
listener = OvertCommandListener()
|
||||
# Client with default readConcern and writeConcern
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
await collection.insert_many([{} for _ in range(10)])
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
self.addAsyncCleanup(client.pymongo_test.collection2.drop)
|
||||
# Commands MUST NOT send the default read/write concern to the server.
|
||||
|
||||
async def rename_and_drop():
|
||||
# Ensure collection exists.
|
||||
await collection.insert_one({})
|
||||
await collection.rename("collection2")
|
||||
await client.pymongo_test.collection2.drop()
|
||||
|
||||
async def insert_command_default_write_concern():
|
||||
await collection.database.command(
|
||||
"insert", "collection", documents=[{}], write_concern=WriteConcern()
|
||||
)
|
||||
|
||||
async def aggregate_op():
|
||||
await (await collection.aggregate([])).to_list()
|
||||
|
||||
ops = [
|
||||
("aggregate", aggregate_op),
|
||||
("find", lambda: collection.find().to_list()),
|
||||
("insert_one", lambda: collection.insert_one({})),
|
||||
("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})),
|
||||
("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})),
|
||||
("delete_one", lambda: collection.delete_one({})),
|
||||
("delete_many", lambda: collection.delete_many({})),
|
||||
("bulk_write", lambda: collection.bulk_write([InsertOne({})])),
|
||||
("rename_and_drop", rename_and_drop),
|
||||
("command", insert_command_default_write_concern),
|
||||
]
|
||||
|
||||
for name, f in ops:
|
||||
listener.reset()
|
||||
await f()
|
||||
|
||||
self.assertGreaterEqual(len(listener.started_events), 1)
|
||||
for _i, event in enumerate(listener.started_events):
|
||||
self.assertNotIn(
|
||||
"readConcern",
|
||||
event.command,
|
||||
f"{name} sent default readConcern with {event.command_name}",
|
||||
)
|
||||
self.assertNotIn(
|
||||
"writeConcern",
|
||||
event.command,
|
||||
f"{name} sent default writeConcern with {event.command_name}",
|
||||
)
|
||||
|
||||
async def assertWriteOpsRaise(self, write_concern, expected_exception):
|
||||
wc = write_concern.document
|
||||
# Set socket timeout to avoid indefinite stalls
|
||||
client = await self.async_rs_or_single_client(
|
||||
w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000
|
||||
)
|
||||
db = client.get_database("pymongo_test")
|
||||
coll = db.test
|
||||
|
||||
async def insert_command():
|
||||
await coll.database.command(
|
||||
"insert",
|
||||
"new_collection",
|
||||
documents=[{}],
|
||||
writeConcern=write_concern.document,
|
||||
parse_write_concern_error=True,
|
||||
)
|
||||
|
||||
ops = [
|
||||
("insert_one", lambda: coll.insert_one({})),
|
||||
("insert_many", lambda: coll.insert_many([{}, {}])),
|
||||
("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})),
|
||||
("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})),
|
||||
("delete_one", lambda: coll.delete_one({})),
|
||||
("delete_many", lambda: coll.delete_many({})),
|
||||
("bulk_write", lambda: coll.bulk_write([InsertOne({})])),
|
||||
("command", insert_command),
|
||||
("aggregate", lambda: coll.aggregate([{"$out": "out"}])),
|
||||
# SERVER-46668 Delete all the documents in the collection to
|
||||
# workaround a hang in createIndexes.
|
||||
("delete_many", lambda: coll.delete_many({})),
|
||||
("create_index", lambda: coll.create_index([("a", DESCENDING)])),
|
||||
("create_indexes", lambda: coll.create_indexes([IndexModel("b")])),
|
||||
("drop_index", lambda: coll.drop_index([("a", DESCENDING)])),
|
||||
("create", lambda: db.create_collection("new")),
|
||||
("rename", lambda: coll.rename("new")),
|
||||
("drop", lambda: db.new.drop()),
|
||||
]
|
||||
# SERVER-47194: dropDatabase does not respect wtimeout in 3.6.
|
||||
if async_client_context.version[:2] != (3, 6):
|
||||
ops.append(("drop_database", lambda: client.drop_database(db)))
|
||||
|
||||
for name, f in ops:
|
||||
# Ensure insert_many and bulk_write still raise BulkWriteError.
|
||||
if name in ("insert_many", "bulk_write"):
|
||||
expected = BulkWriteError
|
||||
else:
|
||||
expected = expected_exception
|
||||
with self.assertRaises(expected, msg=name) as cm:
|
||||
await f()
|
||||
if expected == BulkWriteError:
|
||||
bulk_result = cm.exception.details
|
||||
assert bulk_result is not None
|
||||
wc_errors = bulk_result["writeConcernErrors"]
|
||||
self.assertTrue(wc_errors)
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
async def test_raise_write_concern_error(self):
|
||||
self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test")
|
||||
assert async_client_context.w is not None
|
||||
await self.assertWriteOpsRaise(
|
||||
WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError
|
||||
)
|
||||
|
||||
@async_client_context.require_secondaries_count(1)
|
||||
@async_client_context.require_test_commands
|
||||
async def test_raise_wtimeout(self):
|
||||
self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test")
|
||||
self.addAsyncCleanup(self.enable_replication, async_client_context.client)
|
||||
# Disable replication to guarantee a wtimeout error.
|
||||
await self.disable_replication(async_client_context.client)
|
||||
await self.assertWriteOpsRaise(
|
||||
WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError
|
||||
)
|
||||
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_error_includes_errInfo(self):
|
||||
expected_wce = {
|
||||
"code": 100,
|
||||
"codeName": "UnsatisfiableWriteConcern",
|
||||
"errmsg": "Not enough data-bearing nodes",
|
||||
"errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}},
|
||||
}
|
||||
cause_wce = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 2},
|
||||
"data": {"failCommands": ["insert"], "writeConcernError": expected_wce},
|
||||
}
|
||||
async with self.fail_point(cause_wce):
|
||||
# Write concern error on insert includes errInfo.
|
||||
with self.assertRaises(WriteConcernError) as ctx:
|
||||
await self.db.test.insert_one({})
|
||||
self.assertEqual(ctx.exception.details, expected_wce)
|
||||
|
||||
# Test bulk_write as well.
|
||||
with self.assertRaises(BulkWriteError) as ctx:
|
||||
await self.db.test.bulk_write([InsertOne({})])
|
||||
expected_details = {
|
||||
"writeErrors": [],
|
||||
"writeConcernErrors": [expected_wce],
|
||||
"nInserted": 1,
|
||||
"nUpserted": 0,
|
||||
"nMatched": 0,
|
||||
"nModified": 0,
|
||||
"nRemoved": 0,
|
||||
"upserted": [],
|
||||
}
|
||||
self.assertEqual(ctx.exception.details, expected_details)
|
||||
|
||||
@async_client_context.require_version_min(4, 9)
|
||||
async def test_write_error_details_exposes_errinfo(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
db = client.errinfotest
|
||||
self.addAsyncCleanup(client.drop_database, "errinfotest")
|
||||
validator = {"x": {"$type": "string"}}
|
||||
await db.create_collection("test", validator=validator)
|
||||
with self.assertRaises(WriteError) as ctx:
|
||||
await db.test.insert_one({"x": 1})
|
||||
self.assertEqual(ctx.exception.code, 121)
|
||||
self.assertIsNotNone(ctx.exception.details)
|
||||
assert ctx.exception.details is not None
|
||||
self.assertIsNotNone(ctx.exception.details.get("errInfo"))
|
||||
for event in listener.succeeded_events:
|
||||
if event.command_name == "insert":
|
||||
self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details)
|
||||
break
|
||||
else:
|
||||
self.fail("Couldn't find insert event.")
|
||||
|
||||
|
||||
def normalize_write_concern(concern):
|
||||
result = {}
|
||||
for key in concern:
|
||||
if key.lower() == "wtimeoutms":
|
||||
result["wtimeout"] = concern[key]
|
||||
elif key == "journal":
|
||||
result["j"] = concern[key]
|
||||
else:
|
||||
result[key] = concern[key]
|
||||
return result
|
||||
|
||||
|
||||
def create_connection_string_test(test_case):
|
||||
def run_test(self):
|
||||
uri = test_case["uri"]
|
||||
valid = test_case["valid"]
|
||||
warning = test_case["warning"]
|
||||
|
||||
if not valid:
|
||||
if warning is False:
|
||||
self.assertRaises(
|
||||
(ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False
|
||||
)
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", UserWarning)
|
||||
self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False)
|
||||
else:
|
||||
client = AsyncMongoClient(uri, connect=False)
|
||||
if "writeConcern" in test_case:
|
||||
document = client.write_concern.document
|
||||
self.assertEqual(document, normalize_write_concern(test_case["writeConcern"]))
|
||||
if "readConcern" in test_case:
|
||||
document = client.read_concern.document
|
||||
self.assertEqual(document, test_case["readConcern"])
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
def create_document_test(test_case):
|
||||
def run_test(self):
|
||||
valid = test_case["valid"]
|
||||
|
||||
if "writeConcern" in test_case:
|
||||
normalized = normalize_write_concern(test_case["writeConcern"])
|
||||
if not valid:
|
||||
self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized)
|
||||
else:
|
||||
write_concern = WriteConcern(**normalized)
|
||||
self.assertEqual(write_concern.document, test_case["writeConcernDocument"])
|
||||
self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"])
|
||||
self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"])
|
||||
if "readConcern" in test_case:
|
||||
# Any string for 'level' is equally valid
|
||||
read_concern = ReadConcern(**test_case["readConcern"])
|
||||
self.assertEqual(read_concern.document, test_case["readConcernDocument"])
|
||||
self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"])
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
def create_tests():
|
||||
for dirpath, _, filenames in os.walk(TEST_PATH):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
if dirname == "operation":
|
||||
# This directory is tested by TestOperations.
|
||||
continue
|
||||
elif dirname == "connection-string":
|
||||
create_test = create_connection_string_test
|
||||
else:
|
||||
create_test = create_document_test
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as test_stream:
|
||||
test_cases = json.load(test_stream)["tests"]
|
||||
|
||||
fname = os.path.splitext(filename)[0]
|
||||
for test_case in test_cases:
|
||||
new_test = create_test(test_case)
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname.replace("-", "_"),
|
||||
fname.replace("-", "_"),
|
||||
str(test_case["description"].lower().replace(" ", "_")),
|
||||
)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(TestReadWriteConcernSpec, new_test.__name__, new_test)
|
||||
|
||||
|
||||
create_tests()
|
||||
|
||||
|
||||
# Generate unified tests.
|
||||
# PyMongo does not support MapReduce.
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(TEST_PATH, "operation"),
|
||||
module=__name__,
|
||||
expected_failures=["MapReduce .*"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
46
test/asynchronous/test_retryable_reads_unified.py
Normal file
46
test/asynchronous/test_retryable_reads_unified.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2022-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the Retryable Reads unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects.
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
TEST_PATH,
|
||||
module=__name__,
|
||||
expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"],
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
39
test/asynchronous/test_retryable_writes_unified.py
Normal file
39
test/asynchronous/test_retryable_writes_unified.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the Retryable Writes unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
41
test/asynchronous/test_run_command.py
Normal file
41
test/asynchronous/test_run_command.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2024-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run Command unified tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(TEST_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
45
test/asynchronous/test_server_selection_logging.py
Normal file
45
test/asynchronous/test_server_selection_logging.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the server selection logging unified format spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
77
test/asynchronous/test_server_selection_rtt.py
Normal file
77
test/asynchronous/test_server_selection_rtt.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2015 MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the topology module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous import AsyncPyMongoTestCase
|
||||
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt")
|
||||
|
||||
|
||||
class TestAllScenarios(AsyncPyMongoTestCase):
|
||||
pass
|
||||
|
||||
|
||||
def create_test(scenario_def):
|
||||
def run_scenario(self):
|
||||
moving_average = MovingAverage()
|
||||
|
||||
if scenario_def["avg_rtt_ms"] != "NULL":
|
||||
moving_average.add_sample(scenario_def["avg_rtt_ms"])
|
||||
|
||||
if scenario_def["new_rtt_ms"] != "NULL":
|
||||
moving_average.add_sample(scenario_def["new_rtt_ms"])
|
||||
|
||||
self.assertAlmostEqual(moving_average.get(), scenario_def["new_avg_rtt"])
|
||||
|
||||
return run_scenario
|
||||
|
||||
|
||||
def create_tests():
|
||||
for dirpath, _, filenames in os.walk(TEST_PATH):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
scenario_def = json.load(scenario_stream)
|
||||
|
||||
# Construct test from scenario.
|
||||
new_test = create_test(scenario_def)
|
||||
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(TestAllScenarios, new_test.__name__, new_test)
|
||||
|
||||
|
||||
create_tests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
40
test/asynchronous/test_sessions_unified.py
Normal file
40
test/asynchronous/test_sessions_unified.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the Sessions unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions")
|
||||
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
361
test/asynchronous/test_srv_polling.py
Normal file
361
test/asynchronous/test_srv_polling.py
Normal file
@ -0,0 +1,361 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the SRV support tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest
|
||||
from test.utils import FunctionCallRecorder, async_wait_until
|
||||
|
||||
import pymongo
|
||||
from pymongo import common
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.srv_resolver import _have_dnspython
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
WAIT_TIME = 0.1
|
||||
|
||||
|
||||
class SrvPollingKnobs:
|
||||
def __init__(
|
||||
self,
|
||||
ttl_time=None,
|
||||
min_srv_rescan_interval=None,
|
||||
nodelist_callback=None,
|
||||
count_resolver_calls=False,
|
||||
):
|
||||
self.ttl_time = ttl_time
|
||||
self.min_srv_rescan_interval = min_srv_rescan_interval
|
||||
self.nodelist_callback = nodelist_callback
|
||||
self.count_resolver_calls = count_resolver_calls
|
||||
|
||||
self.old_min_srv_rescan_interval = None
|
||||
self.old_dns_resolver_response = None
|
||||
|
||||
def enable(self):
|
||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
|
||||
if self.min_srv_rescan_interval is not None:
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||
|
||||
def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||
assert self.old_dns_resolver_response is not None
|
||||
nodes, ttl = self.old_dns_resolver_response(resolver)
|
||||
if self.nodelist_callback is not None:
|
||||
nodes = self.nodelist_callback()
|
||||
if self.ttl_time is not None:
|
||||
ttl = self.ttl_time
|
||||
return nodes, ttl
|
||||
|
||||
patch_func: Any
|
||||
if self.count_resolver_calls:
|
||||
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
|
||||
else:
|
||||
patch_func = mock_get_hosts_and_min_ttl
|
||||
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
def disable(self):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
self.old_dns_resolver_response
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.disable()
|
||||
|
||||
|
||||
class TestSrvPolling(AsyncPyMongoTestCase):
|
||||
BASE_SRV_RESPONSE = [
|
||||
("localhost.test.build.10gen.cc", 27017),
|
||||
("localhost.test.build.10gen.cc", 27018),
|
||||
]
|
||||
|
||||
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
|
||||
|
||||
async def asyncSetUp(self):
|
||||
# Patch timeouts to ensure short rescan SRV interval.
|
||||
self.client_knobs = client_knobs(
|
||||
heartbeat_frequency=WAIT_TIME,
|
||||
min_heartbeat_interval=WAIT_TIME,
|
||||
events_queue_frequency=WAIT_TIME,
|
||||
)
|
||||
self.client_knobs.enable()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
self.client_knobs.disable()
|
||||
|
||||
def get_nodelist(self, client):
|
||||
return client._topology.description.server_descriptions().keys()
|
||||
|
||||
async def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
|
||||
"""Check if the client._topology eventually sees all nodes in the
|
||||
expected_nodelist.
|
||||
"""
|
||||
|
||||
def predicate():
|
||||
nodelist = self.get_nodelist(client)
|
||||
if set(expected_nodelist) == set(nodelist):
|
||||
return True
|
||||
return False
|
||||
|
||||
await async_wait_until(predicate, "see expected nodelist", timeout=timeout)
|
||||
|
||||
async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
|
||||
"""Check if the client._topology ever deviates from seeing all nodes
|
||||
in the expected_nodelist. Consistency is checked after sleeping for
|
||||
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
|
||||
least once.
|
||||
"""
|
||||
|
||||
def predicate():
|
||||
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
||||
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
|
||||
return False
|
||||
|
||||
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
||||
|
||||
nodelist = self.get_nodelist(client)
|
||||
if set(expected_nodelist) != set(nodelist):
|
||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||
self.assertGreaterEqual(
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
1,
|
||||
"resolver was never called",
|
||||
)
|
||||
return True
|
||||
|
||||
async def run_scenario(self, dns_response, expect_change):
|
||||
self.assertEqual(_have_dnspython(), True)
|
||||
if callable(dns_response):
|
||||
dns_resolver_response = dns_response
|
||||
else:
|
||||
|
||||
def dns_resolver_response():
|
||||
return dns_response
|
||||
|
||||
if expect_change:
|
||||
assertion_method = self.assert_nodelist_change
|
||||
count_resolver_calls = False
|
||||
expected_response = dns_response
|
||||
else:
|
||||
assertion_method = self.assert_nodelist_nochange
|
||||
count_resolver_calls = True
|
||||
expected_response = self.BASE_SRV_RESPONSE
|
||||
|
||||
# Patch timeouts to ensure short test running times.
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
await client.aconnect()
|
||||
await self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
|
||||
# Patch list of hosts returned by DNS query.
|
||||
with SrvPollingKnobs(
|
||||
nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls
|
||||
):
|
||||
await assertion_method(expected_response, client)
|
||||
|
||||
async def test_addition(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.append(("localhost.test.build.10gen.cc", 27019))
|
||||
await self.run_scenario(response, True)
|
||||
|
||||
async def test_removal(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.remove(("localhost.test.build.10gen.cc", 27018))
|
||||
await self.run_scenario(response, True)
|
||||
|
||||
async def test_replace_one(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.remove(("localhost.test.build.10gen.cc", 27018))
|
||||
response.append(("localhost.test.build.10gen.cc", 27019))
|
||||
await self.run_scenario(response, True)
|
||||
|
||||
async def test_replace_both_with_one(self):
|
||||
response = [("localhost.test.build.10gen.cc", 27019)]
|
||||
await self.run_scenario(response, True)
|
||||
|
||||
async def test_replace_both_with_two(self):
|
||||
response = [
|
||||
("localhost.test.build.10gen.cc", 27019),
|
||||
("localhost.test.build.10gen.cc", 27020),
|
||||
]
|
||||
await self.run_scenario(response, True)
|
||||
|
||||
async def test_dns_failures(self):
|
||||
from dns import exception
|
||||
|
||||
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
|
||||
|
||||
def response_callback(*args):
|
||||
raise exc("DNS Failure!")
|
||||
|
||||
await self.run_scenario(response_callback, False)
|
||||
|
||||
async def test_dns_record_lookup_empty(self):
|
||||
response: list = []
|
||||
await self.run_scenario(response, False)
|
||||
|
||||
async def _test_recover_from_initial(self, initial_callback):
|
||||
# Construct a valid final response callback distinct from base.
|
||||
response_final = self.BASE_SRV_RESPONSE[:]
|
||||
response_final.pop()
|
||||
|
||||
def final_callback():
|
||||
return response_final
|
||||
|
||||
with SrvPollingKnobs(
|
||||
ttl_time=WAIT_TIME,
|
||||
min_srv_rescan_interval=WAIT_TIME,
|
||||
nodelist_callback=initial_callback,
|
||||
count_resolver_calls=True,
|
||||
):
|
||||
# Client uses unpatched method to get initial nodelist
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
await client.aconnect()
|
||||
# Invalid DNS resolver response should not change nodelist.
|
||||
await self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
|
||||
|
||||
with SrvPollingKnobs(
|
||||
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback
|
||||
):
|
||||
# Nodelist should reflect new valid DNS resolver response.
|
||||
await self.assert_nodelist_change(response_final, client)
|
||||
|
||||
async def test_recover_from_initially_empty_seedlist(self):
|
||||
def empty_seedlist():
|
||||
return []
|
||||
|
||||
await self._test_recover_from_initial(empty_seedlist)
|
||||
|
||||
async def test_recover_from_initially_erroring_seedlist(self):
|
||||
def erroring_seedlist():
|
||||
raise ConfigurationError
|
||||
|
||||
await self._test_recover_from_initial(erroring_seedlist)
|
||||
|
||||
async def test_10_all_dns_selected(self):
|
||||
response = [
|
||||
("localhost.test.build.10gen.cc", 27017),
|
||||
("localhost.test.build.10gen.cc", 27019),
|
||||
("localhost.test.build.10gen.cc", 27020),
|
||||
]
|
||||
|
||||
def nodelist_callback():
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
|
||||
await client.aconnect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
await self.assert_nodelist_change(response, client)
|
||||
|
||||
async def test_11_all_dns_selected(self):
|
||||
response = [
|
||||
("localhost.test.build.10gen.cc", 27019),
|
||||
("localhost.test.build.10gen.cc", 27020),
|
||||
]
|
||||
|
||||
def nodelist_callback():
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
await client.aconnect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
await self.assert_nodelist_change(response, client)
|
||||
|
||||
async def test_12_new_dns_randomly_selected(self):
|
||||
response = [
|
||||
("localhost.test.build.10gen.cc", 27020),
|
||||
("localhost.test.build.10gen.cc", 27019),
|
||||
("localhost.test.build.10gen.cc", 27017),
|
||||
]
|
||||
|
||||
def nodelist_callback():
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
await client.aconnect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
await asyncio.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
|
||||
final_topology = set(client.topology_description.server_descriptions())
|
||||
self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology)
|
||||
self.assertEqual(len(final_topology), 2)
|
||||
|
||||
async def test_does_not_flipflop(self):
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
|
||||
await client.aconnect()
|
||||
old = set(client.topology_description.server_descriptions())
|
||||
await asyncio.sleep(4 * WAIT_TIME)
|
||||
new = set(client.topology_description.server_descriptions())
|
||||
self.assertSetEqual(old, new)
|
||||
|
||||
async def test_srv_service_name(self):
|
||||
# Construct a valid final response callback distinct from base.
|
||||
response = [
|
||||
("localhost.test.build.10gen.cc.", 27019),
|
||||
("localhost.test.build.10gen.cc.", 27020),
|
||||
]
|
||||
|
||||
def nodelist_callback():
|
||||
return response
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(
|
||||
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
|
||||
)
|
||||
await client.aconnect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
await self.assert_nodelist_change(response, client)
|
||||
|
||||
async def test_srv_waits_to_poll(self):
|
||||
modified = [("localhost.test.build.10gen.cc", 27019)]
|
||||
|
||||
def resolver_response():
|
||||
return modified
|
||||
|
||||
with SrvPollingKnobs(
|
||||
ttl_time=WAIT_TIME,
|
||||
min_srv_rescan_interval=WAIT_TIME,
|
||||
nodelist_callback=resolver_response,
|
||||
):
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
await client.aconnect()
|
||||
with self.assertRaises(AssertionError):
|
||||
await self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2)
|
||||
|
||||
def test_import_dns_resolver(self):
|
||||
# Regression test for PYTHON-4407
|
||||
import dns.resolver
|
||||
|
||||
self.assertTrue(hasattr(dns.resolver, "resolve"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
662
test/asynchronous/test_ssl.py
Normal file
662
test/asynchronous/test_ssl.py
Normal file
@ -0,0 +1,662 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for SSL support."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import socket
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
HAVE_IPADDRESS,
|
||||
AsyncIntegrationTest,
|
||||
AsyncPyMongoTestCase,
|
||||
SkipTest,
|
||||
async_client_context,
|
||||
connected,
|
||||
remove_all_users,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
cat_files,
|
||||
ignore_deprecations,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pymongo import AsyncMongoClient, ssl_support
|
||||
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_HAVE_PYOPENSSL = False
|
||||
try:
|
||||
# All of these must be available to use PyOpenSSL
|
||||
import OpenSSL
|
||||
import requests
|
||||
import service_identity
|
||||
|
||||
# Ensure service_identity>=18.1 is installed
|
||||
from service_identity.pyopenssl import verify_ip_address
|
||||
|
||||
from pymongo.ocsp_support import _load_trusted_ca_certs
|
||||
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_load_trusted_ca_certs = None # type: ignore
|
||||
|
||||
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
if _IS_SYNC:
|
||||
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates")
|
||||
else:
|
||||
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates")
|
||||
|
||||
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
|
||||
CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem")
|
||||
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
|
||||
CA_BUNDLE_PEM = os.path.join(CERT_PATH, "trusted-ca.pem")
|
||||
CRL_PEM = os.path.join(CERT_PATH, "crl.pem")
|
||||
MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client"
|
||||
|
||||
# To fully test this start a mongod instance (built with SSL support) like so:
|
||||
# mongod --dbpath /path/to/data/directory --sslOnNormalPorts \
|
||||
# --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \
|
||||
# --sslCAFile /path/to/pymongo/test/certificates/ca.pem \
|
||||
# --sslWeakCertificateValidation
|
||||
# Also, make sure you have 'server' as an alias for localhost in /etc/hosts
|
||||
#
|
||||
# Note: For all replica set tests to pass, the replica set configuration must
|
||||
# use 'localhost' for the hostname of all hosts.
|
||||
|
||||
|
||||
class TestClientSSL(AsyncPyMongoTestCase):
|
||||
@unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.")
|
||||
def test_no_ssl_module(self):
|
||||
# Explicit
|
||||
self.assertRaises(ConfigurationError, self.simple_client, ssl=True)
|
||||
|
||||
# Implied
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM)
|
||||
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
@ignore_deprecations
|
||||
def test_config_ssl(self):
|
||||
# Tests various ssl configurations
|
||||
self.assertRaises(ValueError, self.simple_client, ssl="foo")
|
||||
self.assertRaises(
|
||||
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=0)
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=5.5)
|
||||
self.assertRaises(TypeError, self.simple_client, ssl=[])
|
||||
|
||||
self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile")
|
||||
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True)
|
||||
self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[])
|
||||
|
||||
# Test invalid combinations
|
||||
self.assertRaises(
|
||||
ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM)
|
||||
self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM)
|
||||
self.assertRaises(
|
||||
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False
|
||||
)
|
||||
self.assertRaises(
|
||||
ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False
|
||||
)
|
||||
self.assertRaises(
|
||||
ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False
|
||||
)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
def test_use_pyopenssl_when_available(self):
|
||||
self.assertTrue(_ssl.IS_PYOPENSSL)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
|
||||
def test_load_trusted_ca_certs(self):
|
||||
trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM)
|
||||
self.assertEqual(2, len(trusted_ca_certs))
|
||||
|
||||
|
||||
class TestSSL(AsyncIntegrationTest):
|
||||
saved_port: int
|
||||
|
||||
async def assertClientWorks(self, client):
|
||||
coll = client.pymongo_test.ssl_test.with_options(
|
||||
write_concern=WriteConcern(w=async_client_context.w)
|
||||
)
|
||||
await coll.drop()
|
||||
await coll.insert_one({"ssl": True})
|
||||
self.assertTrue((await coll.find_one())["ssl"])
|
||||
await coll.drop()
|
||||
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
# MongoClient should connect to the primary by default.
|
||||
self.saved_port = AsyncMongoClient.PORT
|
||||
AsyncMongoClient.PORT = await async_client_context.port
|
||||
|
||||
async def asyncTearDown(self):
|
||||
AsyncMongoClient.PORT = self.saved_port
|
||||
|
||||
@async_client_context.require_tls
|
||||
async def test_simple_ssl(self):
|
||||
# Expects the server to be running with ssl and with
|
||||
# no --sslPEMKeyFile or with --sslWeakCertificateValidation
|
||||
await self.assertClientWorks(self.client)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@ignore_deprecations
|
||||
async def test_tlsCertificateKeyFilePassword(self):
|
||||
# Expects the server to be running with server.pem and ca.pem
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
|
||||
tlsCertificateKeyFilePassword="qwerty",
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
)
|
||||
else:
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM,
|
||||
tlsCertificateKeyFilePassword="qwerty",
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=5000,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
uri_fmt = (
|
||||
"mongodb://localhost/?ssl=true"
|
||||
"&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty"
|
||||
"&tlsCAFile=%s&serverSelectionTimeoutMS=5000"
|
||||
)
|
||||
await connected(
|
||||
self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_no_auth
|
||||
@ignore_deprecations
|
||||
async def test_cert_ssl_implicitly_set(self):
|
||||
# Expects the server to be running with server.pem and ca.pem
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
#
|
||||
|
||||
# test that setting tlsCertificateKeyFile causes ssl to be set to True
|
||||
client = self.simple_client(
|
||||
await async_client_context.host,
|
||||
await async_client_context.port,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
response = await client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if "setName" in response:
|
||||
client = self.simple_client(
|
||||
await async_client_context.pair,
|
||||
replicaSet=response["setName"],
|
||||
w=len(response["hosts"]),
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
|
||||
await self.assertClientWorks(client)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_no_auth
|
||||
@ignore_deprecations
|
||||
async def test_cert_ssl_validation(self):
|
||||
# Expects the server to be running with server.pem and ca.pem
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
#
|
||||
client = self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
)
|
||||
response = await client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
if "setName" in response:
|
||||
if response["primary"].split(":")[0] != "localhost":
|
||||
raise SkipTest(
|
||||
"No hosts in the replicaset for 'localhost'. "
|
||||
"Cannot validate hostname in the certificate"
|
||||
)
|
||||
|
||||
client = self.simple_client(
|
||||
"localhost",
|
||||
replicaSet=response["setName"],
|
||||
w=len(response["hosts"]),
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
)
|
||||
|
||||
await self.assertClientWorks(client)
|
||||
|
||||
if HAVE_IPADDRESS:
|
||||
client = self.simple_client(
|
||||
"127.0.0.1",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
)
|
||||
await self.assertClientWorks(client)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_no_auth
|
||||
@ignore_deprecations
|
||||
async def test_cert_ssl_uri_support(self):
|
||||
# Expects the server to be running with server.pem and ca.pem
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
#
|
||||
uri_fmt = (
|
||||
"mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates"
|
||||
"=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false"
|
||||
)
|
||||
client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM))
|
||||
await self.assertClientWorks(client)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_server_resolvable
|
||||
@ignore_deprecations
|
||||
async def test_cert_ssl_validation_hostname_matching(self):
|
||||
# Expects the server to be running with server.pem and ca.pem
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, True, False, False)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, True, False)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
self.assertTrue(ctx.check_hostname)
|
||||
|
||||
response = await self.client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
if "setName" in response:
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server",
|
||||
replicaSet=response["setName"],
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server",
|
||||
replicaSet=response["setName"],
|
||||
ssl=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
tlsAllowInvalidCertificates=False,
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=500,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@ignore_deprecations
|
||||
async def test_tlsCRLFile_support(self):
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsCRLFile=CRL_PEM,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
)
|
||||
else:
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"localhost",
|
||||
ssl=True,
|
||||
tlsCAFile=CA_PEM,
|
||||
tlsCRLFile=CRL_PEM,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
|
||||
await connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore
|
||||
|
||||
uri_fmt = (
|
||||
"mongodb://localhost/?ssl=true&tlsCRLFile=%s"
|
||||
"&tlsCAFile=%s&serverSelectionTimeoutMS=1000"
|
||||
)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await connected(
|
||||
self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_server_resolvable
|
||||
@ignore_deprecations
|
||||
async def test_validation_with_system_ca_certs(self):
|
||||
# Expects the server to be running with server.pem and ca.pem.
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
# --sslWeakCertificateValidation
|
||||
#
|
||||
self.patch_system_certs(CA_PEM)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
# Server cert is verified but hostname matching fails
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Server cert is verified. Disable hostname matching.
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"server",
|
||||
ssl=True,
|
||||
tlsAllowInvalidHostnames=True,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Server cert and hostname are verified.
|
||||
await connected(
|
||||
self.simple_client(
|
||||
"mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000",
|
||||
**self.credentials, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
def test_system_certs_config_error(self):
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
|
||||
ctx, "load_default_certs"
|
||||
):
|
||||
raise SkipTest("Can't test when system CA certificates are loadable.")
|
||||
|
||||
have_certifi = ssl_support.HAVE_CERTIFI
|
||||
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
|
||||
# Force the test regardless of environment.
|
||||
ssl_support.HAVE_CERTIFI = False
|
||||
ssl_support.HAVE_WINCERTSTORE = False
|
||||
try:
|
||||
with self.assertRaises(ConfigurationError):
|
||||
self.simple_client("mongodb://localhost/?ssl=true")
|
||||
finally:
|
||||
ssl_support.HAVE_CERTIFI = have_certifi
|
||||
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
|
||||
|
||||
def test_certifi_support(self):
|
||||
if hasattr(ssl, "SSLContext"):
|
||||
# SSLSocket doesn't provide ca_certs attribute on pythons
|
||||
# with SSLContext and SSLContext provides no information
|
||||
# about ca_certs.
|
||||
raise SkipTest("Can't test when SSLContext available.")
|
||||
if not ssl_support.HAVE_CERTIFI:
|
||||
raise SkipTest("Need certifi to test certifi support.")
|
||||
|
||||
have_wincertstore = ssl_support.HAVE_WINCERTSTORE
|
||||
# Force the test on Windows, regardless of environment.
|
||||
ssl_support.HAVE_WINCERTSTORE = False
|
||||
try:
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
|
||||
finally:
|
||||
ssl_support.HAVE_WINCERTSTORE = have_wincertstore
|
||||
|
||||
def test_wincertstore(self):
|
||||
if sys.platform != "win32":
|
||||
raise SkipTest("Only valid on Windows.")
|
||||
if hasattr(ssl, "SSLContext"):
|
||||
# SSLSocket doesn't provide ca_certs attribute on pythons
|
||||
# with SSLContext and SSLContext provides no information
|
||||
# about ca_certs.
|
||||
raise SkipTest("Can't test when SSLContext available.")
|
||||
if not ssl_support.HAVE_WINCERTSTORE:
|
||||
raise SkipTest("Need wincertstore to test wincertstore.")
|
||||
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
|
||||
|
||||
@async_client_context.require_auth
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@ignore_deprecations
|
||||
async def test_mongodb_x509_auth(self):
|
||||
host, port = await async_client_context.host, await async_client_context.port
|
||||
self.addAsyncCleanup(remove_all_users, async_client_context.client["$external"])
|
||||
|
||||
# Give x509 user all necessary privileges.
|
||||
await async_client_context.create_user(
|
||||
"$external",
|
||||
MONGODB_X509_USERNAME,
|
||||
roles=[
|
||||
{"role": "readWriteAnyDatabase", "db": "admin"},
|
||||
{"role": "userAdminAnyDatabase", "db": "admin"},
|
||||
],
|
||||
)
|
||||
|
||||
noauth = self.simple_client(
|
||||
await async_client_context.pair,
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
await noauth.pymongo_test.test.find_one()
|
||||
|
||||
listener = EventListener()
|
||||
auth = self.simple_client(
|
||||
await async_client_context.pair,
|
||||
authMechanism="MONGODB-X509",
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
event_listeners=[listener],
|
||||
)
|
||||
|
||||
# No error
|
||||
await auth.pymongo_test.test.find_one()
|
||||
names = listener.started_command_names()
|
||||
if async_client_context.version.at_least(4, 4, -1):
|
||||
# Speculative auth skips the authenticate command.
|
||||
self.assertEqual(names, ["find"])
|
||||
else:
|
||||
self.assertEqual(names, ["authenticate", "find"])
|
||||
|
||||
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
|
||||
quote_plus(MONGODB_X509_USERNAME),
|
||||
host,
|
||||
port,
|
||||
)
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
# No error
|
||||
await client.pymongo_test.test.find_one()
|
||||
|
||||
uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port)
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
# No error
|
||||
await client.pymongo_test.test.find_one()
|
||||
# Auth should fail if username and certificate do not match
|
||||
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
|
||||
quote_plus("not the username"),
|
||||
host,
|
||||
port,
|
||||
)
|
||||
|
||||
bad_client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
await bad_client.pymongo_test.test.find_one()
|
||||
|
||||
bad_client = self.simple_client(
|
||||
await async_client_context.pair,
|
||||
username="not the username",
|
||||
authMechanism="MONGODB-X509",
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
await bad_client.pymongo_test.test.find_one()
|
||||
|
||||
# Invalid certificate (using CA certificate as client certificate)
|
||||
uri = "mongodb://%s@%s:%d/?authMechanism=MONGODB-X509" % (
|
||||
quote_plus(MONGODB_X509_USERNAME),
|
||||
host,
|
||||
port,
|
||||
)
|
||||
try:
|
||||
await connected(
|
||||
self.simple_client(
|
||||
uri,
|
||||
ssl=True,
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CA_PEM,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
)
|
||||
)
|
||||
except (ConnectionFailure, ConfigurationError):
|
||||
pass
|
||||
else:
|
||||
self.fail("Invalid certificate accepted.")
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@ignore_deprecations
|
||||
async def test_connect_with_ca_bundle(self):
|
||||
def remove(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
temp_ca_bundle = os.path.join(CERT_PATH, "trusted-ca-bundle.pem")
|
||||
self.addCleanup(remove, temp_ca_bundle)
|
||||
# Add the CA cert file to the bundle.
|
||||
cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM)
|
||||
async with self.simple_client(
|
||||
"localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle
|
||||
) as client:
|
||||
self.assertTrue(await client.admin.command("ping"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
228
test/asynchronous/test_streaming_protocol.py
Normal file
228
test/asynchronous/test_streaming_protocol.py
Normal file
@ -0,0 +1,228 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the database module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import (
|
||||
HeartbeatEventListener,
|
||||
ServerEventListener,
|
||||
async_wait_until,
|
||||
)
|
||||
|
||||
from pymongo import monitoring
|
||||
from pymongo.hello import HelloCompat
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestStreamingProtocol(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_failCommand_streaming(self):
|
||||
listener = ServerEventListener()
|
||||
hb_listener = HeartbeatEventListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
event_listeners=[listener, hb_listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="failingHeartbeatTest",
|
||||
)
|
||||
# Force a connection.
|
||||
await client.admin.command("ping")
|
||||
address = await client.address
|
||||
listener.reset()
|
||||
|
||||
fail_hello = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 4},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"closeConnection": False,
|
||||
"errorCode": 10107,
|
||||
"appName": "failingHeartbeatTest",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_hello):
|
||||
|
||||
def _marked_unknown(event):
|
||||
return (
|
||||
event.server_address == address
|
||||
and not event.new_description.is_server_type_known
|
||||
)
|
||||
|
||||
def _discovered_node(event):
|
||||
return (
|
||||
event.server_address == address
|
||||
and not event.previous_description.is_server_type_known
|
||||
and event.new_description.is_server_type_known
|
||||
)
|
||||
|
||||
def marked_unknown():
|
||||
return len(listener.matching(_marked_unknown)) >= 1
|
||||
|
||||
def rediscovered():
|
||||
return len(listener.matching(_discovered_node)) >= 1
|
||||
|
||||
# Topology events are not published synchronously
|
||||
await async_wait_until(marked_unknown, "mark node unknown")
|
||||
await async_wait_until(rediscovered, "rediscover node")
|
||||
|
||||
# Server should be selectable.
|
||||
await client.admin.command("ping")
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_streaming_rtt(self):
|
||||
listener = ServerEventListener()
|
||||
hb_listener = HeartbeatEventListener()
|
||||
# On Windows, RTT can actually be 0.0 because time.time() only has
|
||||
# 1-15 millisecond resolution. We need to delay the initial hello
|
||||
# to ensure that RTT is never zero.
|
||||
name = "streamingRttTest"
|
||||
delay_hello: dict = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1000},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"blockConnection": True,
|
||||
"blockTimeMS": 20,
|
||||
# This can be uncommented after SERVER-49220 is fixed.
|
||||
# 'appName': name,
|
||||
},
|
||||
}
|
||||
async with self.fail_point(delay_hello):
|
||||
client = await self.async_rs_or_single_client(
|
||||
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
|
||||
)
|
||||
# Force a connection.
|
||||
await client.admin.command("ping")
|
||||
address = await client.address
|
||||
|
||||
delay_hello["data"]["blockTimeMS"] = 500
|
||||
delay_hello["data"]["appName"] = name
|
||||
async with self.fail_point(delay_hello):
|
||||
|
||||
def rtt_exceeds_250_ms():
|
||||
# XXX: Add a public TopologyDescription getter to MongoClient?
|
||||
topology = client._topology
|
||||
sd = topology.description.server_descriptions()[address]
|
||||
assert sd.round_trip_time is not None
|
||||
return sd.round_trip_time > 0.250
|
||||
|
||||
await async_wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT")
|
||||
|
||||
# Server should be selectable.
|
||||
await client.admin.command("ping")
|
||||
|
||||
def changed_event(event):
|
||||
return event.server_address == address and isinstance(
|
||||
event, monitoring.ServerDescriptionChangedEvent
|
||||
)
|
||||
|
||||
# There should only be one event published, for the initial discovery.
|
||||
events = listener.matching(changed_event)
|
||||
self.assertEqual(1, len(events))
|
||||
self.assertGreater(events[0].new_description.round_trip_time, 0)
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_monitor_waits_after_server_check_error(self):
|
||||
# This test implements:
|
||||
# https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-discovery-and-monitoring-tests.md#monitors-sleep-at-least-minheartbeatfreqencyms-between-checks
|
||||
fail_hello = {
|
||||
"mode": {"times": 5},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"errorCode": 1234,
|
||||
"appName": "SDAMMinHeartbeatFrequencyTest",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_hello):
|
||||
start = time.time()
|
||||
client = await self.async_single_client(
|
||||
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
|
||||
)
|
||||
# Force a connection.
|
||||
await client.admin.command("ping")
|
||||
duration = time.time() - start
|
||||
# Explanation of the expected events:
|
||||
# 0ms: run configureFailPoint
|
||||
# 1ms: create MongoClient
|
||||
# 2ms: failed monitor handshake, 1
|
||||
# 502ms: failed monitor handshake, 2
|
||||
# 1002ms: failed monitor handshake, 3
|
||||
# 1502ms: failed monitor handshake, 4
|
||||
# 2002ms: failed monitor handshake, 5
|
||||
# 2502ms: monitor handshake succeeds
|
||||
# 2503ms: run awaitable hello
|
||||
# 2504ms: application handshake succeeds
|
||||
# 2505ms: ping command succeeds
|
||||
self.assertGreaterEqual(duration, 2)
|
||||
self.assertLessEqual(duration, 3.5)
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
async def test_heartbeat_awaited_flag(self):
|
||||
hb_listener = HeartbeatEventListener()
|
||||
client = await self.async_single_client(
|
||||
event_listeners=[hb_listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="heartbeatEventAwaitedFlag",
|
||||
)
|
||||
# Force a connection.
|
||||
await client.admin.command("ping")
|
||||
|
||||
def hb_succeeded(event):
|
||||
return isinstance(event, monitoring.ServerHeartbeatSucceededEvent)
|
||||
|
||||
def hb_failed(event):
|
||||
return isinstance(event, monitoring.ServerHeartbeatFailedEvent)
|
||||
|
||||
fail_heartbeat = {
|
||||
"mode": {"times": 2},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"closeConnection": True,
|
||||
"appName": "heartbeatEventAwaitedFlag",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_heartbeat):
|
||||
await async_wait_until(
|
||||
lambda: hb_listener.matching(hb_failed), "published failed event"
|
||||
)
|
||||
# Reconnect.
|
||||
await client.admin.command("ping")
|
||||
|
||||
hb_succeeded_events = hb_listener.matching(hb_succeeded)
|
||||
hb_failed_events = hb_listener.matching(hb_failed)
|
||||
self.assertFalse(hb_succeeded_events[0].awaited)
|
||||
self.assertTrue(hb_failed_events[0].awaited)
|
||||
# Depending on thread scheduling, the failed heartbeat could occur on
|
||||
# the second or third check.
|
||||
events = [type(e) for e in hb_listener.events[:4]]
|
||||
if events == [
|
||||
monitoring.ServerHeartbeatStartedEvent,
|
||||
monitoring.ServerHeartbeatSucceededEvent,
|
||||
monitoring.ServerHeartbeatStartedEvent,
|
||||
monitoring.ServerHeartbeatFailedEvent,
|
||||
]:
|
||||
self.assertFalse(hb_succeeded_events[1].awaited)
|
||||
else:
|
||||
self.assertTrue(hb_succeeded_events[1].awaited)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
56
test/asynchronous/test_transactions_unified.py
Normal file
56
test/asynchronous/test_transactions_unified.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the Transactions unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
@client_context.require_no_mmap
|
||||
def setUpModule():
|
||||
pass
|
||||
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
# Location of JSON test specifications for transactions-convenient-api.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified"
|
||||
)
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
99
test/asynchronous/test_unified_format.py
Normal file
99
test/asynchronous/test_unified_format.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import UnitTest, unittest
|
||||
from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(TEST_PATH, "valid-pass"),
|
||||
module=__name__,
|
||||
class_name_prefix="UnifiedTestFormat",
|
||||
expected_failures=[
|
||||
"Client side error in command starting transaction", # PYTHON-1894
|
||||
],
|
||||
RUN_ON_SERVERLESS=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(TEST_PATH, "valid-fail"),
|
||||
module=__name__,
|
||||
class_name_prefix="UnifiedTestFormat",
|
||||
bypass_test_generation_errors=True,
|
||||
expected_failures=[
|
||||
".*", # All tests expected to fail
|
||||
],
|
||||
RUN_ON_SERVERLESS=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestMatchEvaluatorUtil(UnitTest):
|
||||
def setUp(self):
|
||||
self.match_evaluator = MatchEvaluatorUtil(self)
|
||||
|
||||
def test_unsetOrMatches(self):
|
||||
spec: dict[str, Any] = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}
|
||||
for actual in [{}, {"y": 2}, None]:
|
||||
self.match_evaluator.match_result(spec, actual)
|
||||
|
||||
spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}}
|
||||
for actual in [{}, {"x": {}}, {"x": {"y": 2}}]:
|
||||
self.match_evaluator.match_result(spec, actual)
|
||||
|
||||
spec = {"y": {"$$unsetOrMatches": {"$$exists": True}}}
|
||||
self.match_evaluator.match_result(spec, {})
|
||||
self.match_evaluator.match_result(spec, {"y": 2})
|
||||
self.match_evaluator.match_result(spec, {"x": 1})
|
||||
self.match_evaluator.match_result(spec, {"y": {}})
|
||||
|
||||
def test_type(self):
|
||||
self.match_evaluator.match_result(
|
||||
{
|
||||
"operationType": "insert",
|
||||
"ns": {"db": "change-stream-tests", "coll": "test"},
|
||||
"fullDocument": {"_id": {"$$type": "objectId"}, "x": 1},
|
||||
},
|
||||
{
|
||||
"operationType": "insert",
|
||||
"fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1},
|
||||
"ns": {"db": "change-stream-tests", "coll": "test"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
86
test/asynchronous/test_versioned_api_integration.py
Normal file
86
test/asynchronous/test_versioned_api_integration.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2020-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import OvertCommandListener
|
||||
|
||||
from pymongo.server_api import ServerApi
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api")
|
||||
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
|
||||
class TestServerApiIntegration(AsyncIntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
RUN_ON_SERVERLESS = True
|
||||
|
||||
def assertServerApi(self, event):
|
||||
self.assertIn("apiVersion", event.command)
|
||||
self.assertEqual(event.command["apiVersion"], "1")
|
||||
|
||||
def assertServerApiInAllCommands(self, events):
|
||||
for event in events:
|
||||
self.assertServerApi(event)
|
||||
|
||||
@async_client_context.require_version_min(4, 7)
|
||||
async def test_command_options(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
server_api=ServerApi("1"), event_listeners=[listener]
|
||||
)
|
||||
coll = client.test.test
|
||||
await coll.insert_many([{} for _ in range(100)])
|
||||
self.addAsyncCleanup(coll.delete_many, {})
|
||||
await coll.find(batch_size=25).to_list()
|
||||
await client.admin.command("ping")
|
||||
self.assertServerApiInAllCommands(listener.started_events)
|
||||
|
||||
@async_client_context.require_version_min(4, 7)
|
||||
@async_client_context.require_transactions
|
||||
async def test_command_options_txn(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(
|
||||
server_api=ServerApi("1"), event_listeners=[listener]
|
||||
)
|
||||
coll = client.test.test
|
||||
await coll.insert_many([{} for _ in range(100)])
|
||||
self.addAsyncCleanup(coll.delete_many, {})
|
||||
|
||||
listener.reset()
|
||||
async with client.start_session() as s, await s.start_transaction():
|
||||
await coll.insert_many([{} for _ in range(100)], session=s)
|
||||
await coll.find(batch_size=25, session=s).to_list()
|
||||
await client.test.command("find", "test", session=s)
|
||||
self.assertServerApiInAllCommands(listener.started_events)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -35,6 +35,7 @@ from test.asynchronous import (
|
||||
client_knobs,
|
||||
unittest,
|
||||
)
|
||||
from test.asynchronous.utils_spec_runner import SpecRunnerTask
|
||||
from test.unified_format_shared import (
|
||||
KMS_TLS_OPTS,
|
||||
PLACEHOLDER_MAP,
|
||||
@ -58,7 +59,6 @@ from test.utils import (
|
||||
snake_to_camel,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
from test.version import Version
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
@ -382,8 +382,8 @@ class EntityMapUtil:
|
||||
return
|
||||
elif entity_type == "thread":
|
||||
name = spec["id"]
|
||||
thread = SpecRunnerThread(name)
|
||||
thread.start()
|
||||
thread = SpecRunnerTask(name)
|
||||
await thread.start()
|
||||
self[name] = thread
|
||||
return
|
||||
|
||||
@ -711,7 +711,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
return await target.command(**kwargs)
|
||||
|
||||
async def _databaseOperation_runCursorCommand(self, target, **kwargs):
|
||||
return list(await self._databaseOperation_createCommandCursor(target, **kwargs))
|
||||
return await (await self._databaseOperation_createCommandCursor(target, **kwargs)).to_list()
|
||||
|
||||
async def _databaseOperation_createCommandCursor(self, target, **kwargs):
|
||||
self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase)
|
||||
@ -1177,16 +1177,16 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
|
||||
wait_until(primary_changed, "change primary", timeout=timeout)
|
||||
|
||||
def _testOperation_runOnThread(self, spec):
|
||||
async def _testOperation_runOnThread(self, spec):
|
||||
"""Run the 'runOnThread' operation."""
|
||||
thread = self.entity_map[spec["thread"]]
|
||||
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
|
||||
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
|
||||
|
||||
def _testOperation_waitForThread(self, spec):
|
||||
async def _testOperation_waitForThread(self, spec):
|
||||
"""Run the 'waitForThread' operation."""
|
||||
thread = self.entity_map[spec["thread"]]
|
||||
thread.stop()
|
||||
thread.join(10)
|
||||
await thread.stop()
|
||||
await thread.join(10)
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))
|
||||
|
||||
@ -18,11 +18,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
CompareType,
|
||||
@ -47,6 +47,7 @@ from pymongo.asynchronous import client_session
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
|
||||
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import BulkWriteResult, _WriteResult
|
||||
@ -55,38 +56,36 @@ from pymongo.write_concern import WriteConcern
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class SpecRunnerThread(threading.Thread):
|
||||
class SpecRunnerTask(ConcurrentRunner):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
super().__init__(name=name)
|
||||
self.exc = None
|
||||
self.daemon = True
|
||||
self.cond = threading.Condition()
|
||||
self.cond = _async_create_condition(_async_create_lock())
|
||||
self.ops = []
|
||||
self.stopped = False
|
||||
|
||||
def schedule(self, work):
|
||||
async def schedule(self, work):
|
||||
self.ops.append(work)
|
||||
with self.cond:
|
||||
async with self.cond:
|
||||
self.cond.notify()
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
self.stopped = True
|
||||
with self.cond:
|
||||
async with self.cond:
|
||||
self.cond.notify()
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
while not self.stopped or self.ops:
|
||||
if not self.ops:
|
||||
with self.cond:
|
||||
self.cond.wait(10)
|
||||
async with self.cond:
|
||||
await _async_cond_wait(self.cond, 10)
|
||||
if self.ops:
|
||||
try:
|
||||
work = self.ops.pop(0)
|
||||
work()
|
||||
await work()
|
||||
except Exception as exc:
|
||||
self.exc = exc
|
||||
self.stop()
|
||||
await self.stop()
|
||||
|
||||
|
||||
class AsyncSpecTestCreator:
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import multiprocessing
|
||||
@ -30,6 +31,8 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo._asyncio_task import create_task
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -369,3 +372,38 @@ class SystemCertsPatcher:
|
||||
os.environ.pop("SSL_CERT_FILE")
|
||||
else:
|
||||
os.environ["SSL_CERT_FILE"] = self.original_certs
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
PARENT = threading.Thread
|
||||
else:
|
||||
PARENT = object
|
||||
|
||||
|
||||
class ConcurrentRunner(PARENT):
|
||||
def __init__(self, **kwargs):
|
||||
if _IS_SYNC:
|
||||
super().__init__(**kwargs)
|
||||
self.name = kwargs.get("name", "ConcurrentRunner")
|
||||
self.stopped = False
|
||||
self.task = None
|
||||
self.target = kwargs.get("target", None)
|
||||
self.args = kwargs.get("args", [])
|
||||
|
||||
if not _IS_SYNC:
|
||||
|
||||
def start(self):
|
||||
self.task = create_task(self.run(), name=self.name)
|
||||
|
||||
def join(self, timeout: float | None = 0): # type: ignore[override]
|
||||
if self.task is not None:
|
||||
asyncio.wait([self.task], timeout=timeout)
|
||||
|
||||
def is_alive(self):
|
||||
return not self.stopped
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.target(*self.args)
|
||||
finally:
|
||||
self.stopped = True
|
||||
|
||||
@ -2399,7 +2399,7 @@ class TestMongoClientFailover(MockClientTest):
|
||||
|
||||
# MongoClient discovers it's alone. The first attempt raises either
|
||||
# ServerSelectionTimeoutError or AutoReconnect (from
|
||||
# AsyncMockPool.get_socket).
|
||||
# MockPool.get_socket).
|
||||
with self.assertRaises(AutoReconnect):
|
||||
c.db.collection.find_one()
|
||||
|
||||
|
||||
@ -18,22 +18,35 @@ from __future__ import annotations
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
PyMongoTestCase,
|
||||
client_context,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import wait_until
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestDNSRepl(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set"
|
||||
)
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set"
|
||||
)
|
||||
load_balanced = False
|
||||
|
||||
@client_context.require_replica_set
|
||||
@ -42,9 +55,14 @@ class TestDNSRepl(PyMongoTestCase):
|
||||
|
||||
|
||||
class TestDNSLoadBalanced(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced"
|
||||
)
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced"
|
||||
)
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced"
|
||||
)
|
||||
load_balanced = True
|
||||
|
||||
@client_context.require_load_balancer
|
||||
@ -53,7 +71,12 @@ class TestDNSLoadBalanced(PyMongoTestCase):
|
||||
|
||||
|
||||
class TestDNSSharded(PyMongoTestCase):
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded"
|
||||
)
|
||||
load_balanced = False
|
||||
|
||||
@client_context.require_mongos
|
||||
@ -119,7 +142,9 @@ def create_test(test_case):
|
||||
# tests.
|
||||
copts["tlsAllowInvalidHostnames"] = True
|
||||
|
||||
client = PyMongoTestCase.unmanaged_simple_client(uri, **copts)
|
||||
client = self.simple_client(uri, **copts)
|
||||
if client._options.connect:
|
||||
client._connect()
|
||||
if num_seeds is not None:
|
||||
self.assertEqual(len(client._topology_settings.seeds), num_seeds)
|
||||
if hosts is not None:
|
||||
@ -132,7 +157,6 @@ def create_test(test_case):
|
||||
client.admin.command("ping")
|
||||
# XXX: we should block until SRV poller runs at least once
|
||||
# and re-run these assertions.
|
||||
client.close()
|
||||
else:
|
||||
try:
|
||||
parse_uri(uri)
|
||||
@ -188,7 +212,6 @@ class TestParsingErrors(PyMongoTestCase):
|
||||
class TestCaseInsensitive(IntegrationTest):
|
||||
def test_connect_case_insensitive(self):
|
||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
self.addCleanup(client.close)
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
|
||||
@ -15,9 +15,13 @@
|
||||
"""MongoDB documentation examples in Python."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import functools
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from test.helpers import ConcurrentRunner
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -29,8 +33,11 @@ from pymongo.errors import ConnectionFailure, OperationFailure
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestSampleShellCommands(IntegrationTest):
|
||||
def setUp(self):
|
||||
@ -62,7 +69,7 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"item": "canvas"})
|
||||
# End Example 2
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 3
|
||||
db.inventory.insert_many(
|
||||
@ -137,31 +144,31 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({})
|
||||
# End Example 7
|
||||
|
||||
self.assertEqual(len(list(cursor)), 5)
|
||||
self.assertEqual(len(cursor.to_list()), 5)
|
||||
|
||||
# Start Example 9
|
||||
cursor = db.inventory.find({"status": "D"})
|
||||
# End Example 9
|
||||
|
||||
self.assertEqual(len(list(cursor)), 2)
|
||||
self.assertEqual(len(cursor.to_list()), 2)
|
||||
|
||||
# Start Example 10
|
||||
cursor = db.inventory.find({"status": {"$in": ["A", "D"]}})
|
||||
# End Example 10
|
||||
|
||||
self.assertEqual(len(list(cursor)), 5)
|
||||
self.assertEqual(len(cursor.to_list()), 5)
|
||||
|
||||
# Start Example 11
|
||||
cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}})
|
||||
# End Example 11
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 12
|
||||
cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]})
|
||||
# End Example 12
|
||||
|
||||
self.assertEqual(len(list(cursor)), 3)
|
||||
self.assertEqual(len(cursor.to_list()), 3)
|
||||
|
||||
# Start Example 13
|
||||
cursor = db.inventory.find(
|
||||
@ -169,7 +176,7 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
)
|
||||
# End Example 13
|
||||
|
||||
self.assertEqual(len(list(cursor)), 2)
|
||||
self.assertEqual(len(cursor.to_list()), 2)
|
||||
|
||||
def test_query_embedded_documents(self):
|
||||
db = self.db
|
||||
@ -219,31 +226,31 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])})
|
||||
# End Example 15
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 16
|
||||
cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])})
|
||||
# End Example 16
|
||||
|
||||
self.assertEqual(len(list(cursor)), 0)
|
||||
self.assertEqual(len(cursor.to_list()), 0)
|
||||
|
||||
# Start Example 17
|
||||
cursor = db.inventory.find({"size.uom": "in"})
|
||||
# End Example 17
|
||||
|
||||
self.assertEqual(len(list(cursor)), 2)
|
||||
self.assertEqual(len(cursor.to_list()), 2)
|
||||
|
||||
# Start Example 18
|
||||
cursor = db.inventory.find({"size.h": {"$lt": 15}})
|
||||
# End Example 18
|
||||
|
||||
self.assertEqual(len(list(cursor)), 4)
|
||||
self.assertEqual(len(cursor.to_list()), 4)
|
||||
|
||||
# Start Example 19
|
||||
cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"})
|
||||
# End Example 19
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
def test_query_arrays(self):
|
||||
db = self.db
|
||||
@ -269,49 +276,49 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"tags": ["red", "blank"]})
|
||||
# End Example 21
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 22
|
||||
cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}})
|
||||
# End Example 22
|
||||
|
||||
self.assertEqual(len(list(cursor)), 4)
|
||||
self.assertEqual(len(cursor.to_list()), 4)
|
||||
|
||||
# Start Example 23
|
||||
cursor = db.inventory.find({"tags": "red"})
|
||||
# End Example 23
|
||||
|
||||
self.assertEqual(len(list(cursor)), 4)
|
||||
self.assertEqual(len(cursor.to_list()), 4)
|
||||
|
||||
# Start Example 24
|
||||
cursor = db.inventory.find({"dim_cm": {"$gt": 25}})
|
||||
# End Example 24
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 25
|
||||
cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}})
|
||||
# End Example 25
|
||||
|
||||
self.assertEqual(len(list(cursor)), 4)
|
||||
self.assertEqual(len(cursor.to_list()), 4)
|
||||
|
||||
# Start Example 26
|
||||
cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}})
|
||||
# End Example 26
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 27
|
||||
cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}})
|
||||
# End Example 27
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 28
|
||||
cursor = db.inventory.find({"tags": {"$size": 3}})
|
||||
# End Example 28
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
def test_query_array_of_documents(self):
|
||||
db = self.db
|
||||
@ -360,49 +367,49 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])})
|
||||
# End Example 30
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 31
|
||||
cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])})
|
||||
# End Example 31
|
||||
|
||||
self.assertEqual(len(list(cursor)), 0)
|
||||
self.assertEqual(len(cursor.to_list()), 0)
|
||||
|
||||
# Start Example 32
|
||||
cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}})
|
||||
# End Example 32
|
||||
|
||||
self.assertEqual(len(list(cursor)), 3)
|
||||
self.assertEqual(len(cursor.to_list()), 3)
|
||||
|
||||
# Start Example 33
|
||||
cursor = db.inventory.find({"instock.qty": {"$lte": 20}})
|
||||
# End Example 33
|
||||
|
||||
self.assertEqual(len(list(cursor)), 5)
|
||||
self.assertEqual(len(cursor.to_list()), 5)
|
||||
|
||||
# Start Example 34
|
||||
cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}})
|
||||
# End Example 34
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 35
|
||||
cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}})
|
||||
# End Example 35
|
||||
|
||||
self.assertEqual(len(list(cursor)), 3)
|
||||
self.assertEqual(len(cursor.to_list()), 3)
|
||||
|
||||
# Start Example 36
|
||||
cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}})
|
||||
# End Example 36
|
||||
|
||||
self.assertEqual(len(list(cursor)), 4)
|
||||
self.assertEqual(len(cursor.to_list()), 4)
|
||||
|
||||
# Start Example 37
|
||||
cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"})
|
||||
# End Example 37
|
||||
|
||||
self.assertEqual(len(list(cursor)), 2)
|
||||
self.assertEqual(len(cursor.to_list()), 2)
|
||||
|
||||
def test_query_null(self):
|
||||
db = self.db
|
||||
@ -415,19 +422,19 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"item": None})
|
||||
# End Example 39
|
||||
|
||||
self.assertEqual(len(list(cursor)), 2)
|
||||
self.assertEqual(len(cursor.to_list()), 2)
|
||||
|
||||
# Start Example 40
|
||||
cursor = db.inventory.find({"item": {"$type": 10}})
|
||||
# End Example 40
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
# Start Example 41
|
||||
cursor = db.inventory.find({"item": {"$exists": False}})
|
||||
# End Example 41
|
||||
|
||||
self.assertEqual(len(list(cursor)), 1)
|
||||
self.assertEqual(len(cursor.to_list()), 1)
|
||||
|
||||
def test_projection(self):
|
||||
db = self.db
|
||||
@ -473,7 +480,7 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
cursor = db.inventory.find({"status": "A"})
|
||||
# End Example 43
|
||||
|
||||
self.assertEqual(len(list(cursor)), 3)
|
||||
self.assertEqual(len(cursor.to_list()), 3)
|
||||
|
||||
# Start Example 44
|
||||
cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1})
|
||||
@ -746,8 +753,9 @@ class TestSampleShellCommands(IntegrationTest):
|
||||
while not done:
|
||||
db.inventory.insert_one({"username": "alice"})
|
||||
db.inventory.delete_one({"username": "alice"})
|
||||
time.sleep(0.005)
|
||||
|
||||
t = threading.Thread(target=insert_docs)
|
||||
t = ConcurrentRunner(target=insert_docs)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
@ -1347,20 +1355,37 @@ class TestSnapshotQueryExamples(IntegrationTest):
|
||||
db.drop_collection("dogs")
|
||||
db.cats.insert_one({"name": "Whiskers", "color": "white", "age": 10, "adoptable": True})
|
||||
db.dogs.insert_one({"name": "Pebbles", "color": "Brown", "age": 10, "adoptable": True})
|
||||
wait_until(lambda: self.check_for_snapshot(db.cats), "success")
|
||||
wait_until(lambda: self.check_for_snapshot(db.dogs), "success")
|
||||
|
||||
def predicate_one():
|
||||
return self.check_for_snapshot(db.cats)
|
||||
|
||||
def predicate_two():
|
||||
return self.check_for_snapshot(db.dogs)
|
||||
|
||||
wait_until(predicate_two, "success")
|
||||
wait_until(predicate_one, "success")
|
||||
|
||||
# Start Snapshot Query Example 1
|
||||
|
||||
db = client.pets
|
||||
with client.start_session(snapshot=True) as s:
|
||||
adoptablePetsCount = db.cats.aggregate(
|
||||
[{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}], session=s
|
||||
).next()["adoptableCatsCount"]
|
||||
adoptablePetsCount = (
|
||||
(
|
||||
db.cats.aggregate(
|
||||
[{"$match": {"adoptable": True}}, {"$count": "adoptableCatsCount"}],
|
||||
session=s,
|
||||
)
|
||||
).next()
|
||||
)["adoptableCatsCount"]
|
||||
|
||||
adoptablePetsCount += db.dogs.aggregate(
|
||||
[{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}], session=s
|
||||
).next()["adoptableDogsCount"]
|
||||
adoptablePetsCount += (
|
||||
(
|
||||
db.dogs.aggregate(
|
||||
[{"$match": {"adoptable": True}}, {"$count": "adoptableDogsCount"}],
|
||||
session=s,
|
||||
)
|
||||
).next()
|
||||
)["adoptableDogsCount"]
|
||||
|
||||
print(adoptablePetsCount)
|
||||
|
||||
@ -1371,33 +1396,41 @@ class TestSnapshotQueryExamples(IntegrationTest):
|
||||
|
||||
saleDate = datetime.datetime.now()
|
||||
db.sales.insert_one({"shoeType": "boot", "price": 30, "saleDate": saleDate})
|
||||
wait_until(lambda: self.check_for_snapshot(db.sales), "success")
|
||||
|
||||
def predicate_three():
|
||||
return self.check_for_snapshot(db.sales)
|
||||
|
||||
wait_until(predicate_three, "success")
|
||||
|
||||
# Start Snapshot Query Example 2
|
||||
db = client.retail
|
||||
with client.start_session(snapshot=True) as s:
|
||||
db.sales.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {
|
||||
"$gt": [
|
||||
"$saleDate",
|
||||
{
|
||||
"$dateSubtract": {
|
||||
"startDate": "$$NOW",
|
||||
"unit": "day",
|
||||
"amount": 1,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$count": "totalDailySales"},
|
||||
],
|
||||
session=s,
|
||||
).next()["totalDailySales"]
|
||||
_ = (
|
||||
(
|
||||
db.sales.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {
|
||||
"$gt": [
|
||||
"$saleDate",
|
||||
{
|
||||
"$dateSubtract": {
|
||||
"startDate": "$$NOW",
|
||||
"unit": "day",
|
||||
"amount": 1,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$count": "totalDailySales"},
|
||||
],
|
||||
session=s,
|
||||
)
|
||||
).next()
|
||||
)["totalDailySales"]
|
||||
|
||||
# End Snapshot Query Example 2
|
||||
|
||||
|
||||
@ -17,14 +17,20 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
@ -26,6 +26,8 @@ from pymongo.errors import ConnectionFailure
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.synchronous.monitor import Monitor
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestHeartbeatMonitoring(IntegrationTest):
|
||||
def create_mock_monitor(self, responses, uri, expected_results):
|
||||
@ -40,8 +42,12 @@ class TestHeartbeatMonitoring(IntegrationTest):
|
||||
raise responses[1]
|
||||
return Hello(responses[1]), 99
|
||||
|
||||
m = self.single_client(
|
||||
h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool
|
||||
_ = self.single_client(
|
||||
h=uri,
|
||||
event_listeners=(listener,),
|
||||
_monitor_class=MockMonitor,
|
||||
_pool_class=MockPool,
|
||||
connect=True,
|
||||
)
|
||||
|
||||
expected_len = len(expected_results)
|
||||
@ -50,20 +56,16 @@ class TestHeartbeatMonitoring(IntegrationTest):
|
||||
# of this test.
|
||||
wait_until(lambda: len(listener.events) >= expected_len, "publish all events")
|
||||
|
||||
try:
|
||||
# zip gives us len(expected_results) pairs.
|
||||
for expected, actual in zip(expected_results, listener.events):
|
||||
self.assertEqual(expected, actual.__class__.__name__)
|
||||
self.assertEqual(actual.connection_id, responses[0])
|
||||
if expected != "ServerHeartbeatStartedEvent":
|
||||
if isinstance(actual.reply, Hello):
|
||||
self.assertEqual(actual.duration, 99)
|
||||
self.assertEqual(actual.reply._doc, responses[1])
|
||||
else:
|
||||
self.assertEqual(actual.reply, responses[1])
|
||||
|
||||
finally:
|
||||
m.close()
|
||||
# zip gives us len(expected_results) pairs.
|
||||
for expected, actual in zip(expected_results, listener.events):
|
||||
self.assertEqual(expected, actual.__class__.__name__)
|
||||
self.assertEqual(actual.connection_id, responses[0])
|
||||
if expected != "ServerHeartbeatStartedEvent":
|
||||
if isinstance(actual.reply, Hello):
|
||||
self.assertEqual(actual.duration, 99)
|
||||
self.assertEqual(actual.reply._doc, responses[1])
|
||||
else:
|
||||
self.assertEqual(actual.reply, responses[1])
|
||||
|
||||
def test_standalone(self):
|
||||
responses = (
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
"""Run the auth spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
@ -27,16 +29,22 @@ sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, PyMongoTestCase, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
from test.utils import AllowListEventListener, EventListener, OvertCommandListener
|
||||
from test.utils import AllowListEventListener, OvertCommandListener
|
||||
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
pytestmark = pytest.mark.index_management
|
||||
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management")
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management")
|
||||
|
||||
_NAME = "test-search-index"
|
||||
|
||||
@ -82,23 +90,25 @@ class SearchIndexIntegrationBase(PyMongoTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
if not os.environ.get("TEST_INDEX_MANAGEMENT"):
|
||||
raise unittest.SkipTest("Skipping index management tests")
|
||||
url = os.environ.get("MONGODB_URI")
|
||||
username = os.environ["DB_USER"]
|
||||
password = os.environ["DB_PASSWORD"]
|
||||
cls.listener = listener = OvertCommandListener()
|
||||
cls.client = cls.unmanaged_simple_client(
|
||||
url, username=username, password=password, event_listeners=[listener]
|
||||
)
|
||||
cls.client.drop_database(_NAME)
|
||||
cls.db = cls.client[cls.db_name]
|
||||
cls.url = os.environ.get("MONGODB_URI")
|
||||
cls.username = os.environ["DB_USER"]
|
||||
cls.password = os.environ["DB_PASSWORD"]
|
||||
cls.listener = OvertCommandListener()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.client.drop_database(_NAME)
|
||||
cls.client.close()
|
||||
def setUp(self) -> None:
|
||||
self.client = self.simple_client(
|
||||
self.url,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
event_listeners=[self.listener],
|
||||
)
|
||||
self.client.drop_database(_NAME)
|
||||
self.db = self.client[self.db_name]
|
||||
|
||||
def tearDown(self):
|
||||
self.client.drop_database(_NAME)
|
||||
|
||||
def wait_for_ready(self, coll, name=_NAME, predicate=None):
|
||||
"""Wait for a search index to be ready."""
|
||||
@ -107,10 +117,9 @@ class SearchIndexIntegrationBase(PyMongoTestCase):
|
||||
predicate = lambda index: index.get("queryable") is True
|
||||
|
||||
while True:
|
||||
indices = list(coll.list_search_indexes(name))
|
||||
indices = (coll.list_search_indexes(name)).to_list()
|
||||
if len(indices) and predicate(indices[0]):
|
||||
return indices[0]
|
||||
break
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
@ -133,7 +142,7 @@ class TestSearchIndexIntegration(SearchIndexIntegrationBase):
|
||||
|
||||
# Get the index definition.
|
||||
self.listener.reset()
|
||||
coll0.list_search_indexes(name=implicit_search_resp, comment="foo").next()
|
||||
(coll0.list_search_indexes(name=implicit_search_resp, comment="foo")).next()
|
||||
event = self.listener.events[0]
|
||||
self.assertEqual(event.command["comment"], "foo")
|
||||
|
||||
@ -183,7 +192,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
)
|
||||
|
||||
# .Assert that the command returns an array containing the new indexes' names: ``["test-search-index-1", "test-search-index-2"]``.
|
||||
indices = list(coll0.list_search_indexes())
|
||||
indices = (coll0.list_search_indexes()).to_list()
|
||||
names = [i["name"] for i in indices]
|
||||
self.assertIn(name1, names)
|
||||
self.assertIn(name2, names)
|
||||
@ -223,7 +232,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
# Run ``coll0.listSearchIndexes()`` repeatedly every 5 seconds until ``listSearchIndexes`` returns an empty array.
|
||||
t0 = time.time()
|
||||
while True:
|
||||
indices = list(coll0.list_search_indexes())
|
||||
indices = (coll0.list_search_indexes()).to_list()
|
||||
if indices:
|
||||
break
|
||||
if (time.time() - t0) / 60 > 5:
|
||||
@ -259,7 +268,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
self.wait_for_ready(coll0, predicate=predicate)
|
||||
|
||||
# Assert that an index is present with the name ``test-search-index`` and the definition has a property ``latestDefinition`` whose value is ``{ 'mappings': { 'dynamic': true } }``.
|
||||
index = list(coll0.list_search_indexes(_NAME))[0]
|
||||
index = ((coll0.list_search_indexes(_NAME)).to_list())[0]
|
||||
self.assertIn("latestDefinition", index)
|
||||
self.assertEqual(index["latestDefinition"], model2["definition"])
|
||||
|
||||
@ -324,7 +333,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = coll0.list_search_indexes(name=implicit_search_resp).next()
|
||||
resp = (coll0.list_search_indexes(name=implicit_search_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"search"``.
|
||||
self.assertEqual(resp["type"], "search")
|
||||
@ -335,7 +344,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = coll0.list_search_indexes(name=explicit_search_resp).next()
|
||||
resp = (coll0.list_search_indexes(name=explicit_search_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"search"``.
|
||||
self.assertEqual(resp["type"], "search")
|
||||
@ -350,7 +359,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase):
|
||||
)
|
||||
|
||||
# Get the index definition.
|
||||
resp = coll0.list_search_indexes(name=explicit_vector_resp).next()
|
||||
resp = (coll0.list_search_indexes(name=explicit_vector_resp)).next()
|
||||
|
||||
# Assert that the index model contains the correct index type: ``"vectorSearch"``.
|
||||
self.assertEqual(resp["type"], "vectorSearch")
|
||||
|
||||
@ -21,13 +21,13 @@ import re
|
||||
import sys
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, MutableMapping, Tuple, Type
|
||||
from typing import Any, Tuple, Type
|
||||
|
||||
from bson.codec_options import CodecOptions, DatetimeConversion
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, unittest
|
||||
from test import unittest
|
||||
|
||||
from bson import EPOCH_AWARE, EPOCH_NAIVE, SON, DatetimeMS, json_util
|
||||
from bson.binary import (
|
||||
@ -636,24 +636,5 @@ class TestJsonUtil(unittest.TestCase):
|
||||
self.assertEqual(json_util.dumps(MyBinary(b"bin", USER_DEFINED_SUBTYPE)), expected_json)
|
||||
|
||||
|
||||
class TestJsonUtilRoundtrip(IntegrationTest):
|
||||
def test_cursor(self):
|
||||
db = self.db
|
||||
|
||||
db.drop_collection("test")
|
||||
docs: List[MutableMapping[str, Any]] = [
|
||||
{"foo": [1, 2]},
|
||||
{"bar": {"hello": "world"}},
|
||||
{"code": Code("function x() { return 1; }")},
|
||||
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
|
||||
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
|
||||
]
|
||||
|
||||
db.test.insert_many(docs)
|
||||
reloaded_docs = json_util.loads(json_util.dumps(db.test.find()))
|
||||
for doc in docs:
|
||||
self.assertTrue(doc in reloaded_docs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
28
test/test_json_util_integration.py
Normal file
28
test/test_json_util_integration.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from test import IntegrationTest
|
||||
from typing import Any, List, MutableMapping
|
||||
|
||||
from bson import Binary, Code, DBRef, ObjectId, json_util
|
||||
from bson.binary import USER_DEFINED_SUBTYPE
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestJsonUtilRoundtrip(IntegrationTest):
|
||||
def test_cursor(self):
|
||||
db = self.db
|
||||
|
||||
db.drop_collection("test")
|
||||
docs: List[MutableMapping[str, Any]] = [
|
||||
{"foo": [1, 2]},
|
||||
{"bar": {"hello": "world"}},
|
||||
{"code": Code("function x() { return 1; }")},
|
||||
{"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)},
|
||||
{"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}},
|
||||
]
|
||||
|
||||
db.test.insert_many(docs)
|
||||
reloaded_docs = json_util.loads(json_util.dumps((db.test.find()).to_list()))
|
||||
for doc in docs:
|
||||
self.assertTrue(doc in reloaded_docs)
|
||||
@ -15,10 +15,12 @@
|
||||
"""Test maxStalenessSeconds support."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import _Op
|
||||
@ -31,11 +33,16 @@ from test.utils_selection_tests import create_selection_tests
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "max_staleness")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "max_staleness")
|
||||
|
||||
|
||||
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
|
||||
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -26,18 +26,20 @@ sys.path[0:0] = [""]
|
||||
from test import IntegrationTest, client_context
|
||||
|
||||
from bson.codec_options import CodecOptions
|
||||
from pymongo.synchronous.encryption import _HAVE_PYMONGOCRYPT, ClientEncryption, EncryptionError
|
||||
from pymongo.synchronous.encryption import (
|
||||
_HAVE_PYMONGOCRYPT,
|
||||
ClientEncryption,
|
||||
EncryptionError,
|
||||
)
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
pytestmark = pytest.mark.csfle
|
||||
|
||||
|
||||
class TestonDemandGCPCredentials(IntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@client_context.require_version_min(4, 2, -1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.master_key = {
|
||||
@ -74,12 +76,8 @@ class TestonDemandGCPCredentials(IntegrationTest):
|
||||
|
||||
|
||||
class TestonDemandAzureCredentials(IntegrationTest):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
@client_context.require_version_min(4, 2, -1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.master_key = {
|
||||
|
||||
@ -27,6 +27,8 @@ from bson.son import SON
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.read_concern import ReadConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestReadConcern(IntegrationTest):
|
||||
listener: OvertCommandListener
|
||||
@ -71,14 +73,14 @@ class TestReadConcern(IntegrationTest):
|
||||
def test_find_command(self):
|
||||
# readConcern not sent in command if not specified.
|
||||
coll = self.db.coll
|
||||
tuple(coll.find({"field": "value"}))
|
||||
coll.find({"field": "value"}).to_list()
|
||||
self.assertNotIn("readConcern", self.listener.started_events[0].command)
|
||||
|
||||
self.listener.reset()
|
||||
|
||||
# Explicitly set readConcern to 'local'.
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
tuple(coll.find({"field": "value"}))
|
||||
coll.find({"field": "value"}).to_list()
|
||||
self.assertEqualCommand(
|
||||
SON(
|
||||
[
|
||||
@ -93,19 +95,19 @@ class TestReadConcern(IntegrationTest):
|
||||
def test_command_cursor(self):
|
||||
# readConcern not sent in command if not specified.
|
||||
coll = self.db.coll
|
||||
tuple(coll.aggregate([{"$match": {"field": "value"}}]))
|
||||
(coll.aggregate([{"$match": {"field": "value"}}])).to_list()
|
||||
self.assertNotIn("readConcern", self.listener.started_events[0].command)
|
||||
|
||||
self.listener.reset()
|
||||
|
||||
# Explicitly set readConcern to 'local'.
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
tuple(coll.aggregate([{"$match": {"field": "value"}}]))
|
||||
(coll.aggregate([{"$match": {"field": "value"}}])).to_list()
|
||||
self.assertEqual({"level": "local"}, self.listener.started_events[0].command["readConcern"])
|
||||
|
||||
def test_aggregate_out(self):
|
||||
coll = self.db.get_collection("coll", read_concern=ReadConcern("local"))
|
||||
tuple(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}]))
|
||||
(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])).to_list()
|
||||
|
||||
# Aggregate with $out supports readConcern MongoDB 4.2 onwards.
|
||||
if client_context.version >= (4, 1):
|
||||
|
||||
@ -26,7 +26,13 @@ from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, SkipTest, client_context, connected, unittest
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
SkipTest,
|
||||
client_context,
|
||||
connected,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
one,
|
||||
@ -49,16 +55,22 @@ from pymongo.read_preferences import (
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import Selection, readable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestSelections(IntegrationTest):
|
||||
@client_context.require_connection
|
||||
def test_bool(self):
|
||||
client = self.single_client()
|
||||
|
||||
wait_until(lambda: client.address, "discover primary")
|
||||
def predicate():
|
||||
return client.address
|
||||
|
||||
wait_until(predicate, "discover primary")
|
||||
selection = Selection.from_topology_description(client._topology.description)
|
||||
|
||||
self.assertTrue(selection)
|
||||
@ -88,11 +100,7 @@ class TestReadPreferenceObjects(unittest.TestCase):
|
||||
|
||||
|
||||
class TestReadPreferencesBase(IntegrationTest):
|
||||
@classmethod
|
||||
@client_context.require_secondaries_count(1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Insert some data so we can use cursors in read_from_which_host
|
||||
@ -123,11 +131,14 @@ class TestReadPreferencesBase(IntegrationTest):
|
||||
f"Cursor used address {address}, expected either primary "
|
||||
f"{client.primary} or secondaries {client.secondaries}"
|
||||
)
|
||||
return None
|
||||
|
||||
def assertReadsFrom(self, expected, **kwargs):
|
||||
c = self.rs_client(**kwargs)
|
||||
wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes")
|
||||
|
||||
def predicate():
|
||||
return len(c.nodes - c.arbiters) == client_context.w
|
||||
|
||||
wait_until(predicate, "discovered all nodes")
|
||||
|
||||
used = self.read_from_which_kind(c)
|
||||
self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}")
|
||||
@ -150,7 +161,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase):
|
||||
|
||||
# Test find and find_one.
|
||||
self.assertIsNotNone(coll.find_one())
|
||||
self.assertEqual(10, len(list(coll.find())))
|
||||
self.assertEqual(10, len(coll.find().to_list()))
|
||||
|
||||
# Test some database helpers.
|
||||
self.assertIsNotNone(db.list_collection_names())
|
||||
@ -173,20 +184,22 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
ReadPreference.SECONDARY_PREFERRED,
|
||||
ReadPreference.NEAREST,
|
||||
):
|
||||
self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference)
|
||||
self.assertEqual(mode, (self.rs_client(read_preference=mode)).read_preference)
|
||||
|
||||
self.assertRaises(TypeError, self.rs_client, read_preference="foo")
|
||||
with self.assertRaises(TypeError):
|
||||
self.rs_client(read_preference="foo")
|
||||
|
||||
def test_tag_sets_validation(self):
|
||||
S = Secondary(tag_sets=[{}])
|
||||
self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets)
|
||||
self.assertEqual([{}], (self.rs_client(read_preference=S)).read_preference.tag_sets)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}])
|
||||
self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets)
|
||||
self.assertEqual([{"k": "v"}], (self.rs_client(read_preference=S)).read_preference.tag_sets)
|
||||
|
||||
S = Secondary(tag_sets=[{"k": "v"}, {}])
|
||||
self.assertEqual(
|
||||
[{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets
|
||||
[{"k": "v"}, {}],
|
||||
(self.rs_client(read_preference=S)).read_preference.tag_sets,
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, Secondary, tag_sets=[])
|
||||
@ -200,22 +213,27 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
|
||||
def test_threshold_validation(self):
|
||||
self.assertEqual(
|
||||
17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms
|
||||
17,
|
||||
(self.rs_client(localThresholdMS=17, connect=False)).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms
|
||||
42,
|
||||
(self.rs_client(localThresholdMS=42, connect=False)).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms
|
||||
666,
|
||||
(self.rs_client(localThresholdMS=666, connect=False)).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms
|
||||
0,
|
||||
(self.rs_client(localThresholdMS=0, connect=False)).options.local_threshold_ms,
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, self.rs_client, localthresholdms=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
self.rs_client(localthresholdms=-1)
|
||||
|
||||
def test_zero_latency(self):
|
||||
ping_times: set = set()
|
||||
@ -238,7 +256,8 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
|
||||
def test_primary_with_tags(self):
|
||||
# Tags not allowed with PRIMARY
|
||||
self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}])
|
||||
with self.assertRaises(ConfigurationError):
|
||||
self.rs_client(tag_sets=[{"dc": "ny"}])
|
||||
|
||||
def test_primary_preferred(self):
|
||||
self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED)
|
||||
@ -272,7 +291,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
not_used = data_members.difference(used)
|
||||
latencies = ", ".join(
|
||||
"%s: %sms" % (server.description.address, server.description.round_trip_time)
|
||||
for server in c._get_topology().select_servers(readable_server_selector, _Op.TEST)
|
||||
for server in (c._get_topology()).select_servers(readable_server_selector, _Op.TEST)
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
@ -289,12 +308,9 @@ class ReadPrefTester(MongoClient):
|
||||
client_options.update(kwargs)
|
||||
super().__init__(*args, **client_options)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _conn_for_reads(self, read_preference, session, operation):
|
||||
context = super()._conn_for_reads(read_preference, session, operation)
|
||||
with context as (conn, read_preference):
|
||||
self.record_a_read(conn.address)
|
||||
yield conn, read_preference
|
||||
return context
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _conn_from_server(self, read_preference, server, session):
|
||||
@ -304,7 +320,7 @@ class ReadPrefTester(MongoClient):
|
||||
yield conn, read_preference
|
||||
|
||||
def record_a_read(self, address):
|
||||
server = self._get_topology().select_server_by_address(address, _Op.TEST, 0)
|
||||
server = (self._get_topology()).select_server_by_address(address, _Op.TEST, 0)
|
||||
self.has_read_from.add(server)
|
||||
|
||||
|
||||
@ -321,25 +337,23 @@ class TestCommandAndReadPreference(IntegrationTest):
|
||||
c: ReadPrefTester
|
||||
client_version: Version
|
||||
|
||||
@classmethod
|
||||
@client_context.require_secondaries_count(1)
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.c = ReadPrefTester(
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.c = ReadPrefTester(
|
||||
# Ignore round trip times, to test ReadPreference modes only.
|
||||
localThresholdMS=1000 * 1000,
|
||||
)
|
||||
cls.client_version = Version.from_client(cls.c)
|
||||
self.client_version = Version.from_client(self.c)
|
||||
# mapReduce fails if the collection does not exist.
|
||||
coll = cls.c.pymongo_test.get_collection(
|
||||
coll = self.c.pymongo_test.get_collection(
|
||||
"test", write_concern=WriteConcern(w=client_context.w)
|
||||
)
|
||||
coll.insert_one({})
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.c.drop_database("pymongo_test")
|
||||
cls.c.close()
|
||||
def tearDown(self):
|
||||
self.c.drop_database("pymongo_test")
|
||||
self.c.close()
|
||||
|
||||
def executed_on_which_server(self, client, fn, *args, **kwargs):
|
||||
"""Execute fn(*args, **kwargs) and return the Server instance used."""
|
||||
@ -366,7 +380,7 @@ class TestCommandAndReadPreference(IntegrationTest):
|
||||
break
|
||||
|
||||
assert self.c.primary is not None
|
||||
unused = self.c.secondaries.union({self.c.primary}).difference(used)
|
||||
unused = (self.c.secondaries).union({self.c.primary}).difference(used)
|
||||
if unused:
|
||||
self.fail("Some members not used for NEAREST: %s" % (unused))
|
||||
else:
|
||||
@ -401,11 +415,12 @@ class TestCommandAndReadPreference(IntegrationTest):
|
||||
def test_create_collection(self):
|
||||
# create_collection runs listCollections on the primary to check if
|
||||
# the collection already exists.
|
||||
self._test_primary_helper(
|
||||
lambda: self.c.pymongo_test.create_collection(
|
||||
def func():
|
||||
return self.c.pymongo_test.create_collection(
|
||||
"some_collection%s" % random.randint(0, sys.maxsize)
|
||||
)
|
||||
)
|
||||
|
||||
self._test_primary_helper(func)
|
||||
|
||||
def test_count_documents(self):
|
||||
self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {})
|
||||
@ -545,7 +560,6 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
cases["secondary"] = Secondary
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command("ping")
|
||||
for _mode, cls in cases.items():
|
||||
pref = cls(hedge={"enabled": True})
|
||||
@ -645,10 +659,10 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
# tell what shard member a query ran on.
|
||||
for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()):
|
||||
qcoll = coll.with_options(read_preference=pref)
|
||||
results = list(qcoll.find().sort([("_id", 1)]))
|
||||
results = qcoll.find().sort([("_id", 1)]).to_list()
|
||||
self.assertEqual(first_id, results[0]["_id"])
|
||||
self.assertEqual(last_id, results[-1]["_id"])
|
||||
results = list(qcoll.find().sort([("_id", -1)]))
|
||||
results = qcoll.find().sort([("_id", -1)]).to_list()
|
||||
self.assertEqual(first_id, results[-1]["_id"])
|
||||
self.assertEqual(last_id, results[0]["_id"])
|
||||
|
||||
@ -671,14 +685,14 @@ class TestMongosAndReadPreference(IntegrationTest):
|
||||
else:
|
||||
self.fail("mongos accepted invalid staleness")
|
||||
|
||||
coll = self.single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=120
|
||||
coll = (
|
||||
self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=120)
|
||||
).pymongo_test.test
|
||||
# No error
|
||||
coll.find_one()
|
||||
|
||||
coll = self.single_client(
|
||||
readPreference="secondaryPreferred", maxStalenessSeconds=10
|
||||
coll = (
|
||||
self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=10)
|
||||
).pymongo_test.test
|
||||
try:
|
||||
coll.find_one()
|
||||
|
||||
@ -19,6 +19,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
@ -39,7 +40,13 @@ from pymongo.read_concern import ReadConcern
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern")
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern")
|
||||
|
||||
|
||||
class TestReadWriteConcernSpec(IntegrationTest):
|
||||
@ -47,7 +54,6 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
listener = OvertCommandListener()
|
||||
# Client with default readConcern and writeConcern
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
collection.insert_many([{} for _ in range(10)])
|
||||
@ -66,9 +72,12 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
"insert", "collection", documents=[{}], write_concern=WriteConcern()
|
||||
)
|
||||
|
||||
def aggregate_op():
|
||||
(collection.aggregate([])).to_list()
|
||||
|
||||
ops = [
|
||||
("aggregate", lambda: list(collection.aggregate([]))),
|
||||
("find", lambda: list(collection.find())),
|
||||
("aggregate", aggregate_op),
|
||||
("find", lambda: collection.find().to_list()),
|
||||
("insert_one", lambda: collection.insert_one({})),
|
||||
("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})),
|
||||
("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})),
|
||||
@ -207,7 +216,6 @@ class TestReadWriteConcernSpec(IntegrationTest):
|
||||
def test_write_error_details_exposes_errinfo(self):
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
self.addCleanup(client.close)
|
||||
db = client.errinfotest
|
||||
self.addCleanup(client.drop_database, "errinfotest")
|
||||
validator = {"x": {"$type": "string"}}
|
||||
@ -286,7 +294,7 @@ def create_document_test(test_case):
|
||||
|
||||
|
||||
def create_tests():
|
||||
for dirpath, _, filenames in os.walk(_TEST_PATH):
|
||||
for dirpath, _, filenames in os.walk(TEST_PATH):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
if dirname == "operation":
|
||||
@ -321,7 +329,7 @@ create_tests()
|
||||
# PyMongo does not support MapReduce.
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(_TEST_PATH, "operation"),
|
||||
os.path.join(TEST_PATH, "operation"),
|
||||
module=__name__,
|
||||
expected_failures=["MapReduce .*"],
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Test the Retryable Reads unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = Path(__file__).parent / "retryable_reads/unified"
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
# PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects.
|
||||
|
||||
@ -17,14 +17,20 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "unified")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
@ -1,15 +1,37 @@
|
||||
# Copyright 2024-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run Command unified tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "run_command")
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(_TEST_PATH, "unified"),
|
||||
os.path.join(TEST_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
@ -17,19 +17,25 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection_logging")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
_TEST_PATH,
|
||||
TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
@ -18,18 +18,24 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test import PyMongoTestCase, unittest
|
||||
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection/rtt")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection/rtt")
|
||||
|
||||
|
||||
class TestAllScenarios(unittest.TestCase):
|
||||
class TestAllScenarios(PyMongoTestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -49,7 +55,7 @@ def create_test(scenario_def):
|
||||
|
||||
|
||||
def create_tests():
|
||||
for dirpath, _, filenames in os.walk(_TEST_PATH):
|
||||
for dirpath, _, filenames in os.walk(TEST_PATH):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
|
||||
@ -17,14 +17,21 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions")
|
||||
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
@ -15,8 +15,9 @@
|
||||
"""Run the SRV support tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from time import sleep
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -28,7 +29,8 @@ import pymongo
|
||||
from pymongo import common
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.srv_resolver import _have_dnspython
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
WAIT_TIME = 0.1
|
||||
|
||||
@ -168,6 +170,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
# Patch timeouts to ensure short test running times.
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
client._connect()
|
||||
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
|
||||
# Patch list of hosts returned by DNS query.
|
||||
with SrvPollingKnobs(
|
||||
@ -232,6 +235,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
):
|
||||
# Client uses unpatched method to get initial nodelist
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
client._connect()
|
||||
# Invalid DNS resolver response should not change nodelist.
|
||||
self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
|
||||
|
||||
@ -265,6 +269,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
|
||||
client._connect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
self.assert_nodelist_change(response, client)
|
||||
|
||||
@ -279,6 +284,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
client._connect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
self.assert_nodelist_change(response, client)
|
||||
|
||||
@ -294,8 +300,9 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
|
||||
client._connect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
|
||||
time.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
|
||||
final_topology = set(client.topology_description.server_descriptions())
|
||||
self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology)
|
||||
self.assertEqual(len(final_topology), 2)
|
||||
@ -303,8 +310,9 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
def test_does_not_flipflop(self):
|
||||
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
|
||||
client._connect()
|
||||
old = set(client.topology_description.server_descriptions())
|
||||
sleep(4 * WAIT_TIME)
|
||||
time.sleep(4 * WAIT_TIME)
|
||||
new = set(client.topology_description.server_descriptions())
|
||||
self.assertSetEqual(old, new)
|
||||
|
||||
@ -322,6 +330,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
client = self.simple_client(
|
||||
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
|
||||
)
|
||||
client._connect()
|
||||
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
|
||||
self.assert_nodelist_change(response, client)
|
||||
|
||||
@ -337,9 +346,9 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
nodelist_callback=resolver_response,
|
||||
):
|
||||
client = self.simple_client(self.CONNECTION_STRING)
|
||||
self.assertRaises(
|
||||
AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2
|
||||
)
|
||||
client._connect()
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2)
|
||||
|
||||
def test_import_dns_resolver(self):
|
||||
# Regression test for PYTHON-4407
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import socket
|
||||
import sys
|
||||
|
||||
@ -65,7 +66,13 @@ except ImportError:
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
|
||||
_IS_SYNC = True
|
||||
|
||||
if _IS_SYNC:
|
||||
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "certificates")
|
||||
else:
|
||||
CERT_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "certificates")
|
||||
|
||||
CLIENT_PEM = os.path.join(CERT_PATH, "client.pem")
|
||||
CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem")
|
||||
CA_PEM = os.path.join(CERT_PATH, "ca.pem")
|
||||
@ -144,21 +151,18 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
coll.drop()
|
||||
coll.insert_one({"ssl": True})
|
||||
self.assertTrue(coll.find_one()["ssl"])
|
||||
self.assertTrue((coll.find_one())["ssl"])
|
||||
coll.drop()
|
||||
|
||||
@classmethod
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# MongoClient should connect to the primary by default.
|
||||
cls.saved_port = MongoClient.PORT
|
||||
self.saved_port = MongoClient.PORT
|
||||
MongoClient.PORT = client_context.port
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
MongoClient.PORT = cls.saved_port
|
||||
super().tearDownClass()
|
||||
def tearDown(self):
|
||||
MongoClient.PORT = self.saved_port
|
||||
|
||||
@client_context.require_tls
|
||||
def test_simple_ssl(self):
|
||||
@ -548,7 +552,6 @@ class TestSSL(IntegrationTest):
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
self.addCleanup(noauth.close)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
noauth.pymongo_test.test.find_one()
|
||||
@ -562,7 +565,6 @@ class TestSSL(IntegrationTest):
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
event_listeners=[listener],
|
||||
)
|
||||
self.addCleanup(auth.close)
|
||||
|
||||
# No error
|
||||
auth.pymongo_test.test.find_one()
|
||||
@ -581,7 +583,6 @@ class TestSSL(IntegrationTest):
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# No error
|
||||
client.pymongo_test.test.find_one()
|
||||
|
||||
@ -589,7 +590,6 @@ class TestSSL(IntegrationTest):
|
||||
client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# No error
|
||||
client.pymongo_test.test.find_one()
|
||||
# Auth should fail if username and certificate do not match
|
||||
@ -602,7 +602,6 @@ class TestSSL(IntegrationTest):
|
||||
bad_client = self.simple_client(
|
||||
uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM
|
||||
)
|
||||
self.addCleanup(bad_client.close)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
bad_client.pymongo_test.test.find_one()
|
||||
@ -615,7 +614,6 @@ class TestSSL(IntegrationTest):
|
||||
tlsAllowInvalidCertificates=True,
|
||||
tlsCertificateKeyFile=CLIENT_PEM,
|
||||
)
|
||||
self.addCleanup(bad_client.close)
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
bad_client.pymongo_test.test.find_one()
|
||||
|
||||
@ -30,6 +30,8 @@ from test.utils import (
|
||||
from pymongo import monitoring
|
||||
from pymongo.hello import HelloCompat
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestStreamingProtocol(IntegrationTest):
|
||||
@client_context.require_failCommand_appName
|
||||
@ -41,7 +43,6 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="failingHeartbeatTest",
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# Force a connection.
|
||||
client.admin.command("ping")
|
||||
address = client.address
|
||||
@ -78,7 +79,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
def rediscovered():
|
||||
return len(listener.matching(_discovered_node)) >= 1
|
||||
|
||||
# Topology events are published asynchronously
|
||||
# Topology events are not published synchronously
|
||||
wait_until(marked_unknown, "mark node unknown")
|
||||
wait_until(rediscovered, "rediscover node")
|
||||
|
||||
@ -108,7 +109,6 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
client = self.rs_or_single_client(
|
||||
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# Force a connection.
|
||||
client.admin.command("ping")
|
||||
address = client.address
|
||||
@ -156,7 +156,6 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
client = self.single_client(
|
||||
appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# Force a connection.
|
||||
client.admin.command("ping")
|
||||
duration = time.time() - start
|
||||
@ -183,7 +182,6 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
heartbeatFrequencyMS=500,
|
||||
appName="heartbeatEventAwaitedFlag",
|
||||
)
|
||||
self.addCleanup(client.close)
|
||||
# Force a connection.
|
||||
client.admin.command("ping")
|
||||
|
||||
|
||||
@ -17,12 +17,15 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@client_context.require_no_mmap
|
||||
def setUpModule():
|
||||
@ -30,15 +33,21 @@ def setUpModule():
|
||||
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "unified")
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
# Location of JSON test specifications for transactions-convenient-api.
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api", "unified"
|
||||
)
|
||||
if _IS_SYNC:
|
||||
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified")
|
||||
else:
|
||||
TEST_PATH = os.path.join(
|
||||
Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified"
|
||||
)
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user