From 3b2151760851c1c6cd09a62ab4126cb9580d731e Mon Sep 17 00:00:00 2001 From: Jib Date: Mon, 16 Sep 2024 22:23:09 -0400 Subject: [PATCH 1/6] PYTHON-4752 Migrate docs links to Internal Docs Where Possible (#1715) Co-authored-by: Steven Silvester --- CONTRIBUTING.md | 11 ++++------- README.md | 2 +- bson/_cbsonmodule.c | 2 +- bson/datetime_ms.py | 2 +- doc/index.rst | 7 +++++-- pymongo/asynchronous/topology.py | 5 +++-- pymongo/synchronous/topology.py | 5 +++-- pyproject.toml | 2 +- 8 files changed, 19 insertions(+), 17 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 42cc8dc1b..2c2a5f431 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -163,13 +163,10 @@ hatch run lint:build-manual ## Documentation -To contribute to the [API -documentation](https://pymongo.readthedocs.io/en/stable/) just make your -changes to the inline documentation of the appropriate [source -code](https://github.com/mongodb/mongo-python-driver) or [rst -file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in -a branch and submit a [pull -request](https://help.github.com/articles/using-pull-requests). You +To contribute to the [API documentation](https://pymongo.readthedocs.io/en/stable/) just make your +changes to the inline documentation of the appropriate [source code](https://github.com/mongodb/mongo-python-driver) or +[rst file](https://github.com/mongodb/mongo-python-driver/tree/master/doc) in +a branch and submit a [pull request](https://help.github.com/articles/using-pull-requests). You might also use the GitHub [Edit](https://github.com/blog/844-forking-with-the-edit-button) button. diff --git a/README.md b/README.md index bb773b795..1076b6637 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI Version](https://img.shields.io/pypi/v/pymongo)](https://pypi.org/project/pymongo) [![Python Versions](https://img.shields.io/pypi/pyversions/pymongo)](https://pypi.org/project/pymongo) [![Monthly Downloads](https://static.pepy.tech/badge/pymongo/month)](https://pepy.tech/project/pymongo) -[![Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/?badge=stable) +[![API Documentation Status](https://readthedocs.org/projects/pymongo/badge/?version=stable)](http://pymongo.readthedocs.io/en/stable/api?badge=stable) ## About diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 3e9d5ecc2..34b407b94 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -306,7 +306,7 @@ static PyObject* datetime_from_millis(long long millis) { if (evalue) { PyObject* err_msg = PyObject_Str(evalue); if (err_msg) { - PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes"); + PyObject* appendage = PyUnicode_FromString(" (Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO) or MongoClient(datetime_conversion='DATETIME_AUTO')). See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes"); if (appendage) { PyObject* msg = PyUnicode_Concat(err_msg, appendage); if (msg) { diff --git a/bson/datetime_ms.py b/bson/datetime_ms.py index 1b6fa2279..679524cb6 100644 --- a/bson/datetime_ms.py +++ b/bson/datetime_ms.py @@ -31,7 +31,7 @@ EPOCH_NAIVE = EPOCH_AWARE.replace(tzinfo=None) _DATETIME_ERROR_SUGGESTION = ( "(Consider Using CodecOptions(datetime_conversion=DATETIME_AUTO)" " or MongoClient(datetime_conversion='DATETIME_AUTO'))." - " See: https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes" + " See: https://www.mongodb.com/docs/languages/python/pymongo-driver/current/data-formats/dates-and-times/#handling-out-of-range-datetimes" ) diff --git a/doc/index.rst b/doc/index.rst index 71e142381..0ac8bdec6 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,6 +1,11 @@ PyMongo |release| Documentation =============================== +.. note:: The PyMongo documentation has been migrated to the + `MongoDB Documentation site `_. + As of PyMongo 4.10, the ReadTheDocs site will contain the detailed changelog and API docs, while the + rest of the documentation will only appear on the MongoDB Documentation site. + Overview -------- **PyMongo** is a Python distribution containing tools for working with @@ -95,8 +100,6 @@ pull request. Changes ------- See the :doc:`changelog` for a full list of changes to PyMongo. -For older versions of the documentation please see the -`archive list `_. About This Documentation ------------------------ diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 9dd1a1c76..4e778cbc1 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -227,8 +227,9 @@ class Topology: warnings.warn( # type: ignore[call-overload] # noqa: B028 "AsyncMongoClient opened before fork. May not be entirely fork-safe, " "proceed with caution. See PyMongo's documentation for details: " - "https://pymongo.readthedocs.io/en/stable/faq.html#" - "is-pymongo-fork-safe", + "https://www.mongodb.com/docs/languages/" + "python/pymongo-driver/current/faq/" + "#is-pymongo-fork-safe-", **kwargs, ) async with self._lock: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 414865154..e8070e30a 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -227,8 +227,9 @@ class Topology: warnings.warn( # type: ignore[call-overload] # noqa: B028 "MongoClient opened before fork. May not be entirely fork-safe, " "proceed with caution. See PyMongo's documentation for details: " - "https://pymongo.readthedocs.io/en/stable/faq.html#" - "is-pymongo-fork-safe", + "https://www.mongodb.com/docs/languages/" + "python/pymongo-driver/current/faq/" + "#is-pymongo-fork-safe-", **kwargs, ) with self._lock: diff --git a/pyproject.toml b/pyproject.toml index b64c7d603..2df172fde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ classifiers = [ [project.urls] Homepage = "https://www.mongodb.org" -Documentation = "https://pymongo.readthedocs.io" +Documentation = "https://www.mongodb.com/docs/languages/python/pymongo-driver/current/" Source = "https://github.com/mongodb/mongo-python-driver" Tracker = "https://jira.mongodb.org/projects/PYTHON/issues" From fb51c11cacce56dca3bf48a810947e45d6c01d2f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 16 Sep 2024 21:23:40 -0500 Subject: [PATCH 2/6] PYTHON-4756 Add changelog note about dropping srv extra (#1861) --- doc/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index ba3cba832..69fbb6f8f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -93,6 +93,10 @@ Unavoidable breaking changes - Since we are now using ``hatch`` as our build backend, we no longer have a usable ``setup.py`` file and require installation using ``pip``. Attempts to invoke the ``setup.py`` file will raise an exception. Additionally, ``pip`` >= 21.3 is now required for editable installs. +- We no longer support the ``srv`` extra, since ``dnspython`` is included as a dependency in PyMongo 4.7+. + Instead of ``pip install pymongo[srv]``, use ``pip install pymongo``. +- We no longer support the ``tls`` extra, which was only valid for Python 2. + Instead of ``pip install pymongo[tls]``, use ``pip install pymongo``. Issues Resolved ............... From 739510214b799664829b1b085918d8c7eb4d67a1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 17 Sep 2024 09:22:17 -0400 Subject: [PATCH 3/6] PYTHON-4731 - Explicitly close all MongoClients opened during tests (#1855) --- pymongo/asynchronous/mongo_client.py | 1 - pymongo/synchronous/mongo_client.py | 1 - pyproject.toml | 3 - test/__init__.py | 189 +++++++- test/asynchronous/__init__.py | 205 +++++++- test/asynchronous/test_auth.py | 119 +++-- test/asynchronous/test_auth_spec.py | 5 +- test/asynchronous/test_bulk.py | 19 +- test/asynchronous/test_change_stream.py | 22 +- test/asynchronous/test_client.py | 438 +++++++++--------- test/asynchronous/test_client_bulk_write.py | 46 +- test/asynchronous/test_collection.py | 20 +- test/asynchronous/test_cursor.py | 34 +- test/asynchronous/test_database.py | 8 +- test/asynchronous/test_encryption.py | 218 ++++----- test/asynchronous/test_grid_file.py | 8 +- test/asynchronous/test_logger.py | 3 +- test/asynchronous/test_monitoring.py | 8 +- test/asynchronous/test_session.py | 18 +- test/asynchronous/test_transactions.py | 50 +- test/asynchronous/utils_spec_runner.py | 5 +- test/auth_aws/test_auth_aws.py | 35 +- test/auth_oidc/test_auth_oidc.py | 8 +- test/mockupdb/test_cursor.py | 7 +- test/ocsp/test_ocsp.py | 7 +- test/test_auth.py | 123 ++--- test/test_auth_spec.py | 5 +- test/test_bulk.py | 19 +- test/test_change_stream.py | 20 +- test/test_client.py | 386 ++++++++------- test/test_client_bulk_write.py | 46 +- test/test_collation.py | 4 +- test/test_collection.py | 20 +- test/test_comment.py | 8 +- test/test_common.py | 19 +- test/test_connection_monitoring.py | 19 +- ...nnections_survive_primary_stepdown_spec.py | 3 +- test/test_cursor.py | 34 +- test/test_custom_types.py | 3 +- test/test_data_lake.py | 8 +- test/test_database.py | 8 +- test/test_discovery_and_monitoring.py | 18 +- test/test_dns.py | 24 +- test/test_encryption.py | 216 ++++----- test/test_examples.py | 8 +- test/test_grid_file.py | 8 +- test/test_gridfs.py | 12 +- test/test_gridfs_bucket.py | 12 +- test/test_heartbeat_monitoring.py | 4 +- test/test_load_balancer.py | 8 +- test/test_logger.py | 3 +- test/test_max_staleness.py | 45 +- test/test_monitor.py | 37 +- test/test_monitoring.py | 10 +- test/test_pooling.py | 18 +- test/test_read_concern.py | 6 +- test/test_read_preferences.py | 46 +- test/test_read_write_concern_spec.py | 19 +- test/test_retryable_reads.py | 9 +- test/test_retryable_writes.py | 27 +- test/test_sdam_monitoring_spec.py | 3 +- test/test_server_selection.py | 7 +- test/test_server_selection_in_window.py | 3 +- test/test_session.py | 17 +- test/test_srv_polling.py | 24 +- test/test_ssl.py | 101 ++-- test/test_streaming_protocol.py | 10 +- test/test_transactions.py | 44 +- test/test_typing.py | 9 +- test/test_versioned_api.py | 6 +- test/unified_format.py | 8 +- test/utils.py | 159 ------- test/utils_spec_runner.py | 5 +- 73 files changed, 1608 insertions(+), 1520 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f7fc8e5e8..6d0e5d528 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1193,7 +1193,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5786bbf5a..b2dff5b4a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1193,7 +1193,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass diff --git a/pyproject.toml b/pyproject.toml index 2df172fde..30c7c046b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,9 +96,6 @@ filterwarnings = [ "module:please use dns.resolver.Resolver.resolve:DeprecationWarning", # https://github.com/dateutil/dateutil/issues/1314 "module:datetime.datetime.utc:DeprecationWarning:dateutil", - # TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731 - "ignore:Unclosed AsyncMongoClient*", - "ignore:Unclosed MongoClient*", ] markers = [ "auth_aws: tests that rely on pymongo-auth-aws", diff --git a/test/__init__.py b/test/__init__.py index 41af81f97..1a17ff14c 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import subprocess import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ from test.helpers import ( sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ from pymongo.synchronous.mongo_client import MongoClient _IS_SYNC = True +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class ClientContext: client: MongoClient @@ -230,6 +235,9 @@ class ClientContext: if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) + if self.client: + self.client.close() + self.client = self._connect( host, port, @@ -256,6 +264,8 @@ class ClientContext: if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + self.client.close() if self.auth_enabled: # It doesn't matter which member we use as the seed here. self.client = pymongo.MongoClient( @@ -318,6 +328,7 @@ class ClientContext: hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + mongos_client.close() def init(self): with self.conn_lock: @@ -537,12 +548,6 @@ class ClientContext: lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -930,6 +935,172 @@ class PyMongoTestCase(unittest.TestCase): self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + return client + + def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + self.addCleanup(client.close) + return client + + @classmethod + def unmanaged_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + def unmanaged_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + + @classmethod + def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return cls._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + def unmanaged_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, **kwargs) + + def single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) + + def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return self._async_mongo_client(h, p, directConnection=True, **kwargs) + + def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return self._async_mongo_client(h, p, **kwargs) + + def rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + self.addCleanup(client.close) + return client + + @classmethod + def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + return client + + def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d1af89c18..0d9433158 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import subprocess import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ from test.helpers import ( sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] _IS_SYNC = False +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class AsyncClientContext: client: AsyncMongoClient @@ -230,6 +235,9 @@ class AsyncClientContext: if not await self._check_user_provided(): await _create_user(self.client.admin, db_user, db_pwd) + if self.client: + await self.client.close() + self.client = await self._connect( host, port, @@ -256,6 +264,8 @@ class AsyncClientContext: if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + await self.client.close() if self.auth_enabled: # It doesn't matter which member we use as the seed here. self.client = pymongo.AsyncMongoClient( @@ -320,6 +330,7 @@ class AsyncClientContext: hello = await mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + await mongos_client.close() async def init(self): with self.conn_lock: @@ -539,12 +550,6 @@ class AsyncClientContext: lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -932,6 +937,188 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + async def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + return client + + async def _async_mongo_client( + self, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + self.addAsyncCleanup(client.close) + return client + + @classmethod + async def unmanaged_async_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + async def unmanaged_async_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + + @classmethod + async def unmanaged_async_rs_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + async def unmanaged_async_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + async def unmanaged_async_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + async def unmanaged_async_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) + + async def async_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await self._async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + async def async_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return await self._async_mongo_client(h, p, directConnection=True, **kwargs) + + async def async_rs_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await self._async_mongo_client(h, p, **kwargs) + + async def async_rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_or_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return await self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient: + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + self.addAsyncCleanup(client.close) + return client + + @classmethod + def unmanaged_simple_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient: + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + return client + + async def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + async def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 06f7fb9ca..fbaca41f0 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -23,16 +23,14 @@ from urllib.parse import quote_plus sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest -from test.utils import ( - AllowListEventListener, - async_rs_or_single_client, - async_rs_or_single_client_noauth, - async_single_client, - async_single_client_noauth, - delay, - ignore_deprecations, +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + SkipTest, + async_client_context, + unittest, ) +from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import AsyncMongoClient, monitoring from pymongo.asynchronous.auth import HAVE_KERBEROS @@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread): self.success = True -class TestGSSAPI(unittest.IsolatedAsyncioTestCase): +class TestGSSAPI(AsyncPyMongoTestCase): mech_properties: str service_realm_required: bool @@ -138,7 +136,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): if not self.service_realm_required: # Without authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -149,11 +147,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): await client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -166,14 +164,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].collection.find_one() set_name = async_client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -185,11 +183,11 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): await client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -202,13 +200,13 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): await client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].list_collection_names() @ignore_deprecations @async_client_context.require_sync async def test_gssapi_threaded(self): - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -244,7 +242,7 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -267,14 +265,14 @@ class TestGSSAPI(unittest.IsolatedAsyncioTestCase): self.assertTrue(thread.success) -class TestSASLPlain(unittest.IsolatedAsyncioTestCase): +class TestSASLPlain(AsyncPyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") async def test_sasl_plain(self): - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -293,12 +291,12 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase): SASL_PORT, SASL_DB, ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -317,7 +315,7 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase): SASL_DB, str(set_name), ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() async def test_sasl_plain_bad_credentials(self): @@ -331,8 +329,8 @@ class TestSASLPlain(unittest.IsolatedAsyncioTestCase): ) return uri - bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): await bad_user.admin.command("ping") @@ -356,7 +354,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest): async def test_scram_sha1(self): host, port = await async_client_context.host, await async_client_context.port - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) ) await client.pymongo_test.command("dbstats") @@ -367,7 +365,7 @@ class TestSCRAMSHA1(AsyncIntegrationTest): "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" "&replicaSet=%s" % (host, port, async_client_context.replica_set_name) ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) await client.pymongo_test.command("dbstats") db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) await db.command("dbstats") @@ -395,7 +393,7 @@ class TestSCRAM(AsyncIntegrationTest): "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] ) await client.testscram.command("dbstats") @@ -432,38 +430,38 @@ class TestSCRAM(AsyncIntegrationTest): ) # Step 2: verify auth success cases - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") self.listener.reset() - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] ) await client.testscram.command("dbstats") @@ -476,19 +474,19 @@ class TestSCRAM(AsyncIntegrationTest): self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) with self.assertRaises(OperationFailure): await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) with self.assertRaises(OperationFailure): await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="not-a-user", password="pwd", authSource="testscram" ) with self.assertRaises(OperationFailure): @@ -501,7 +499,7 @@ class TestSCRAM(AsyncIntegrationTest): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) await client.testscram.command("dbstats") db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) await db.command("dbstats") @@ -521,12 +519,12 @@ class TestSCRAM(AsyncIntegrationTest): "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram", @@ -534,17 +532,17 @@ class TestSCRAM(AsyncIntegrationTest): ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="IV", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram", @@ -552,31 +550,31 @@ class TestSCRAM(AsyncIntegrationTest): ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://IX:IX@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") async def test_cache(self): - client = await async_single_client() + client = await self.async_single_client() credentials = client.options.pool_options._credentials cache = credentials.cache self.assertIsNotNone(cache) @@ -601,8 +599,7 @@ class TestSCRAM(AsyncIntegrationTest): await coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate - client = await async_rs_or_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client() coll = client.db.test threads = [] for _ in range(4): @@ -631,7 +628,9 @@ class TestAuthURIOptions(AsyncIntegrationTest): async def test_uri_options(self): # Test default to admin host, port = await async_client_context.host, await async_client_context.port - client = await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + client = await self.async_rs_or_single_client_noauth( + "mongodb://admin:pass@%s:%d" % (host, port) + ) self.assertTrue(await client.admin.command("dbstats")) if async_client_context.is_rs: @@ -640,14 +639,14 @@ class TestAuthURIOptions(AsyncIntegrationTest): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) self.assertTrue(await client.admin.command("dbstats")) db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) self.assertTrue(await db.command("dbstats")) # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) - client = await async_rs_or_single_client_noauth(uri) + client = await self.async_rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.admin.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -658,7 +657,7 @@ class TestAuthURIOptions(AsyncIntegrationTest): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.admin.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -667,7 +666,7 @@ class TestAuthURIOptions(AsyncIntegrationTest): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) - client = await async_rs_or_single_client_noauth(uri) + client = await self.async_rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.pymongo_test2.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -677,7 +676,7 @@ class TestAuthURIOptions(AsyncIntegrationTest): "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" "%s;authSource=pymongo_test" % (host, port, async_client_context.replica_set_name) ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.pymongo_test2.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 329b3eec6..a6ab1cb33 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -20,6 +20,7 @@ import json import os import sys import warnings +from test.asynchronous import AsyncPyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _IS_SYNC = False _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.IsolatedAsyncioTestCase): +class TestAuthSpec(AsyncPyMongoTestCase): pass @@ -54,7 +55,7 @@ def create_test(test_case): warnings.simplefilter("default") self.assertRaises(Exception, AsyncMongoClient, uri, connect=False) else: - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 79d8e1a0f..42a331107 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -24,23 +24,14 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest -from test.utils import ( - async_rs_or_single_client_noauth, - async_single_client, - async_wait_until, -) +from test.utils import async_wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.asynchronous.collection import AsyncCollection from pymongo.common import partition_node -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, -) +from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import * from pymongo.write_concern import WriteConcern @@ -915,7 +906,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase): async def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -926,7 +917,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase): async def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -954,7 +945,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): if cls.w is not None and cls.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await async_single_client(*partition_node(member)) + cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) break @classmethod diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 1b89c43bb..883ed72c4 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -28,12 +28,17 @@ from typing import no_type_check sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, Version, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + Version, + async_client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, - async_rs_or_single_client, async_wait_until, ) @@ -69,8 +74,7 @@ class TestAsyncChangeStreamBase(AsyncIntegrationTest): async def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) return client, listener def watched_collection(self, *args, **kwargs): @@ -176,7 +180,7 @@ class APITestsMixin: @no_type_check async def test_try_next_runs_one_getmore(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -234,7 +238,7 @@ class APITestsMixin: @no_type_check async def test_batch_size_is_honored(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -481,7 +485,9 @@ class ProseSpecTestsMixin: @no_type_check async def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + event_listeners=[listener] + ) self.addAsyncCleanup(client.close) return client, listener @@ -1131,7 +1137,7 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 97cbdf6db..f610f3277 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -36,6 +36,7 @@ from unittest import mock from unittest.mock import patch import pytest +import pytest_asyncio from pymongo.operations import _Op @@ -61,10 +62,6 @@ from test.utils import ( CMAPListener, FunctionCallRecorder, async_get_pool, - async_rs_client, - async_rs_or_single_client, - async_rs_or_single_client_noauth, - async_single_client, async_wait_until, asyncAssertRaisesExactly, delay, @@ -72,7 +69,6 @@ from test.utils import ( is_greenthread_patched, lazy_client_trial, one, - rs_or_single_client, wait_until, ) @@ -133,7 +129,9 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _setup_class(cls): - cls.client = await async_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = await cls.unmanaged_async_rs_or_single_client( + connect=False, serverSelectionTimeoutMS=100 + ) @classmethod async def _tearDown_class(cls): @@ -143,8 +141,8 @@ class AsyncClientUnitTest(AsyncUnitTest): def inject_fixtures(self, caplog): self._caplog = caplog - def test_keyword_arg_defaults(self): - client = AsyncMongoClient( + async def test_keyword_arg_defaults(self): + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -169,16 +167,18 @@ class AsyncClientUnitTest(AsyncUnitTest): self.assertEqual(ReadPreference.PRIMARY, client.read_preference) self.assertAlmostEqual(12, client.options.server_selection_timeout) - def test_connect_timeout(self): - client = AsyncMongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + async def test_connect_timeout(self): + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = AsyncMongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = AsyncMongoClient( + + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -194,8 +194,8 @@ class AsyncClientUnitTest(AsyncUnitTest): self.assertRaises(ConfigurationError, AsyncMongoClient, []) - def test_max_pool_size_zero(self): - AsyncMongoClient(maxPoolSize=0) + async def test_max_pool_size_zero(self): + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") @@ -260,7 +260,7 @@ class AsyncClientUnitTest(AsyncUnitTest): self.assertNotIsInstance(client, Iterable) async def test_get_default_database(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -277,7 +277,7 @@ class AsyncClientUnitTest(AsyncUnitTest): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -285,7 +285,7 @@ class AsyncClientUnitTest(AsyncUnitTest): async def test_get_default_database_error(self): # URI with no database. - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -297,11 +297,11 @@ class AsyncClientUnitTest(AsyncUnitTest): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -310,7 +310,7 @@ class AsyncClientUnitTest(AsyncUnitTest): async def test_get_database_default_error(self): # URI with no database. - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -322,47 +322,53 @@ class AsyncClientUnitTest(AsyncUnitTest): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) - def test_primary_read_pref_with_tags(self): + async def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreferencetags=dc:east") + await self.async_single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + await self.async_single_client( + "mongodb://host/?readpreference=primary&readpreferencetags=dc:east" + ) async def test_read_preference(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) self.assertEqual(c.read_preference, ReadPreference.NEAREST) - def test_metadata(self): + async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = AsyncMongoClient("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error - AsyncMongoClient(appname="x" * 128) - self.assertRaises(ValueError, AsyncMongoClient, appname="x" * 129) + self.simple_client(appname="x" * 128) + with self.assertRaises(ValueError): + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, AsyncMongoClient, driver=1) - self.assertRaises(TypeError, AsyncMongoClient, driver="abc") - self.assertRaises(TypeError, AsyncMongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + self.simple_client(driver=1) + with self.assertRaises(TypeError): + self.simple_client(driver="abc") + with self.assertRaises(TypeError): + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = AsyncMongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -372,7 +378,7 @@ class AsyncClientUnitTest(AsyncUnitTest): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = AsyncMongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -382,7 +388,7 @@ class AsyncClientUnitTest(AsyncUnitTest): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = AsyncMongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -391,7 +397,7 @@ class AsyncClientUnitTest(AsyncUnitTest): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = AsyncMongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -407,11 +413,11 @@ class AsyncClientUnitTest(AsyncUnitTest): metadata["driver"]["name"] = "PyMongo|async" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) - def test_kwargs_codec_options(self): + async def test_kwargs_codec_options(self): class MyFloatType: def __init__(self, x): self.__x = x @@ -433,7 +439,7 @@ class AsyncClientUnitTest(AsyncUnitTest): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = AsyncMongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -442,12 +448,12 @@ class AsyncClientUnitTest(AsyncUnitTest): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) @@ -469,11 +475,11 @@ class AsyncClientUnitTest(AsyncUnitTest): datetime_conversion, ) ) - c = AsyncMongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual( @@ -482,16 +488,15 @@ class AsyncClientUnitTest(AsyncUnitTest): # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = AsyncMongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual( c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] ) - def test_uri_option_precedence(self): + async def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = AsyncMongoClient( + c = self.simple_client( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" ) clopts = c.options @@ -501,7 +506,7 @@ class AsyncClientUnitTest(AsyncUnitTest): self.assertEqual(clopts.replica_set_name, "newname") self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - def test_connection_timeout_ms_propagates_to_DNS_resolver(self): + async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve @@ -520,37 +525,37 @@ class AsyncClientUnitTest(AsyncUnitTest): uri_with_timeout = base_uri + "/?connectTimeoutMS=6000" expected_uri_value = 6.0 - def test_scenario(args, kwargs, expected_value): + async def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - AsyncMongoClient(*args, **kwargs) + self.simple_client(*args, **kwargs) for _, kw in patched_resolver.call_list(): self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. - test_scenario((base_uri,), {}, CONNECT_TIMEOUT) + await test_scenario((base_uri,), {}, CONNECT_TIMEOUT) # Timeout only specified in connection string. - test_scenario((uri_with_timeout,), {}, expected_uri_value) + await test_scenario((uri_with_timeout,), {}, expected_uri_value) # Timeout only specified in keyword arguments. kwarg = {"connectTimeoutMS": connectTimeoutMS} - test_scenario((base_uri,), kwarg, expected_kw_value) + await test_scenario((base_uri,), kwarg, expected_kw_value) # Timeout specified in both kwargs and connection string. - test_scenario((uri_with_timeout,), kwarg, expected_kw_value) + await test_scenario((uri_with_timeout,), kwarg, expected_kw_value) - def test_uri_security_options(self): + async def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = AsyncMongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, @@ -558,7 +563,7 @@ class AsyncClientUnitTest(AsyncUnitTest): # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, @@ -566,10 +571,10 @@ class AsyncClientUnitTest(AsyncUnitTest): # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - AsyncMongoClient(ssl=True, tls=False) + self.simple_client(ssl=True, tls=False) - def test_event_listeners(self): - c = AsyncMongoClient(event_listeners=[], connect=False) + async def test_event_listeners(self): + c = self.simple_client(event_listeners=[], connect=False) self.assertEqual(c.options.event_listeners, []) listeners = [ event_loggers.CommandLogger(), @@ -578,11 +583,11 @@ class AsyncClientUnitTest(AsyncUnitTest): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = AsyncMongoClient(event_listeners=listeners, connect=False) + c = self.simple_client(event_listeners=listeners, connect=False) self.assertEqual(c.options.event_listeners, listeners) - def test_client_options(self): - c = AsyncMongoClient(connect=False) + async def test_client_options(self): + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) @@ -612,16 +617,16 @@ class AsyncClientUnitTest(AsyncUnitTest): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - AsyncMongoClient(host) + AsyncMongoClient(host, connect=False) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - AsyncMongoClient(host) - AsyncMongoClient(multi_host) + AsyncMongoClient(host, connect=False) + AsyncMongoClient(multi_host, connect=False) logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_warning(self, mock_get_hosts): + async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ "host.cosmos.azure.com", @@ -634,13 +639,13 @@ class AsyncClientUnitTest(AsyncUnitTest): ) for host in normal_hosts: with self.assertWarns(UserWarning): - AsyncMongoClient(host) + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - AsyncMongoClient(host) + self.simple_client(host) with self.assertWarns(UserWarning): - AsyncMongoClient(multi_host) + self.simple_client(multi_host) class TestClient(AsyncIntegrationTest): @@ -657,7 +662,7 @@ class TestClient(AsyncIntegrationTest): async def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -665,12 +670,11 @@ class TestClient(AsyncIntegrationTest): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - await client.close() async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = await async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -681,12 +685,11 @@ class TestClient(AsyncIntegrationTest): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1 ) server = await (await client._get_topology()).select_server( @@ -699,12 +702,11 @@ class TestClient(AsyncIntegrationTest): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -719,18 +721,17 @@ class TestClient(AsyncIntegrationTest): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - await client.close() async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = await async_rs_or_single_client(minPoolSize=10) + client = await self.async_rs_or_single_client(minPoolSize=10) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -751,7 +752,7 @@ class TestClient(AsyncIntegrationTest): async def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -767,7 +768,7 @@ class TestClient(AsyncIntegrationTest): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -793,36 +794,38 @@ class TestClient(AsyncIntegrationTest): AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" AsyncMongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - await connected(AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs)) + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + await connected(c) + c = self.simple_client(host, port, **kwargs) # Override the defaults. No error. - await connected(AsyncMongoClient(host, port, **kwargs)) + await connected(c) # Set good defaults. AsyncMongoClient.HOST = host AsyncMongoClient.PORT = port # No error. - await connected(AsyncMongoClient(**kwargs)) + c = self.simple_client(**kwargs) + await connected(c) async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) # is_primary causes client to block until connected self.assertIsInstance(await c.is_primary, bool) - - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(await c.is_mongos, bool) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(await c.address) # PYTHON-2981 @@ -834,45 +837,44 @@ class TestClient(AsyncIntegrationTest): self.assertEqual(await c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) - c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) + + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) + # Seeds differ: - self.assertNotEqual( - AsyncMongoClient("a", connect=False), AsyncMongoClient("b", connect=False) - ) + self.assertNotEqual(c1, c2) + + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) + # Same seeds but out of order still compares equal: - self.assertEqual( - AsyncMongoClient(["a", "b", "c"], connect=False), - AsyncMongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertIn(c, {async_client_context.client}) - c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -886,7 +888,7 @@ class TestClient(AsyncIntegrationTest): ) ) - def test_repr(self): + async def test_repr(self): # Used to test 'eval' below. import bson @@ -905,9 +907,10 @@ class TestClient(AsyncIntegrationTest): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) - client = AsyncMongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -925,7 +928,8 @@ class TestClient(AsyncIntegrationTest): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") @@ -941,8 +945,7 @@ class TestClient(AsyncIntegrationTest): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(document_class=SON) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -981,7 +984,7 @@ class TestClient(AsyncIntegrationTest): await self.client.drop_database("pymongo_test") if async_client_context.is_rs: - wc_client = await async_rs_or_single_client(w=len(async_client_context.nodes) + 1) + wc_client = await self.async_rs_or_single_client(w=len(async_client_context.nodes) + 1) with self.assertRaises(WriteConcernError): await wc_client.drop_database("pymongo_test2") @@ -991,7 +994,7 @@ class TestClient(AsyncIntegrationTest): self.assertNotIn("pymongo_test2", dbs) async def test_close(self): - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() coll = test_client.pymongo_test.bar await test_client.close() with self.assertRaises(InvalidOperation): @@ -1001,7 +1004,7 @@ class TestClient(AsyncIntegrationTest): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() await test_client._process_periodic_tasks() @@ -1028,13 +1031,13 @@ class TestClient(AsyncIntegrationTest): self.assertTrue(test_client._topology._opened) await test_client.close() self.assertFalse(test_client._topology._opened) - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) async def test_close_stops_kill_cursors_thread(self): - client = await async_rs_client() + client = await self.async_rs_client() await client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1050,7 +1053,7 @@ class TestClient(AsyncIntegrationTest): async def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1063,19 +1066,15 @@ class TestClient(AsyncIntegrationTest): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - await client.close() - async def test_close_does_not_open_servers(self): - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) await client.close() self.assertEqual(topology._servers, {}) async def test_close_closes_sockets(self): - client = await async_rs_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_client() await client.test.test.find_one() topology = client._topology await client.close() @@ -1104,35 +1103,35 @@ class TestClient(AsyncIntegrationTest): with self.assertRaises(OperationFailure): await connected( - await async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) ) # No error. await connected( - await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) ) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - await connected(await async_rs_or_single_client_noauth(uri)) + await connected(await self.async_rs_or_single_client_noauth(uri)) # No error. await connected( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) ) ) # Auth with lazy connection. await ( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = await async_rs_or_single_client_noauth( + bad_client = await self.async_rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1144,7 +1143,7 @@ class TestClient(AsyncIntegrationTest): await async_client_context.create_user("admin", "ad min", "pa/ss") self.addAsyncCleanup(async_client_context.drop_user, "admin", "ad min") - c = await async_rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = await self.async_rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1157,14 +1156,14 @@ class TestClient(AsyncIntegrationTest): with self.assertRaises(OperationFailure): await ( - await async_rs_or_single_client_noauth(username="ad min", password="foo") + await self.async_rs_or_single_client_noauth(username="ad min", password="foo") ).server_info() @async_client_context.require_auth @async_client_context.require_no_fips async def test_lazy_auth_raises_operation_failure(self): host = await async_client_context.host - lazy_client = await async_rs_or_single_client_noauth( + lazy_client = await self.async_rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1182,8 +1181,7 @@ class TestClient(AsyncIntegrationTest): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = await async_rs_or_single_client(uri) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1192,11 +1190,10 @@ class TestClient(AsyncIntegrationTest): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - await connected( - AsyncMongoClient( - "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ), + c = self.simple_client( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 ) + await connected(c) async def test_document_class(self): c = self.client @@ -1207,15 +1204,15 @@ class TestClient(AsyncIntegrationTest): self.assertTrue(isinstance(await db.test.find_one(), dict)) self.assertFalse(isinstance(await db.test.find_one(), SON)) - c = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(document_class=SON) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) self.assertTrue(isinstance(await db.test.find_one(), SON)) async def test_timeouts(self): - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1228,28 +1225,31 @@ class TestClient(AsyncIntegrationTest): self.assertEqual(10.5, client.options.server_selection_timeout) async def test_socket_timeout_ms_validation(self): - c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + c = await self.async_rs_or_single_client(socketTimeoutMS=10 * 1000) self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=None)) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=-1) + async with await self.async_rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=1e10) + async with await self.async_rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS="foo") + async with await self.async_rs_or_single_client(socketTimeoutMS="foo"): + pass async def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") @@ -1266,7 +1266,7 @@ class TestClient(AsyncIntegrationTest): with self.assertRaises(NetworkTimeout): await get_x(timeout.pymongo_test) - def test_server_selection_timeout(self): + async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) @@ -1298,7 +1298,7 @@ class TestClient(AsyncIntegrationTest): self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): - client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) async def test_socketKeepAlive(self): @@ -1311,7 +1311,7 @@ class TestClient(AsyncIntegrationTest): async def test_tz_aware(self): self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") - aware = await async_rs_or_single_client(tz_aware=True) + aware = await self.async_rs_or_single_client(tz_aware=True) self.addAsyncCleanup(aware.close) naive = self.client await aware.pymongo_test.drop_collection("test") @@ -1341,8 +1341,7 @@ class TestClient(AsyncIntegrationTest): if async_client_context.is_rs: uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") - client = await async_rs_or_single_client_noauth(uri) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client_noauth(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1351,7 +1350,7 @@ class TestClient(AsyncIntegrationTest): self.assertTrue("pymongo_test_bernie" in dbs) async def test_contextlib(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() await client.pymongo_test.drop_collection("test") await client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1365,7 +1364,7 @@ class TestClient(AsyncIntegrationTest): self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): await client.pymongo_test.test.find_one() - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() async with client as client: self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1443,8 +1442,7 @@ class TestClient(AsyncIntegrationTest): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = await async_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_single_client() await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) @@ -1468,8 +1466,7 @@ class TestClient(AsyncIntegrationTest): await async_client_context.client.drop_database("test_lazy_connect_w0") self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1477,8 +1474,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "find one document") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1486,8 +1482,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "update one document") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1499,8 +1494,7 @@ class TestClient(AsyncIntegrationTest): async def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1527,7 +1521,9 @@ class TestClient(AsyncIntegrationTest): # Get a client with one socket so we detect if it's leaked. c = await connected( - await async_rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + await self.async_rs_or_single_client( + maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False + ) ) # Cause a network error on the actual socket. @@ -1545,8 +1541,7 @@ class TestClient(AsyncIntegrationTest): @async_client_context.require_no_replica_set async def test_connect_to_standalone_using_replica_set_name(self): - client = await async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - + client = await self.async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): await client.test.test.find_one() @@ -1556,7 +1551,7 @@ class TestClient(AsyncIntegrationTest): # the topology before the getMore message is sent. Test that # AsyncMongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = await async_rs_client(connect=False, serverSelectionTimeoutMS=100) + client = await self.async_rs_client(connect=False, serverSelectionTimeoutMS=100) await client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1604,7 +1599,7 @@ class TestClient(AsyncIntegrationTest): await async_client_context.host, await async_client_context.port, ) - client = await async_single_client(uri, event_listeners=[listener]) + await self.async_single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1613,7 +1608,6 @@ class TestClient(AsyncIntegrationTest): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - await client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1630,31 +1624,31 @@ class TestClient(AsyncIntegrationTest): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) @@ -1662,56 +1656,55 @@ class TestClient(AsyncIntegrationTest): # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = async_client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = await async_single_client(zlibcompressionlevel=level) + client = await self.async_single_client(zlibcompressionlevel=level) # No error await client.pymongo_test.test.find_one() async def test_reset_during_update_pool(self): - client = await async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1757,11 +1750,9 @@ class TestClient(AsyncIntegrationTest): async def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addAsyncCleanup(client.close) - # Create a single connection in the pool. await client.admin.command("ping") @@ -1791,21 +1782,19 @@ class TestClient(AsyncIntegrationTest): @async_client_context.require_replica_set async def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = await async_rs_or_single_client(directConnection=True) + client = await self.async_rs_or_single_client(directConnection=True) await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - await client.close() # direct_connection=False should result in RS topology. - client = await async_rs_or_single_client(directConnection=False) + client = await self.async_rs_or_single_client(directConnection=False) await client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - await client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1825,11 +1814,10 @@ class TestClient(AsyncIntegrationTest): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = AsyncMongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addAsyncCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() @@ -1842,8 +1830,7 @@ class TestClient(AsyncIntegrationTest): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): - client = await async_single_client(retryReads=False) - self.addAsyncCleanup(client.close) + client = await self.async_single_client(retryReads=False) await client.admin.command("ping") # connect async with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1855,7 +1842,7 @@ class TestClient(AsyncIntegrationTest): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") async def test_process_periodic_tasks(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() coll = client.db.collection await coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -1873,7 +1860,7 @@ class TestClient(AsyncIntegrationTest): with self.assertRaises(InvalidOperation): await coll.insert_many([{} for _ in range(5)]) - def test_service_name_from_kwargs(self): + async def test_service_name_from_kwargs(self): client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", @@ -1893,12 +1880,12 @@ class TestClient(AsyncIntegrationTest): ) self.assertEqual(client._topology_settings.srv_service_name, "customname") - def test_srv_max_hosts_kwarg(self): - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + async def test_srv_max_hosts_kwarg(self): + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @@ -1946,10 +1933,10 @@ class TestClient(AsyncIntegrationTest): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - async with await async_rs_or_single_client(serverSelectionTimeoutMS=10000) as client: - await client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = await self.async_rs_or_single_client(serverSelectionTimeoutMS=10000) + await client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) async def test_handshake_01_aws(self): await self._test_handshake( @@ -2045,7 +2032,7 @@ class TestExhaustCursor(AsyncIntegrationTest): async def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1)) + client = await connected(await self.async_rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2068,7 +2055,7 @@ class TestExhaustCursor(AsyncIntegrationTest): async def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() @@ -2107,7 +2094,9 @@ class TestExhaustCursor(AsyncIntegrationTest): async def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = await connected( + await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + ) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2128,7 +2117,7 @@ class TestExhaustCursor(AsyncIntegrationTest): async def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2177,7 +2166,7 @@ class TestExhaustCursor(AsyncIntegrationTest): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.async_rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2209,7 +2198,7 @@ class TestExhaustCursor(AsyncIntegrationTest): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.async_rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = async_get_pool(client) @@ -2246,7 +2235,7 @@ class TestClientLazyConnect(AsyncIntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.async_rs_or_single_client(connect=False) @async_client_context.require_sync def test_insert_one(self): @@ -2380,6 +2369,7 @@ class TestMongoClientFailover(AsyncMockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addAsyncCleanup(c.close) # Set host-specific information so we can test whether it is reset. diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index c35e823d0..3a1729945 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -27,7 +27,6 @@ from test.asynchronous import ( ) from test.utils import ( OvertCommandListener, - async_rs_or_single_client, ) from unittest.mock import patch @@ -39,7 +38,6 @@ from pymongo.errors import ( InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.write_concern import WriteConcern @@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) models = [] for _ in range(self.max_write_batch_size + 1): @@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_failCommand_fail_point async def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], retryWrites=False, ) - self.addAsyncCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_standalone async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_failCommand_fail_point async def test_handles_getMore_error(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) b_repeated = "b" * self.max_bson_object_size @@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) num_models, models = await self._setup_namespace_test_models() models.append( @@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_no_serverless async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) num_models, models = await self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_version_min(8, 0, 0, -24) @async_client_context.require_no_serverless async def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = await async_rs_or_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client() # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): async def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = await async_rs_or_single_client(timeoutMS=None) + internal_client = await self.async_rs_or_single_client(timeoutMS=None) self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): ) listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", timeoutMS=2000, w="majority", ) - self.addAsyncCleanup(client.close) await client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 10d64a525..74a4a5151 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -30,6 +30,7 @@ sys.path[0:0] = [""] from test import unittest from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528 AsyncIntegrationTest, + AsyncUnitTest, async_client_context, ) from test.utils import ( @@ -37,8 +38,6 @@ from test.utils import ( EventListener, async_get_pool, async_is_mongos, - async_rs_or_single_client, - async_single_client, async_wait_until, wait_until, ) @@ -82,14 +81,20 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = False -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(AsyncUnitTest): """Test Collection features on a client that does not connect.""" db: AsyncDatabase + client: AsyncMongoClient @classmethod - def setUpClass(cls): - cls.db = AsyncMongoClient(connect=False).pymongo_test + async def _setup_class(cls): + cls.client = AsyncMongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -1819,8 +1824,7 @@ class AsyncTestCollection(AsyncIntegrationTest): # Insert enough documents to require more than one batch await self.db.test.insert_many([{"i": i} for i in range(150)]) - client = await async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(maxPoolSize=1) pool = await async_get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2100,7 +2104,7 @@ class AsyncTestCollection(AsyncIntegrationTest): async def test_find_one_and_write_concern(self): listener = EventListener() - db = (await async_single_client(event_listeners=[listener]))[self.db.name] + db = (await self.async_single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 6967205fe..d6773d832 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,7 +34,6 @@ from test.utils import ( AllowListEventListener, EventListener, OvertCommandListener, - async_rs_or_single_client, ignore_deprecations, wait_until, ) @@ -232,7 +231,7 @@ class TestCursor(AsyncIntegrationTest): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (await async_rs_or_single_client(event_listeners=[listener]))[ + coll = (await self.async_rs_or_single_client(event_listeners=[listener]))[ self.db.name ].pymongo_test @@ -353,8 +352,7 @@ class TestCursor(AsyncIntegrationTest): async def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) started = listener.started_events @@ -1261,8 +1259,7 @@ class TestCursor(AsyncIntegrationTest): await self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1300,8 +1297,7 @@ class TestCursor(AsyncIntegrationTest): @async_client_context.require_failCommand_appName async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1358,8 +1354,7 @@ class TestCursor(AsyncIntegrationTest): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1463,7 +1458,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1493,7 +1488,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1514,7 +1509,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1577,7 +1572,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1643,7 +1638,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1674,7 +1669,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1698,7 +1693,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1744,7 +1739,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1788,8 +1783,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): @async_client_context.require_no_mongos async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.delete_many({}) await c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 8f6886a2a..c5d62323d 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -29,7 +29,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - async_rs_or_single_client, async_wait_until, ) @@ -208,7 +207,7 @@ class TestDatabase(AsyncIntegrationTest): async def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] await db.capped.drop() await db.create_collection("capped", capped=True, size=4096) @@ -235,8 +234,7 @@ class TestDatabase(AsyncIntegrationTest): async def test_check_exists(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] await db.drop_collection("unique") await db.create_collection("unique", check_exists=True) @@ -326,7 +324,7 @@ class TestDatabase(AsyncIntegrationTest): await self.client.drop_database("pymongo_test") async def test_list_collection_names_single_socket(self): - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) await client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 030f468db..3f3714eeb 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -31,7 +31,7 @@ import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -44,6 +44,8 @@ sys.path[0:0] = [""] from test import ( unittest, ) +from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers import ( AWS_CREDS, AZURE_CREDS, @@ -59,12 +61,10 @@ from test.utils import ( OvertCommandListener, SpecTestCreator, TopologyEventListener, - async_rs_or_single_client, async_wait_until, camel_to_snake_args, is_greenthread_patched, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") async def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = AsyncMongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addAsyncCleanup(client.aclose) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase): class TestClientOptions(AsyncPyMongoTestCase): async def test_default(self): - client = AsyncMongoClient(connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = AsyncMongoClient(auto_encryption_opts=None, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") async def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = AsyncMongoClient(auto_encryption_opts=opts, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addAsyncCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -260,8 +284,7 @@ def bson_data(*paths): class TestClientSimple(AsyncEncryptionIntegrationTest): async def _test_auto_encrypt(self, opts): - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) # Create the encrypted field's data key. key_vault = await create_key_vault( @@ -342,8 +365,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest): async def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client.admin.command("ping") await client.aclose() @@ -360,8 +382,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest): ) async def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) async def target(): with warnings.catch_warnings(): @@ -375,8 +396,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest): class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): async def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -416,8 +436,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest): @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): await client.test.test.insert_one({}) @@ -430,8 +449,7 @@ class TestClientMaxWireVersion(AsyncIntegrationTest): async def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): await client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ class TestClientMaxWireVersion(AsyncIntegrationTest): class TestExplicitSimple(AsyncEncryptionIntegrationTest): async def test_encrypt_decrypt(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Use standard UUID representation. key_vault = async_client_context.client.keyvault.get_collection( "datakeys", codec_options=OPTS @@ -495,10 +512,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): self.assertEqual(decrypted_ssn, doc["ssn"]) async def test_validation(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -512,10 +528,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): await client_encryption.encrypt("str", algo, key_id=Binary(b"123")) async def test_bson_errors(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -528,7 +543,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): async def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - AsyncClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, @@ -536,10 +551,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = AsyncClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = await client_encryption_legacy.create_data_key("local") @@ -554,10 +568,9 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption.close) encrypted_standard = await client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -573,7 +586,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value) async def test_close(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) await client_encryption.close() @@ -589,7 +602,7 @@ class TestExplicitSimple(AsyncEncryptionIntegrationTest): await client_encryption.decrypt(Binary(b"", 6)) async def test_with_statement(self): - async with AsyncClientEncryption( + async with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) as client_encryption: pass @@ -613,7 +626,7 @@ KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PE if _IS_SYNC: # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): + class TestSpec(AsyncSpecRunner): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): @@ -811,7 +824,7 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = OvertCommandListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) await cls.client.db.coll.drop() cls.vault = await create_key_vault(cls.client.keyvault.datakeys) @@ -833,10 +846,10 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = AsyncClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -923,10 +936,9 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = await async_rs_or_single_client( + key_vault_client = await self.async_rs_or_single_client( username="fake-user", password="fake-pwd" ) - self.addAsyncCleanup(key_vault_client.close) else: key_vault_client = async_client_context.client opts = AutoEncryptionOpts( @@ -936,15 +948,13 @@ class TestExternalKeyVault(AsyncEncryptionIntegrationTest): key_vault_client=key_vault_client, ) - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.close) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addAsyncCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -990,10 +1000,9 @@ class TestViews(AsyncEncryptionIntegrationTest): self.addAsyncCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.aclose) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): await client_encrypted.db.view.insert_one({}) @@ -1050,17 +1059,15 @@ class TestCorpus(AsyncEncryptionIntegrationTest): ) self.addAsyncCleanup(vault.drop) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", async_client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addAsyncCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1203,7 +1210,7 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1291,7 +1298,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1303,7 +1310,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = AsyncClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1484,7 +1491,7 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest): await self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1496,7 +1503,7 @@ class AzureGCPEncryptionTestMixin: await create_key_vault(keyvault, self.DEK) async def _test_explicit(self, expectation): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), async_client_context.client, @@ -1525,7 +1532,7 @@ class AzureGCPEncryptionTestMixin: ) insert_listener = AllowListEventListener("insert") - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addAsyncCleanup(client.aclose) @@ -1604,19 +1611,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationT # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): - self.client_test = await async_rs_or_single_client( + self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addAsyncCleanup(self.client_test.aclose) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = await async_rs_or_single_client( + self.client_keyvault = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addAsyncCleanup(self.client_keyvault.aclose) await self.client_test.keyvault.datakeys.drop() await self.client_test.db.coll.drop() @@ -1629,7 +1634,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest): codec_options=OPTS, ) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1645,7 +1650,7 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") async def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1663,8 +1668,6 @@ class TestDeadlockProse(AsyncEncryptionIntegrationTest): result = await client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addAsyncCleanup(client_encrypted.close) - async def test_case_1(self): await self._run_test( max_pool_size=1, @@ -1840,7 +1843,7 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest): await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = await self.client_encryption.create_data_key("local") @@ -1855,10 +1858,9 @@ class TestDecryptProse(AsyncEncryptionIntegrationTest): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = await async_rs_or_single_client( + self.encrypted_client = await self.async_rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addAsyncCleanup(self.encrypted_client.close) async def test_01_command_error(self): async with self.fail_point( @@ -1935,8 +1937,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) with self.assertRaisesRegex(EncryptionError, "Timeout"): await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1950,11 +1951,10 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = AsyncMongoClient( + mongocryptd_client = self.simple_client( "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" ) with self.assertRaises(ServerSelectionTimeoutError): @@ -1978,15 +1978,13 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest): ], crypt_shared_lib_required=True, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((await async_client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = AsyncMongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addAsyncCleanup(no_mongocryptd_client.aclose) with self.assertRaises(ServerSelectionTimeoutError): await no_mongocryptd_client.db.command("ping") @@ -2020,8 +2018,7 @@ class TestBypassSpawningMongocryptdProse(AsyncEncryptionIntegrationTest): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2035,10 +2032,9 @@ class TestKmsTLSProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): await super().asyncSetUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = AsyncClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addAsyncCleanup(self.client_encrypted.close) async def test_invalid_kms_certificate_expired(self): key = { @@ -2083,36 +2079,32 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = AsyncClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = AsyncClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addAsyncCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = AsyncClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = AsyncClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2150,7 +2142,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = AsyncClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2232,10 +2224,9 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest): async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = AsyncClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addAsyncCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2285,7 +2276,7 @@ class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest): self.client = async_client_context.client await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = await self.client_encryption.create_data_key( @@ -2327,17 +2318,15 @@ class TestExplicitQueryableEncryption(AsyncEncryptionIntegrationTest): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(self.encrypted_client.aclose) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) async def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2464,14 +2453,13 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest): await self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``AsyncClientEncryption`` object named ``client_encryption1`` - client_encryption1 = AsyncClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = await client_encryption1.create_data_key( @@ -2484,16 +2472,14 @@ class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest): ) # Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2`` - client2 = await async_rs_or_single_client() - self.addAsyncCleanup(client2.aclose) - client_encryption2 = AsyncClientEncryption( + client2 = await self.async_rs_or_single_client() + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = await client_encryption2.rewrap_many_data_key( @@ -2528,7 +2514,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") async def test_01_failure(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2539,7 +2525,7 @@ class TestOnDemandAWSCredentials(AsyncEncryptionIntegrationTest): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") async def test_02_success(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2559,8 +2545,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest): # AsyncMongoClient to use in testing that handles auth/tls/etc, # and cleanup. async def AsyncMongoClient(**kwargs): - c = await async_rs_or_single_client(**kwargs) - self.addAsyncCleanup(c.aclose) + c = await self.async_rs_or_single_client(**kwargs) return c # Drop data from prior test runs. @@ -2571,7 +2556,7 @@ class TestQueryableEncryptionDocsExample(AsyncEncryptionIntegrationTest): # Create two data keys. key_vault_client = await AsyncMongoClient() - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = await client_encryption.create_data_key("local") @@ -2652,18 +2637,16 @@ class TestRangeQueryProse(AsyncEncryptionIntegrationTest): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addAsyncCleanup(self.encrypted_client.aclose) async def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2860,10 +2843,9 @@ class TestRangeQueryDefaultsProse(AsyncEncryptionIntegrationTest): await super().asyncSetUp() await self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) self.key_id = await self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = await self.client_encryption.encrypt( @@ -2896,13 +2878,12 @@ class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest): await self.client.drop_database(self.db) self.key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addAsyncCleanup(self.key_vault.drop) - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addAsyncCleanup(self.client_encryption.close) async def test_01_simple_create(self): coll, _ = await self.client_encryption.create_encrypted_collection( @@ -3118,10 +3099,9 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): async def asyncSetUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = AsyncMongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addAsyncCleanup(self.mongocryptd_client.aclose) hello = await self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 6d589dc01..9c57c15c5 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ from pymongo.asynchronous.database import AsyncDatabase sys.path[0:0] = [""] -from test.utils import EventListener, async_rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( @@ -792,7 +792,7 @@ Bye""" await outfile.readchunk() async def test_grid_in_lazy_connect(self): - client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = AsyncGridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -803,7 +803,7 @@ Bye""" async def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs) + AsyncGridIn((await self.async_rs_or_single_client(w=0)).pymongo_test.fs) async def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -811,7 +811,7 @@ Bye""" chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile: await infile.write(data) diff --git a/test/asynchronous/test_logger.py b/test/asynchronous/test_logger.py index b219d530e..a2e8b35c5 100644 --- a/test/asynchronous/test_logger.py +++ b/test/asynchronous/test_logger.py @@ -16,7 +16,6 @@ from __future__ import annotations import os from test import unittest from test.asynchronous import AsyncIntegrationTest -from test.utils import async_single_client from unittest.mock import patch from bson import json_util @@ -86,7 +85,7 @@ class TestLogger(AsyncIntegrationTest): self.assertEqual(last_3_bytes, str_to_repeat) async def test_logging_without_listeners(self): - c = await async_single_client() + c = await self.async_single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: await c.db.test.insert_one({"x": "1"}) diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index 3f6563ee5..b5d8708dc 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -31,8 +31,6 @@ from test.asynchronous import ( ) from test.utils import ( EventListener, - async_rs_or_single_client, - async_single_client, async_wait_until, ) @@ -57,7 +55,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = EventListener() - cls.client = await async_rs_or_single_client( + cls.client = await cls.unmanaged_async_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False ) @@ -407,7 +405,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): @async_client_context.require_secondaries_count(1) async def test_not_primary_error(self): address = next(iter(await async_client_context.client.secondaries)) - client = await async_single_client(*address, event_listeners=[self.listener]) + client = await self.async_single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. await client.admin.command("ping") self.listener.reset() @@ -1146,7 +1144,7 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await async_single_client() + cls.client = await cls.unmanaged_async_single_client() # Get one (authenticated) socket in the pool. await cls.client.pymongo_test.command("ping") diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 1e1f5659b..d264b5ecb 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,9 +36,7 @@ from test.asynchronous import ( from test.utils import ( EventListener, ExceptionCatchingThread, - async_rs_or_single_client, async_wait_until, - rs_or_single_client, wait_until, ) @@ -90,7 +88,7 @@ class TestSession(AsyncIntegrationTest): await super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await async_rs_or_single_client() + cls.client2 = await cls.unmanaged_async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -105,7 +103,7 @@ class TestSession(AsyncIntegrationTest): async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = await async_rs_or_single_client( + self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addAsyncCleanup(self.client.close) @@ -202,7 +200,7 @@ class TestSession(AsyncIntegrationTest): failures = 0 for _ in range(5): listener = EventListener() - client = async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -285,7 +283,7 @@ class TestSession(AsyncIntegrationTest): async def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -789,8 +787,7 @@ class TestSession(AsyncIntegrationTest): async def test_unacknowledged_writes(self): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = await async_rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener]) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -838,7 +835,7 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): @@ -1153,10 +1150,9 @@ class TestClusterTime(AsyncIntegrationTest): async def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 4034c8e2c..b5d068641 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -17,6 +17,7 @@ from __future__ import annotations import sys from io import BytesIO +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket @@ -25,8 +26,6 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils import ( OvertCommandListener, - async_rs_client, - async_single_client, wait_until, ) from typing import List @@ -59,7 +58,18 @@ _IS_SYNC = False UNPIN_TEST_MAX_ATTEMPTS = 50 -class TestTransactions(AsyncIntegrationTest): +class AsyncTransactionsBase(AsyncSpecRunner): + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + if ( + "secondary" in self.id() + and not async_client_context.is_mongos + and not async_client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") + + +class TestTransactions(AsyncTransactionsBase): RUN_ON_SERVERLESS = True @async_client_context.require_transactions @@ -92,8 +102,7 @@ class TestTransactions(AsyncIntegrationTest): @async_client_context.require_transactions async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = await async_rs_client(w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(w=0) db = client.test coll = db.test await coll.insert_one({}) @@ -150,12 +159,13 @@ class TestTransactions(AsyncIntegrationTest): async def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -178,12 +188,13 @@ class TestTransactions(AsyncIntegrationTest): async def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -307,11 +318,10 @@ class TestTransactions(AsyncIntegrationTest): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test await coll.delete_many({}) listener.reset() - self.addAsyncCleanup(client.close) self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -336,8 +346,7 @@ class TestTransactions(AsyncIntegrationTest): @async_client_context.require_transactions async def test_transaction_direct_connection(self): - client = await async_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_single_client() coll = client.pymongo_test.test # Make sure the collection exists. @@ -393,14 +402,16 @@ class PatchSessionTimeout: client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout -class TestTransactionsConvenientAPI(AsyncIntegrationTest): +class TestTransactionsConvenientAPI(AsyncTransactionsBase): @classmethod async def _setup_class(cls): await super()._setup_class() cls.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append(await async_single_client("{}:{}".format(*address))) + cls.mongos_clients.append( + await cls.unmanaged_async_single_client("{}:{}".format(*address)) + ) @classmethod async def _tearDown_class(cls): @@ -450,8 +461,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest): @async_client_context.require_transactions async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): @@ -479,8 +489,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest): @async_client_context.require_transactions async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): @@ -514,8 +523,7 @@ class TestTransactionsConvenientAPI(AsyncIntegrationTest): @async_client_context.require_transactions async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 71044d153..12cb13c2c 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -25,7 +25,6 @@ from test.utils import ( EventListener, OvertCommandListener, ServerAndTopologyEventListener, - async_rs_client, camel_to_snake, camel_to_snake_args, parse_spec_options, @@ -101,6 +100,8 @@ class AsyncSpecRunner(AsyncIntegrationTest): @classmethod async def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() await super()._tearDown_class() def setUp(self): @@ -527,7 +528,7 @@ class AsyncSpecRunner(AsyncIntegrationTest): host = async_client_context.MULTI_MONGOS_LB_URI elif async_client_context.is_mongos: host = async_client_context.mongos_seeds() - client = await async_rs_client( + client = await self.async_rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 10416ae5f..a7660f2f6 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -18,6 +18,7 @@ from __future__ import annotations import os import sys import unittest +from test import PyMongoTestCase from unittest.mock import patch import pytest @@ -36,7 +37,7 @@ from pymongo.uri_parser import parse_uri pytestmark = pytest.mark.auth_aws -class TestAuthAWS(unittest.TestCase): +class TestAuthAWS(PyMongoTestCase): uri: str @classmethod @@ -69,7 +70,7 @@ class TestAuthAWS(unittest.TestCase): self.skipTest("Not testing cached credentials") # Make a connection to ensure that we enable caching. - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() @@ -79,7 +80,7 @@ class TestAuthAWS(unittest.TestCase): auth.set_cached_credentials(None) self.assertEqual(auth.get_cached_credentials(), None) - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() return auth.get_cached_credentials() @@ -90,8 +91,7 @@ class TestAuthAWS(unittest.TestCase): def test_cache_about_to_expire(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Make the creds about to expire. creds = auth.get_cached_credentials() @@ -107,8 +107,7 @@ class TestAuthAWS(unittest.TestCase): def test_poisoned_cache(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Poison the creds with invalid password. assert creds is not None @@ -130,8 +129,7 @@ class TestAuthAWS(unittest.TestCase): self.assertIsNotNone(creds) os.environ.copy() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) client.get_database().test.find_one() @@ -149,8 +147,7 @@ class TestAuthAWS(unittest.TestCase): auth.set_cached_credentials(None) - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") @@ -166,8 +163,7 @@ class TestAuthAWS(unittest.TestCase): if creds.token: mock_env["AWS_SESSION_TOKEN"] = creds.token - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) with patch.dict(os.environ, mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username) @@ -177,22 +173,19 @@ class TestAuthAWS(unittest.TestCase): mock_env["AWS_ACCESS_KEY_ID"] = "foo" - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") client2.get_database().test.find_one() -class TestAWSLambdaExamples(unittest.TestCase): +class TestAWSLambdaExamples(PyMongoTestCase): def test_shared_client(self): # Start AWS Lambda Example 1 import os - from pymongo import MongoClient - - client = MongoClient(host=os.environ["MONGODB_URI"]) + client = self.simple_client(host=os.environ["MONGODB_URI"]) def lambda_handler(event, context): return client.db.command("ping") @@ -203,9 +196,7 @@ class TestAWSLambdaExamples(unittest.TestCase): # Start AWS Lambda Example 2 import os - from pymongo import MongoClient - - client = MongoClient( + client = self.simple_client( host=os.environ["MONGODB_URI"], authSource="$external", authMechanism="MONGODB-AWS", diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index fa4b7d669..6d31f3db4 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -23,6 +23,7 @@ import unittest import warnings from contextlib import contextmanager from pathlib import Path +from test import PyMongoTestCase from typing import Dict import pytest @@ -56,7 +57,7 @@ globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) pytestmark = pytest.mark.auth_oidc -class OIDCTestBase(unittest.TestCase): +class OIDCTestBase(PyMongoTestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] @@ -94,6 +95,7 @@ class OIDCTestBase(unittest.TestCase): yield finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + client.close() @pytest.mark.auth_oidc @@ -149,7 +151,9 @@ class TestAuthOIDCHuman(OIDCTestBase): if not len(args): args = [self.uri_single] - return MongoClient(*args, authmechanismproperties=props, **kwargs) + client = self.simple_client(*args, authmechanismproperties=props, **kwargs) + + return client def test_1_1_single_principal_implicit_username(self): # Create default OIDC client with authMechanism=MONGODB-OIDC. diff --git a/test/mockupdb/test_cursor.py b/test/mockupdb/test_cursor.py index 46af39c7b..230029721 100644 --- a/test/mockupdb/test_cursor.py +++ b/test/mockupdb/test_cursor.py @@ -29,13 +29,12 @@ except ImportError: from bson.objectid import ObjectId -from pymongo import MongoClient from pymongo.errors import OperationFailure pytestmark = pytest.mark.mockupdb -class TestCursor(unittest.TestCase): +class TestCursor(PyMongoTestCase): def test_getmore_load_balanced(self): server = MockupDB() server.autoresponds( @@ -50,7 +49,7 @@ class TestCursor(unittest.TestCase): server.run() self.addCleanup(server.stop) - client = MongoClient(server.uri, loadBalanced=True) + client = self.simple_client(server.uri, loadBalanced=True) self.addCleanup(client.close) collection = client.db.coll cursor = collection.find() @@ -77,7 +76,7 @@ class TestRetryableErrorCodeCatch(PyMongoTestCase): self.addCleanup(server.stop) server.autoresponds("ismaster", maxWireVersion=6) - client = MongoClient(server.uri) + client = self.simple_client(server.uri) with going(lambda: server.receives(OpMsg({"find": "collection"})).command_err(code=code)): cursor = client.db.collection.find() diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index fe7f21160..a42b3a34e 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -48,8 +48,11 @@ else: def _connect(options): uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}" print(uri) - client = pymongo.MongoClient(uri) - client.admin.command("ping") + try: + client = pymongo.MongoClient(uri) + client.admin.command("ping") + finally: + client.close() class TestOCSP(unittest.TestCase): diff --git a/test/test_auth.py b/test/test_auth.py index fa3d0905b..b311d330b 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -23,16 +23,14 @@ from urllib.parse import quote_plus sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, unittest -from test.utils import ( - AllowListEventListener, - delay, - ignore_deprecations, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, - single_client_noauth, +from test import ( + IntegrationTest, + PyMongoTestCase, + SkipTest, + client_context, + unittest, ) +from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import MongoClient, monitoring from pymongo.auth_shared import _build_credentials_tuple @@ -81,7 +79,7 @@ class AutoAuthenticateThread(threading.Thread): self.success = True -class TestGSSAPI(unittest.TestCase): +class TestGSSAPI(PyMongoTestCase): mech_properties: str service_realm_required: bool @@ -138,7 +136,7 @@ class TestGSSAPI(unittest.TestCase): if not self.service_realm_required: # Without authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -149,11 +147,11 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -166,14 +164,14 @@ class TestGSSAPI(unittest.TestCase): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].collection.find_one() set_name = client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -185,11 +183,11 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -202,13 +200,13 @@ class TestGSSAPI(unittest.TestCase): client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].list_collection_names() @ignore_deprecations @client_context.require_sync def test_gssapi_threaded(self): - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -244,7 +242,7 @@ class TestGSSAPI(unittest.TestCase): set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -267,14 +265,14 @@ class TestGSSAPI(unittest.TestCase): self.assertTrue(thread.success) -class TestSASLPlain(unittest.TestCase): +class TestSASLPlain(PyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") def test_sasl_plain(self): - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -293,12 +291,12 @@ class TestSASLPlain(unittest.TestCase): SASL_PORT, SASL_DB, ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -317,7 +315,7 @@ class TestSASLPlain(unittest.TestCase): SASL_DB, str(set_name), ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() def test_sasl_plain_bad_credentials(self): @@ -331,8 +329,8 @@ class TestSASLPlain(unittest.TestCase): ) return uri - bad_user = MongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): bad_user.admin.command("ping") @@ -354,7 +352,7 @@ class TestSCRAMSHA1(IntegrationTest): def test_scram_sha1(self): host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) ) client.pymongo_test.command("dbstats") @@ -365,7 +363,7 @@ class TestSCRAMSHA1(IntegrationTest): "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" "&replicaSet=%s" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.pymongo_test.command("dbstats") db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -393,7 +391,7 @@ class TestSCRAM(IntegrationTest): "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] ) client.testscram.command("dbstats") @@ -430,36 +428,38 @@ class TestSCRAM(IntegrationTest): ) # Step 2: verify auth success cases - client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram") + client = self.rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram" + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") self.listener.reset() - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] ) client.testscram.command("dbstats") @@ -472,19 +472,19 @@ class TestSCRAM(IntegrationTest): self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="not-a-user", password="pwd", authSource="testscram" ) with self.assertRaises(OperationFailure): @@ -497,7 +497,7 @@ class TestSCRAM(IntegrationTest): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.testscram.command("dbstats") db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -517,12 +517,12 @@ class TestSCRAM(IntegrationTest): "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram", @@ -530,17 +530,17 @@ class TestSCRAM(IntegrationTest): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="IV", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram", @@ -548,25 +548,29 @@ class TestSCRAM(IntegrationTest): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) client.testscram.command("dbstats") def test_cache(self): - client = single_client() + client = self.single_client() credentials = client.options.pool_options._credentials cache = credentials.cache self.assertIsNotNone(cache) @@ -591,8 +595,7 @@ class TestSCRAM(IntegrationTest): coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate - client = rs_or_single_client() - self.addCleanup(client.close) + client = self.rs_or_single_client() coll = client.db.test threads = [] for _ in range(4): @@ -619,7 +622,7 @@ class TestAuthURIOptions(IntegrationTest): def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) self.assertTrue(client.admin.command("dbstats")) if client_context.is_rs: @@ -628,14 +631,14 @@ class TestAuthURIOptions(IntegrationTest): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) self.assertTrue(client.admin.command("dbstats")) db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command("dbstats")) # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -646,7 +649,7 @@ class TestAuthURIOptions(IntegrationTest): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) with self.assertRaises(OperationFailure): client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -655,7 +658,7 @@ class TestAuthURIOptions(IntegrationTest): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -665,7 +668,7 @@ class TestAuthURIOptions(IntegrationTest): "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" "%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) with self.assertRaises(OperationFailure): client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 38e5f19bf..3c3a1a67a 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -20,6 +20,7 @@ import json import os import sys import warnings +from test import PyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _IS_SYNC = True _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.TestCase): +class TestAuthSpec(PyMongoTestCase): pass @@ -54,7 +55,7 @@ def create_test(test_case): warnings.simplefilter("default") self.assertRaises(Exception, MongoClient, uri, connect=False) else: - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/test_bulk.py b/test/test_bulk.py index 63b8c7790..64fd48e8c 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -24,22 +24,13 @@ from pymongo.synchronous.mongo_client import MongoClient sys.path[0:0] = [""] from test import IntegrationTest, client_context, remove_all_users, unittest -from test.utils import ( - rs_or_single_client_noauth, - single_client, - wait_until, -) +from test.utils import wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.common import partition_node -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, -) +from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import * from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern @@ -913,7 +904,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -924,7 +915,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -952,7 +943,7 @@ class TestBulkWriteConcern(BulkTestBase): if cls.w is not None and cls.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = single_client(*partition_node(member)) + cls.secondary = cls.unmanaged_single_client(*partition_node(member)) break @classmethod diff --git a/test/test_change_stream.py b/test/test_change_stream.py index cb19452ae..dae224c5e 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -28,12 +28,17 @@ from typing import no_type_check sys.path[0:0] = [""] -from test import IntegrationTest, Version, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + Version, + client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, - rs_or_single_client, wait_until, ) @@ -69,8 +74,7 @@ class TestChangeStreamBase(IntegrationTest): def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) return client, listener def watched_collection(self, *args, **kwargs): @@ -174,7 +178,7 @@ class APITestsMixin: @no_type_check def test_try_next_runs_one_getmore(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -232,7 +236,7 @@ class APITestsMixin: @no_type_check def test_batch_size_is_honored(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -473,7 +477,7 @@ class ProseSpecTestsMixin: @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) + client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener @@ -1111,7 +1115,7 @@ class TestAllLegacyScenarios(IntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): diff --git a/test/test_client.py b/test/test_client.py index 785139d6a..bc45325f0 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -67,10 +67,6 @@ from test.utils import ( is_greenthread_patched, lazy_client_trial, one, - rs_client, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, wait_until, ) @@ -131,7 +127,7 @@ class ClientUnitTest(UnitTest): @classmethod def _setup_class(cls): - cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @classmethod def _tearDown_class(cls): @@ -142,7 +138,7 @@ class ClientUnitTest(UnitTest): self._caplog = caplog def test_keyword_arg_defaults(self): - client = MongoClient( + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -168,15 +164,17 @@ class ClientUnitTest(UnitTest): self.assertAlmostEqual(12, client.options.server_selection_timeout) def test_connect_timeout(self): - client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = MongoClient( + + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -193,7 +191,7 @@ class ClientUnitTest(UnitTest): self.assertRaises(ConfigurationError, MongoClient, []) def test_max_pool_size_zero(self): - MongoClient(maxPoolSize=0) + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, MongoClient, "/foo") @@ -258,7 +256,7 @@ class ClientUnitTest(UnitTest): self.assertNotIsInstance(client, Iterable) def test_get_default_database(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) @@ -274,7 +272,7 @@ class ClientUnitTest(UnitTest): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -282,7 +280,7 @@ class ClientUnitTest(UnitTest): def test_get_default_database_error(self): # URI with no database. - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -294,11 +292,11 @@ class ClientUnitTest(UnitTest): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) @@ -306,7 +304,7 @@ class ClientUnitTest(UnitTest): def test_get_database_default_error(self): # URI with no database. - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -318,19 +316,19 @@ class ClientUnitTest(UnitTest): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) self.assertEqual(c.read_preference, ReadPreference.NEAREST) @@ -339,26 +337,30 @@ class ClientUnitTest(UnitTest): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = MongoClient("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error - MongoClient(appname="x" * 128) - self.assertRaises(ValueError, MongoClient, appname="x" * 129) + self.simple_client(appname="x" * 128) + with self.assertRaises(ValueError): + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, MongoClient, driver=1) - self.assertRaises(TypeError, MongoClient, driver="abc") - self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + self.simple_client(driver=1) + with self.assertRaises(TypeError): + self.simple_client(driver="abc") + with self.assertRaises(TypeError): + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = MongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -368,7 +370,7 @@ class ClientUnitTest(UnitTest): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = MongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -378,7 +380,7 @@ class ClientUnitTest(UnitTest): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = MongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -387,7 +389,7 @@ class ClientUnitTest(UnitTest): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = MongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -403,7 +405,7 @@ class ClientUnitTest(UnitTest): metadata["driver"]["name"] = "PyMongo" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) @@ -429,7 +431,7 @@ class ClientUnitTest(UnitTest): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = MongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -438,12 +440,12 @@ class ClientUnitTest(UnitTest): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) @@ -465,11 +467,11 @@ class ClientUnitTest(UnitTest): datetime_conversion, ) ) - c = MongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual( @@ -478,8 +480,7 @@ class ClientUnitTest(UnitTest): # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = MongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual( c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] ) @@ -487,7 +488,9 @@ class ClientUnitTest(UnitTest): def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = MongoClient(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") + c = self.simple_client( + uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" + ) clopts = c.options opts = clopts._options @@ -516,7 +519,7 @@ class ClientUnitTest(UnitTest): def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - MongoClient(*args, **kwargs) + self.simple_client(*args, **kwargs) for _, kw in patched_resolver.call_list(): self.assertAlmostEqual(kw["lifetime"], expected_value) @@ -536,15 +539,15 @@ class ClientUnitTest(UnitTest): def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, @@ -552,7 +555,7 @@ class ClientUnitTest(UnitTest): # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, @@ -560,10 +563,10 @@ class ClientUnitTest(UnitTest): # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - MongoClient(ssl=True, tls=False) + self.simple_client(ssl=True, tls=False) def test_event_listeners(self): - c = MongoClient(event_listeners=[], connect=False) + c = self.simple_client(event_listeners=[], connect=False) self.assertEqual(c.options.event_listeners, []) listeners = [ event_loggers.CommandLogger(), @@ -572,11 +575,11 @@ class ClientUnitTest(UnitTest): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = MongoClient(event_listeners=listeners, connect=False) + c = self.simple_client(event_listeners=listeners, connect=False) self.assertEqual(c.options.event_listeners, listeners) def test_client_options(self): - c = MongoClient(connect=False) + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) @@ -606,11 +609,11 @@ class ClientUnitTest(UnitTest): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - MongoClient(host) + MongoClient(host, connect=False) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - MongoClient(host) - MongoClient(multi_host) + MongoClient(host, connect=False) + MongoClient(multi_host, connect=False) logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @@ -628,13 +631,13 @@ class ClientUnitTest(UnitTest): ) for host in normal_hosts: with self.assertWarns(UserWarning): - MongoClient(host) + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - MongoClient(host) + self.simple_client(host) with self.assertWarns(UserWarning): - MongoClient(multi_host) + self.simple_client(multi_host) class TestClient(IntegrationTest): @@ -651,18 +654,17 @@ class TestClient(IntegrationTest): def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - client.close() def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -671,12 +673,11 @@ class TestClient(IntegrationTest): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -685,12 +686,11 @@ class TestClient(IntegrationTest): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn_one: pass @@ -703,16 +703,15 @@ class TestClient(IntegrationTest): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - client.close() def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = rs_or_single_client(minPoolSize=10) + client = self.rs_or_single_client(minPoolSize=10) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( lambda: len(server._pool.conns) == 10, @@ -731,7 +730,7 @@ class TestClient(IntegrationTest): def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -745,7 +744,7 @@ class TestClient(IntegrationTest): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -769,36 +768,38 @@ class TestClient(IntegrationTest): MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - connected(MongoClient(serverSelectionTimeoutMS=10, **kwargs)) + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + connected(c) + c = self.simple_client(host, port, **kwargs) # Override the defaults. No error. - connected(MongoClient(host, port, **kwargs)) + connected(c) # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. - connected(MongoClient(**kwargs)) + c = self.simple_client(**kwargs) + connected(c) def test_init_disconnected(self): host, port = client_context.host, client_context.port - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) - - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.is_mongos, bool) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertFalse(c.primary) self.assertFalse(c.secondaries) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(c.address) # PYTHON-2981 @@ -810,43 +811,44 @@ class TestClient(IntegrationTest): self.assertEqual(c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client(seed, connect=False) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - c = rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) + + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) + # Seeds differ: - self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False)) + self.assertNotEqual(c1, c2) + + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) + # Same seeds but out of order still compares equal: - self.assertEqual( - MongoClient(["a", "b", "c"], connect=False), - MongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client(seed, connect=False) self.assertIn(c, {client_context.client}) - c = rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): @@ -879,9 +881,10 @@ class TestClient(IntegrationTest): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) - client = MongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -899,7 +902,8 @@ class TestClient(IntegrationTest): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") @@ -915,8 +919,7 @@ class TestClient(IntegrationTest): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = rs_or_single_client(document_class=SON) - self.addCleanup(client.close) + client = self.rs_or_single_client(document_class=SON) for doc in client.list_databases(): self.assertIs(type(doc), dict) @@ -955,7 +958,7 @@ class TestClient(IntegrationTest): self.client.drop_database("pymongo_test") if client_context.is_rs: - wc_client = rs_or_single_client(w=len(client_context.nodes) + 1) + wc_client = self.rs_or_single_client(w=len(client_context.nodes) + 1) with self.assertRaises(WriteConcernError): wc_client.drop_database("pymongo_test2") @@ -965,7 +968,7 @@ class TestClient(IntegrationTest): self.assertNotIn("pymongo_test2", dbs) def test_close(self): - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() coll = test_client.pymongo_test.bar test_client.close() with self.assertRaises(InvalidOperation): @@ -975,7 +978,7 @@ class TestClient(IntegrationTest): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() test_client._process_periodic_tasks() @@ -1002,13 +1005,13 @@ class TestClient(IntegrationTest): self.assertTrue(test_client._topology._opened) test_client.close() self.assertFalse(test_client._topology._opened) - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) def test_close_stops_kill_cursors_thread(self): - client = rs_client() + client = self.rs_client() client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1024,7 +1027,7 @@ class TestClient(IntegrationTest): def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = rs_client(connect=False) + client = self.rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1037,19 +1040,15 @@ class TestClient(IntegrationTest): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - client.close() - def test_close_does_not_open_servers(self): - client = rs_client(connect=False) + client = self.rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) client.close() self.assertEqual(topology._servers, {}) def test_close_closes_sockets(self): - client = rs_client() - self.addCleanup(client.close) + client = self.rs_client() client.test.test.find_one() topology = client._topology client.close() @@ -1075,30 +1074,30 @@ class TestClient(IntegrationTest): client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) # No error. - connected(rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth(uri)) + connected(self.rs_or_single_client_noauth(uri)) # No error. connected( - rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) + self.rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) ) # Auth with lazy connection. ( - rs_or_single_client_noauth( + self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = rs_or_single_client_noauth( + bad_client = self.rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1110,7 +1109,7 @@ class TestClient(IntegrationTest): client_context.create_user("admin", "ad min", "pa/ss") self.addCleanup(client_context.drop_user, "admin", "ad min") - c = rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = self.rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1122,13 +1121,13 @@ class TestClient(IntegrationTest): c.server_info() with self.assertRaises(OperationFailure): - (rs_or_single_client_noauth(username="ad min", password="foo")).server_info() + (self.rs_or_single_client_noauth(username="ad min", password="foo")).server_info() @client_context.require_auth @client_context.require_no_fips def test_lazy_auth_raises_operation_failure(self): host = client_context.host - lazy_client = rs_or_single_client_noauth( + lazy_client = self.rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1146,8 +1145,7 @@ class TestClient(IntegrationTest): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = rs_or_single_client(uri) - self.addCleanup(client.close) + client = self.rs_or_single_client(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1156,9 +1154,10 @@ class TestClient(IntegrationTest): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - connected( - MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), + c = self.simple_client( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 ) + connected(c) def test_document_class(self): c = self.client @@ -1169,15 +1168,15 @@ class TestClient(IntegrationTest): self.assertTrue(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) - c = rs_or_single_client(document_class=SON) - self.addCleanup(c.close) + c = self.rs_or_single_client(document_class=SON) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) self.assertTrue(isinstance(db.test.find_one(), SON)) def test_timeouts(self): - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1190,28 +1189,31 @@ class TestClient(IntegrationTest): self.assertEqual(10.5, client.options.server_selection_timeout) def test_socket_timeout_ms_validation(self): - c = rs_or_single_client(socketTimeoutMS=10 * 1000) + c = self.rs_or_single_client(socketTimeoutMS=10 * 1000) self.assertEqual(10, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=None)) + c = connected(self.rs_or_single_client(socketTimeoutMS=None)) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=0)) + c = connected(self.rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=-1) + with self.rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=1e10) + with self.rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS="foo") + with self.rs_or_single_client(socketTimeoutMS="foo"): + pass def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addCleanup(timeout.close) no_timeout.pymongo_test.drop_collection("test") @@ -1256,7 +1258,7 @@ class TestClient(IntegrationTest): self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): - client = rs_or_single_client(waitQueueTimeoutMS=2000) + client = self.rs_or_single_client(waitQueueTimeoutMS=2000) self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) def test_socketKeepAlive(self): @@ -1269,7 +1271,7 @@ class TestClient(IntegrationTest): def test_tz_aware(self): self.assertRaises(ValueError, MongoClient, tz_aware="foo") - aware = rs_or_single_client(tz_aware=True) + aware = self.rs_or_single_client(tz_aware=True) self.addCleanup(aware.close) naive = self.client aware.pymongo_test.drop_collection("test") @@ -1299,8 +1301,7 @@ class TestClient(IntegrationTest): if client_context.is_rs: uri += "/?replicaSet=" + (client_context.replica_set_name or "") - client = rs_or_single_client_noauth(uri) - self.addCleanup(client.close) + client = self.rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1309,7 +1310,7 @@ class TestClient(IntegrationTest): self.assertTrue("pymongo_test_bernie" in dbs) def test_contextlib(self): - client = rs_or_single_client() + client = self.rs_or_single_client() client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1323,7 +1324,7 @@ class TestClient(IntegrationTest): self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): client.pymongo_test.test.find_one() - client = rs_or_single_client() + client = self.rs_or_single_client() with client as client: self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1401,8 +1402,7 @@ class TestClient(IntegrationTest): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = single_client() - self.addCleanup(client.close) + client = self.single_client() client.pymongo_test.test.find_one() pool = get_pool(client) socket_count = len(pool.conns) @@ -1426,8 +1426,7 @@ class TestClient(IntegrationTest): client_context.client.drop_database("test_lazy_connect_w0") self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.insert_one({}) def predicate(): @@ -1435,8 +1434,7 @@ class TestClient(IntegrationTest): wait_until(predicate, "find one document") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) def predicate(): @@ -1444,8 +1442,7 @@ class TestClient(IntegrationTest): wait_until(predicate, "update one document") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.delete_one({}) def predicate(): @@ -1457,8 +1454,7 @@ class TestClient(IntegrationTest): def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addCleanup(client.close) + client = self.rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1484,7 +1480,9 @@ class TestClient(IntegrationTest): # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. - c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) + c = connected( + self.rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + ) # Cause a network error on the actual socket. pool = get_pool(c) @@ -1501,8 +1499,7 @@ class TestClient(IntegrationTest): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - client = single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - + client = self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): client.test.test.find_one() @@ -1512,7 +1509,7 @@ class TestClient(IntegrationTest): # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = rs_client(connect=False, serverSelectionTimeoutMS=100) + client = self.rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1560,7 +1557,7 @@ class TestClient(IntegrationTest): client_context.host, client_context.port, ) - client = single_client(uri, event_listeners=[listener]) + self.single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1569,7 +1566,6 @@ class TestClient(IntegrationTest): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1586,31 +1582,31 @@ class TestClient(IntegrationTest): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) @@ -1618,56 +1614,55 @@ class TestClient(IntegrationTest): # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = single_client(zlibcompressionlevel=level) + client = self.single_client(zlibcompressionlevel=level) # No error client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): - client = rs_or_single_client(minPoolSize=10) - self.addCleanup(client.close) + client = self.rs_or_single_client(minPoolSize=10) client.admin.command("ping") pool = get_pool(client) generation = pool.gen.get_overall() @@ -1713,11 +1708,9 @@ class TestClient(IntegrationTest): def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = rs_or_single_client( + client = self.rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addCleanup(client.close) - # Create a single connection in the pool. client.admin.command("ping") @@ -1747,21 +1740,19 @@ class TestClient(IntegrationTest): @client_context.require_replica_set def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = rs_or_single_client(directConnection=True) + client = self.rs_or_single_client(directConnection=True) client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - client.close() # direct_connection=False should result in RS topology. - client = rs_or_single_client(directConnection=False) + client = self.rs_or_single_client(directConnection=False) client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1781,11 +1772,10 @@ class TestClient(IntegrationTest): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = MongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): client.test.test.find_one() gc.collect() @@ -1798,8 +1788,7 @@ class TestClient(IntegrationTest): @client_context.require_failCommand_fail_point def test_network_error_message(self): - client = single_client(retryReads=False) - self.addCleanup(client.close) + client = self.single_client(retryReads=False) client.admin.command("ping") # connect with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1811,7 +1800,7 @@ class TestClient(IntegrationTest): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") def test_process_periodic_tasks(self): - client = rs_or_single_client() + client = self.rs_or_single_client() coll = client.db.collection coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -1850,11 +1839,11 @@ class TestClient(IntegrationTest): self.assertEqual(client._topology_settings.srv_service_name, "customname") def test_srv_max_hosts_kwarg(self): - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = MongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @@ -1902,10 +1891,10 @@ class TestClient(IntegrationTest): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - with rs_or_single_client(serverSelectionTimeoutMS=10000) as client: - client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = self.rs_or_single_client(serverSelectionTimeoutMS=10000) + client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) def test_handshake_01_aws(self): self._test_handshake( @@ -2001,7 +1990,7 @@ class TestExhaustCursor(IntegrationTest): def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1)) + client = connected(self.rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = get_pool(client) @@ -2024,7 +2013,7 @@ class TestExhaustCursor(IntegrationTest): def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() @@ -2063,7 +2052,7 @@ class TestExhaustCursor(IntegrationTest): def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = connected(self.rs_or_single_client(maxPoolSize=1, retryReads=False)) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2084,7 +2073,7 @@ class TestExhaustCursor(IntegrationTest): def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2133,7 +2122,7 @@ class TestExhaustCursor(IntegrationTest): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2165,7 +2154,7 @@ class TestExhaustCursor(IntegrationTest): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = get_pool(client) @@ -2202,7 +2191,7 @@ class TestClientLazyConnect(IntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.rs_or_single_client(connect=False) @client_context.require_sync def test_insert_one(self): @@ -2336,6 +2325,7 @@ class TestMongoClientFailover(MockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addCleanup(c.close) # Set host-specific information so we can test whether it is reset. diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ee19a0417..ebbdc74c1 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -27,7 +27,6 @@ from test import ( ) from test.utils import ( OvertCommandListener, - rs_or_single_client, ) from unittest.mock import patch @@ -38,7 +37,6 @@ from pymongo.errors import ( InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.write_concern import WriteConcern @@ -97,8 +95,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) models = [] for _ in range(self.max_write_batch_size + 1): @@ -123,8 +120,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -157,11 +153,10 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_failCommand_fail_point def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], retryWrites=False, ) - self.addCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -200,8 +195,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -231,8 +225,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -262,8 +255,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -304,8 +296,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_standalone def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -348,8 +339,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_failCommand_fail_point def test_handles_getMore_error(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -403,8 +393,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) b_repeated = "b" * self.max_bson_object_size @@ -460,8 +449,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) num_models, models = self._setup_namespace_test_models() models.append( @@ -492,8 +480,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_no_serverless def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) num_models, models = self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -530,8 +517,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): @client_context.require_version_min(8, 0, 0, -24) @client_context.require_no_serverless def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = rs_or_single_client() - self.addCleanup(client.close) + client = self.rs_or_single_client() # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -554,8 +540,7 @@ class TestClientBulkWriteCRUD(IntegrationTest): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -580,7 +565,7 @@ class TestClientBulkWriteCSOT(IntegrationTest): def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = rs_or_single_client(timeoutMS=None) + internal_client = self.rs_or_single_client(timeoutMS=None) self.addCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,14 +590,13 @@ class TestClientBulkWriteCSOT(IntegrationTest): ) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", timeoutMS=2000, w="majority", ) - self.addCleanup(client.close) client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) diff --git a/test/test_collation.py b/test/test_collation.py index bedf0a2ea..19df25c1c 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ from __future__ import annotations import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from typing import Any from pymongo.collation import ( @@ -99,7 +99,7 @@ class TestCollation(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation("en_US") cls.warn_context = warnings.catch_warnings() diff --git a/test/test_collection.py b/test/test_collection.py index b68aa74f7..dab59cf1b 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -29,6 +29,7 @@ sys.path[0:0] = [""] from test import ( # TODO: fix sync imports in PYTHON-4528 IntegrationTest, + UnitTest, client_context, unittest, ) @@ -37,8 +38,6 @@ from test.utils import ( EventListener, get_pool, is_mongos, - rs_or_single_client, - single_client, wait_until, ) @@ -81,14 +80,20 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = True -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(UnitTest): """Test Collection features on a client that does not connect.""" db: Database + client: MongoClient @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def _setup_class(cls): + cls.client = MongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + def _tearDown_class(cls): + cls.client.close() def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -1800,8 +1805,7 @@ class TestCollection(IntegrationTest): # Insert enough documents to require more than one batch self.db.test.insert_many([{"i": i} for i in range(150)]) - client = rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) + client = self.rs_or_single_client(maxPoolSize=1) pool = get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2077,7 +2081,7 @@ class TestCollection(IntegrationTest): def test_find_one_and_write_concern(self): listener = EventListener() - db = (single_client(event_listeners=[listener]))[self.db.name] + db = (self.single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/test_comment.py b/test/test_comment.py index 931446ef3..c0f037ea4 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ import sys sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.dbref import DBRef from pymongo.operations import IndexModel @@ -109,7 +109,7 @@ class TestComment(IntegrationTest): @client_context.require_replica_set def test_database_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener]).db + db = self.rs_or_single_client(event_listeners=[listener]).db helpers = [ (db.watch, []), (db.command, ["hello"]), @@ -126,7 +126,7 @@ class TestComment(IntegrationTest): @client_context.require_replica_set def test_client_helpers(self): listener = EventListener() - cli = rs_or_single_client(event_listeners=[listener]) + cli = self.rs_or_single_client(event_listeners=[listener]) helpers = [ (cli.watch, []), (cli.list_databases, []), @@ -141,7 +141,7 @@ class TestComment(IntegrationTest): @client_context.require_version_min(4, 7, -1) def test_collection_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener])[self.db.name] + db = self.rs_or_single_client(event_listeners=[listener])[self.db.name] coll = db.get_collection("test") helpers = [ diff --git a/test/test_common.py b/test/test_common.py index 358cd29b8..3228dc97f 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -21,7 +21,6 @@ import uuid sys.path[0:0] = [""] from test import IntegrationTest, client_context, connected, unittest -from test.utils import rs_or_single_client, single_client from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions @@ -111,10 +110,10 @@ class TestCommon(IntegrationTest): ) def test_write_concern(self): - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertEqual(WriteConcern(), c.write_concern) - c = rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) + c = self.rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) wc = WriteConcern(w=2, wtimeout=1000) self.assertEqual(wc, c.write_concern) @@ -134,7 +133,7 @@ class TestCommon(IntegrationTest): def test_mongo_client(self): pair = client_context.pair - m = rs_or_single_client(w=0) + m = self.rs_or_single_client(w=0) coll = m.pymongo_test.write_concern_test coll.drop() doc = {"_id": ObjectId()} @@ -143,17 +142,19 @@ class TestCommon(IntegrationTest): coll = coll.with_options(write_concern=WriteConcern(w=1)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client() + m = self.rs_or_single_client() coll = m.pymongo_test.write_concern_test new_coll = coll.with_options(write_concern=WriteConcern(w=0)) self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name) + m = self.rs_or_single_client( + f"mongodb://{pair}/", replicaSet=client_context.replica_set_name + ) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client( + m = self.rs_or_single_client( f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name ) @@ -161,8 +162,8 @@ class TestCommon(IntegrationTest): coll.insert_one(doc) # Equality tests - direct = connected(single_client(w=0)) - direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials)) + direct = connected(self.single_client(w=0)) + direct2 = connected(self.single_client(f"mongodb://{pair}/?w=0", **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 9ee3202e1..142af0f9a 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -30,9 +30,6 @@ from test.utils import ( client_context, get_pool, get_pools, - rs_or_single_client, - single_client, - single_client_noauth, wait_until, ) from test.utils_spec_runner import SpecRunnerThread @@ -250,7 +247,7 @@ class TestCMAP(IntegrationTest): else: kill_cursor_frequency = interval / 1000.0 with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): - client = single_client(**opts) + client = self.single_client(**opts) # Update the SD to a known type because the DummyMonitor will not. # Note we cannot simply call topology.on_change because that would # internally call pool.ready() which introduces unexpected @@ -323,13 +320,13 @@ class TestCMAP(IntegrationTest): # Prose tests. Numbers correspond to the prose test number in the spec. # def test_1_client_connection_pool_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. @@ -345,14 +342,14 @@ class TestCMAP(IntegrationTest): def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" - client = rs_or_single_client(uri) + client = self.rs_or_single_client(uri) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) @@ -376,7 +373,7 @@ class TestCMAP(IntegrationTest): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) pool = get_pool(client) @@ -403,7 +400,7 @@ class TestCMAP(IntegrationTest): @client_context.require_no_fips def test_5_check_out_fails_auth_error(self): listener = CMAPListener() - client = single_client_noauth( + client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) self.addCleanup(client.close) @@ -449,7 +446,7 @@ class TestCMAP(IntegrationTest): def test_close_leaves_pool_unpaused(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) client.admin.command("ping") pool = get_pool(client) client.close() diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 674612693..fba767574 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -24,7 +24,6 @@ from test.utils import ( CMAPListener, ensure_all_connected, repl_set_step_down, - rs_or_single_client, ) from bson import SON @@ -43,7 +42,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = CMAPListener() - cls.client = rs_or_single_client( + cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 ) diff --git a/test/test_cursor.py b/test/test_cursor.py index 8e6fade1e..9bc22aca3 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -35,7 +35,6 @@ from test.utils import ( EventListener, OvertCommandListener, ignore_deprecations, - rs_or_single_client, wait_until, ) @@ -230,7 +229,7 @@ class TestCursor(IntegrationTest): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test + coll = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test # Tailable_defaults. coll.find(cursor_type=CursorType.TAILABLE_AWAIT).to_list() @@ -345,8 +344,7 @@ class TestCursor(IntegrationTest): def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) started = listener.started_events @@ -1252,8 +1250,7 @@ class TestCursor(IntegrationTest): self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1291,8 +1288,7 @@ class TestCursor(IntegrationTest): @client_context.require_failCommand_appName def test_timeout_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1349,8 +1345,7 @@ class TestCursor(IntegrationTest): def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1454,7 +1449,7 @@ class TestRawBatchCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1484,7 +1479,7 @@ class TestRawBatchCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1505,7 +1500,7 @@ class TestRawBatchCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1566,7 +1561,7 @@ class TestRawBatchCursor(IntegrationTest): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1632,7 +1627,7 @@ class TestRawBatchCommandCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1663,7 +1658,7 @@ class TestRawBatchCommandCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1687,7 +1682,7 @@ class TestRawBatchCommandCursor(IntegrationTest): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1733,7 +1728,7 @@ class TestRawBatchCommandCursor(IntegrationTest): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1777,8 +1772,7 @@ class TestRawBatchCommandCursor(IntegrationTest): @client_context.require_no_mongos def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.delete_many({}) c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index c30c62b1b..abaa820cb 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -27,7 +27,6 @@ sys.path[0:0] = [""] from test import client_context, unittest from test.test_client import IntegrationTest -from test.utils import rs_client from bson import ( _BUILT_IN_TYPES, @@ -971,7 +970,7 @@ class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustom if codec_options: kwargs["type_registry"] = codec_options.type_registry kwargs["document_class"] = codec_options.document_class - self.watched_target = rs_client(*args, **kwargs) + self.watched_target = self.rs_client(*args, **kwargs) self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 8ba83ab19..a374db550 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -27,8 +27,6 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes from test.utils import ( OvertCommandListener, - rs_client_noauth, - rs_or_single_client, ) pytestmark = pytest.mark.data_lake @@ -65,7 +63,7 @@ class TestDataLakeProse(IntegrationTest): # Test killCursors def test_1(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) next(cursor) @@ -90,13 +88,13 @@ class TestDataLakeProse(IntegrationTest): # Test no auth def test_2(self): - client = rs_client_noauth() + client = self.rs_client_noauth() client.admin.command("ping") # Test with auth def test_3(self): for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: - client = rs_or_single_client(authMechanism=mechanism) + client = self.rs_or_single_client(authMechanism=mechanism) client[self.TEST_DB][self.TEST_COLLECTION].find_one() diff --git a/test/test_database.py b/test/test_database.py index 12d4eb666..fe07f343c 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -28,7 +28,6 @@ from test.test_custom_types import DECIMAL_CODECOPTS from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - rs_or_single_client, wait_until, ) @@ -207,7 +206,7 @@ class TestDatabase(IntegrationTest): def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.capped.drop() db.create_collection("capped", capped=True, size=4096) @@ -234,8 +233,7 @@ class TestDatabase(IntegrationTest): def test_check_exists(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.drop_collection("unique") db.create_collection("unique", check_exists=True) @@ -323,7 +321,7 @@ class TestDatabase(IntegrationTest): self.client.drop_database("pymongo_test") def test_list_collection_names_single_socket(self): - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ef32afbcd..3554619f1 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -22,7 +22,7 @@ import threading sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import IntegrationTest, PyMongoTestCase, unittest from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( @@ -32,9 +32,7 @@ from test.utils import ( assertion_context, client_context, get_pool, - rs_or_single_client, server_name_to_type, - single_client, wait_until, ) from unittest.mock import patch @@ -272,7 +270,7 @@ class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) - client = rs_or_single_client(minPoolSize=N_THREADS) + client = self.rs_or_single_client(minPoolSize=N_THREADS) self.addCleanup(client.close) # Wait for initial discovery. @@ -319,7 +317,7 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = single_client( + client = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) self.addCleanup(client.close) @@ -353,7 +351,7 @@ class TestServerMonitoringMode(IntegrationTest): super().setUp() def test_rtt_connection_is_enabled_stream(self): - client = rs_or_single_client(serverMonitoringMode="stream") + client = self.rs_or_single_client(serverMonitoringMode="stream") self.addCleanup(client.close) client.admin.command("ping") @@ -373,7 +371,7 @@ class TestServerMonitoringMode(IntegrationTest): wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): - client = rs_or_single_client(serverMonitoringMode="poll") + client = self.rs_or_single_client(serverMonitoringMode="poll") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) @@ -387,7 +385,7 @@ class TestServerMonitoringMode(IntegrationTest): ] for env in envs: with patch.dict("os.environ", env): - client = rs_or_single_client(serverMonitoringMode="auto") + client = self.rs_or_single_client(serverMonitoringMode="auto") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) @@ -415,7 +413,7 @@ class TCPServer(socketserver.TCPServer): self.server_close() -class TestHeartbeatStartOrdering(unittest.TestCase): +class TestHeartbeatStartOrdering(PyMongoTestCase): def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) @@ -423,7 +421,7 @@ class TestHeartbeatStartOrdering(unittest.TestCase): server.events = events server_thread = threading.Thread(target=server.handle_request_and_shutdown) server_thread.start() - _c = MongoClient( + _c = self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) ) server_thread.join() diff --git a/test/test_dns.py b/test/test_dns.py index b4c5e3684..f2185efb1 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -22,16 +22,15 @@ import sys sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, PyMongoTestCase, client_context, unittest from test.utils import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri, split_hosts -class TestDNSRepl(unittest.TestCase): +class TestDNSRepl(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" ) @@ -42,7 +41,7 @@ class TestDNSRepl(unittest.TestCase): pass -class TestDNSLoadBalanced(unittest.TestCase): +class TestDNSLoadBalanced(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" ) @@ -53,7 +52,7 @@ class TestDNSLoadBalanced(unittest.TestCase): pass -class TestDNSSharded(unittest.TestCase): +class TestDNSSharded(PyMongoTestCase): TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") load_balanced = False @@ -120,7 +119,7 @@ def create_test(test_case): # tests. copts["tlsAllowInvalidHostnames"] = True - client = MongoClient(uri, **copts) + client = PyMongoTestCase.unmanaged_simple_client(uri, **copts) if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: @@ -133,6 +132,7 @@ def create_test(test_case): client.admin.command("ping") # XXX: we should block until SRV poller runs at least once # and re-run these assertions. + client.close() else: try: parse_uri(uri) @@ -157,37 +157,37 @@ create_tests(TestDNSLoadBalanced) create_tests(TestDNSSharded) -class TestParsingErrors(unittest.TestCase): +class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb.com is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb.com", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://127.0.0.1", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://[::1]", ) class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): - client = MongoClient("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/test_encryption.py b/test/test_encryption.py index 5e02e4d62..96d40c4a3 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -31,7 +31,7 @@ import warnings from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -53,6 +53,7 @@ from test.helpers import ( KMIP_CREDS, LOCAL_MASTER_KEY, ) +from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -61,7 +62,6 @@ from test.utils import ( TopologyEventListener, camel_to_snake_args, is_greenthread_patched, - rs_or_single_client, wait_until, ) from test.utils_spec_runner import SpecRunner @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(PyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = MongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addCleanup(client.close) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ class TestAutoEncryptionOpts(PyMongoTestCase): class TestClientOptions(PyMongoTestCase): def test_default(self): - client = MongoClient(connect=False) - self.addCleanup(client.close) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = MongoClient(auto_encryption_opts=None, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = MongoClient(auto_encryption_opts=opts, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ class EncryptionIntegrationTest(IntegrationTest): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -260,8 +284,7 @@ def bson_data(*paths): class TestClientSimple(EncryptionIntegrationTest): def _test_auto_encrypt(self, opts): - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) # Create the encrypted field's data key. key_vault = create_key_vault( @@ -342,8 +365,7 @@ class TestClientSimple(EncryptionIntegrationTest): def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) client.admin.command("ping") client.close() @@ -360,8 +382,7 @@ class TestClientSimple(EncryptionIntegrationTest): ) def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) def target(): with warnings.catch_warnings(): @@ -375,8 +396,7 @@ class TestClientSimple(EncryptionIntegrationTest): class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -416,8 +436,7 @@ class TestClientMaxWireVersion(IntegrationTest): @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.insert_one({}) @@ -430,8 +449,7 @@ class TestClientMaxWireVersion(IntegrationTest): def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ class TestClientMaxWireVersion(IntegrationTest): class TestExplicitSimple(EncryptionIntegrationTest): def test_encrypt_decrypt(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Use standard UUID representation. key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS) self.addCleanup(key_vault.drop) @@ -493,10 +510,9 @@ class TestExplicitSimple(EncryptionIntegrationTest): self.assertEqual(decrypted_ssn, doc["ssn"]) def test_validation(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -510,10 +526,9 @@ class TestExplicitSimple(EncryptionIntegrationTest): client_encryption.encrypt("str", algo, key_id=Binary(b"123")) def test_bson_errors(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -526,7 +541,7 @@ class TestExplicitSimple(EncryptionIntegrationTest): def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - ClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, @@ -534,10 +549,9 @@ class TestExplicitSimple(EncryptionIntegrationTest): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = ClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = client_encryption_legacy.create_data_key("local") @@ -552,10 +566,9 @@ class TestExplicitSimple(EncryptionIntegrationTest): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption.close) encrypted_standard = client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -571,7 +584,7 @@ class TestExplicitSimple(EncryptionIntegrationTest): self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value) def test_close(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) client_encryption.close() @@ -587,7 +600,7 @@ class TestExplicitSimple(EncryptionIntegrationTest): client_encryption.decrypt(Binary(b"", 6)) def test_with_statement(self): - with ClientEncryption( + with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) as client_encryption: pass @@ -807,7 +820,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.client.db.coll.drop() cls.vault = create_key_vault(cls.client.keyvault.datakeys) @@ -829,10 +842,10 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = ClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -919,8 +932,7 @@ class TestExternalKeyVault(EncryptionIntegrationTest): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd") - self.addCleanup(key_vault_client.close) + key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd") else: key_vault_client = client_context.client opts = AutoEncryptionOpts( @@ -930,15 +942,13 @@ class TestExternalKeyVault(EncryptionIntegrationTest): key_vault_client=key_vault_client, ) - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -984,10 +994,9 @@ class TestViews(EncryptionIntegrationTest): self.addCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): client_encrypted.db.view.insert_one({}) @@ -1044,17 +1053,15 @@ class TestCorpus(EncryptionIntegrationTest): ) self.addCleanup(vault.drop) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1197,7 +1204,7 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1285,7 +1292,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1297,7 +1304,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = ClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1476,7 +1483,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest): self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1488,7 +1495,7 @@ class AzureGCPEncryptionTestMixin: create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, @@ -1517,7 +1524,7 @@ class AzureGCPEncryptionTestMixin: ) insert_listener = AllowListEventListener("insert") - client = rs_or_single_client( + client = self.rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addCleanup(client.close) @@ -1596,19 +1603,17 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): - self.client_test = rs_or_single_client( + self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addCleanup(self.client_test.close) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = rs_or_single_client( + self.client_keyvault = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addCleanup(self.client_keyvault.close) self.client_test.keyvault.datakeys.drop() self.client_test.db.coll.drop() @@ -1619,7 +1624,7 @@ class TestDeadlockProse(EncryptionIntegrationTest): codec_options=OPTS, ) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1635,7 +1640,7 @@ class TestDeadlockProse(EncryptionIntegrationTest): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1653,8 +1658,6 @@ class TestDeadlockProse(EncryptionIntegrationTest): result = client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addCleanup(client_encrypted.close) - def test_case_1(self): self._run_test( max_pool_size=1, @@ -1830,7 +1833,7 @@ class TestDecryptProse(EncryptionIntegrationTest): create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = self.client_encryption.create_data_key("local") @@ -1845,10 +1848,9 @@ class TestDecryptProse(EncryptionIntegrationTest): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = rs_or_single_client( + self.encrypted_client = self.rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addCleanup(self.encrypted_client.close) def test_01_command_error(self): with self.fail_point( @@ -1925,8 +1927,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) with self.assertRaisesRegex(EncryptionError, "Timeout"): client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1940,11 +1941,12 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500") + mongocryptd_client = self.simple_client( + "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" + ) with self.assertRaises(ServerSelectionTimeoutError): mongocryptd_client.admin.command("ping") @@ -1966,15 +1968,13 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest): ], crypt_shared_lib_required=True, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = MongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addCleanup(no_mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): no_mongocryptd_client.db.command("ping") @@ -2008,8 +2008,7 @@ class TestBypassSpawningMongocryptdProse(EncryptionIntegrationTest): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2023,10 +2022,9 @@ class TestKmsTLSProse(EncryptionIntegrationTest): def setUp(self): super().setUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = ClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addCleanup(self.client_encrypted.close) def test_invalid_kms_certificate_expired(self): key = { @@ -2071,36 +2069,32 @@ class TestKmsTLSOptions(EncryptionIntegrationTest): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = ClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = ClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = ClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = ClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2138,7 +2132,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = ClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2220,10 +2214,9 @@ class TestKmsTLSOptions(EncryptionIntegrationTest): def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = ClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2273,7 +2266,7 @@ class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest): self.client = client_context.client create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = self.client_encryption.create_data_key("local", key_alt_names=["def"]) @@ -2311,17 +2304,15 @@ class TestExplicitQueryableEncryption(EncryptionIntegrationTest): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(self.encrypted_client.close) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2444,14 +2435,13 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest): self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``ClientEncryption`` object named ``client_encryption1`` - client_encryption1 = ClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = client_encryption1.create_data_key( @@ -2464,16 +2454,14 @@ class TestRewrapWithSeparateClientEncryption(EncryptionIntegrationTest): ) # Step 5. Create a ``ClientEncryption`` object named ``client_encryption2`` - client2 = rs_or_single_client() - self.addCleanup(client2.close) - client_encryption2 = ClientEncryption( + client2 = self.rs_or_single_client() + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key( @@ -2508,7 +2496,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") def test_01_failure(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2519,7 +2507,7 @@ class TestOnDemandAWSCredentials(EncryptionIntegrationTest): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_02_success(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2539,8 +2527,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest): # MongoClient to use in testing that handles auth/tls/etc, # and cleanup. def MongoClient(**kwargs): - c = rs_or_single_client(**kwargs) - self.addCleanup(c.close) + c = self.rs_or_single_client(**kwargs) return c # Drop data from prior test runs. @@ -2551,7 +2538,7 @@ class TestQueryableEncryptionDocsExample(EncryptionIntegrationTest): # Create two data keys. key_vault_client = MongoClient() - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = client_encryption.create_data_key("local") @@ -2632,18 +2619,16 @@ class TestRangeQueryProse(EncryptionIntegrationTest): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addCleanup(self.encrypted_client.close) def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2838,10 +2823,9 @@ class TestRangeQueryDefaultsProse(EncryptionIntegrationTest): super().setUp() self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) self.key_id = self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = self.client_encryption.encrypt( @@ -2874,13 +2858,12 @@ class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): self.client.drop_database(self.db) self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(self.key_vault.drop) - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addCleanup(self.client_encryption.close) def test_01_simple_create(self): coll, _ = self.client_encryption.create_encrypted_collection( @@ -3096,10 +3079,9 @@ class TestNoSessionsSupport(EncryptionIntegrationTest): def setUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = MongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addCleanup(self.mongocryptd_client.close) hello = self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/test_examples.py b/test/test_examples.py index 296283db2..ebf1d784a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -22,7 +22,7 @@ import threading sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import rs_client, wait_until +from test.utils import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure @@ -1128,7 +1128,7 @@ class TestTransactionExamples(IntegrationTest): self.assertEqual(employee["status"], "Inactive") def MongoClient(_): - return rs_client() + return self.rs_client() uriString = None @@ -1220,7 +1220,7 @@ class TestVersionedApiExamples(IntegrationTest): def test_versioned_api(self): # Versioned API examples def MongoClient(_, server_api): - return rs_client(server_api=server_api, connect=False) + return self.rs_client(server_api=server_api, connect=False) uri = None @@ -1251,7 +1251,7 @@ class TestVersionedApiExamples(IntegrationTest): ): self.skipTest("This test needs MongoDB 5.0.2 or newer") - client = rs_client(server_api=ServerApi("1", strict=True)) + client = self.rs_client(server_api=ServerApi("1", strict=True)) client.db.sales.drop() # Start Versioned API Example 5 diff --git a/test/test_grid_file.py b/test/test_grid_file.py index bd89235b7..fe88aec5f 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ from pymongo.synchronous.database import Database sys.path[0:0] = [""] -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs.errors import NoFile @@ -790,7 +790,7 @@ Bye""" outfile.readchunk() def test_grid_in_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -801,7 +801,7 @@ Bye""" def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - GridIn((rs_or_single_client(w=0)).pymongo_test.fs) + GridIn((self.rs_or_single_client(w=0)).pymongo_test.fs) def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -809,7 +809,7 @@ Bye""" chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test with GridIn(db.fs, chunk_size=chunk_size) as infile: infile.write(data) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 19ec152bd..549dc0b20 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -26,7 +26,7 @@ from unittest.mock import patch sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -411,7 +411,7 @@ class TestGridfs(IntegrationTest): self.assertTrue(iterate_file(f)) def test_gridfs_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) self.assertRaises(ServerSelectionTimeoutError, gfs.list) @@ -492,7 +492,7 @@ class TestGridfs(IntegrationTest): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -519,7 +519,7 @@ class TestGridfsReplicaSet(IntegrationTest): client_context.client.drop_database("gfsreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest") @@ -532,7 +532,7 @@ class TestGridfsReplicaSet(IntegrationTest): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -547,7 +547,7 @@ class TestGridfsReplicaSet(IntegrationTest): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index c3945d105..28adb7051 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -27,7 +27,7 @@ from unittest.mock import patch sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -345,7 +345,7 @@ class TestGridfs(IntegrationTest): self.assertTrue(iterate_file(fstr)) def test_gridfs_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=0) + client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) @@ -391,7 +391,7 @@ class TestGridfs(IntegrationTest): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b"testing") @@ -489,7 +489,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest): client_context.client.drop_database("gfsbucketreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") oid = gfs.upload_from_stream("test_filename", b"foo") @@ -498,7 +498,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -513,7 +513,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 1302df8fd..5e203a33b 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -20,7 +20,7 @@ import sys sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until +from test.utils import HeartbeatEventListener, MockPool, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat @@ -40,7 +40,7 @@ class TestHeartbeatMonitoring(IntegrationTest): raise responses[1] return Hello(responses[1]), 99 - m = single_client( + m = self.single_client( h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool ) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index a4db7395f..23bea4d98 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until +from test.utils import ExceptionCatchingThread, get_pool, wait_until pytestmark = pytest.mark.load_balancer @@ -54,7 +54,7 @@ class TestLB(IntegrationTest): @client_context.require_load_balancer def test_unpin_committed_transaction(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -85,7 +85,7 @@ class TestLB(IntegrationTest): self._test_no_gc_deadlock(create_resource) def _test_no_gc_deadlock(self, create_resource): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -124,7 +124,7 @@ class TestLB(IntegrationTest): @client_context.require_transactions def test_session_gc(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() diff --git a/test/test_logger.py b/test/test_logger.py index c0011ec3a..b3c8e6d17 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -15,7 +15,6 @@ from __future__ import annotations import os from test import IntegrationTest, unittest -from test.utils import single_client from unittest.mock import patch from bson import json_util @@ -85,7 +84,7 @@ class TestLogger(IntegrationTest): self.assertEqual(last_3_bytes, str_to_repeat) def test_logging_without_listeners(self): - c = single_client() + c = self.single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: c.db.test.insert_one({"x": "1"}) diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1b0130f7d..32d09ada9 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -20,15 +20,14 @@ import sys import time import warnings +from pymongo import MongoClient from pymongo.operations import _Op sys.path[0:0] = [""] -from test import client_context, unittest -from test.utils import rs_or_single_client +from test import PyMongoTestCase, client_context, unittest from test.utils_selection_tests import create_selection_tests -from pymongo import MongoClient from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector @@ -40,54 +39,58 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore pass -class TestMaxStaleness(unittest.TestCase): +class TestMaxStaleness(PyMongoTestCase): def test_max_staleness(self): - client = MongoClient() + client = self.simple_client() self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary") + client = self.simple_client("mongodb://a/?readPreference=secondary") self.assertEqual(-1, client.read_preference.max_staleness) # These tests are specified in max-staleness-tests.rst. with self.assertRaises(ConfigurationError): # Default read pref "primary" can't be used with max staleness. - MongoClient("mongodb://a/?maxStalenessSeconds=120") + self.simple_client("mongodb://a/?maxStalenessSeconds=120") with self.assertRaises(ConfigurationError): # Read pref "primary" can't be used with max staleness. - MongoClient("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") + self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") - client = MongoClient("mongodb://host/?maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=secondary&maxStalenessSeconds=120") + client = self.simple_client( + "mongodb://host/?readPreference=secondary&maxStalenessSeconds=120" + ) self.assertEqual(120, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") self.assertEqual(1, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest") + client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest") self.assertEqual(-1, client.read_preference.max_staleness) with self.assertRaises(TypeError): # Prohibit None. - MongoClient(maxStalenessSeconds=None, readPreference="nearest") + self.simple_client(maxStalenessSeconds=None, readPreference="nearest") def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: - rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") self.assertIn("must be an integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest" + ) # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) @@ -96,13 +99,15 @@ class TestMaxStaleness(unittest.TestCase): def test_max_staleness_zero(self): # Zero is too small. with self.assertRaises(ValueError) as ctx: - rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") self.assertIn("must be a positive integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=0&readPreference=nearest") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=0&readPreference=nearest" + ) # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) @@ -111,7 +116,7 @@ class TestMaxStaleness(unittest.TestCase): @client_context.require_replica_set def test_last_write_date(self): # From max-staleness-tests.rst, "Parse lastWriteDate". - client = rs_or_single_client(heartbeatFrequencyMS=500) + client = self.rs_or_single_client(heartbeatFrequencyMS=500) client.pymongo_test.test.insert_one({}) # Wait for the server description to be updated. time.sleep(1) diff --git a/test/test_monitor.py b/test/test_monitor.py index fd82fc1ca..f8e9443fa 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -18,6 +18,7 @@ from __future__ import annotations import gc import subprocess import sys +import warnings from functools import partial sys.path[0:0] = [""] @@ -25,7 +26,6 @@ sys.path[0:0] = [""] from test import IntegrationTest, connected, unittest from test.utils import ( ServerAndTopologyEventListener, - single_client, wait_until, ) @@ -47,30 +47,31 @@ def get_executors(client): return [e for e in executors if e is not None] -def create_client(): - listener = ServerAndTopologyEventListener() - client = single_client(event_listeners=[listener]) - connected(client) - return client - - class TestMonitor(IntegrationTest): + def create_client(self): + listener = ServerAndTopologyEventListener() + client = self.unmanaged_single_client(event_listeners=[listener]) + connected(client) + return client + def test_cleanup_executors_on_client_del(self): - client = create_client() - executors = get_executors(client) - self.assertEqual(len(executors), 4) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self.create_client() + executors = get_executors(client) + self.assertEqual(len(executors), 4) - # Each executor stores a weakref to itself in _EXECUTORS. - executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + # Each executor stores a weakref to itself in _EXECUTORS. + executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] - del executors - del client + del executors + del client - for ref, name in executor_refs: - wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) + for ref, name in executor_refs: + wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) def test_cleanup_executors_on_client_close(self): - client = create_client() + client = self.create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 8322e2991..a0c520ed2 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -31,8 +31,6 @@ from test import ( ) from test.utils import ( EventListener, - rs_or_single_client, - single_client, wait_until, ) @@ -57,7 +55,9 @@ class TestCommandMonitoring(IntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False) + cls.client = cls.unmanaged_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False + ) @classmethod def _tearDown_class(cls): @@ -405,7 +405,7 @@ class TestCommandMonitoring(IntegrationTest): @client_context.require_secondaries_count(1) def test_not_primary_error(self): address = next(iter(client_context.client.secondaries)) - client = single_client(*address, event_listeners=[self.listener]) + client = self.single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. client.admin.command("ping") self.listener.reset() @@ -1144,7 +1144,7 @@ class TestGlobalListener(IntegrationTest): # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = single_client() + cls.client = cls.unmanaged_single_client() # Get one (authenticated) socket in the pool. cls.client.pymongo_test.command("ping") diff --git a/test/test_pooling.py b/test/test_pooling.py index 31259d7b3..3b867965b 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -31,7 +31,7 @@ from pymongo.hello import HelloCompat sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import delay, get_pool, joinall, rs_or_single_client +from test.utils import delay, get_pool, joinall from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions @@ -151,7 +151,7 @@ class _TestPoolingBase(IntegrationTest): def setUp(self): super().setUp() - self.c = rs_or_single_client() + self.c = self.rs_or_single_client() db = self.c[DB] db.unique.drop() db.test.drop() @@ -378,7 +378,7 @@ class TestPooling(_TestPoolingBase): socket_info.close_conn(None) def test_maxConnecting(self): - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) self.client.test.test.insert_one({}) self.addCleanup(self.client.test.test.delete_many, {}) @@ -415,7 +415,7 @@ class TestPooling(_TestPoolingBase): @client_context.require_failCommand_appName def test_csot_timeout_message(self): - client = rs_or_single_client(appName="connectionTimeoutApp") + client = self.rs_or_single_client(appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to pymongo.timeout(). mock_connection_timeout = { @@ -440,7 +440,7 @@ class TestPooling(_TestPoolingBase): @client_context.require_failCommand_appName def test_socket_timeout_message(self): - client = rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") + client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to socketTimeoutMS. mock_connection_timeout = { @@ -479,7 +479,7 @@ class TestPooling(_TestPoolingBase): }, } - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=500, socketTimeoutMS=500, appName="connectionTimeoutApp", @@ -502,7 +502,7 @@ class TestPooling(_TestPoolingBase): class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 - c = rs_or_single_client(maxPoolSize=max_pool_size) + c = self.rs_or_single_client(maxPoolSize=max_pool_size) self.addCleanup(c.close) collection = c[DB].test @@ -538,7 +538,7 @@ class TestPoolMaxSize(_TestPoolingBase): self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): - c = rs_or_single_client(maxPoolSize=None) + c = self.rs_or_single_client(maxPoolSize=None) self.addCleanup(c.close) collection = c[DB].test @@ -570,7 +570,7 @@ class TestPoolMaxSize(_TestPoolingBase): self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): - c = rs_or_single_client(maxPoolSize=0) + c = self.rs_or_single_client(maxPoolSize=0) self.addCleanup(c.close) pool = get_pool(c) self.assertEqual(pool.max_pool_size, float("inf")) diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 97855872c..ea9ce49a3 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,7 +21,7 @@ import unittest sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure @@ -36,7 +36,7 @@ class TestReadConcern(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") @@ -67,7 +67,7 @@ class TestReadConcern(IntegrationTest): def test_read_concern_uri(self): uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority" - client = rs_or_single_client(uri, connect=False) + client = self.rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern("majority"), client.read_concern) def test_invalid_read_concern(self): diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 2cd3195f4..32883399e 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -30,8 +30,6 @@ from test import IntegrationTest, SkipTest, client_context, connected, unittest from test.utils import ( OvertCommandListener, one, - rs_client, - single_client, wait_until, ) from test.version import Version @@ -58,7 +56,7 @@ from pymongo.write_concern import WriteConcern class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): - client = single_client() + client = self.single_client() wait_until(lambda: client.address, "discover primary") selection = Selection.from_topology_description(client._topology.description) @@ -128,7 +126,7 @@ class TestReadPreferencesBase(IntegrationTest): return None def assertReadsFrom(self, expected, **kwargs): - c = rs_client(**kwargs) + c = self.rs_client(**kwargs) wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") used = self.read_from_which_kind(c) @@ -139,7 +137,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase): def test_reads_from_secondary(self): host, port = next(iter(self.client.secondaries)) # Direct connection to a secondary. - client = single_client(host, port) + client = self.single_client(host, port) self.assertFalse(client.is_primary) # Regardless of read preference, we should be able to do @@ -175,19 +173,21 @@ class TestReadPreferences(TestReadPreferencesBase): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) - self.assertRaises(TypeError, rs_client, read_preference="foo") + self.assertRaises(TypeError, self.rs_client, read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) - self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual( + [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,20 +200,22 @@ class TestReadPreferences(TestReadPreferencesBase): def test_threshold_validation(self): self.assertEqual( - 17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms ) self.assertEqual( - 42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms ) self.assertEqual( - 666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms ) - self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms) + self.assertEqual( + 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + ) - self.assertRaises(ValueError, rs_client, localthresholdms=-1) + self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -223,7 +225,7 @@ class TestReadPreferences(TestReadPreferencesBase): for ping_time, host in zip(ping_times, self.client.nodes): ServerDescription._host_to_round_trip_time[host] = ping_time try: - client = connected(rs_client(readPreference="nearest", localThresholdMS=0)) + client = connected(self.rs_client(readPreference="nearest", localThresholdMS=0)) wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes") host = self.read_from_which_host(client) for _ in range(5): @@ -236,7 +238,7 @@ class TestReadPreferences(TestReadPreferencesBase): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}]) + self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -250,7 +252,9 @@ class TestReadPreferences(TestReadPreferencesBase): def test_nearest(self): # With high localThresholdMS, expect to read from any # member - c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds + c = self.rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds data_members = {self.client.primary} | self.client.secondaries @@ -540,7 +544,7 @@ class TestMongosAndReadPreference(IntegrationTest): if client_context.supports_secondary_read_pref: cases["secondary"] = Secondary listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): @@ -667,13 +671,13 @@ class TestMongosAndReadPreference(IntegrationTest): else: self.fail("mongos accepted invalid staleness") - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=120 ).pymongo_test.test # No error coll.find_one() - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=10 ).pymongo_test.test try: diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 3e37e8f9a..67943d495 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,12 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( - EventListener, - disable_replication, - enable_replication, - rs_or_single_client, -) +from test.utils import EventListener from pymongo import DESCENDING from pymongo.errors import ( @@ -51,7 +46,7 @@ class TestReadWriteConcernSpec(IntegrationTest): def test_omit_default_read_write_concern(self): listener = EventListener() # Client with default readConcern and writeConcern - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). @@ -104,7 +99,9 @@ class TestReadWriteConcernSpec(IntegrationTest): def assertWriteOpsRaise(self, write_concern, expected_exception): wc = write_concern.document # Set socket timeout to avoid indefinite stalls - client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000) + client = self.rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) db = client.get_database("pymongo_test") coll = db.test @@ -167,9 +164,9 @@ class TestReadWriteConcernSpec(IntegrationTest): @client_context.require_test_commands def test_raise_wtimeout(self): self.addCleanup(client_context.client.drop_database, "pymongo_test") - self.addCleanup(enable_replication, client_context.client) + self.addCleanup(self.enable_replication, client_context.client) # Disable replication to guarantee a wtimeout error. - disable_replication(client_context.client) + self.disable_replication(client_context.client) self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) @client_context.require_failCommand_fail_point @@ -209,7 +206,7 @@ class TestReadWriteConcernSpec(IntegrationTest): @client_context.require_version_min(4, 9) def test_write_error_details_exposes_errinfo(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) db = client.errinfotest self.addCleanup(client.drop_database, "errinfotest") diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index b0fa42a0c..a1c72bb7b 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -34,7 +34,6 @@ from test import ( from test.utils import ( CMAPListener, OvertCommandListener, - rs_or_single_client, set_fail_point, ) @@ -93,7 +92,9 @@ class TestPoolPausedError(IntegrationTest): self.skipTest("Test is flakey on PyPy") cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -163,13 +164,13 @@ class TestRetryableReads(IntegrationTest): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableReadTest", event_listeners=[listener], diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 2938b7efa..89454ad23 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -28,7 +28,6 @@ from test.utils import ( DeprecationFilter, EventListener, OvertCommandListener, - rs_or_single_client, set_fail_point, ) from test.version import Version @@ -145,7 +144,7 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() - cls.client = rs_or_single_client(retryWrites=True) + cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) cls.db = cls.client.pymongo_test @classmethod @@ -181,7 +180,9 @@ class TestRetryableWrites(IgnoreDeprecationsTest): cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client( + retryWrites=True, event_listeners=[cls.listener] + ) cls.db = cls.client.pymongo_test @classmethod @@ -204,7 +205,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=False, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" @@ -297,7 +298,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = OvertCommandListener() - client = MongoClient( + client = self.simple_client( "somedomainthatdoesntexist.org", serverSelectionTimeoutMS=1, retryWrites=True, @@ -317,7 +318,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): original error. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -443,13 +444,13 @@ class TestRetryableWrites(IgnoreDeprecationsTest): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableWriteTest", event_listeners=[listener], @@ -492,7 +493,7 @@ class TestWriteConcernError(IntegrationTest): @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_RetryableWriteError_error_label(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) # Ensure collection exists. @@ -551,7 +552,9 @@ class TestPoolPausedError(IntegrationTest): def test_pool_paused_error_is_retryable(self): cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -613,7 +616,7 @@ class TestPoolPausedError(IntegrationTest): self, ): cmd_listener = InsertEventListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) client.test.test.drop() self.addCleanup(client.close) cmd_listener.reset() @@ -650,7 +653,7 @@ class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest): the first attempt fails before sending the command. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 8e0a3cbbb..81b208d51 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -25,7 +25,6 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, client_knobs, unittest from test.utils import ( ServerAndTopologyEventListener, - rs_or_single_client, server_name_to_type, wait_until, ) @@ -279,7 +278,7 @@ class TestSdamMonitoring(IntegrationTest): cls.knobs.enable() cls.listener = ServerAndTopologyEventListener() retry_writes = client_context.supports_transactions() - cls.test_client = rs_or_single_client( + cls.test_client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=retry_writes ) cls.coll = cls.test_client[cls.client.db.name].test diff --git a/test/test_server_selection.py b/test/test_server_selection.py index d3526617f..67e9716bf 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -33,7 +33,6 @@ from test import IntegrationTest, client_context, unittest from test.utils import ( EventListener, FunctionCallRecorder, - rs_or_single_client, wait_until, ) from test.utils_selection_tests import ( @@ -76,7 +75,9 @@ class TestCustomServerSelectorFunction(IntegrationTest): # Initialize client with appropriate listeners. listener = EventListener() - client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener]) + client = self.rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener] + ) self.addCleanup(client.close) coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, "testdb") @@ -117,7 +118,7 @@ class TestCustomServerSelectorFunction(IntegrationTest): selector = FunctionCallRecorder(lambda x: x) # Client setup. - mongo_client = rs_or_single_client(server_selector=selector) + mongo_client = self.rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.close) self.addCleanup(mongo_client.drop_database, "testdb") diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 9dced595c..8e030f61e 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -22,7 +22,6 @@ from test.utils import ( OvertCommandListener, SpecTestCreator, get_pool, - rs_client, wait_until, ) from test.utils_selection_tests import create_topology @@ -134,7 +133,7 @@ class TestProse(IntegrationTest): listener = OvertCommandListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. - client = rs_client( + client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", event_listeners=[listener], diff --git a/test/test_session.py b/test/test_session.py index 563b33c70..9f94ded92 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,7 +36,6 @@ from test import ( from test.utils import ( EventListener, ExceptionCatchingThread, - rs_or_single_client, wait_until, ) @@ -88,7 +87,7 @@ class TestSession(IntegrationTest): super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = rs_or_single_client() + cls.client2 = cls.unmanaged_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -103,7 +102,7 @@ class TestSession(IntegrationTest): def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = rs_or_single_client( + self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addCleanup(self.client.close) @@ -200,7 +199,7 @@ class TestSession(IntegrationTest): failures = 0 for _ in range(5): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -283,7 +282,7 @@ class TestSession(IntegrationTest): def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -787,8 +786,7 @@ class TestSession(IntegrationTest): def test_unacknowledged_writes(self): # Ensure the collection exists. self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(w=0, event_listeners=[self.listener]) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -836,7 +834,7 @@ class TestCausalConsistency(UnitTest): @classmethod def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): @@ -1137,8 +1135,7 @@ class TestClusterTime(IntegrationTest): def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 405db14ac..e01552bf7 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -21,7 +21,7 @@ from typing import Any sys.path[0:0] = [""] -from test import client_knobs, unittest +from test import PyMongoTestCase, client_knobs, unittest from test.utils import FunctionCallRecorder, wait_until import pymongo @@ -86,7 +86,7 @@ class SrvPollingKnobs: self.disable() -class TestSrvPolling(unittest.TestCase): +class TestSrvPolling(PyMongoTestCase): BASE_SRV_RESPONSE = [ ("localhost.test.build.10gen.cc", 27017), ("localhost.test.build.10gen.cc", 27018), @@ -167,7 +167,7 @@ class TestSrvPolling(unittest.TestCase): # Patch timeouts to ensure short test running times. with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( @@ -231,7 +231,7 @@ class TestSrvPolling(unittest.TestCase): count_resolver_calls=True, ): # Client uses unpatched method to get initial nodelist - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) @@ -264,8 +264,7 @@ class TestSrvPolling(unittest.TestCase): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -279,8 +278,7 @@ class TestSrvPolling(unittest.TestCase): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -295,8 +293,7 @@ class TestSrvPolling(unittest.TestCase): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) final_topology = set(client.topology_description.server_descriptions()) @@ -305,8 +302,7 @@ class TestSrvPolling(unittest.TestCase): def test_does_not_flipflop(self): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) old = set(client.topology_description.server_descriptions()) sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) @@ -323,7 +319,7 @@ class TestSrvPolling(unittest.TestCase): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient( + client = self.simple_client( "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" ) with SrvPollingKnobs(nodelist_callback=nodelist_callback): @@ -340,7 +336,7 @@ class TestSrvPolling(unittest.TestCase): min_srv_rescan_interval=WAIT_TIME, nodelist_callback=resolver_response, ): - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assertRaises( AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2 ) diff --git a/test/test_ssl.py b/test/test_ssl.py index 5b3855a82..36d7ba12b 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -24,6 +24,7 @@ sys.path[0:0] = [""] from test import ( HAVE_IPADDRESS, IntegrationTest, + PyMongoTestCase, SkipTest, client_context, connected, @@ -82,45 +83,45 @@ MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=cl # use 'localhost' for the hostname of all hosts. -class TestClientSSL(unittest.TestCase): +class TestClientSSL(PyMongoTestCase): @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.") def test_no_ssl_module(self): # Explicit - self.assertRaises(ConfigurationError, MongoClient, ssl=True) + self.assertRaises(ConfigurationError, self.simple_client, ssl=True) # Implied - self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM) @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") @ignore_deprecations def test_config_ssl(self): # Tests various ssl configurations - self.assertRaises(ValueError, MongoClient, ssl="foo") + self.assertRaises(ValueError, self.simple_client, ssl="foo") self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(TypeError, MongoClient, ssl=0) - self.assertRaises(TypeError, MongoClient, ssl=5.5) - self.assertRaises(TypeError, MongoClient, ssl=[]) + self.assertRaises(TypeError, self.simple_client, ssl=0) + self.assertRaises(TypeError, self.simple_client, ssl=5.5) + self.assertRaises(TypeError, self.simple_client, ssl=[]) - self.assertRaises(IOError, MongoClient, tlsCertificateKeyFile="NoSuchFile") - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=True) - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[]) + self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile") + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True) + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[]) # Test invalid combinations self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False + ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False ) @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") @@ -174,7 +175,7 @@ class TestSSL(IntegrationTest): if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -184,7 +185,7 @@ class TestSSL(IntegrationTest): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -201,7 +202,7 @@ class TestSSL(IntegrationTest): "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" ) connected( - MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -215,7 +216,7 @@ class TestSSL(IntegrationTest): # # test that setting tlsCertificateKeyFile causes ssl to be set to True - client = MongoClient( + client = self.simple_client( client_context.host, client_context.port, tlsAllowInvalidCertificates=True, @@ -223,7 +224,7 @@ class TestSSL(IntegrationTest): ) response = client.admin.command(HelloCompat.LEGACY_CMD) if "setName" in response: - client = MongoClient( + client = self.simple_client( client_context.pair, replicaSet=response["setName"], w=len(response["hosts"]), @@ -242,7 +243,7 @@ class TestSSL(IntegrationTest): # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # - client = MongoClient( + client = self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -257,7 +258,7 @@ class TestSSL(IntegrationTest): "Cannot validate hostname in the certificate" ) - client = MongoClient( + client = self.simple_client( "localhost", replicaSet=response["setName"], w=len(response["hosts"]), @@ -270,7 +271,7 @@ class TestSSL(IntegrationTest): self.assertClientWorks(client) if HAVE_IPADDRESS: - client = MongoClient( + client = self.simple_client( "127.0.0.1", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -292,7 +293,7 @@ class TestSSL(IntegrationTest): "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" ) - client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) + client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) self.assertClientWorks(client) @client_context.require_tlsCertificateKeyFile @@ -316,7 +317,7 @@ class TestSSL(IntegrationTest): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -328,7 +329,7 @@ class TestSSL(IntegrationTest): ) connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -343,7 +344,7 @@ class TestSSL(IntegrationTest): if "setName" in response: with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -356,7 +357,7 @@ class TestSSL(IntegrationTest): ) connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -375,7 +376,7 @@ class TestSSL(IntegrationTest): if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -384,7 +385,7 @@ class TestSSL(IntegrationTest): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -395,7 +396,7 @@ class TestSSL(IntegrationTest): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -406,7 +407,7 @@ class TestSSL(IntegrationTest): ) uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000" - connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore + connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore uri_fmt = ( "mongodb://localhost/?ssl=true&tlsCRLFile=%s" @@ -414,7 +415,7 @@ class TestSSL(IntegrationTest): ) with self.assertRaises(ConnectionFailure): connected( - MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -431,12 +432,14 @@ class TestSSL(IntegrationTest): with self.assertRaises(ConnectionFailure): # Server cert is verified but hostname matching fails connected( - MongoClient("server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert is verified. Disable hostname matching. connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsAllowInvalidHostnames=True, @@ -447,12 +450,14 @@ class TestSSL(IntegrationTest): # Server cert and hostname are verified. connected( - MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert and hostname are verified. connected( - MongoClient( + self.simple_client( "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000", **self.credentials, # type: ignore[arg-type] ) @@ -472,7 +477,7 @@ class TestSSL(IntegrationTest): ssl_support.HAVE_WINCERTSTORE = False try: with self.assertRaises(ConfigurationError): - MongoClient("mongodb://localhost/?ssl=true") + self.simple_client("mongodb://localhost/?ssl=true") finally: ssl_support.HAVE_CERTIFI = have_certifi ssl_support.HAVE_WINCERTSTORE = have_wincertstore @@ -536,7 +541,7 @@ class TestSSL(IntegrationTest): ], ) - noauth = MongoClient( + noauth = self.simple_client( client_context.pair, ssl=True, tlsAllowInvalidCertificates=True, @@ -548,7 +553,7 @@ class TestSSL(IntegrationTest): noauth.pymongo_test.test.find_one() listener = EventListener() - auth = MongoClient( + auth = self.simple_client( client_context.pair, authMechanism="MONGODB-X509", ssl=True, @@ -572,7 +577,7 @@ class TestSSL(IntegrationTest): host, port, ) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -580,7 +585,7 @@ class TestSSL(IntegrationTest): client.pymongo_test.test.find_one() uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -593,7 +598,7 @@ class TestSSL(IntegrationTest): port, ) - bad_client = MongoClient( + bad_client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(bad_client.close) @@ -601,7 +606,7 @@ class TestSSL(IntegrationTest): with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() - bad_client = MongoClient( + bad_client = self.simple_client( client_context.pair, username="not the username", authMechanism="MONGODB-X509", @@ -622,7 +627,7 @@ class TestSSL(IntegrationTest): ) try: connected( - MongoClient( + self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, @@ -648,7 +653,7 @@ class TestSSL(IntegrationTest): self.addCleanup(remove, temp_ca_bundle) # Add the CA cert file to the bundle. cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) - with MongoClient( + with self.simple_client( "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle ) as client: self.assertTrue(client.admin.command("ping")) diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 9bca899a4..b3b68703a 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -24,8 +24,6 @@ from test import IntegrationTest, client_context, unittest from test.utils import ( HeartbeatEventListener, ServerEventListener, - rs_or_single_client, - single_client, wait_until, ) @@ -38,7 +36,7 @@ class TestStreamingProtocol(IntegrationTest): def test_failCommand_streaming(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName="failingHeartbeatTest", @@ -107,7 +105,7 @@ class TestStreamingProtocol(IntegrationTest): }, } with self.fail_point(delay_hello): - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name ) self.addCleanup(client.close) @@ -155,7 +153,7 @@ class TestStreamingProtocol(IntegrationTest): } with self.fail_point(fail_hello): start = time.time() - client = single_client( + client = self.single_client( appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 ) self.addCleanup(client.close) @@ -180,7 +178,7 @@ class TestStreamingProtocol(IntegrationTest): @client_context.require_failCommand_appName def test_heartbeat_awaited_flag(self): hb_listener = HeartbeatEventListener() - client = single_client( + client = self.single_client( event_listeners=[hb_listener], heartbeatFrequencyMS=500, appName="heartbeatEventAwaitedFlag", diff --git a/test/test_transactions.py b/test/test_transactions.py index c8c3c32d5..3cecbe9d3 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -17,6 +17,7 @@ from __future__ import annotations import sys from io import BytesIO +from test.utils_spec_runner import SpecRunner from gridfs.synchronous.grid_file import GridFS, GridFSBucket @@ -25,8 +26,6 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.utils import ( OvertCommandListener, - rs_client, - single_client, wait_until, ) from typing import List @@ -59,7 +58,18 @@ _IS_SYNC = True UNPIN_TEST_MAX_ATTEMPTS = 50 -class TestTransactions(IntegrationTest): +class TransactionsBase(SpecRunner): + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + if ( + "secondary" in self.id() + and not client_context.is_mongos + and not client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") + + +class TestTransactions(TransactionsBase): RUN_ON_SERVERLESS = True @client_context.require_transactions @@ -92,8 +102,7 @@ class TestTransactions(IntegrationTest): @client_context.require_transactions def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = rs_client(w=0) - self.addCleanup(client.close) + client = self.rs_client(w=0) db = client.test coll = db.test coll.insert_one({}) @@ -146,12 +155,11 @@ class TestTransactions(IntegrationTest): def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -174,12 +182,11 @@ class TestTransactions(IntegrationTest): def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -303,11 +310,10 @@ class TestTransactions(IntegrationTest): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test coll.delete_many({}) listener.reset() - self.addCleanup(client.close) self.addCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -332,8 +338,7 @@ class TestTransactions(IntegrationTest): @client_context.require_transactions def test_transaction_direct_connection(self): - client = single_client() - self.addCleanup(client.close) + client = self.single_client() coll = client.pymongo_test.test # Make sure the collection exists. @@ -389,14 +394,14 @@ class PatchSessionTimeout: client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout -class TestTransactionsConvenientAPI(IntegrationTest): +class TestTransactionsConvenientAPI(TransactionsBase): @classmethod def _setup_class(cls): super()._setup_class() cls.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) + cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) @classmethod def _tearDown_class(cls): @@ -446,8 +451,7 @@ class TestTransactionsConvenientAPI(IntegrationTest): @client_context.require_transactions def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): @@ -475,8 +479,7 @@ class TestTransactionsConvenientAPI(IntegrationTest): @client_context.require_transactions def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): @@ -508,8 +511,7 @@ class TestTransactionsConvenientAPI(IntegrationTest): @client_context.require_transactions def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): diff --git a/test/test_typing.py b/test/test_typing.py index f423b70a3..6cfe40537 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -68,8 +68,7 @@ except ImportError: sys.path[0:0] = [""] -from test import IntegrationTest, client_context -from test.utils import rs_or_single_client +from test import IntegrationTest, PyMongoTestCase, client_context from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter, encode from bson.raw_bson import RawBSONDocument @@ -194,7 +193,7 @@ class TestPymongo(IntegrationTest): value.items() def test_default_document_type(self) -> None: - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.test.test doc = {"my": "doc"} @@ -366,7 +365,7 @@ class TestDecode(unittest.TestCase): doc["a"] = 2 -class TestDocumentType(unittest.TestCase): +class TestDocumentType(PyMongoTestCase): @only_type_check def test_default(self) -> None: client: MongoClient = MongoClient() @@ -480,7 +479,7 @@ class TestDocumentType(unittest.TestCase): def test_typeddict_find_notrequired(self): if NotRequired is None or ImplicitMovie is None: raise unittest.SkipTest("Python 3.11+ is required to use NotRequired.") - client: MongoClient[ImplicitMovie] = rs_or_single_client() + client: MongoClient[ImplicitMovie] = self.rs_or_single_client() coll = client.test.test coll.insert_one(ImplicitMovie(name="THX-1138", year=1971)) out = coll.find_one({}) diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 7fe8ebd76..7a25a507d 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from pymongo.server_api import ServerApi, ServerApiVersion from pymongo.synchronous.mongo_client import MongoClient @@ -77,7 +77,7 @@ class TestServerApi(IntegrationTest): @client_context.require_version_min(4, 7) def test_command_options(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) @@ -90,7 +90,7 @@ class TestServerApi(IntegrationTest): @client_context.require_transactions def test_command_options_txn(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) diff --git a/test/unified_format.py b/test/unified_format.py index 78fc63878..62211d3d2 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -55,8 +55,6 @@ from test.utils import ( parse_collection_options, parse_spec_options, prepare_spec_arguments, - rs_or_single_client, - single_client, snake_to_camel, wait_until, ) @@ -574,7 +572,7 @@ class EntityMapUtil: ) if uri: kwargs["h"] = uri - client = rs_or_single_client(**kwargs) + client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client self.test.addCleanup(client.close) return @@ -1115,7 +1113,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) + cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs( @@ -1646,7 +1644,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): ) ) - client = single_client("{}:{}".format(*session._pinned_address)) + client = self.single_client("{}:{}".format(*session._pinned_address)) self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) diff --git a/test/utils.py b/test/utils.py index fa198b1c6..6eefd1c7e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -565,151 +565,6 @@ class SpecTestCreator: setattr(self._test_class, new_test.__name__, new_test) -def _connection_string(h): - if h.startswith(("mongodb://", "mongodb+srv://")): - return h - return f"mongodb://{h!s}" - - -def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or client_context.host - port = port or client_context.port - client_options: dict = client_context.default_client_options.copy() - if client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - return MongoClient(uri, port, **client_options) - - -async def _async_mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or await async_client_context.host - port = port or await async_client_context.port - client_options: dict = async_client_context.default_client_options.copy() - if async_client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = async_client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - client = AsyncMongoClient(uri, port, **client_options) - if client._options.connect: - await client.aconnect() - return client - - -def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return _mongo_client(h, p, directConnection=True, **kwargs) - - -def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return _mongo_client(h, p, **kwargs) - - -def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return _mongo_client(h, p, **kwargs) - - -async def async_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -async def async_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return await _async_mongo_client(h, p, directConnection=True, **kwargs) - - -async def async_rs_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return await _async_mongo_client(h, p, **kwargs) - - -async def async_rs_or_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_or_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return await _async_mongo_client(h, p, **kwargs) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a @@ -1108,20 +963,6 @@ def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() -def disable_replication(client): - """Disable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") - - -def enable_replication(client): - """Enable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") - - class ExceptionCatchingThread(threading.Thread): """A thread that stores any exception encountered from run().""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 0b882a8bc..06a40351c 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -29,7 +29,6 @@ from test.utils import ( camel_to_snake_args, parse_spec_options, prepare_spec_arguments, - rs_client, ) from typing import List @@ -101,6 +100,8 @@ class SpecRunner(IntegrationTest): @classmethod def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + client.close() super()._tearDown_class() def setUp(self): @@ -524,7 +525,7 @@ class SpecRunner(IntegrationTest): host = client_context.MULTI_MONGOS_LB_URI elif client_context.is_mongos: host = client_context.mongos_seeds() - client = rs_client( + client = self.rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client From 163e3d4a0db5f49548040afde6a04db502431cd1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 17 Sep 2024 12:56:03 -0400 Subject: [PATCH 4/6] PYTHON-4738 - Make test_encryption.TestClientSimple.test_fork sync-only (#1862) --- test/asynchronous/test_encryption.py | 1 + test/test_encryption.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 3f3714eeb..f29b0f824 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -380,6 +380,7 @@ class TestClientSimple(AsyncEncryptionIntegrationTest): is_greenthread_patched(), "gevent and eventlet do not support POSIX-style forking.", ) + @async_client_context.require_sync async def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) diff --git a/test/test_encryption.py b/test/test_encryption.py index 96d40c4a3..512c92f4d 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -380,6 +380,7 @@ class TestClientSimple(EncryptionIntegrationTest): is_greenthread_patched(), "gevent and eventlet do not support POSIX-style forking.", ) + @client_context.require_sync def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) From 40ebc1644c89f352c35aa100ed6548023101b72a Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 17 Sep 2024 15:16:55 -0500 Subject: [PATCH 5/6] PYTHON-4764 Update to use current supported EVG hosts (#1858) --- .evergreen/config.yml | 50 +++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index a9dc66971..14e3426b3 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2081,10 +2081,10 @@ axes: batchtime: 10080 # 7 days - id: rhel8 display_name: "RHEL 8.x" - run_on: rhel87-small + run_on: rhel8.8-small batchtime: 10080 # 7 days - - id: rhel92-fips - display_name: "RHEL 9.2 FIPS" + - id: rhel9-fips + display_name: "RHEL 9 FIPS" run_on: rhel92-fips batchtime: 10080 # 7 days - id: ubuntu-22.04 @@ -2095,24 +2095,24 @@ axes: display_name: "Ubuntu 20.04" run_on: ubuntu2004-small batchtime: 10080 # 7 days - - id: rhel83-zseries - display_name: "RHEL 8.3 (zSeries)" - run_on: rhel83-zseries-small + - id: rhel8-zseries + display_name: "RHEL 8 (zSeries)" + run_on: rhel8-zseries-small batchtime: 10080 # 7 days variables: SKIP_HATCH: true - - id: rhel81-power8 - display_name: "RHEL 8.1 (POWER8)" - run_on: rhel81-power8-small + - id: rhel8-power8 + display_name: "RHEL 8 (POWER8)" + run_on: rhel8-power-small batchtime: 10080 # 7 days variables: SKIP_HATCH: true - - id: rhel82-arm64 - display_name: "RHEL 8.2 (ARM64)" + - id: rhel8-arm64 + display_name: "RHEL 8 (ARM64)" run_on: rhel82-arm64-small batchtime: 10080 # 7 days variables: - - id: windows-64-vsMulti-small + - id: windows display_name: "Windows 64" run_on: windows-64-vsMulti-small batchtime: 10080 # 7 days @@ -2470,7 +2470,7 @@ buildvariants: - matrix_name: "tests-fips" matrix_spec: platform: - - rhel92-fips + - rhel9-fips auth: "auth" ssl: "ssl" display_name: "${platform} ${auth} ${ssl}" @@ -2547,9 +2547,9 @@ buildvariants: - matrix_name: "test-different-cpu-architectures" matrix_spec: platform: - - rhel83-zseries # Added in 5.0.8 (SERVER-44074) - - rhel81-power8 # Added in 4.2.7 (SERVER-44072) - - rhel82-arm64 # Added in 4.4.2 (SERVER-48282) + - rhel8-zseries # Added in 5.0.8 (SERVER-44074) + - rhel8-power8 # Added in 4.2.7 (SERVER-44072) + - rhel8-arm64 # Added in 4.4.2 (SERVER-48282) auth-ssl: "*" display_name: "${platform} ${auth-ssl}" tasks: @@ -2606,7 +2606,7 @@ buildvariants: - matrix_name: "tests-pyopenssl-windows" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows: "*" auth: "auth" ssl: "ssl" @@ -2698,7 +2698,7 @@ buildvariants: - matrix_name: "tests-windows-python-version" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows: "*" auth-ssl: "*" display_name: "${platform} ${python-version-windows} ${auth-ssl}" @@ -2706,7 +2706,7 @@ buildvariants: - matrix_name: "tests-windows-python-version-32-bit" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows-32: "*" auth-ssl: "*" display_name: "${platform} ${python-version-windows-32} ${auth-ssl}" @@ -2724,7 +2724,7 @@ buildvariants: - matrix_name: "tests-windows-encryption" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows: "*" auth-ssl: "*" encryption: "*" @@ -2733,7 +2733,7 @@ buildvariants: rules: - if: encryption: ["encryption", "encryption_crypt_shared"] - platform: windows-64-vsMulti-small + platform: windows python-version-windows: "*" auth-ssl: "*" then: @@ -2795,7 +2795,7 @@ buildvariants: - matrix_name: "tests-windows-enterprise-auth" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows: "*" auth: "auth" display_name: "Enterprise ${auth} ${platform} ${python-version-windows}" @@ -2907,7 +2907,7 @@ buildvariants: - matrix_name: "ocsp-test-windows" matrix_spec: - platform: windows-64-vsMulti-small + platform: windows python-version-windows: ["3.8", "3.10"] mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] auth: "noauth" @@ -2932,7 +2932,7 @@ buildvariants: - matrix_name: "oidc-auth-test" matrix_spec: - platform: [ rhel8, macos, windows-64-vsMulti-small ] + platform: [ rhel8, macos, windows ] display_name: "OIDC Auth ${platform}" tasks: - name: testoidc_task_group @@ -2981,7 +2981,7 @@ buildvariants: - matrix_name: "aws-auth-test-windows" matrix_spec: - platform: [windows-64-vsMulti-small] + platform: [windows] python-version-windows: "*" display_name: "MONGODB-AWS Auth ${platform} ${python-version-windows}" tasks: From c136684047e54f30c8949d3c9be30017e7cd0213 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:38:24 -0700 Subject: [PATCH 6/6] PYTHON-4585 Cursor.to_list does not apply client's timeoutMS setting (#1860) --- pymongo/asynchronous/command_cursor.py | 3 +++ pymongo/asynchronous/cursor.py | 4 ++- pymongo/synchronous/command_cursor.py | 3 +++ pymongo/synchronous/cursor.py | 4 ++- test/asynchronous/test_cursor.py | 34 +++++++++++++++++++++++++- test/test_cursor.py | 34 +++++++++++++++++++++++++- 6 files changed, 78 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index b2cd345f6..5a4559bd7 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -29,6 +29,7 @@ from typing import ( ) from bson import CodecOptions, _convert_raw_document_lists_to_streams +from pymongo import _csot from pymongo.asynchronous.cursor import _ConnectionManager from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure @@ -77,6 +78,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): self._address = address self._batch_size = batch_size self._max_await_time_ms = max_await_time_ms + self._timeout = self._collection.database.client.options.timeout self._session = session self._explicit_session = explicit_session self._killed = self._id == 0 @@ -385,6 +387,7 @@ class AsyncCommandCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() + @_csot.apply async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index bae77bb30..4b4bb52a8 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -36,7 +36,7 @@ from typing import ( from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo import helpers_shared +from pymongo import _csot, helpers_shared from pymongo.asynchronous.helpers import anext from pymongo.collation import validate_collation_or_none from pymongo.common import ( @@ -196,6 +196,7 @@ class AsyncCursor(Generic[_DocumentType]): self._explain = False self._comment = comment self._max_time_ms = max_time_ms + self._timeout = self._collection.database.client.options.timeout self._max_await_time_ms: Optional[int] = None self._max: Optional[Union[dict[Any, Any], _Sort]] = max self._min: Optional[Union[dict[Any, Any], _Sort]] = min @@ -1290,6 +1291,7 @@ class AsyncCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() + @_csot.apply async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index da05bf1a3..3a4372856 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -29,6 +29,7 @@ from typing import ( ) from bson import CodecOptions, _convert_raw_document_lists_to_streams +from pymongo import _csot from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.message import ( @@ -77,6 +78,7 @@ class CommandCursor(Generic[_DocumentType]): self._address = address self._batch_size = batch_size self._max_await_time_ms = max_await_time_ms + self._timeout = self._collection.database.client.options.timeout self._session = session self._explicit_session = explicit_session self._killed = self._id == 0 @@ -385,6 +387,7 @@ class CommandCursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() + @_csot.apply def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index c352b6409..27a76cf91 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -36,7 +36,7 @@ from typing import ( from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo import helpers_shared +from pymongo import _csot, helpers_shared from pymongo.collation import validate_collation_or_none from pymongo.common import ( validate_is_document_type, @@ -196,6 +196,7 @@ class Cursor(Generic[_DocumentType]): self._explain = False self._comment = comment self._max_time_ms = max_time_ms + self._timeout = self._collection.database.client.options.timeout self._max_await_time_ms: Optional[int] = None self._max: Optional[Union[dict[Any, Any], _Sort]] = max self._min: Optional[Union[dict[Any, Any], _Sort]] = min @@ -1288,6 +1289,7 @@ class Cursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() + @_csot.apply def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d6773d832..33eaacee9 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,6 +34,7 @@ from test.utils import ( AllowListEventListener, EventListener, OvertCommandListener, + delay, ignore_deprecations, wait_until, ) @@ -44,7 +45,7 @@ from pymongo import ASCENDING, DESCENDING from pymongo.asynchronous.cursor import AsyncCursor, CursorType from pymongo.asynchronous.helpers import anext from pymongo.collation import Collation -from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure +from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError from pymongo.operations import _IndexList from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -1410,6 +1411,18 @@ class TestCursor(AsyncIntegrationTest): docs = await c.to_list(3) self.assertEqual(len(docs), 2) + async def test_to_list_csot_applied(self): + client = await self.async_single_client(timeoutMS=500) + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(2): + await client.admin.command("ping") + coll = client.pymongo.test + await coll.insert_many([{} for _ in range(5)]) + cursor = coll.find({"$where": delay(1)}) + with self.assertRaises(PyMongoError) as ctx: + await cursor.to_list() + self.assertTrue(ctx.exception.timeout) + @async_client_context.require_change_streams async def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1439,6 +1452,25 @@ class TestCursor(AsyncIntegrationTest): result = await db.test.aggregate([pipeline]) self.assertEqual(len(await result.to_list(1)), 1) + @async_client_context.require_failCommand_blockConnection + async def test_command_cursor_to_list_csot_applied(self): + client = await self.async_single_client(timeoutMS=500) + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(2): + await client.admin.command("ping") + coll = client.pymongo.test + await coll.insert_many([{} for _ in range(5)]) + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 5}, + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000}, + } + cursor = await coll.aggregate([], batchSize=1) + async with self.fail_point(fail_command): + with self.assertRaises(PyMongoError) as ctx: + await cursor.to_list() + self.assertTrue(ctx.exception.timeout) + class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): diff --git a/test/test_cursor.py b/test/test_cursor.py index 9bc22aca3..d99732aec 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -34,6 +34,7 @@ from test.utils import ( AllowListEventListener, EventListener, OvertCommandListener, + delay, ignore_deprecations, wait_until, ) @@ -42,7 +43,7 @@ from bson import decode_all from bson.code import Code from pymongo import ASCENDING, DESCENDING from pymongo.collation import Collation -from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure +from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure, PyMongoError from pymongo.operations import _IndexList from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -1401,6 +1402,18 @@ class TestCursor(IntegrationTest): docs = c.to_list(3) self.assertEqual(len(docs), 2) + def test_to_list_csot_applied(self): + client = self.single_client(timeoutMS=500) + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(2): + client.admin.command("ping") + coll = client.pymongo.test + coll.insert_many([{} for _ in range(5)]) + cursor = coll.find({"$where": delay(1)}) + with self.assertRaises(PyMongoError) as ctx: + cursor.to_list() + self.assertTrue(ctx.exception.timeout) + @client_context.require_change_streams def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1430,6 +1443,25 @@ class TestCursor(IntegrationTest): result = db.test.aggregate([pipeline]) self.assertEqual(len(result.to_list(1)), 1) + @client_context.require_failCommand_blockConnection + def test_command_cursor_to_list_csot_applied(self): + client = self.single_client(timeoutMS=500) + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(2): + client.admin.command("ping") + coll = client.pymongo.test + coll.insert_many([{} for _ in range(5)]) + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 5}, + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 1000}, + } + cursor = coll.aggregate([], batchSize=1) + with self.fail_point(fail_command): + with self.assertRaises(PyMongoError) as ctx: + cursor.to_list() + self.assertTrue(ctx.exception.timeout) + class TestRawBatchCursor(IntegrationTest): def test_find_raw(self):