diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 19db50de8..94168f200 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] self.client_ref = None self.key_vault_coll = None if self.mongocryptd_client: - await self.mongocryptd_client.close() + await self.mongocryptd_client.aclose() self.mongocryptd_client = None @@ -439,7 +439,7 @@ class _Encrypter: self._closed = True await self._auto_encrypter.close() if self._internal_client: - await self._internal_client.close() + await self._internal_client.aclose() self._internal_client = None diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 27a420bb1..a37914d6f 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -861,6 +861,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self + async def aconnect(self) -> None: + """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" + await self._get_topology() + def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) # Seed the topology with the old one's pid so we can detect clients @@ -1354,13 +1358,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - await self.close() + await self.aclose() # See PYTHON-3084. __iter__ = None def __next__(self) -> NoReturn: - raise TypeError("'MongoClient' object is not iterable") + raise TypeError("'AsyncMongoClient' object is not iterable") next = __next__ @@ -1490,7 +1494,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # command. pass - async def close(self) -> None: + async def aclose(self) -> None: """Cleanup client resources and disconnect from MongoDB. End all server sessions created by this client by sending one or more diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6b9f6231b..bd14311b5 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -860,6 +860,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self + def _connect(self) -> None: + """Explicitly connect to MongoDB synchronously instead of on the first operation.""" + self._get_topology() + def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) # Seed the topology with the old one's pid so we can detect clients diff --git a/test/__init__.py b/test/__init__.py index 2106a37af..ede3e2387 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import base64 +import contextlib import gc import multiprocessing import os @@ -39,8 +40,6 @@ from test.helpers import ( TEST_SERVERLESS, TLS_OPTIONS, SystemCertsPatcher, - _all_users, - _create_user, client_knobs, db_pwd, db_user, @@ -62,9 +61,9 @@ try: except ImportError: HAVE_IPADDRESS = False from contextlib import contextmanager -from functools import wraps +from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator +from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -812,6 +811,12 @@ class ClientContext: func=func, ) + def require_sync(self, func): + """Run a test only if using the synchronous API.""" + return self._require( + lambda: _IS_SYNC, "This test only works with the synchronous API", func=func + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) @@ -919,6 +924,32 @@ class PyMongoTestCase(unittest.TestCase): self.assertEqual(proc.exitcode, 0) +class UnitTest(PyMongoTestCase): + """Async base class for TestCases that don't require a connection to MongoDB.""" + + @classmethod + def setUpClass(cls): + if _IS_SYNC: + cls._setup_class() + else: + asyncio.run(cls._setup_class()) + + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + + @classmethod + def _setup_class(cls): + cls._setup_class() + + @classmethod + def _tearDown_class(cls): + cls._tearDown_class() + + class IntegrationTest(PyMongoTestCase): """Async base class for TestCases that need a connection to MongoDB to pass.""" @@ -933,6 +964,13 @@ class IntegrationTest(PyMongoTestCase): else: asyncio.run(cls._setup_class()) + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + @classmethod @client_context.require_connection def _setup_class(cls): @@ -947,6 +985,10 @@ class IntegrationTest(PyMongoTestCase): else: cls.credentials = {} + @classmethod + def _tearDown_class(cls): + pass + def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" for c in collections: @@ -959,7 +1001,7 @@ class IntegrationTest(PyMongoTestCase): self.addCleanup(patcher.disable) -class MockClientTest(unittest.TestCase): +class MockClientTest(UnitTest): """Base class for TestCases that use MockClient. This class is *not* an IntegrationTest: if properly written, MockClient @@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase): # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. @classmethod - @client_context.require_no_load_balancer def setUpClass(cls): + if _IS_SYNC: + cls._setup_class() + else: + asyncio.run(cls._setup_class()) + + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + + @classmethod + @client_context.require_no_load_balancer + def _setup_class(cls): + pass + + @classmethod + def _tearDown_class(cls): pass def setUp(self): @@ -1051,3 +1111,38 @@ def print_running_clients(): processed.add(obj._topology_id) except ReferenceError: pass + + +def _all_users(db): + return {u["user"] for u in (db.command("usersInfo")).get("users", [])} + + +def _create_user(authdb, user, pwd=None, roles=None, **kwargs): + cmd = SON([("createUser", user)]) + # X509 doesn't use a password + if pwd: + cmd["pwd"] = pwd + cmd["roles"] = roles or ["root"] + cmd.update(**kwargs) + return authdb.command(cmd) + + +def connected(client): + """Convenience to wait for a newly-constructed client to connect.""" + with warnings.catch_warnings(): + # Ignore warning that ping is always routed to primary even + # if client's read preference isn't PRIMARY. + warnings.simplefilter("ignore", UserWarning) + client.admin.command("ping") # Force connection. + + return client + + +def drop_collections(db: Database): + # Drop all non-system collections in this database. + for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}): + db.drop_collection(coll) + + +def remove_all_users(db: Database): + db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w}) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d6d2d9bee..8ccb4ad1c 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import base64 +import contextlib import gc import multiprocessing import os @@ -39,8 +40,6 @@ from test.helpers import ( TEST_SERVERLESS, TLS_OPTIONS, SystemCertsPatcher, - _all_users, - _create_user, client_knobs, db_pwd, db_user, @@ -62,9 +61,9 @@ try: except ImportError: HAVE_IPADDRESS = False from contextlib import asynccontextmanager, contextmanager -from functools import wraps +from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator +from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -184,7 +183,7 @@ class AsyncClientContext: self.connection_attempts.append(f"failed to connect client {client!r}: {exc}") return None finally: - await client.close() + await client.aclose() async def _init_client(self): self.client = await self._connect(host, port) @@ -229,7 +228,7 @@ class AsyncClientContext: if not self.serverless and not IS_SRV: # See if db_user already exists. if not await self._check_user_provided(): - _create_user(self.client.admin, db_user, db_pwd) + await _create_user(self.client.admin, db_user, db_pwd) self.client = await self._connect( host, @@ -304,7 +303,7 @@ class AsyncClientContext: params = self.cmd_line["parsed"].get("setParameter", {}) if params.get("enableTestCommands") == "1": self.test_commands_enabled = True - self.has_ipv6 = self._server_started_with_ipv6() + self.has_ipv6 = await self._server_started_with_ipv6() self.is_mongos = (await self.hello).get("msg") == "isdbgrid" if self.is_mongos: @@ -390,7 +389,7 @@ class AsyncClientContext: ) try: - return db_user in _all_users(client.admin) + return db_user in await _all_users(client.admin) except pymongo.errors.OperationFailure as e: assert e.details is not None msg = e.details.get("errmsg", "") @@ -400,7 +399,7 @@ class AsyncClientContext: else: raise finally: - await client.close() + await client.aclose() def _server_started_with_auth(self): # MongoDB >= 2.0 @@ -482,9 +481,9 @@ class AsyncClientContext: return decorate return make_wrapper(func) - def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): + async def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): kwargs["writeConcern"] = {"w": self.w} - return _create_user(self.client[dbname], user, pwd, roles, **kwargs) + return await _create_user(self.client[dbname], user, pwd, roles, **kwargs) async def drop_user(self, dbname, user): await self.client[dbname].command("dropUser", user, writeConcern={"w": self.w}) @@ -814,6 +813,12 @@ class AsyncClientContext: func=func, ) + def require_sync(self, func): + """Run a test only if using the synchronous API.""" + return self._require( + lambda: _IS_SYNC, "This test only works with the synchronous API", func=func + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) @@ -921,6 +926,32 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): self.assertEqual(proc.exitcode, 0) +class AsyncUnitTest(AsyncPyMongoTestCase): + """Async base class for TestCases that don't require a connection to MongoDB.""" + + @classmethod + def setUpClass(cls): + if _IS_SYNC: + cls._setup_class() + else: + asyncio.run(cls._setup_class()) + + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + + @classmethod + async def _setup_class(cls): + await cls._setup_class() + + @classmethod + async def _tearDown_class(cls): + await cls._tearDown_class() + + class AsyncIntegrationTest(AsyncPyMongoTestCase): """Async base class for TestCases that need a connection to MongoDB to pass.""" @@ -935,6 +966,13 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): else: asyncio.run(cls._setup_class()) + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + @classmethod @async_client_context.require_connection async def _setup_class(cls): @@ -949,6 +987,10 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): else: cls.credentials = {} + @classmethod + async def _tearDown_class(cls): + pass + async def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" for c in collections: @@ -961,7 +1003,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): self.addCleanup(patcher.disable) -class AsyncMockClientTest(unittest.TestCase): +class AsyncMockClientTest(AsyncUnitTest): """Base class for TestCases that use MockClient. This class is *not* an IntegrationTest: if properly written, MockClient @@ -974,8 +1016,26 @@ class AsyncMockClientTest(unittest.TestCase): # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. @classmethod - @async_client_context.require_no_load_balancer def setUpClass(cls): + if _IS_SYNC: + cls._setup_class() + else: + asyncio.run(cls._setup_class()) + + @classmethod + def tearDownClass(cls): + if _IS_SYNC: + cls._tearDown_class() + else: + asyncio.run(cls._tearDown_class()) + + @classmethod + @async_client_context.require_no_load_balancer + async def _setup_class(cls): + pass + + @classmethod + async def _tearDown_class(cls): pass def setUp(self): @@ -1015,7 +1075,7 @@ async def async_teardown(): await c.drop_database("pymongo_test2") await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") - await c.close() + await c.aclose() print_running_clients() @@ -1053,3 +1113,38 @@ def print_running_clients(): processed.add(obj._topology_id) except ReferenceError: pass + + +async def _all_users(db): + return {u["user"] for u in (await db.command("usersInfo")).get("users", [])} + + +async def _create_user(authdb, user, pwd=None, roles=None, **kwargs): + cmd = SON([("createUser", user)]) + # X509 doesn't use a password + if pwd: + cmd["pwd"] = pwd + cmd["roles"] = roles or ["root"] + cmd.update(**kwargs) + return await authdb.command(cmd) + + +async def connected(client): + """Convenience to wait for a newly-constructed client to connect.""" + with warnings.catch_warnings(): + # Ignore warning that ping is always routed to primary even + # if client's read preference isn't PRIMARY. + warnings.simplefilter("ignore", UserWarning) + await client.admin.command("ping") # Force connection. + + return client + + +async def drop_collections(db: AsyncDatabase): + # Drop all non-system collections in this database. + for coll in await db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}): + await db.drop_collection(coll) + + +async def remove_all_users(db: AsyncDatabase): + await db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": async_client_context.w}) diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py new file mode 100644 index 000000000..ed2395bc9 --- /dev/null +++ b/test/asynchronous/pymongo_mocks.py @@ -0,0 +1,252 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for mocking parts of PyMongo to test other parts.""" +from __future__ import annotations + +import contextlib +import weakref +from functools import partial +from test import client_context +from test.asynchronous import async_client_context + +from pymongo import AsyncMongoClient, common +from pymongo.asynchronous.monitor import Monitor +from pymongo.asynchronous.pool import Pool +from pymongo.errors import AutoReconnect, NetworkTimeout +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription + +_IS_SYNC = False + + +class MockPool(Pool): + def __init__(self, client, pair, *args, **kwargs): + # MockPool gets a 'client' arg, regular pools don't. Weakref it to + # avoid cycle with __del__, causing ResourceWarnings in Python 3.3. + self.client = weakref.proxy(client) + self.mock_host, self.mock_port = pair + + # Actually connect to the default server. + Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + client = self.client + host_and_port = f"{self.mock_host}:{self.mock_port}" + if host_and_port in client.mock_down_hosts: + raise AutoReconnect("mock error") + + assert host_and_port in ( + client.mock_standalones + client.mock_members + client.mock_mongoses + ), "bad host: %s" % host_and_port + + async with Pool.checkout(self, handler) as conn: + conn.mock_host = self.mock_host + conn.mock_port = self.mock_port + yield conn + + +class DummyMonitor: + def __init__(self, server_description, topology, pool, topology_settings): + self._server_description = server_description + self.opened = False + + def cancel_check(self): + pass + + def join(self): + pass + + def open(self): + self.opened = True + + def request_check(self): + pass + + def close(self): + self.opened = False + + +class AsyncMockMonitor(Monitor): + def __init__(self, client, server_description, topology, pool, topology_settings): + # MockMonitor gets a 'client' arg, regular monitors don't. Weakref it + # to avoid cycles. + self.client = weakref.proxy(client) + Monitor.__init__(self, server_description, topology, pool, topology_settings) + + async def _check_once(self): + client = self.client + address = self._server_description.address + response, rtt = client.mock_hello("%s:%d" % address) # type: ignore[str-format] + return ServerDescription(address, Hello(response), rtt) + + +class AsyncMockClient(AsyncMongoClient): + def __init__( + self, + standalones, + members, + mongoses, + hello_hosts=None, + arbiters=None, + down_hosts=None, + *args, + **kwargs, + ): + """An AsyncMongoClient connected to the default server, with a mock topology. + + standalones, members, mongoses, arbiters, and down_hosts determine the + configuration of the topology. They are formatted like ['a:1', 'b:2']. + hello_hosts provides an alternative host list for the server's + mocked hello response; see test_connect_with_internal_ips. + """ + self.mock_standalones = standalones[:] + self.mock_members = members[:] + + if self.mock_members: + self.mock_primary = self.mock_members[0] + else: + self.mock_primary = None + + # Hosts that should be considered an arbiter. + self.mock_arbiters = arbiters[:] if arbiters else [] + + if hello_hosts is not None: + self.mock_hello_hosts = hello_hosts + else: + self.mock_hello_hosts = members[:] + + self.mock_mongoses = mongoses[:] + + # Hosts that should raise socket errors. + self.mock_down_hosts = down_hosts[:] if down_hosts else [] + + # Hostname -> (min wire version, max wire version) + self.mock_wire_versions = {} + + # Hostname -> max write batch size + self.mock_max_write_batch_sizes = {} + + # Hostname -> round trip time + self.mock_rtts = {} + + kwargs["_pool_class"] = partial(MockPool, self) + kwargs["_monitor_class"] = partial(AsyncMockMonitor, self) + + client_options = async_client_context.default_client_options.copy() + client_options.update(kwargs) + + super().__init__(*args, **client_options) + + @classmethod + async def get_async_mock_client( + cls, + standalones, + members, + mongoses, + hello_hosts=None, + arbiters=None, + down_hosts=None, + *args, + **kwargs, + ): + c = AsyncMockClient( + standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs + ) + + await c.aconnect() + return c + + def kill_host(self, host): + """Host is like 'a:1'.""" + self.mock_down_hosts.append(host) + + def revive_host(self, host): + """Host is like 'a:1'.""" + self.mock_down_hosts.remove(host) + + def set_wire_version_range(self, host, min_version, max_version): + self.mock_wire_versions[host] = (min_version, max_version) + + def set_max_write_batch_size(self, host, size): + self.mock_max_write_batch_sizes[host] = size + + def mock_hello(self, host): + """Return mock hello response (a dict) and round trip time.""" + if host in self.mock_wire_versions: + min_wire_version, max_wire_version = self.mock_wire_versions[host] + else: + min_wire_version = common.MIN_SUPPORTED_WIRE_VERSION + max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION + + max_write_batch_size = self.mock_max_write_batch_sizes.get( + host, common.MAX_WRITE_BATCH_SIZE + ) + + rtt = self.mock_rtts.get(host, 0) + + # host is like 'a:1'. + if host in self.mock_down_hosts: + raise NetworkTimeout("mock timeout") + + elif host in self.mock_standalones: + response = { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "maxWriteBatchSize": max_write_batch_size, + } + elif host in self.mock_members: + primary = host == self.mock_primary + + # Simulate a replica set member. + response = { + "ok": 1, + HelloCompat.LEGACY_CMD: primary, + "secondary": not primary, + "setName": "rs", + "hosts": self.mock_hello_hosts, + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "maxWriteBatchSize": max_write_batch_size, + } + + if self.mock_primary: + response["primary"] = self.mock_primary + + if host in self.mock_arbiters: + response["arbiterOnly"] = True + response["secondary"] = False + elif host in self.mock_mongoses: + response = { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "msg": "isdbgrid", + "maxWriteBatchSize": max_write_batch_size, + } + else: + # In test_internal_ips(), we try to connect to a host listed + # in hello['hosts'] but not publicly accessible. + raise AutoReconnect("Unknown host: %s" % host) + + return response, rtt + + def _process_periodic_tasks(self): + # Avoid the background thread causing races, e.g. a surprising + # reconnect while we're trying to test a disconnected client. + pass diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py new file mode 100644 index 000000000..051b51254 --- /dev/null +++ b/test/asynchronous/test_client.py @@ -0,0 +1,2500 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the mongo_client module.""" +from __future__ import annotations + +import _thread as thread +import asyncio +import contextlib +import copy +import datetime +import gc +import logging +import os +import re +import signal +import socket +import struct +import subprocess +import sys +import threading +import time +from typing import Iterable, Type, no_type_check +from unittest import mock +from unittest.mock import patch + +import pytest + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import ( + HAVE_IPADDRESS, + AsyncIntegrationTest, + AsyncMockClientTest, + AsyncUnitTest, + SkipTest, + async_client_context, + client_knobs, + connected, + db_pwd, + db_user, + remove_all_users, + unittest, +) +from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.utils import ( + NTHREADS, + CMAPListener, + FunctionCallRecorder, + async_get_pool, + async_rs_client, + async_rs_or_single_client, + async_rs_or_single_client_noauth, + async_single_client, + async_wait_until, + asyncAssertRaisesExactly, + delay, + gevent_monkey_patched, + is_greenthread_patched, + lazy_client_trial, + one, + rs_or_single_client, + wait_until, +) + +import bson +import pymongo +from bson import encode +from bson.codec_options import ( + CodecOptions, + DatetimeConversion, + TypeEncoder, + TypeRegistry, +) +from bson.son import SON +from bson.tz_util import utc +from pymongo import event_loggers, message, monitoring +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor, CursorType +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.helpers import anext +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.asynchronous.pool import ( + AsyncConnection, +) +from pymongo.asynchronous.settings import TOPOLOGY_TYPE +from pymongo.asynchronous.topology import _ErrorContext +from pymongo.client_options import ClientOptions +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.compression_support import _have_snappy, _have_zstd +from pymongo.driver_info import DriverInfo +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + ConnectionFailure, + InvalidName, + InvalidOperation, + InvalidURI, + NetworkTimeout, + OperationFailure, + ServerSelectionTimeoutError, + WriteConcernError, +) +from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions +from pymongo.read_preferences import ReadPreference +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import readable_server_selector, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TopologyDescription +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class AsyncClientUnitTest(AsyncUnitTest): + """AsyncMongoClient tests that don't require a server.""" + + client: AsyncMongoClient + + @classmethod + async def _setup_class(cls): + cls.client = await async_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + + @classmethod + async def _tearDown_class(cls): + await cls.client.aclose() + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + def test_keyword_arg_defaults(self): + client = AsyncMongoClient( + socketTimeoutMS=None, + connectTimeoutMS=20000, + waitQueueTimeoutMS=None, + replicaSet=None, + read_preference=ReadPreference.PRIMARY, + ssl=False, + tlsCertificateKeyFile=None, + tlsAllowInvalidCertificates=True, + tlsCAFile=None, + connect=False, + serverSelectionTimeoutMS=12000, + ) + + options = client.options + pool_opts = options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + # socket.Socket.settimeout takes a float in seconds + self.assertEqual(20.0, pool_opts.connect_timeout) + self.assertEqual(None, pool_opts.wait_queue_timeout) + self.assertEqual(None, pool_opts._ssl_context) + self.assertEqual(None, options.replica_set_name) + self.assertEqual(ReadPreference.PRIMARY, client.read_preference) + self.assertAlmostEqual(12, client.options.server_selection_timeout) + + def test_connect_timeout(self): + client = AsyncMongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + pool_opts = client.options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + self.assertEqual(None, pool_opts.connect_timeout) + client = AsyncMongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + pool_opts = client.options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + self.assertEqual(None, pool_opts.connect_timeout) + client = AsyncMongoClient( + "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False + ) + pool_opts = client.options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + self.assertEqual(None, pool_opts.connect_timeout) + + def test_types(self): + self.assertRaises(TypeError, AsyncMongoClient, 1) + self.assertRaises(TypeError, AsyncMongoClient, 1.14) + self.assertRaises(TypeError, AsyncMongoClient, "localhost", "27017") + self.assertRaises(TypeError, AsyncMongoClient, "localhost", 1.14) + self.assertRaises(TypeError, AsyncMongoClient, "localhost", []) + + self.assertRaises(ConfigurationError, AsyncMongoClient, []) + + def test_max_pool_size_zero(self): + AsyncMongoClient(maxPoolSize=0) + + def test_uri_detection(self): + self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") + self.assertRaises(ConfigurationError, AsyncMongoClient, "://") + self.assertRaises(ConfigurationError, AsyncMongoClient, "foo/") + + def test_get_db(self): + def make_db(base, name): + return base[name] + + self.assertRaises(InvalidName, make_db, self.client, "") + self.assertRaises(InvalidName, make_db, self.client, "te$t") + self.assertRaises(InvalidName, make_db, self.client, "te.t") + self.assertRaises(InvalidName, make_db, self.client, "te\\t") + self.assertRaises(InvalidName, make_db, self.client, "te/t") + self.assertRaises(InvalidName, make_db, self.client, "te st") + + self.assertTrue(isinstance(self.client.test, AsyncDatabase)) + self.assertEqual(self.client.test, self.client["test"]) + self.assertEqual(self.client.test, AsyncDatabase(self.client, "test")) + + def test_get_database(self): + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = self.client.get_database("foo", codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) + + def test_getattr(self): + self.assertTrue(isinstance(self.client["_does_not_exist"], AsyncDatabase)) + + with self.assertRaises(AttributeError) as context: + self.client._does_not_exist + + # Message should be: + # "AttributeError: AsyncMongoClient has no attribute '_does_not_exist'. To + # access the _does_not_exist database, use client['_does_not_exist']". + self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) + + def test_iteration(self): + client = self.client + if "PyPy" in sys.version and sys.version_info < (3, 8, 15): + msg = "'NoneType' object is not callable" + else: + msg = "'AsyncMongoClient' object is not iterable" + # Iteration fails + with self.assertRaisesRegex(TypeError, msg): + for _ in client: # type: ignore[misc] # error: "None" not callable [misc] + break + # Index fails + with self.assertRaises(TypeError): + _ = client[0] + # next fails + with self.assertRaisesRegex(TypeError, "'AsyncMongoClient' object is not iterable"): + _ = next(client) + # .next() fails + with self.assertRaisesRegex(TypeError, "'AsyncMongoClient' object is not iterable"): + _ = client.next() + # Do not implement typing.Iterable. + self.assertNotIsInstance(client, Iterable) + + async def test_get_default_database(self): + c = await async_rs_or_single_client( + "mongodb://%s:%d/foo" + % (await async_client_context.host, await async_client_context.port), + connect=False, + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + # Test that default doesn't override the URI value. + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) + + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) + + c = await async_rs_or_single_client( + "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + connect=False, + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) + + async def test_get_default_database_error(self): + # URI with no database. + c = await async_rs_or_single_client( + "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + connect=False, + ) + self.assertRaises(ConfigurationError, c.get_default_database) + + async def test_get_default_database_with_authsource(self): + # Ensure we distinguish database name from authSource. + uri = "mongodb://%s:%d/foo?authSource=src" % ( + await async_client_context.host, + await async_client_context.port, + ) + c = await async_rs_or_single_client(uri, connect=False) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + + async def test_get_database_default(self): + c = await async_rs_or_single_client( + "mongodb://%s:%d/foo" + % (await async_client_context.host, await async_client_context.port), + connect=False, + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + + async def test_get_database_default_error(self): + # URI with no database. + c = await async_rs_or_single_client( + "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + connect=False, + ) + self.assertRaises(ConfigurationError, c.get_database) + + async def test_get_database_default_with_authsource(self): + # Ensure we distinguish database name from authSource. + uri = "mongodb://%s:%d/foo?authSource=src" % ( + await async_client_context.host, + await async_client_context.port, + ) + c = await async_rs_or_single_client(uri, connect=False) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + + def test_primary_read_pref_with_tags(self): + # No tags allowed with "primary". + with self.assertRaises(ConfigurationError): + AsyncMongoClient("mongodb://host/?readpreferencetags=dc:east") + + with self.assertRaises(ConfigurationError): + AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + + async def test_read_preference(self): + c = await async_rs_or_single_client( + "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode + ) + self.assertEqual(c.read_preference, ReadPreference.NEAREST) + + def test_metadata(self): + metadata = copy.deepcopy(_METADATA) + metadata["application"] = {"name": "foobar"} + client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + client = AsyncMongoClient("foo", 27017, appname="foobar", connect=False) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + # No error + AsyncMongoClient(appname="x" * 128) + self.assertRaises(ValueError, AsyncMongoClient, appname="x" * 129) + # Bad "driver" options. + self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") + self.assertRaises(TypeError, DriverInfo, version="1", platform="a") + self.assertRaises(TypeError, DriverInfo) + self.assertRaises(TypeError, AsyncMongoClient, driver=1) + self.assertRaises(TypeError, AsyncMongoClient, driver="abc") + self.assertRaises(TypeError, AsyncMongoClient, driver=("Foo", "1", "a")) + # Test appending to driver info. + metadata["driver"]["name"] = "PyMongo|FooDriver" + metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) + client = AsyncMongoClient( + "foo", + 27017, + appname="foobar", + driver=DriverInfo("FooDriver", "1.2.3", None), + connect=False, + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) + client = AsyncMongoClient( + "foo", + 27017, + appname="foobar", + driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), + connect=False, + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + # Test truncating driver info metadata. + client = AsyncMongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + client = AsyncMongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + + @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) + def test_container_metadata(self): + metadata = copy.deepcopy(_METADATA) + metadata["env"] = {} + metadata["env"]["container"] = {"orchestrator": "kubernetes"} + client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) + + def test_kwargs_codec_options(self): + class MyFloatType: + def __init__(self, x): + self.__x = x + + @property + def x(self): + return self.__x + + class MyFloatAsIntEncoder(TypeEncoder): + python_type = MyFloatType + + def transform_python(self, value): + return int(value) + + # Ensure codec options are passed in correctly + document_class: Type[SON] = SON + type_registry = TypeRegistry([MyFloatAsIntEncoder()]) + tz_aware = True + uuid_representation_label = "javaLegacy" + unicode_decode_error_handler = "ignore" + tzinfo = utc + c = AsyncMongoClient( + document_class=document_class, + type_registry=type_registry, + tz_aware=tz_aware, + uuidrepresentation=uuid_representation_label, + unicode_decode_error_handler=unicode_decode_error_handler, + tzinfo=tzinfo, + connect=False, + ) + + self.assertEqual(c.codec_options.document_class, document_class) + self.assertEqual(c.codec_options.type_registry, type_registry) + self.assertEqual(c.codec_options.tz_aware, tz_aware) + self.assertEqual( + c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual(c.codec_options.tzinfo, tzinfo) + + async def test_uri_codec_options(self): + # Ensure codec options are passed in correctly + uuid_representation_label = "javaLegacy" + unicode_decode_error_handler = "ignore" + datetime_conversion = "DATETIME_CLAMP" + uri = ( + "mongodb://%s:%d/foo?tz_aware=true&uuidrepresentation=" + "%s&unicode_decode_error_handler=%s" + "&datetime_conversion=%s" + % ( + await async_client_context.host, + await async_client_context.port, + uuid_representation_label, + unicode_decode_error_handler, + datetime_conversion, + ) + ) + c = AsyncMongoClient(uri, connect=False) + + self.assertEqual(c.codec_options.tz_aware, True) + self.assertEqual( + c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) + + # Change the passed datetime_conversion to a number and re-assert. + uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") + c = AsyncMongoClient(uri, connect=False) + + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) + + def test_uri_option_precedence(self): + # Ensure kwarg options override connection string options. + uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" + c = AsyncMongoClient( + uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" + ) + clopts = c.options + opts = clopts._options + + self.assertEqual(opts["tls"], False) + self.assertEqual(clopts.replica_set_name, "newname") + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + + def test_connection_timeout_ms_propagates_to_DNS_resolver(self): + # Patch the resolver. + from pymongo.srv_resolver import _resolve + + patched_resolver = FunctionCallRecorder(_resolve) + pymongo.srv_resolver._resolve = patched_resolver + + def reset_resolver(): + pymongo.srv_resolver._resolve = _resolve + + self.addCleanup(reset_resolver) + + # Setup. + base_uri = "mongodb+srv://test5.test.build.10gen.cc" + connectTimeoutMS = 5000 + expected_kw_value = 5.0 + uri_with_timeout = base_uri + "/?connectTimeoutMS=6000" + expected_uri_value = 6.0 + + def test_scenario(args, kwargs, expected_value): + patched_resolver.reset() + AsyncMongoClient(*args, **kwargs) + for _, kw in patched_resolver.call_list(): + self.assertAlmostEqual(kw["lifetime"], expected_value) + + # No timeout specified. + test_scenario((base_uri,), {}, CONNECT_TIMEOUT) + + # Timeout only specified in connection string. + test_scenario((uri_with_timeout,), {}, expected_uri_value) + + # Timeout only specified in keyword arguments. + kwarg = {"connectTimeoutMS": connectTimeoutMS} + test_scenario((base_uri,), kwarg, expected_kw_value) + + # Timeout specified in both kwargs and connection string. + test_scenario((uri_with_timeout,), kwarg, expected_kw_value) + + def test_uri_security_options(self): + # Ensure that we don't silently override security-related options. + with self.assertRaises(InvalidURI): + AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + + # Matching SSL and TLS options should not cause errors. + c = AsyncMongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + self.assertEqual(c.options._options["tls"], False) + + # Conflicting tlsInsecure options should raise an error. + with self.assertRaises(InvalidURI): + AsyncMongoClient( + "mongodb://localhost/?tlsInsecure=true", + connect=False, + tlsAllowInvalidHostnames=True, + ) + + # Conflicting legacy tlsInsecure options should also raise an error. + with self.assertRaises(InvalidURI): + AsyncMongoClient( + "mongodb://localhost/?tlsInsecure=true", + connect=False, + tlsAllowInvalidCertificates=False, + ) + + # Conflicting kwargs should raise InvalidURI + with self.assertRaises(InvalidURI): + AsyncMongoClient(ssl=True, tls=False) + + def test_event_listeners(self): + c = AsyncMongoClient(event_listeners=[], connect=False) + self.assertEqual(c.options.event_listeners, []) + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] + c = AsyncMongoClient(event_listeners=listeners, connect=False) + self.assertEqual(c.options.event_listeners, listeners) + + def test_client_options(self): + c = AsyncMongoClient(connect=False) + self.assertIsInstance(c.options, ClientOptions) + self.assertIsInstance(c.options.pool_options, PoolOptions) + self.assertEqual(c.options.server_selection_timeout, 30) + self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) + self.assertIsInstance(c.options.retry_writes, bool) + self.assertIsInstance(c.options.retry_reads, bool) + + def test_validate_suggestion(self): + """Validate kwargs in constructor.""" + for typo in ["auth", "Auth", "AUTH"]: + expected = f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + expected = re.escape(expected) + with self.assertRaisesRegex(ConfigurationError, expected): + AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] + + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + def test_detected_environment_logging(self, mock_get_hosts): + normal_hosts = [ + "normal.host.com", + "host.cosmos.azure.com", + "host.docdb.amazonaws.com", + "host.docdb-elastic.amazonaws.com", + ] + srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] + multi_host = ( + "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" + ) + with self.assertLogs("pymongo", level="INFO") as cm: + for host in normal_hosts: + AsyncMongoClient(host) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + AsyncMongoClient(host) + AsyncMongoClient(multi_host) + logs = [record.message for record in cm.records if record.name == "pymongo.client"] + self.assertEqual(len(logs), 7) + + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + def test_detected_environment_warning(self, mock_get_hosts): + with self._caplog.at_level(logging.WARN): + normal_hosts = [ + "host.cosmos.azure.com", + "host.docdb.amazonaws.com", + "host.docdb-elastic.amazonaws.com", + ] + srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] + multi_host = ( + "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" + ) + for host in normal_hosts: + with self.assertWarns(UserWarning): + AsyncMongoClient(host) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + with self.assertWarns(UserWarning): + AsyncMongoClient(host) + with self.assertWarns(UserWarning): + AsyncMongoClient(multi_host) + + +class TestClient(AsyncIntegrationTest): + def test_multiple_uris(self): + with self.assertRaises(ConfigurationError): + AsyncMongoClient( + host=[ + "mongodb+srv://cluster-a.abc12.mongodb.net", + "mongodb+srv://cluster-b.abc12.mongodb.net", + "mongodb+srv://cluster-c.abc12.mongodb.net", + ] + ) + + async def test_max_idle_time_reaper_default(self): + with client_knobs(kill_cursor_frequency=0.1): + # Assert reaper doesn't remove connections when maxIdleTimeMS not set + client = await async_rs_or_single_client() + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn: + pass + self.assertEqual(1, len(server._pool.conns)) + self.assertTrue(conn in server._pool.conns) + await client.aclose() + + async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): + with client_knobs(kill_cursor_frequency=0.1): + # Assert reaper removes idle socket and replaces it with a new one + client = await async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn: + pass + # When the reaper runs at the same time as the get_socket, two + # connections could be created and checked into the pool. + self.assertGreaterEqual(len(server._pool.conns), 1) + wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") + await client.aclose() + + async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): + with client_knobs(kill_cursor_frequency=0.1): + # Assert reaper respects maxPoolSize when adding new connections. + client = await async_rs_or_single_client( + maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1 + ) + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn: + pass + # When the reaper runs at the same time as the get_socket, + # maxPoolSize=1 should prevent two connections from being created. + self.assertEqual(1, len(server._pool.conns)) + wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") + await client.aclose() + + async def test_max_idle_time_reaper_removes_stale(self): + with client_knobs(kill_cursor_frequency=0.1): + # Assert reaper has removed idle socket and NOT replaced it + client = await async_rs_or_single_client(maxIdleTimeMS=500) + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn_one: + pass + # Assert that the pool does not close connections prematurely. + await asyncio.sleep(0.300) + async with server._pool.checkout() as conn_two: + pass + self.assertIs(conn_one, conn_two) + wait_until( + lambda: len(server._pool.conns) == 0, + "stale socket reaped and new one NOT added to the pool", + ) + await client.aclose() + + async def test_min_pool_size(self): + with client_knobs(kill_cursor_frequency=0.1): + client = await async_rs_or_single_client() + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + self.assertEqual(0, len(server._pool.conns)) + + # Assert that pool started up at minPoolSize + client = await async_rs_or_single_client(minPoolSize=10) + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + wait_until( + lambda: len(server._pool.conns) == 10, + "pool initialized with 10 connections", + ) + + # Assert that if a socket is closed, a new one takes its place + async with server._pool.checkout() as conn: + conn.close_conn(None) + wait_until( + lambda: len(server._pool.conns) == 10, + "a closed socket gets replaced from the pool", + ) + self.assertFalse(conn in server._pool.conns) + + async def test_max_idle_time_checkout(self): + # Use high frequency to test _get_socket_no_auth. + with client_knobs(kill_cursor_frequency=99999999): + client = await async_rs_or_single_client(maxIdleTimeMS=500) + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn: + pass + self.assertEqual(1, len(server._pool.conns)) + await asyncio.sleep(1) # Sleep so that the socket becomes stale. + + async with server._pool.checkout() as new_con: + self.assertNotEqual(conn, new_con) + self.assertEqual(1, len(server._pool.conns)) + self.assertFalse(conn in server._pool.conns) + self.assertTrue(new_con in server._pool.conns) + + # Test that connections are reused if maxIdleTimeMS is not set. + client = await async_rs_or_single_client() + server = await (await client._get_topology()).select_server( + readable_server_selector, _Op.TEST + ) + async with server._pool.checkout() as conn: + pass + self.assertEqual(1, len(server._pool.conns)) + await asyncio.sleep(1) + async with server._pool.checkout() as new_con: + self.assertEqual(conn, new_con) + self.assertEqual(1, len(server._pool.conns)) + + async def test_constants(self): + """This test uses AsyncMongoClient explicitly to make sure that host and + port are not overloaded. + """ + host, port = await async_client_context.host, await async_client_context.port + kwargs: dict = async_client_context.default_client_options.copy() + if async_client_context.auth_enabled: + kwargs["username"] = db_user + kwargs["password"] = db_pwd + + # Set bad defaults. + AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" + AsyncMongoClient.PORT = 123456789 + with self.assertRaises(AutoReconnect): + await connected(AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs)) + + # Override the defaults. No error. + await connected(AsyncMongoClient(host, port, **kwargs)) + + # Set good defaults. + AsyncMongoClient.HOST = host + AsyncMongoClient.PORT = port + + # No error. + await connected(AsyncMongoClient(**kwargs)) + + async def test_init_disconnected(self): + host, port = await async_client_context.host, await async_client_context.port + c = await async_rs_or_single_client(connect=False) + # is_primary causes client to block until connected + self.assertIsInstance(await c.is_primary, bool) + + c = await async_rs_or_single_client(connect=False) + self.assertIsInstance(await c.is_mongos, bool) + c = await async_rs_or_single_client(connect=False) + self.assertIsInstance(c.options.pool_options.max_pool_size, int) + self.assertIsInstance(c.nodes, frozenset) + + c = await async_rs_or_single_client(connect=False) + self.assertEqual(c.codec_options, CodecOptions()) + c = await async_rs_or_single_client(connect=False) + self.assertFalse(await c.primary) + self.assertFalse(await c.secondaries) + c = await async_rs_or_single_client(connect=False) + self.assertIsInstance(c.topology_description, TopologyDescription) + self.assertEqual(c.topology_description, c._topology._description) + self.assertIsNone(await c.address) # PYTHON-2981 + await c.admin.command("ping") # connect + if async_client_context.is_rs: + # The primary's host and port are from the replica set config. + self.assertIsNotNone(await c.address) + else: + self.assertEqual(await c.address, (host, port)) + + bad_host = "somedomainthatdoesntexist.org" + c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with self.assertRaises(ConnectionFailure): + await c.pymongo_test.test.find_one() + + async def test_init_disconnected_with_auth(self): + uri = "mongodb://user:pass@somedomainthatdoesntexist" + c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with self.assertRaises(ConnectionFailure): + await c.pymongo_test.test.find_one() + + async def test_equality(self): + seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) + c = await async_rs_or_single_client(seed, connect=False) + self.addAsyncCleanup(c.aclose) + self.assertEqual(async_client_context.client, c) + # Explicitly test inequality + self.assertFalse(async_client_context.client != c) + + c = await async_rs_or_single_client("invalid.com", connect=False) + self.addAsyncCleanup(c.aclose) + self.assertNotEqual(async_client_context.client, c) + self.assertTrue(async_client_context.client != c) + # Seeds differ: + self.assertNotEqual( + AsyncMongoClient("a", connect=False), AsyncMongoClient("b", connect=False) + ) + # Same seeds but out of order still compares equal: + self.assertEqual( + AsyncMongoClient(["a", "b", "c"], connect=False), + AsyncMongoClient(["c", "a", "b"], connect=False), + ) + + async def test_hashable(self): + seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) + c = await async_rs_or_single_client(seed, connect=False) + self.addAsyncCleanup(c.aclose) + self.assertIn(c, {async_client_context.client}) + c = await async_rs_or_single_client("invalid.com", connect=False) + self.addAsyncCleanup(c.aclose) + self.assertNotIn(c, {async_client_context.client}) + + async def test_host_w_port(self): + with self.assertRaises(ValueError): + host = await async_client_context.host + await connected( + AsyncMongoClient( + f"{host}:1234567", + connectTimeoutMS=1, + serverSelectionTimeoutMS=10, + ) + ) + + def test_repr(self): + # Used to test 'eval' below. + import bson + + client = AsyncMongoClient( # type: ignore[type-var] + "mongodb://localhost:27017,localhost:27018/?replicaSet=replset" + "&connectTimeoutMS=12345&w=1&wtimeoutms=100", + connect=False, + document_class=SON, + ) + + the_repr = repr(client) + self.assertIn("AsyncMongoClient(host=", the_repr) + self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) + self.assertIn("connecttimeoutms=12345", the_repr) + self.assertIn("replicaset='replset'", the_repr) + self.assertIn("w=1", the_repr) + self.assertIn("wtimeoutms=100", the_repr) + + self.assertEqual(eval(the_repr), client) + + client = AsyncMongoClient( + "localhost:27017,localhost:27018", + replicaSet="replset", + connectTimeoutMS=12345, + socketTimeoutMS=None, + w=1, + wtimeoutms=100, + connect=False, + ) + the_repr = repr(client) + self.assertIn("AsyncMongoClient(host=", the_repr) + self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) + self.assertIn("connecttimeoutms=12345", the_repr) + self.assertIn("replicaset='replset'", the_repr) + self.assertIn("sockettimeoutms=None", the_repr) + self.assertIn("w=1", the_repr) + self.assertIn("wtimeoutms=100", the_repr) + + self.assertEqual(eval(the_repr), client) + + def test_getters(self): + wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") + + async def test_list_databases(self): + cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] + cursor = await self.client.list_databases() + self.assertIsInstance(cursor, AsyncCommandCursor) + helper_docs = await cursor.to_list() + self.assertTrue(len(helper_docs) > 0) + self.assertEqual(len(helper_docs), len(cmd_docs)) + # PYTHON-3529 Some fields may change between calls, just compare names. + for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): + self.assertIs(type(helper_doc), dict) + self.assertEqual(helper_doc.keys(), cmd_doc.keys()) + client = await async_rs_or_single_client(document_class=SON) + self.addAsyncCleanup(client.aclose) + async for doc in await client.list_databases(): + self.assertIs(type(doc), dict) + + await self.client.pymongo_test.test.insert_one({}) + cursor = await self.client.list_databases(filter={"name": "admin"}) + docs = await cursor.to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(docs[0]["name"], "admin") + + cursor = await self.client.list_databases(nameOnly=True) + async for doc in cursor: + self.assertEqual(["name"], list(doc)) + + async def test_list_database_names(self): + await self.client.pymongo_test.test.insert_one({"dummy": "object"}) + await self.client.pymongo_test_mike.test.insert_one({"dummy": "object"}) + cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] + cmd_names = [doc["name"] for doc in cmd_docs] + + db_names = await self.client.list_database_names() + self.assertTrue("pymongo_test" in db_names) + self.assertTrue("pymongo_test_mike" in db_names) + self.assertEqual(db_names, cmd_names) + + async def test_drop_database(self): + with self.assertRaises(TypeError): + await self.client.drop_database(5) # type: ignore[arg-type] + with self.assertRaises(TypeError): + await self.client.drop_database(None) # type: ignore[arg-type] + + await self.client.pymongo_test.test.insert_one({"dummy": "object"}) + await self.client.pymongo_test2.test.insert_one({"dummy": "object"}) + dbs = await self.client.list_database_names() + self.assertIn("pymongo_test", dbs) + self.assertIn("pymongo_test2", dbs) + await self.client.drop_database("pymongo_test") + + if async_client_context.is_rs: + wc_client = await async_rs_or_single_client(w=len(async_client_context.nodes) + 1) + with self.assertRaises(WriteConcernError): + await wc_client.drop_database("pymongo_test2") + + await self.client.drop_database(self.client.pymongo_test2) + dbs = await self.client.list_database_names() + self.assertNotIn("pymongo_test", dbs) + self.assertNotIn("pymongo_test2", dbs) + + async def test_close(self): + test_client = await async_rs_or_single_client() + coll = test_client.pymongo_test.bar + await test_client.aclose() + with self.assertRaises(InvalidOperation): + await coll.count_documents({}) + + async def test_close_kills_cursors(self): + if sys.platform.startswith("java"): + # We can't figure out how to make this test reliable with Jython. + raise SkipTest("Can't test with Jython") + test_client = await async_rs_or_single_client() + # Kill any cursors possibly queued up by previous tests. + gc.collect() + await test_client._process_periodic_tasks() + + # Add some test data. + coll = test_client.pymongo_test.test_close_kills_cursors + docs_inserted = 1000 + await coll.insert_many([{"i": i} for i in range(docs_inserted)]) + + # Open a cursor and leave it open on the server. + cursor = (await coll.find()).batch_size(10) + self.assertTrue(bool(await anext(cursor))) + self.assertLess(cursor.retrieved, docs_inserted) + + # Open a command cursor and leave it open on the server. + cursor = await coll.aggregate([], batchSize=10) + self.assertTrue(bool(await anext(cursor))) + del cursor + # Required for PyPy, Jython and other Python implementations that + # don't use reference counting garbage collection. + gc.collect() + + # Close the client and ensure the topology is closed. + self.assertTrue(test_client._topology._opened) + await test_client.aclose() + self.assertFalse(test_client._topology._opened) + test_client = await async_rs_or_single_client() + # The killCursors task should not need to re-open the topology. + await test_client._process_periodic_tasks() + self.assertTrue(test_client._topology._opened) + + async def test_close_stops_kill_cursors_thread(self): + client = await async_rs_client() + await client.test.test.find_one() + self.assertFalse(client._kill_cursors_executor._stopped) + + # Closing the client should stop the thread. + await client.aclose() + self.assertTrue(client._kill_cursors_executor._stopped) + + # Reusing the closed client should raise an InvalidOperation error. + with self.assertRaises(InvalidOperation): + await client.admin.command("ping") + # Thread is still stopped. + self.assertTrue(client._kill_cursors_executor._stopped) + + async def test_uri_connect_option(self): + # Ensure that topology is not opened if connect=False. + client = await async_rs_client(connect=False) + self.assertFalse(client._topology._opened) + + # Ensure kill cursors thread has not been started. + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + + # Using the client should open topology and start the thread. + await client.admin.command("ping") + self.assertTrue(client._topology._opened) + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + + # Tear down. + await client.aclose() + + async def test_close_does_not_open_servers(self): + client = await async_rs_client(connect=False) + topology = client._topology + self.assertEqual(topology._servers, {}) + await client.aclose() + self.assertEqual(topology._servers, {}) + + async def test_close_closes_sockets(self): + client = await async_rs_client() + self.addAsyncCleanup(client.aclose) + await client.test.test.find_one() + topology = client._topology + await client.aclose() + for server in topology._servers.values(): + self.assertFalse(server._pool.conns) + self.assertTrue(server._monitor._executor._stopped) + self.assertTrue(server._monitor._rtt_monitor._executor._stopped) + self.assertFalse(server._monitor._pool.conns) + self.assertFalse(server._monitor._rtt_monitor._pool.conns) + + def test_bad_uri(self): + with self.assertRaises(InvalidURI): + AsyncMongoClient("http://localhost") + + @async_client_context.require_auth + @async_client_context.require_no_fips + async def test_auth_from_uri(self): + host, port = await async_client_context.host, await async_client_context.port + await async_client_context.create_user("admin", "admin", "pass") + self.addAsyncCleanup(async_client_context.drop_user, "admin", "admin") + self.addAsyncCleanup(remove_all_users, self.client.pymongo_test) + + await async_client_context.create_user( + "pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"] + ) + + with self.assertRaises(OperationFailure): + await connected( + await async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) + ) + + # No error. + await connected( + await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + ) + + # Wrong database. + uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) + with self.assertRaises(OperationFailure): + await connected(await async_rs_or_single_client_noauth(uri)) + + # No error. + await connected( + await async_rs_or_single_client_noauth( + "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) + ) + ) + + # Auth with lazy connection. + await ( + await async_rs_or_single_client_noauth( + "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False + ) + ).pymongo_test.test.find_one() + + # Wrong password. + bad_client = await async_rs_or_single_client_noauth( + "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False + ) + + with self.assertRaises(OperationFailure): + await bad_client.pymongo_test.test.find_one() + + @async_client_context.require_auth + async def test_username_and_password(self): + await async_client_context.create_user("admin", "ad min", "pa/ss") + self.addAsyncCleanup(async_client_context.drop_user, "admin", "ad min") + + c = await async_rs_or_single_client_noauth(username="ad min", password="pa/ss") + + # Username and password aren't in strings that will likely be logged. + self.assertNotIn("ad min", repr(c)) + self.assertNotIn("ad min", str(c)) + self.assertNotIn("pa/ss", repr(c)) + self.assertNotIn("pa/ss", str(c)) + + # Auth succeeds. + await c.server_info() + + with self.assertRaises(OperationFailure): + await ( + await async_rs_or_single_client_noauth(username="ad min", password="foo") + ).server_info() + + @async_client_context.require_auth + @async_client_context.require_no_fips + async def test_lazy_auth_raises_operation_failure(self): + host = await async_client_context.host + lazy_client = await async_rs_or_single_client_noauth( + f"mongodb://user:wrong@{host}/pymongo_test", connect=False + ) + + await asyncAssertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) + + @async_client_context.require_no_tls + async def test_unix_socket(self): + if not hasattr(socket, "AF_UNIX"): + raise SkipTest("UNIX-sockets are not supported on this system") + + mongodb_socket = "/tmp/mongodb-%d.sock" % (await async_client_context.port,) + encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % (await async_client_context.port,) + if not os.access(mongodb_socket, os.R_OK): + raise SkipTest("Socket file is not accessible") + + uri = "mongodb://%s" % encoded_socket + # Confirm we can do operations via the socket. + client = await async_rs_or_single_client(uri) + self.addAsyncCleanup(client.aclose) + await client.pymongo_test.test.insert_one({"dummy": "object"}) + dbs = await client.list_database_names() + self.assertTrue("pymongo_test" in dbs) + + self.assertTrue(mongodb_socket in repr(client)) + + # Confirm it fails with a missing socket. + with self.assertRaises(ConnectionFailure): + await connected( + AsyncMongoClient( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 + ), + ) + + async def test_document_class(self): + c = self.client + db = c.pymongo_test + await db.test.insert_one({"x": 1}) + + self.assertEqual(dict, c.codec_options.document_class) + self.assertTrue(isinstance(await db.test.find_one(), dict)) + self.assertFalse(isinstance(await db.test.find_one(), SON)) + + c = await async_rs_or_single_client(document_class=SON) + self.addAsyncCleanup(c.aclose) + db = c.pymongo_test + + self.assertEqual(SON, c.codec_options.document_class) + self.assertTrue(isinstance(await db.test.find_one(), SON)) + + async def test_timeouts(self): + client = await async_rs_or_single_client( + connectTimeoutMS=10500, + socketTimeoutMS=10500, + maxIdleTimeMS=10500, + serverSelectionTimeoutMS=10500, + ) + self.assertEqual(10.5, (await async_get_pool(client)).opts.connect_timeout) + self.assertEqual(10.5, (await async_get_pool(client)).opts.socket_timeout) + self.assertEqual(10.5, (await async_get_pool(client)).opts.max_idle_time_seconds) + self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds) + self.assertEqual(10.5, client.options.server_selection_timeout) + + async def test_socket_timeout_ms_validation(self): + c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) + + c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) + + c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) + + with self.assertRaises(ValueError): + await async_rs_or_single_client(socketTimeoutMS=-1) + + with self.assertRaises(ValueError): + await async_rs_or_single_client(socketTimeoutMS=1e10) + + with self.assertRaises(ValueError): + await async_rs_or_single_client(socketTimeoutMS="foo") + + async def test_socket_timeout(self): + no_timeout = self.client + timeout_sec = 1 + timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + self.addAsyncCleanup(timeout.aclose) + + await no_timeout.pymongo_test.drop_collection("test") + await no_timeout.pymongo_test.test.insert_one({"x": 1}) + + # A $where clause that takes a second longer than the timeout + where_func = delay(timeout_sec + 1) + + async def get_x(db): + doc = await anext((await db.test.find()).where(where_func)) + return doc["x"] + + self.assertEqual(1, await get_x(no_timeout.pymongo_test)) + with self.assertRaises(NetworkTimeout): + await get_x(timeout.pymongo_test) + + def test_server_selection_timeout(self): + client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) + self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + + client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) + + self.assertAlmostEqual(0, client.options.server_selection_timeout) + + self.assertRaises( + ValueError, AsyncMongoClient, serverSelectionTimeoutMS="foo", connect=False + ) + self.assertRaises(ValueError, AsyncMongoClient, serverSelectionTimeoutMS=-1, connect=False) + self.assertRaises( + ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False + ) + + client = AsyncMongoClient( + "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False + ) + self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + + client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) + self.assertAlmostEqual(0, client.options.server_selection_timeout) + + # Test invalid timeout in URI ignored and set to default. + client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) + self.assertAlmostEqual(30, client.options.server_selection_timeout) + + client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) + self.assertAlmostEqual(30, client.options.server_selection_timeout) + + async def test_waitQueueTimeoutMS(self): + client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) + + async def test_socketKeepAlive(self): + pool = await async_get_pool(self.client) + async with pool.checkout() as conn: + keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + self.assertTrue(keepalive) + + @no_type_check + async def test_tz_aware(self): + self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") + + aware = await async_rs_or_single_client(tz_aware=True) + self.addAsyncCleanup(aware.aclose) + naive = self.client + await aware.pymongo_test.drop_collection("test") + + now = datetime.datetime.now(tz=datetime.timezone.utc) + await aware.pymongo_test.test.insert_one({"x": now}) + + self.assertEqual(None, (await naive.pymongo_test.test.find_one())["x"].tzinfo) + self.assertEqual(utc, (await aware.pymongo_test.test.find_one())["x"].tzinfo) + self.assertEqual( + (await aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None), + (await naive.pymongo_test.test.find_one())["x"], + ) + + @async_client_context.require_ipv6 + async def test_ipv6(self): + if async_client_context.tls: + if not HAVE_IPADDRESS: + raise SkipTest("Need the ipaddress module to test with SSL") + + if async_client_context.auth_enabled: + auth_str = f"{db_user}:{db_pwd}@" + else: + auth_str = "" + + uri = "mongodb://%s[::1]:%d" % (auth_str, await async_client_context.port) + if async_client_context.is_rs: + uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") + + client = await async_rs_or_single_client_noauth(uri) + self.addAsyncCleanup(client.aclose) + await client.pymongo_test.test.insert_one({"dummy": "object"}) + await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) + + dbs = await client.list_database_names() + self.assertTrue("pymongo_test" in dbs) + self.assertTrue("pymongo_test_bernie" in dbs) + + async def test_contextlib(self): + client = await async_rs_or_single_client() + await client.pymongo_test.drop_collection("test") + await client.pymongo_test.test.insert_one({"foo": "bar"}) + + # The socket used for the previous commands has been returned to the + # pool + self.assertEqual(1, len((await async_get_pool(client)).conns)) + + # contextlib async support was added in Python 3.10 + if _IS_SYNC or sys.version_info >= (3, 10): + async with contextlib.aclosing(client): + self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) + with self.assertRaises(InvalidOperation): + await client.pymongo_test.test.find_one() + client = await async_rs_or_single_client() + async with client as client: + self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) + with self.assertRaises(InvalidOperation): + await client.pymongo_test.test.find_one() + + @async_client_context.require_sync + def test_interrupt_signal(self): + if sys.platform.startswith("java"): + # We can't figure out how to raise an exception on a thread that's + # blocked on a socket, whether that's the main thread or a worker, + # without simply killing the whole thread in Jython. This suggests + # PYTHON-294 can't actually occur in Jython. + raise SkipTest("Can't test interrupts in Jython") + if is_greenthread_patched(): + raise SkipTest("Can't reliably test interrupts with green threads") + + # Test fix for PYTHON-294 -- make sure AsyncMongoClient closes its + # socket if it gets an interrupt while waiting to recv() from it. + db = self.client.pymongo_test + + # A $where clause which takes 1.5 sec to execute + where = delay(1.5) + + # Need exactly 1 document so find() will execute its $where clause once + db.drop_collection("foo") + db.foo.insert_one({"_id": 1}) + + old_signal_handler = None + try: + # Platform-specific hacks for raising a KeyboardInterrupt on the + # main thread while find() is in-progress: On Windows, SIGALRM is + # unavailable so we use a second thread. In our Evergreen setup on + # Linux, the thread technique causes an error in the test at + # conn.recv(): TypeError: 'int' object is not callable + # We don't know what causes this, so we hack around it. + + if sys.platform == "win32": + + def interrupter(): + # Raises KeyboardInterrupt in the main thread + time.sleep(0.25) + thread.interrupt_main() + + thread.start_new_thread(interrupter, ()) + else: + # Convert SIGALRM to SIGINT -- it's hard to schedule a SIGINT + # for one second in the future, but easy to schedule SIGALRM. + def sigalarm(num, frame): + raise KeyboardInterrupt + + old_signal_handler = signal.signal(signal.SIGALRM, sigalarm) + signal.alarm(1) + + raised = False + try: + # Will be interrupted by a KeyboardInterrupt. + next(db.foo.find({"$where": where})) # type: ignore[call-overload] + except KeyboardInterrupt: + raised = True + + # Can't use self.assertRaises() because it doesn't catch system + # exceptions + self.assertTrue(raised, "Didn't raise expected KeyboardInterrupt") + + # Raises AssertionError due to PYTHON-294 -- Mongo's response to + # the previous find() is still waiting to be read on the socket, + # so the request id's don't match. + self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload] + finally: + if old_signal_handler: + signal.signal(signal.SIGALRM, old_signal_handler) + + async def test_operation_failure(self): + # Ensure AsyncMongoClient doesn't close socket after it gets an error + # response to getLastError. PYTHON-395. We need a new client here + # to avoid race conditions caused by replica set failover or idle + # socket reaping. + client = await async_single_client() + self.addAsyncCleanup(client.aclose) + await client.pymongo_test.test.find_one() + pool = await async_get_pool(client) + socket_count = len(pool.conns) + self.assertGreaterEqual(socket_count, 1) + old_conn = next(iter(pool.conns)) + await client.pymongo_test.test.drop() + await client.pymongo_test.test.insert_one({"_id": "foo"}) + with self.assertRaises(OperationFailure): + await client.pymongo_test.test.insert_one({"_id": "foo"}) + + self.assertEqual(socket_count, len(pool.conns)) + new_con = next(iter(pool.conns)) + self.assertEqual(old_conn, new_con) + + async def test_lazy_connect_w0(self): + # Ensure that connect-on-demand works when the first operation is + # an unacknowledged write. This exercises _writable_max_wire_version(). + + # Use a separate collection to avoid races where we're still + # completing an operation on a collection while the next test begins. + await async_client_context.client.drop_database("test_lazy_connect_w0") + self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") + + client = await async_rs_or_single_client(connect=False, w=0) + self.addAsyncCleanup(client.aclose) + await client.test_lazy_connect_w0.test.insert_one({}) + + async def predicate(): + return await client.test_lazy_connect_w0.test.count_documents({}) == 1 + + await async_wait_until(predicate, "find one document") + + client = await async_rs_or_single_client(connect=False, w=0) + self.addAsyncCleanup(client.aclose) + await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) + + async def predicate(): + return (await client.test_lazy_connect_w0.test.find_one()).get("x") == 1 + + await async_wait_until(predicate, "update one document") + + client = await async_rs_or_single_client(connect=False, w=0) + self.addAsyncCleanup(client.aclose) + await client.test_lazy_connect_w0.test.delete_one({}) + + async def predicate(): + return await client.test_lazy_connect_w0.test.count_documents({}) == 0 + + await async_wait_until(predicate, "delete one document") + + @async_client_context.require_no_mongos + async def test_exhaust_network_error(self): + # When doing an exhaust query, the socket stays checked out on success + # but must be checked in on error to avoid semaphore leaks. + client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) + self.addAsyncCleanup(client.aclose) + collection = client.pymongo_test.test + pool = await async_get_pool(client) + pool._check_interval_seconds = None # Never check. + + # Ensure a socket. + await connected(client) + + # Cause a network error. + conn = one(pool.conns) + conn.conn.close() + cursor = await collection.find(cursor_type=CursorType.EXHAUST) + with self.assertRaises(ConnectionFailure): + await anext(cursor) + + self.assertTrue(conn.closed) + + # The semaphore was decremented despite the error. + self.assertEqual(0, pool.requests) + + @async_client_context.require_auth + async def test_auth_network_error(self): + # Make sure there's no semaphore leak if we get a network error + # when authenticating a new socket with cached credentials. + + # Get a client with one socket so we detect if it's leaked. + c = await connected( + await async_rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + ) + + # Cause a network error on the actual socket. + pool = await async_get_pool(c) + conn = one(pool.conns) + conn.conn.close() + + # AsyncConnection.authenticate logs, but gets a socket.error. Should be + # reraised as AutoReconnect. + with self.assertRaises(AutoReconnect): + await c.test.collection.find_one() + + # No semaphore leak, the pool is allowed to make a new socket. + await c.test.collection.find_one() + + @async_client_context.require_no_replica_set + async def test_connect_to_standalone_using_replica_set_name(self): + client = await async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) + + with self.assertRaises(AutoReconnect): + await client.test.test.find_one() + + @async_client_context.require_replica_set + async def test_stale_getmore(self): + # A cursor is created, but its member goes down and is removed from + # the topology before the getMore message is sent. Test that + # AsyncMongoClient._run_operation_with_response handles the error. + with self.assertRaises(AutoReconnect): + client = await async_rs_client(connect=False, serverSelectionTimeoutMS=100) + await client._run_operation( + operation=message._GetMore( + "pymongo_test", + "collection", + 101, + 1234, + client.codec_options, + ReadPreference.PRIMARY, + None, + client, + None, + None, + False, + None, + ), + unpack_res=AsyncCursor(client.pymongo_test.collection)._unpack_response, + address=("not-a-member", 27017), + ) + + async def test_heartbeat_frequency_ms(self): + class HeartbeatStartedListener(ServerHeartbeatListener): + def __init__(self): + self.results = [] + + def started(self, event): + self.results.append(event) + + def succeeded(self, event): + pass + + def failed(self, event): + pass + + old_init = ServerHeartbeatStartedEvent.__init__ + heartbeat_times = [] + + def init(self, *args): + old_init(self, *args) + heartbeat_times.append(time.time()) + + try: + ServerHeartbeatStartedEvent.__init__ = init # type: ignore + listener = HeartbeatStartedListener() + uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( + await async_client_context.host, + await async_client_context.port, + ) + client = await async_single_client(uri, event_listeners=[listener]) + wait_until( + lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" + ) + + # Default heartbeatFrequencyMS is 10 sec. Check the interval was + # closer to 0.5 sec with heartbeatFrequencyMS configured. + self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) + + await client.aclose() + finally: + ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore + + def test_small_heartbeat_frequency_ms(self): + uri = "mongodb://example/?heartbeatFrequencyMS=499" + with self.assertRaises(ConfigurationError) as context: + AsyncMongoClient(uri) + + self.assertIn("heartbeatFrequencyMS", str(context.exception)) + + async def test_compression(self): + def compression_settings(client): + pool_options = client.options.pool_options + return pool_options._compression_settings + + uri = "mongodb://localhost:27017/?compressors=zlib" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, 4) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=foobar" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=foobar,zlib" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) + + # According to the connection string spec, unsupported values + # just raise a warning and are ignored. + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) + uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) + + if not _have_snappy(): + uri = "mongodb://localhost:27017/?compressors=snappy" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + else: + uri = "mongodb://localhost:27017/?compressors=snappy" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy"]) + uri = "mongodb://localhost:27017/?compressors=snappy,zlib" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) + + if not _have_zstd(): + uri = "mongodb://localhost:27017/?compressors=zstd" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + else: + uri = "mongodb://localhost:27017/?compressors=zstd" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd"]) + uri = "mongodb://localhost:27017/?compressors=zstd,zlib" + client = AsyncMongoClient(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) + + options = async_client_context.default_client_options + if "compressors" in options and "zlib" in options["compressors"]: + for level in range(-1, 10): + client = await async_single_client(zlibcompressionlevel=level) + # No error + await client.pymongo_test.test.find_one() + + async def test_reset_during_update_pool(self): + client = await async_rs_or_single_client(minPoolSize=10) + self.addAsyncCleanup(client.aclose) + await client.admin.command("ping") + pool = await async_get_pool(client) + generation = pool.gen.get_overall() + + # Continuously reset the pool. + class ResetPoolThread(threading.Thread): + def __init__(self, pool): + super().__init__() + self.running = True + self.pool = pool + + def stop(self): + self.running = False + + async def _run(self): + while self.running: + exc = AutoReconnect("mock pool error") + ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None) + await client._topology.handle_error(pool.address, ctx) + await asyncio.sleep(0.001) + + def run(self): + if _IS_SYNC: + self._run() + else: + asyncio.run(self._run()) + + t = ResetPoolThread(pool) + t.start() + + # Ensure that update_pool completes without error even when the pool + # is reset concurrently. + try: + while True: + for _ in range(10): + await client._topology.update_pool() + if generation != pool.gen.get_overall(): + break + finally: + t.stop() + t.join() + await client.admin.command("ping") + + async def test_background_connections_do_not_hold_locks(self): + min_pool_size = 10 + client = await async_rs_or_single_client( + serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False + ) + self.addAsyncCleanup(client.aclose) + + # Create a single connection in the pool. + await client.admin.command("ping") + + # Cause new connections stall for a few seconds. + pool = await async_get_pool(client) + original_connect = pool.connect + + def stall_connect(*args, **kwargs): + time.sleep(2) + return original_connect(*args, **kwargs) + + pool.connect = stall_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + # Wait for the background thread to start creating connections + wait_until(lambda: len(pool.conns) > 1, "start creating connections") + + # Assert that application operations do not block. + for _ in range(10): + start = time.monotonic() + await client.admin.command("ping") + total = time.monotonic() - start + # Each ping command should not take more than 2 seconds + self.assertLess(total, 2) + + @async_client_context.require_replica_set + async def test_direct_connection(self): + # direct_connection=True should result in Single topology. + client = await async_rs_or_single_client(directConnection=True) + await client.admin.command("ping") + self.assertEqual(len(client.nodes), 1) + self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) + await client.aclose() + + # direct_connection=False should result in RS topology. + client = await async_rs_or_single_client(directConnection=False) + await client.admin.command("ping") + self.assertGreaterEqual(len(client.nodes), 1) + self.assertIn( + client._topology_settings.get_topology_type(), + [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], + ) + await client.aclose() + + # directConnection=True, should error with multiple hosts as a list. + with self.assertRaises(ConfigurationError): + AsyncMongoClient(["host1", "host2"], directConnection=True) + + @unittest.skipIf("PyPy" in sys.version, "PYTHON-2927 fails often on PyPy") + async def test_continuous_network_errors(self): + def server_description_count(): + i = 0 + for obj in gc.get_objects(): + try: + if isinstance(obj, ServerDescription): + i += 1 + except ReferenceError: + pass + return i + + gc.collect() + with client_knobs(min_heartbeat_interval=0.003): + client = AsyncMongoClient( + "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 + ) + initial_count = server_description_count() + self.addAsyncCleanup(client.aclose) + with self.assertRaises(ServerSelectionTimeoutError): + await client.test.test.find_one() + gc.collect() + final_count = server_description_count() + # If a bug like PYTHON-2433 is reintroduced then too many + # ServerDescriptions will be kept alive and this test will fail: + # AssertionError: 19 != 46 within 15 delta (27 difference) + # On Python 3.11 we seem to get more of a delta. + self.assertAlmostEqual(initial_count, final_count, delta=20) + + @async_client_context.require_failCommand_fail_point + async def test_network_error_message(self): + client = await async_single_client(retryReads=False) + self.addAsyncCleanup(client.aclose) + await client.admin.command("ping") # connect + async with self.fail_point( + {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} + ): + assert await client.address is not None + expected = "{}:{}: ".format(*(await client.address)) + with self.assertRaisesRegex(AutoReconnect, expected): + await client.pymongo_test.test.find_one({}) + + @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") + async def test_process_periodic_tasks(self): + client = await async_rs_or_single_client() + coll = client.db.collection + await coll.insert_many([{} for _ in range(5)]) + cursor = await coll.find(batch_size=2) + await cursor.next() + c_id = cursor.cursor_id + self.assertIsNotNone(c_id) + await client.aclose() + # Add cursor to kill cursors queue + del cursor + wait_until( + lambda: client._kill_cursors_queue, + "waited for cursor to be added to queue", + ) + await client._process_periodic_tasks() # This must not raise or print any exceptions + with self.assertRaises(InvalidOperation): + await coll.insert_many([{} for _ in range(5)]) + + def test_service_name_from_kwargs(self): + client = AsyncMongoClient( + "mongodb+srv://user:password@test22.test.build.10gen.cc", + srvServiceName="customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") + client = AsyncMongoClient( + "mongodb+srv://user:password@test22.test.build.10gen.cc" + "/?srvServiceName=shouldbeoverriden", + srvServiceName="customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") + client = AsyncMongoClient( + "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") + + def test_srv_max_hosts_kwarg(self): + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + self.assertGreater(len(client.topology_description.server_descriptions()), 1) + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + self.assertEqual(len(client.topology_description.server_descriptions()), 1) + client = AsyncMongoClient( + "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 + ) + self.assertEqual(len(client.topology_description.server_descriptions()), 2) + + @unittest.skipIf( + async_client_context.load_balancer or async_client_context.serverless, + "loadBalanced clients do not run SDAM", + ) + @unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP") + @async_client_context.require_sync + def test_sigstop_sigcont(self): + test_dir = os.path.dirname(os.path.realpath(__file__)) + script = os.path.join(test_dir, "sigstop_sigcont.py") + p = subprocess.Popen( + [sys.executable, script, async_client_context.uri], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + self.addCleanup(p.wait, timeout=1) + self.addCleanup(p.kill) + time.sleep(1) + # Stop the child, sleep for twice the streaming timeout + # (heartbeatFrequencyMS + connectTimeoutMS), and restart. + os.kill(p.pid, signal.SIGSTOP) + time.sleep(2) + os.kill(p.pid, signal.SIGCONT) + time.sleep(0.5) + # Tell the script to exit gracefully. + outs, _ = p.communicate(input=b"q\n", timeout=10) + self.assertTrue(outs) + log_output = outs.decode("utf-8") + self.assertIn("TEST STARTED", log_output) + self.assertIn("ServerHeartbeatStartedEvent", log_output) + self.assertIn("ServerHeartbeatSucceededEvent", log_output) + self.assertIn("TEST COMPLETED", log_output) + self.assertNotIn("ServerHeartbeatFailedEvent", log_output) + + async def _test_handshake(self, env_vars, expected_env): + with patch.dict("os.environ", env_vars): + metadata = copy.deepcopy(_METADATA) + if expected_env is not None: + metadata["env"] = expected_env + + if "AWS_REGION" not in env_vars: + os.environ["AWS_REGION"] = "" + async with await async_rs_or_single_client(serverSelectionTimeoutMS=10000) as client: + await client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + + async def test_handshake_01_aws(self): + await self._test_handshake( + { + "AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", + "AWS_REGION": "us-east-2", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", + }, + {"name": "aws.lambda", "region": "us-east-2", "memory_mb": 1024}, + ) + + async def test_handshake_02_azure(self): + await self._test_handshake({"FUNCTIONS_WORKER_RUNTIME": "python"}, {"name": "azure.func"}) + + async def test_handshake_03_gcp(self): + await self._test_handshake( + { + "K_SERVICE": "servicename", + "FUNCTION_MEMORY_MB": "1024", + "FUNCTION_TIMEOUT_SEC": "60", + "FUNCTION_REGION": "us-central1", + }, + {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + ) + # Extra case for FUNCTION_NAME. + await self._test_handshake( + { + "FUNCTION_NAME": "funcname", + "FUNCTION_MEMORY_MB": "1024", + "FUNCTION_TIMEOUT_SEC": "60", + "FUNCTION_REGION": "us-central1", + }, + {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + ) + + async def test_handshake_04_vercel(self): + await self._test_handshake( + {"VERCEL": "1", "VERCEL_REGION": "cdg1"}, {"name": "vercel", "region": "cdg1"} + ) + + async def test_handshake_05_multiple(self): + await self._test_handshake( + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "FUNCTIONS_WORKER_RUNTIME": "python"}, + None, + ) + # Extra cases for other combos. + await self._test_handshake( + {"FUNCTIONS_WORKER_RUNTIME": "python", "K_SERVICE": "servicename"}, + None, + ) + await self._test_handshake({"K_SERVICE": "servicename", "VERCEL": "1"}, None) + + async def test_handshake_06_region_too_long(self): + await self._test_handshake( + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_REGION": "a" * 512}, + {"name": "aws.lambda"}, + ) + + async def test_handshake_07_memory_invalid_int(self): + await self._test_handshake( + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big"}, + {"name": "aws.lambda"}, + ) + + async def test_handshake_08_invalid_aws_ec2(self): + # AWS_EXECUTION_ENV needs to start with "AWS_Lambda_". + await self._test_handshake( + {"AWS_EXECUTION_ENV": "EC2"}, + None, + ) + + async def test_dict_hints(self): + await self.db.t.find(hint={"x": 1}) + + async def test_dict_hints_sort(self): + result = await self.db.t.find() + result.sort({"x": 1}) + + await self.db.t.find(sort={"x": 1}) + + async def test_dict_hints_create_index(self): + await self.db.t.create_index({"x": pymongo.ASCENDING}) + + +class TestExhaustCursor(AsyncIntegrationTest): + """Test that clients properly handle errors from exhaust cursors.""" + + def setUp(self): + super().setUp() + if async_client_context.is_mongos: + raise SkipTest("mongos doesn't support exhaust, SERVER-2627") + + async def test_exhaust_query_server_error(self): + # When doing an exhaust query, the socket stays checked out on success + # but must be checked in on error to avoid semaphore leaks. + client = await connected(await async_rs_or_single_client(maxPoolSize=1)) + + collection = client.pymongo_test.test + pool = await async_get_pool(client) + conn = one(pool.conns) + + # This will cause OperationFailure in all mongo versions since + # the value for $orderby must be a document. + cursor = await collection.find( + SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST + ) + + with self.assertRaises(OperationFailure): + await cursor.next() + self.assertFalse(conn.closed) + + # The socket was checked in and the semaphore was decremented. + self.assertIn(conn, pool.conns) + self.assertEqual(0, pool.requests) + + async def test_exhaust_getmore_server_error(self): + # When doing a getmore on an exhaust cursor, the socket stays checked + # out on success but it's checked in on error to avoid semaphore leaks. + client = await async_rs_or_single_client(maxPoolSize=1) + collection = client.pymongo_test.test + await collection.drop() + + await collection.insert_many([{} for _ in range(200)]) + self.addAsyncCleanup(async_client_context.client.pymongo_test.test.drop) + + pool = await async_get_pool(client) + pool._check_interval_seconds = None # Never check. + conn = one(pool.conns) + + cursor = await collection.find(cursor_type=CursorType.EXHAUST) + + # Initial query succeeds. + await cursor.next() + + # Cause a server error on getmore. + async def receive_message(request_id): + # Discard the actual server response. + await AsyncConnection.receive_message(conn, request_id) + + # responseFlags bit 1 is QueryFailure. + msg = struct.pack(" 0) self.assertEqual(len(helper_docs), len(cmd_docs)) # PYTHON-3529 Some fields may change between calls, just compare names. @@ -900,7 +920,7 @@ class TestClient(IntegrationTest): self.client.pymongo_test.test.insert_one({}) cursor = self.client.list_databases(filter={"name": "admin"}) - docs = list(cursor) + docs = cursor.to_list() self.assertEqual(1, len(docs)) self.assertEqual(docs[0]["name"], "admin") @@ -911,7 +931,7 @@ class TestClient(IntegrationTest): def test_list_database_names(self): self.client.pymongo_test.test.insert_one({"dummy": "object"}) self.client.pymongo_test_mike.test.insert_one({"dummy": "object"}) - cmd_docs = self.client.admin.command("listDatabases")["databases"] + cmd_docs = (self.client.admin.command("listDatabases"))["databases"] cmd_names = [doc["name"] for doc in cmd_docs] db_names = self.client.list_database_names() @@ -920,8 +940,10 @@ class TestClient(IntegrationTest): self.assertEqual(db_names, cmd_names) def test_drop_database(self): - self.assertRaises(TypeError, self.client.drop_database, 5) - self.assertRaises(TypeError, self.client.drop_database, None) + with self.assertRaises(TypeError): + self.client.drop_database(5) # type: ignore[arg-type] + with self.assertRaises(TypeError): + self.client.drop_database(None) # type: ignore[arg-type] self.client.pymongo_test.test.insert_one({"dummy": "object"}) self.client.pymongo_test2.test.insert_one({"dummy": "object"}) @@ -944,7 +966,8 @@ class TestClient(IntegrationTest): test_client = rs_or_single_client() coll = test_client.pymongo_test.bar test_client.close() - self.assertRaises(InvalidOperation, coll.count_documents, {}) + with self.assertRaises(InvalidOperation): + coll.count_documents({}) def test_close_kills_cursors(self): if sys.platform.startswith("java"): @@ -961,7 +984,7 @@ class TestClient(IntegrationTest): coll.insert_many([{"i": i} for i in range(docs_inserted)]) # Open a cursor and leave it open on the server. - cursor = coll.find().batch_size(10) + cursor = (coll.find()).batch_size(10) self.assertTrue(bool(next(cursor))) self.assertLess(cursor.retrieved, docs_inserted) @@ -992,7 +1015,8 @@ class TestClient(IntegrationTest): self.assertTrue(client._kill_cursors_executor._stopped) # Reusing the closed client should raise an InvalidOperation error. - self.assertRaises(InvalidOperation, client.admin.command, "ping") + with self.assertRaises(InvalidOperation): + client.admin.command("ping") # Thread is still stopped. self.assertTrue(client._kill_cursors_executor._stopped) @@ -1065,8 +1089,10 @@ class TestClient(IntegrationTest): ) # Auth with lazy connection. - rs_or_single_client_noauth( - "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False + ( + rs_or_single_client_noauth( + "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False + ) ).pymongo_test.test.find_one() # Wrong password. @@ -1074,7 +1100,8 @@ class TestClient(IntegrationTest): "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) - self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one) + with self.assertRaises(OperationFailure): + bad_client.pymongo_test.test.find_one() @client_context.require_auth def test_username_and_password(self): @@ -1093,13 +1120,14 @@ class TestClient(IntegrationTest): c.server_info() with self.assertRaises(OperationFailure): - rs_or_single_client_noauth(username="ad min", password="foo").server_info() + (rs_or_single_client_noauth(username="ad min", password="foo")).server_info() @client_context.require_auth @client_context.require_no_fips def test_lazy_auth_raises_operation_failure(self): + host = client_context.host lazy_client = rs_or_single_client_noauth( - f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False + f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) @@ -1125,11 +1153,10 @@ class TestClient(IntegrationTest): self.assertTrue(mongodb_socket in repr(client)) # Confirm it fails with a missing socket. - self.assertRaises( - ConnectionFailure, - connected, - MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), - ) + with self.assertRaises(ConnectionFailure): + connected( + MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), + ) def test_document_class(self): c = self.client @@ -1154,27 +1181,30 @@ class TestClient(IntegrationTest): maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) - self.assertEqual(10.5, get_pool(client).opts.connect_timeout) - self.assertEqual(10.5, get_pool(client).opts.socket_timeout) - self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds) + self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout) + self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout) + self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds) self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds) self.assertEqual(10.5, client.options.server_selection_timeout) def test_socket_timeout_ms_validation(self): c = rs_or_single_client(socketTimeoutMS=10 * 1000) - self.assertEqual(10, get_pool(c).opts.socket_timeout) + self.assertEqual(10, (get_pool(c)).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=None)) - self.assertEqual(None, get_pool(c).opts.socket_timeout) + self.assertEqual(None, (get_pool(c)).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=0)) - self.assertEqual(None, get_pool(c).opts.socket_timeout) + self.assertEqual(None, (get_pool(c)).opts.socket_timeout) - self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1) + with self.assertRaises(ValueError): + rs_or_single_client(socketTimeoutMS=-1) - self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10) + with self.assertRaises(ValueError): + rs_or_single_client(socketTimeoutMS=1e10) - self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS="foo") + with self.assertRaises(ValueError): + rs_or_single_client(socketTimeoutMS="foo") def test_socket_timeout(self): no_timeout = self.client @@ -1189,11 +1219,12 @@ class TestClient(IntegrationTest): where_func = delay(timeout_sec + 1) def get_x(db): - doc = next(db.test.find().where(where_func)) + doc = next((db.test.find()).where(where_func)) return doc["x"] self.assertEqual(1, get_x(no_timeout.pymongo_test)) - self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test) + with self.assertRaises(NetworkTimeout): + get_x(timeout.pymongo_test) def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) @@ -1224,7 +1255,7 @@ class TestClient(IntegrationTest): def test_waitQueueTimeoutMS(self): client = rs_or_single_client(waitQueueTimeoutMS=2000) - self.assertEqual(get_pool(client).opts.wait_queue_timeout, 2) + self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) def test_socketKeepAlive(self): pool = get_pool(self.client) @@ -1244,11 +1275,11 @@ class TestClient(IntegrationTest): now = datetime.datetime.now(tz=datetime.timezone.utc) aware.pymongo_test.test.insert_one({"x": now}) - self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo) - self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo) + self.assertEqual(None, (naive.pymongo_test.test.find_one())["x"].tzinfo) + self.assertEqual(utc, (aware.pymongo_test.test.find_one())["x"].tzinfo) self.assertEqual( - aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None), - naive.pymongo_test.test.find_one()["x"], + (aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None), + (naive.pymongo_test.test.find_one())["x"], ) @client_context.require_ipv6 @@ -1282,18 +1313,21 @@ class TestClient(IntegrationTest): # The socket used for the previous commands has been returned to the # pool - self.assertEqual(1, len(get_pool(client).conns)) + self.assertEqual(1, len((get_pool(client)).conns)) - with contextlib.closing(client): - self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - with self.assertRaises(InvalidOperation): - client.pymongo_test.test.find_one() - client = rs_or_single_client() - with client as client: - self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - with self.assertRaises(InvalidOperation): - client.pymongo_test.test.find_one() + # contextlib async support was added in Python 3.10 + if _IS_SYNC or sys.version_info >= (3, 10): + with contextlib.closing(client): + self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) + with self.assertRaises(InvalidOperation): + client.pymongo_test.test.find_one() + client = rs_or_single_client() + with client as client: + self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) + with self.assertRaises(InvalidOperation): + client.pymongo_test.test.find_one() + @client_context.require_sync def test_interrupt_signal(self): if sys.platform.startswith("java"): # We can't figure out how to raise an exception on a thread that's @@ -1344,7 +1378,7 @@ class TestClient(IntegrationTest): raised = False try: # Will be interrupted by a KeyboardInterrupt. - next(db.foo.find({"$where": where})) + next(db.foo.find({"$where": where})) # type: ignore[call-overload] except KeyboardInterrupt: raised = True @@ -1355,7 +1389,7 @@ class TestClient(IntegrationTest): # Raises AssertionError due to PYTHON-294 -- Mongo's response to # the previous find() is still waiting to be read on the socket, # so the request id's don't match. - self.assertEqual({"_id": 1}, next(db.foo.find())) + self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload] finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) @@ -1374,7 +1408,8 @@ class TestClient(IntegrationTest): old_conn = next(iter(pool.conns)) client.pymongo_test.test.drop() client.pymongo_test.test.insert_one({"_id": "foo"}) - self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"}) + with self.assertRaises(OperationFailure): + client.pymongo_test.test.insert_one({"_id": "foo"}) self.assertEqual(socket_count, len(pool.conns)) new_con = next(iter(pool.conns)) @@ -1392,23 +1427,29 @@ class TestClient(IntegrationTest): client = rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.insert_one({}) - wait_until( - lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document" - ) + + def predicate(): + return client.test_lazy_connect_w0.test.count_documents({}) == 1 + + wait_until(predicate, "find one document") client = rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) - wait_until( - lambda: client.test_lazy_connect_w0.test.find_one().get("x") == 1, "update one document" - ) + + def predicate(): + return (client.test_lazy_connect_w0.test.find_one()).get("x") == 1 + + wait_until(predicate, "update one document") client = rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.delete_one({}) - wait_until( - lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document" - ) + + def predicate(): + return client.test_lazy_connect_w0.test.count_documents({}) == 0 + + wait_until(predicate, "delete one document") @client_context.require_no_mongos def test_exhaust_network_error(self): @@ -1445,12 +1486,13 @@ class TestClient(IntegrationTest): # Cause a network error on the actual socket. pool = get_pool(c) - socket_info = one(pool.conns) - socket_info.conn.close() + conn = one(pool.conns) + conn.conn.close() # Connection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. - self.assertRaises(AutoReconnect, c.test.collection.find_one) + with self.assertRaises(AutoReconnect): + c.test.collection.find_one() # No semaphore leak, the pool is allowed to make a new socket. c.test.collection.find_one() @@ -1638,13 +1680,19 @@ class TestClient(IntegrationTest): def stop(self): self.running = False - def run(self): + def _run(self): while self.running: exc = AutoReconnect("mock pool error") ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None) client._topology.handle_error(pool.address, ctx) time.sleep(0.001) + def run(self): + if _IS_SYNC: + self._run() + else: + asyncio.run(self._run()) + t = ResetPoolThread(pool) t.start() @@ -1755,7 +1803,7 @@ class TestClient(IntegrationTest): {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} ): assert client.address is not None - expected = "{}:{}: ".format(*client.address) + expected = "{}:{}: ".format(*(client.address)) with self.assertRaisesRegex(AutoReconnect, expected): client.pymongo_test.test.find_one({}) @@ -1814,6 +1862,7 @@ class TestClient(IntegrationTest): "loadBalanced clients do not run SDAM", ) @unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP") + @client_context.require_sync def test_sigstop_sigcont(self): test_dir = os.path.dirname(os.path.realpath(__file__)) script = os.path.join(test_dir, "sigstop_sigcont.py") @@ -1961,7 +2010,8 @@ class TestExhaustCursor(IntegrationTest): SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST ) - self.assertRaises(OperationFailure, cursor.next) + with self.assertRaises(OperationFailure): + cursor.next() self.assertFalse(conn.closed) # The socket was checked in and the semaphore was decremented. @@ -1998,7 +2048,8 @@ class TestExhaustCursor(IntegrationTest): return message._OpReply.unpack(msg) conn.receive_message = receive_message - self.assertRaises(OperationFailure, list, cursor) + with self.assertRaises(OperationFailure): + cursor.to_list() # Unpatch the instance. del conn.receive_message @@ -2019,7 +2070,8 @@ class TestExhaustCursor(IntegrationTest): conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) - self.assertRaises(ConnectionFailure, cursor.next) + with self.assertRaises(ConnectionFailure): + cursor.next() self.assertTrue(conn.closed) # The socket was closed and the semaphore was decremented. @@ -2046,7 +2098,8 @@ class TestExhaustCursor(IntegrationTest): conn.conn.close() # A getmore fails. - self.assertRaises(ConnectionFailure, list, cursor) + with self.assertRaises(ConnectionFailure): + cursor.to_list() self.assertTrue(conn.closed) wait_until( @@ -2057,6 +2110,7 @@ class TestExhaustCursor(IntegrationTest): self.assertNotIn(conn, pool.conns) self.assertEqual(0, pool.requests) + @client_context.require_sync def test_gevent_task(self): if not gevent_monkey_patched(): raise SkipTest("Must be running monkey patched by gevent") @@ -2070,6 +2124,7 @@ class TestExhaustCursor(IntegrationTest): task.kill() self.assertTrue(task.dead) + @client_context.require_sync def test_gevent_timeout(self): if not gevent_monkey_patched(): raise SkipTest("Must be running monkey patched by gevent") @@ -2101,6 +2156,7 @@ class TestExhaustCursor(IntegrationTest): self.assertIsNone(tt.get()) self.assertIsNone(ct.get()) + @client_context.require_sync def test_gevent_timeout_when_creating_connection(self): if not gevent_monkey_patched(): raise SkipTest("Must be running monkey patched by gevent") @@ -2145,6 +2201,7 @@ class TestClientLazyConnect(IntegrationTest): def _get_client(self): return rs_or_single_client(connect=False) + @client_context.require_sync def test_insert_one(self): def reset(collection): collection.drop() @@ -2157,6 +2214,7 @@ class TestClientLazyConnect(IntegrationTest): lazy_client_trial(reset, insert_one, test, self._get_client) + @client_context.require_sync def test_update_one(self): def reset(collection): collection.drop() @@ -2171,6 +2229,7 @@ class TestClientLazyConnect(IntegrationTest): lazy_client_trial(reset, update_one, test, self._get_client) + @client_context.require_sync def test_delete_one(self): def reset(collection): collection.drop() @@ -2184,6 +2243,7 @@ class TestClientLazyConnect(IntegrationTest): lazy_client_trial(reset, delete_one, test, self._get_client) + @client_context.require_sync def test_find_one(self): results: list = [] @@ -2203,7 +2263,7 @@ class TestClientLazyConnect(IntegrationTest): class TestMongoClientFailover(MockClientTest): def test_discover_primary(self): - c = MockClient( + c = MockClient.get_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2219,13 +2279,17 @@ class TestMongoClientFailover(MockClientTest): # Fail over. c.kill_host("a:1") c.mock_primary = "b:2" - wait_until(lambda: c.address == ("b", 2), "wait for server address to be updated") + + def predicate(): + return (c.address) == ("b", 2) + + wait_until(predicate, "wait for server address to be updated") # a:1 not longer in nodes. self.assertLess(len(c.nodes), 3) def test_reconnect(self): # Verify the node list isn't forgotten during a network failure. - c = MockClient( + c = MockClient.get_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2245,12 +2309,13 @@ class TestMongoClientFailover(MockClientTest): # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from - # MockPool.get_socket). - self.assertRaises(AutoReconnect, c.db.collection.find_one) + # AsyncMockPool.get_socket). + with self.assertRaises(AutoReconnect): + c.db.collection.find_one() # But it can reconnect. c.revive_host("a:1") - c._get_topology().select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) self.assertEqual(c.address, ("a", 1)) def _test_network_error(self, operation_callback): @@ -2273,7 +2338,7 @@ class TestMongoClientFailover(MockClientTest): # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) c.set_wire_version_range("b:2", 2, 7) - c._get_topology().select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) wait_until(lambda: len(c.nodes) == 2, "connect") c.kill_host("a:1") @@ -2281,17 +2346,18 @@ class TestMongoClientFailover(MockClientTest): # MongoClient is disconnected from the primary. This raises either # ServerSelectionTimeoutError or AutoReconnect (from # MockPool.get_socket). - self.assertRaises(AutoReconnect, operation_callback, c) + with self.assertRaises(AutoReconnect): + operation_callback(c) # The primary's description is reset. - server_a = c._get_topology().get_server_by_address(("a", 1)) + server_a = (c._get_topology()).get_server_by_address(("a", 1)) sd_a = server_a.description self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type) self.assertEqual(0, sd_a.min_wire_version) self.assertEqual(0, sd_a.max_wire_version) # ...but not the secondary's. - server_b = c._get_topology().get_server_by_address(("b", 2)) + server_b = (c._get_topology()).get_server_by_address(("b", 2)) sd_b = server_b.description self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type) self.assertEqual(2, sd_b.min_wire_version) @@ -2332,7 +2398,7 @@ class TestClientPool(MockClientTest): @client_context.require_connection def test_rs_client_does_not_maintain_pool_to_arbiters(self): listener = CMAPListener() - c = MockClient( + c = MockClient.get_mock_client( standalones=[], members=["a:1", "b:2", "c:3", "d:4"], mongoses=[], @@ -2363,7 +2429,7 @@ class TestClientPool(MockClientTest): @client_context.require_connection def test_direct_client_maintains_pool_to_arbiter(self): listener = CMAPListener() - c = MockClient( + c = MockClient.get_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], diff --git a/test/test_common.py b/test/test_common.py index fdd4513d0..358cd29b8 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -20,8 +20,8 @@ import uuid sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest -from test.utils import connected, rs_or_single_client, single_client +from test import IntegrationTest, client_context, connected, unittest +from test.utils import rs_or_single_client, single_client from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index d528a1dfe..12b3b8864 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -20,13 +20,12 @@ import sys sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import IntegrationTest, drop_collections, unittest from test.utils import ( SpecTestCreator, camel_to_snake, camel_to_snake_args, camel_to_upper_camel, - drop_collections, ) from pymongo import WriteConcern, operations diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index f39a1cb03..7bc822546 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -22,9 +22,9 @@ from pymongo.operations import _Op sys.path[0:0] = [""] -from test import MockClientTest, client_context, unittest +from test import MockClientTest, client_context, connected, unittest from test.pymongo_mocks import MockClient -from test.utils import connected, wait_until +from test.utils import wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector diff --git a/test/test_monitor.py b/test/test_monitor.py index 3bf610294..fd82fc1ca 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -22,10 +22,9 @@ from functools import partial sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import IntegrationTest, connected, unittest from test.utils import ( ServerAndTopologyEventListener, - connected, single_client, wait_until, ) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 4df791e94..2cd3195f4 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -26,10 +26,9 @@ from pymongo.operations import _Op sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, unittest +from test import IntegrationTest, SkipTest, client_context, connected, unittest from test.utils import ( OvertCommandListener, - connected, one, rs_client, single_client, diff --git a/test/test_ssl.py b/test/test_ssl.py index 3b307df39..5b3855a82 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -21,13 +21,19 @@ import sys sys.path[0:0] = [""] -from test import HAVE_IPADDRESS, IntegrationTest, SkipTest, client_context, unittest +from test import ( + HAVE_IPADDRESS, + IntegrationTest, + SkipTest, + client_context, + connected, + remove_all_users, + unittest, +) from test.utils import ( EventListener, cat_files, - connected, ignore_deprecations, - remove_all_users, ) from urllib.parse import quote_plus diff --git a/test/utils.py b/test/utils.py index 98666e271..d4dddc844 100644 --- a/test/utils.py +++ b/test/utils.py @@ -621,7 +621,10 @@ async def _async_mongo_client(host, port, authenticate=True, directConnection=No ): client_options["username"] = db_user client_options["password"] = db_pwd - return AsyncMongoClient(uri, port, **client_options) + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + return client def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: @@ -843,16 +846,6 @@ def server_started_with_auth(client): return "--auth" in argv or "--keyFile" in argv -def drop_collections(db): - # Drop all non-system collections in this database. - for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}): - db.drop_collection(coll) - - -def remove_all_users(db): - db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w}) - - def joinall(threads): """Join threads with a 5-minute timeout, assert joins succeeded""" for t in threads: @@ -860,17 +853,6 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t -def connected(client): - """Convenience to wait for a newly-constructed client to connect.""" - with warnings.catch_warnings(): - # Ignore warning that ping is always routed to primary even - # if client's read preference isn't PRIMARY. - warnings.simplefilter("ignore", UserWarning) - client.admin.command("ping") # Force connection. - - return client - - def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. @@ -957,6 +939,20 @@ def assertRaisesExactly(cls, fn, *args, **kwargs): raise AssertionError("%s not raised" % cls) +async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): + """ + Unlike the standard assertRaises, this checks that a function raises a + specific class of exception, and not a subclass. E.g., check that + MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. + """ + try: + await fn(*args, **kwargs) + except Exception as e: + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" + else: + raise AssertionError("%s not raised" % cls) + + @contextlib.contextmanager def _ignore_deprecations(): with warnings.catch_warnings(): diff --git a/tools/synchro.py b/tools/synchro.py index 5468e4932..0c608e4e6 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -72,11 +72,20 @@ replacements = { "addAsyncCleanup": "addCleanup", "async_setup_class": "setup_class", "IsolatedAsyncioTestCase": "TestCase", + "AsyncUnitTest": "UnitTest", + "AsyncMockClient": "MockClient", "async_get_pool": "get_pool", "async_is_mongos": "is_mongos", "async_rs_or_single_client": "rs_or_single_client", + "async_rs_or_single_client_noauth": "rs_or_single_client_noauth", + "async_rs_client": "rs_client", "async_single_client": "single_client", "async_from_client": "from_client", + "aclosing": "closing", + "asyncAssertRaisesExactly": "assertRaisesExactly", + "get_async_mock_client": "get_mock_client", + "aconnect": "_connect", + "aclose": "close", } docstring_replacements: dict[tuple[str, str], str] = { @@ -131,6 +140,8 @@ sync_gridfs_files = [ converted_tests = [ "__init__.py", "conftest.py", + "pymongo_mocks.py", + "test_client.py", "test_collection.py", "test_database.py", ]