PYTHON-4533 - Convert test/test_client.py to async (#1730)

This commit is contained in:
Noah Stapp 2024-07-10 13:15:13 -07:00 committed by GitHub
parent 554ce7d984
commit d0193eb045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 3216 additions and 171 deletions

View File

@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
self.client_ref = None
self.key_vault_coll = None
if self.mongocryptd_client:
await self.mongocryptd_client.close()
await self.mongocryptd_client.aclose()
self.mongocryptd_client = None
@ -439,7 +439,7 @@ class _Encrypter:
self._closed = True
await self._auto_encrypter.close()
if self._internal_client:
await self._internal_client.close()
await self._internal_client.aclose()
self._internal_client = None

View File

@ -861,6 +861,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# This will be used later if we fork.
AsyncMongoClient._clients[self._topology._topology_id] = self
async def aconnect(self) -> None:
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
await self._get_topology()
def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
# Seed the topology with the old one's pid so we can detect clients
@ -1354,13 +1358,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
await self.aclose()
# See PYTHON-3084.
__iter__ = None
def __next__(self) -> NoReturn:
raise TypeError("'MongoClient' object is not iterable")
raise TypeError("'AsyncMongoClient' object is not iterable")
next = __next__
@ -1490,7 +1494,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# command.
pass
async def close(self) -> None:
async def aclose(self) -> None:
"""Cleanup client resources and disconnect from MongoDB.
End all server sessions created by this client by sending one or more

View File

@ -860,6 +860,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# This will be used later if we fork.
MongoClient._clients[self._topology._topology_id] = self
def _connect(self) -> None:
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
self._get_topology()
def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
# Seed the topology with the old one's pid so we can detect clients

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
@ -39,8 +40,6 @@ from test.helpers import (
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
_all_users,
_create_user,
client_knobs,
db_pwd,
db_user,
@ -62,9 +61,9 @@ try:
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import wraps
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
@ -812,6 +811,12 @@ class ClientContext:
func=func,
)
def require_sync(self, func):
"""Run a test only if using the synchronous API."""
return self._require(
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)
@ -919,6 +924,32 @@ class PyMongoTestCase(unittest.TestCase):
self.assertEqual(proc.exitcode, 0)
class UnitTest(PyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
def _setup_class(cls):
cls._setup_class()
@classmethod
def _tearDown_class(cls):
cls._tearDown_class()
class IntegrationTest(PyMongoTestCase):
"""Async base class for TestCases that need a connection to MongoDB to pass."""
@ -933,6 +964,13 @@ class IntegrationTest(PyMongoTestCase):
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_connection
def _setup_class(cls):
@ -947,6 +985,10 @@ class IntegrationTest(PyMongoTestCase):
else:
cls.credentials = {}
@classmethod
def _tearDown_class(cls):
pass
def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
for c in collections:
@ -959,7 +1001,7 @@ class IntegrationTest(PyMongoTestCase):
self.addCleanup(patcher.disable)
class MockClientTest(unittest.TestCase):
class MockClientTest(UnitTest):
"""Base class for TestCases that use MockClient.
This class is *not* an IntegrationTest: if properly written, MockClient
@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
@client_context.require_no_load_balancer
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@client_context.require_no_load_balancer
def _setup_class(cls):
pass
@classmethod
def _tearDown_class(cls):
pass
def setUp(self):
@ -1051,3 +1111,38 @@ def print_running_clients():
processed.add(obj._topology_id)
except ReferenceError:
pass
def _all_users(db):
return {u["user"] for u in (db.command("usersInfo")).get("users", [])}
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([("createUser", user)])
# X509 doesn't use a password
if pwd:
cmd["pwd"] = pwd
cmd["roles"] = roles or ["root"]
cmd.update(**kwargs)
return authdb.command(cmd)
def connected(client):
"""Convenience to wait for a newly-constructed client to connect."""
with warnings.catch_warnings():
# Ignore warning that ping is always routed to primary even
# if client's read preference isn't PRIMARY.
warnings.simplefilter("ignore", UserWarning)
client.admin.command("ping") # Force connection.
return client
def drop_collections(db: Database):
# Drop all non-system collections in this database.
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
db.drop_collection(coll)
def remove_all_users(db: Database):
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import base64
import contextlib
import gc
import multiprocessing
import os
@ -39,8 +40,6 @@ from test.helpers import (
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
_all_users,
_create_user,
client_knobs,
db_pwd,
db_user,
@ -62,9 +61,9 @@ try:
except ImportError:
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
@ -184,7 +183,7 @@ class AsyncClientContext:
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
return None
finally:
await client.close()
await client.aclose()
async def _init_client(self):
self.client = await self._connect(host, port)
@ -229,7 +228,7 @@ class AsyncClientContext:
if not self.serverless and not IS_SRV:
# See if db_user already exists.
if not await self._check_user_provided():
_create_user(self.client.admin, db_user, db_pwd)
await _create_user(self.client.admin, db_user, db_pwd)
self.client = await self._connect(
host,
@ -304,7 +303,7 @@ class AsyncClientContext:
params = self.cmd_line["parsed"].get("setParameter", {})
if params.get("enableTestCommands") == "1":
self.test_commands_enabled = True
self.has_ipv6 = self._server_started_with_ipv6()
self.has_ipv6 = await self._server_started_with_ipv6()
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
if self.is_mongos:
@ -390,7 +389,7 @@ class AsyncClientContext:
)
try:
return db_user in _all_users(client.admin)
return db_user in await _all_users(client.admin)
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get("errmsg", "")
@ -400,7 +399,7 @@ class AsyncClientContext:
else:
raise
finally:
await client.close()
await client.aclose()
def _server_started_with_auth(self):
# MongoDB >= 2.0
@ -482,9 +481,9 @@ class AsyncClientContext:
return decorate
return make_wrapper(func)
def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
async def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
kwargs["writeConcern"] = {"w": self.w}
return _create_user(self.client[dbname], user, pwd, roles, **kwargs)
return await _create_user(self.client[dbname], user, pwd, roles, **kwargs)
async def drop_user(self, dbname, user):
await self.client[dbname].command("dropUser", user, writeConcern={"w": self.w})
@ -814,6 +813,12 @@ class AsyncClientContext:
func=func,
)
def require_sync(self, func):
"""Run a test only if using the synchronous API."""
return self._require(
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)
@ -921,6 +926,32 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
self.assertEqual(proc.exitcode, 0)
class AsyncUnitTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that don't require a connection to MongoDB."""
@classmethod
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
async def _setup_class(cls):
await cls._setup_class()
@classmethod
async def _tearDown_class(cls):
await cls._tearDown_class()
class AsyncIntegrationTest(AsyncPyMongoTestCase):
"""Async base class for TestCases that need a connection to MongoDB to pass."""
@ -935,6 +966,13 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
@ -949,6 +987,10 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
else:
cls.credentials = {}
@classmethod
async def _tearDown_class(cls):
pass
async def cleanup_colls(self, *collections):
"""Cleanup collections faster than drop_collection."""
for c in collections:
@ -961,7 +1003,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
self.addCleanup(patcher.disable)
class AsyncMockClientTest(unittest.TestCase):
class AsyncMockClientTest(AsyncUnitTest):
"""Base class for TestCases that use MockClient.
This class is *not* an IntegrationTest: if properly written, MockClient
@ -974,8 +1016,26 @@ class AsyncMockClientTest(unittest.TestCase):
# multiple seed addresses, or wait for heartbeat events are incompatible
# with loadBalanced=True.
@classmethod
@async_client_context.require_no_load_balancer
def setUpClass(cls):
if _IS_SYNC:
cls._setup_class()
else:
asyncio.run(cls._setup_class())
@classmethod
def tearDownClass(cls):
if _IS_SYNC:
cls._tearDown_class()
else:
asyncio.run(cls._tearDown_class())
@classmethod
@async_client_context.require_no_load_balancer
async def _setup_class(cls):
pass
@classmethod
async def _tearDown_class(cls):
pass
def setUp(self):
@ -1015,7 +1075,7 @@ async def async_teardown():
await c.drop_database("pymongo_test2")
await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie")
await c.close()
await c.aclose()
print_running_clients()
@ -1053,3 +1113,38 @@ def print_running_clients():
processed.add(obj._topology_id)
except ReferenceError:
pass
async def _all_users(db):
return {u["user"] for u in (await db.command("usersInfo")).get("users", [])}
async def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([("createUser", user)])
# X509 doesn't use a password
if pwd:
cmd["pwd"] = pwd
cmd["roles"] = roles or ["root"]
cmd.update(**kwargs)
return await authdb.command(cmd)
async def connected(client):
"""Convenience to wait for a newly-constructed client to connect."""
with warnings.catch_warnings():
# Ignore warning that ping is always routed to primary even
# if client's read preference isn't PRIMARY.
warnings.simplefilter("ignore", UserWarning)
await client.admin.command("ping") # Force connection.
return client
async def drop_collections(db: AsyncDatabase):
# Drop all non-system collections in this database.
for coll in await db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
await db.drop_collection(coll)
async def remove_all_users(db: AsyncDatabase):
await db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": async_client_context.w})

View File

@ -0,0 +1,252 @@
# Copyright 2013-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for mocking parts of PyMongo to test other parts."""
from __future__ import annotations
import contextlib
import weakref
from functools import partial
from test import client_context
from test.asynchronous import async_client_context
from pymongo import AsyncMongoClient, common
from pymongo.asynchronous.monitor import Monitor
from pymongo.asynchronous.pool import Pool
from pymongo.errors import AutoReconnect, NetworkTimeout
from pymongo.hello import Hello, HelloCompat
from pymongo.server_description import ServerDescription
_IS_SYNC = False
class MockPool(Pool):
def __init__(self, client, pair, *args, **kwargs):
# MockPool gets a 'client' arg, regular pools don't. Weakref it to
# avoid cycle with __del__, causing ResourceWarnings in Python 3.3.
self.client = weakref.proxy(client)
self.mock_host, self.mock_port = pair
# Actually connect to the default server.
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
@contextlib.asynccontextmanager
async def checkout(self, handler=None):
client = self.client
host_and_port = f"{self.mock_host}:{self.mock_port}"
if host_and_port in client.mock_down_hosts:
raise AutoReconnect("mock error")
assert host_and_port in (
client.mock_standalones + client.mock_members + client.mock_mongoses
), "bad host: %s" % host_and_port
async with Pool.checkout(self, handler) as conn:
conn.mock_host = self.mock_host
conn.mock_port = self.mock_port
yield conn
class DummyMonitor:
def __init__(self, server_description, topology, pool, topology_settings):
self._server_description = server_description
self.opened = False
def cancel_check(self):
pass
def join(self):
pass
def open(self):
self.opened = True
def request_check(self):
pass
def close(self):
self.opened = False
class AsyncMockMonitor(Monitor):
def __init__(self, client, server_description, topology, pool, topology_settings):
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
# to avoid cycles.
self.client = weakref.proxy(client)
Monitor.__init__(self, server_description, topology, pool, topology_settings)
async def _check_once(self):
client = self.client
address = self._server_description.address
response, rtt = client.mock_hello("%s:%d" % address) # type: ignore[str-format]
return ServerDescription(address, Hello(response), rtt)
class AsyncMockClient(AsyncMongoClient):
def __init__(
self,
standalones,
members,
mongoses,
hello_hosts=None,
arbiters=None,
down_hosts=None,
*args,
**kwargs,
):
"""An AsyncMongoClient connected to the default server, with a mock topology.
standalones, members, mongoses, arbiters, and down_hosts determine the
configuration of the topology. They are formatted like ['a:1', 'b:2'].
hello_hosts provides an alternative host list for the server's
mocked hello response; see test_connect_with_internal_ips.
"""
self.mock_standalones = standalones[:]
self.mock_members = members[:]
if self.mock_members:
self.mock_primary = self.mock_members[0]
else:
self.mock_primary = None
# Hosts that should be considered an arbiter.
self.mock_arbiters = arbiters[:] if arbiters else []
if hello_hosts is not None:
self.mock_hello_hosts = hello_hosts
else:
self.mock_hello_hosts = members[:]
self.mock_mongoses = mongoses[:]
# Hosts that should raise socket errors.
self.mock_down_hosts = down_hosts[:] if down_hosts else []
# Hostname -> (min wire version, max wire version)
self.mock_wire_versions = {}
# Hostname -> max write batch size
self.mock_max_write_batch_sizes = {}
# Hostname -> round trip time
self.mock_rtts = {}
kwargs["_pool_class"] = partial(MockPool, self)
kwargs["_monitor_class"] = partial(AsyncMockMonitor, self)
client_options = async_client_context.default_client_options.copy()
client_options.update(kwargs)
super().__init__(*args, **client_options)
@classmethod
async def get_async_mock_client(
cls,
standalones,
members,
mongoses,
hello_hosts=None,
arbiters=None,
down_hosts=None,
*args,
**kwargs,
):
c = AsyncMockClient(
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
)
await c.aconnect()
return c
def kill_host(self, host):
"""Host is like 'a:1'."""
self.mock_down_hosts.append(host)
def revive_host(self, host):
"""Host is like 'a:1'."""
self.mock_down_hosts.remove(host)
def set_wire_version_range(self, host, min_version, max_version):
self.mock_wire_versions[host] = (min_version, max_version)
def set_max_write_batch_size(self, host, size):
self.mock_max_write_batch_sizes[host] = size
def mock_hello(self, host):
"""Return mock hello response (a dict) and round trip time."""
if host in self.mock_wire_versions:
min_wire_version, max_wire_version = self.mock_wire_versions[host]
else:
min_wire_version = common.MIN_SUPPORTED_WIRE_VERSION
max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION
max_write_batch_size = self.mock_max_write_batch_sizes.get(
host, common.MAX_WRITE_BATCH_SIZE
)
rtt = self.mock_rtts.get(host, 0)
# host is like 'a:1'.
if host in self.mock_down_hosts:
raise NetworkTimeout("mock timeout")
elif host in self.mock_standalones:
response = {
"ok": 1,
HelloCompat.LEGACY_CMD: True,
"minWireVersion": min_wire_version,
"maxWireVersion": max_wire_version,
"maxWriteBatchSize": max_write_batch_size,
}
elif host in self.mock_members:
primary = host == self.mock_primary
# Simulate a replica set member.
response = {
"ok": 1,
HelloCompat.LEGACY_CMD: primary,
"secondary": not primary,
"setName": "rs",
"hosts": self.mock_hello_hosts,
"minWireVersion": min_wire_version,
"maxWireVersion": max_wire_version,
"maxWriteBatchSize": max_write_batch_size,
}
if self.mock_primary:
response["primary"] = self.mock_primary
if host in self.mock_arbiters:
response["arbiterOnly"] = True
response["secondary"] = False
elif host in self.mock_mongoses:
response = {
"ok": 1,
HelloCompat.LEGACY_CMD: True,
"minWireVersion": min_wire_version,
"maxWireVersion": max_wire_version,
"msg": "isdbgrid",
"maxWriteBatchSize": max_write_batch_size,
}
else:
# In test_internal_ips(), we try to connect to a host listed
# in hello['hosts'] but not publicly accessible.
raise AutoReconnect("Unknown host: %s" % host)
return response, rtt
def _process_periodic_tasks(self):
# Avoid the background thread causing races, e.g. a surprising
# reconnect while we're trying to test a disconnected client.
pass

File diff suppressed because it is too large Load Diff

View File

@ -1825,7 +1825,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
await self.db.test.insert_many([{"i": i} for i in range(150)])
client = await async_rs_or_single_client(maxPoolSize=1)
self.addAsyncCleanup(client.close)
self.addAsyncCleanup(client.aclose)
pool = await async_get_pool(client)
# Make sure the socket is returned after exhaustion.

View File

@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest):
async def test_check_exists(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.close)
self.addAsyncCleanup(client.aclose)
db = client[self.db.name]
await db.drop_collection("unique")
await db.create_collection("unique", check_exists=True)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared constants and helper method for pymongo, bson, and gridfs test suites."""
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations
import base64

View File

@ -20,14 +20,15 @@ import weakref
from functools import partial
from test import client_context
from pymongo import common
from pymongo import MongoClient, common
from pymongo.errors import AutoReconnect, NetworkTimeout
from pymongo.hello import Hello, HelloCompat
from pymongo.server_description import ServerDescription
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.monitor import Monitor
from pymongo.synchronous.pool import Pool
_IS_SYNC = True
class MockPool(Pool):
def __init__(self, client, pair, *args, **kwargs):
@ -77,7 +78,7 @@ class DummyMonitor:
self.opened = False
class MockMonitor(Monitor):
class SyncMockMonitor(Monitor):
def __init__(self, client, server_description, topology, pool, topology_settings):
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
# to avoid cycles.
@ -141,13 +142,32 @@ class MockClient(MongoClient):
self.mock_rtts = {}
kwargs["_pool_class"] = partial(MockPool, self)
kwargs["_monitor_class"] = partial(MockMonitor, self)
kwargs["_monitor_class"] = partial(SyncMockMonitor, self)
client_options = client_context.default_client_options.copy()
client_options.update(kwargs)
super().__init__(*args, **client_options)
@classmethod
def get_mock_client(
cls,
standalones,
members,
mongoses,
hello_hosts=None,
arbiters=None,
down_hosts=None,
*args,
**kwargs,
):
c = MockClient(
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
)
c._connect()
return c
def kill_host(self, host):
"""Host is like 'a:1'."""
self.mock_down_hosts.append(host)

View File

@ -23,9 +23,8 @@ from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test import IntegrationTest, client_context, remove_all_users, unittest
from test.utils import (
remove_all_users,
rs_or_single_client_noauth,
single_client,
wait_until,

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import _thread as thread
import asyncio
import contextlib
import copy
import datetime
@ -45,10 +46,13 @@ from test import (
IntegrationTest,
MockClientTest,
SkipTest,
UnitTest,
client_context,
client_knobs,
connected,
db_pwd,
db_user,
remove_all_users,
unittest,
)
from test.pymongo_mocks import MockClient
@ -57,14 +61,12 @@ from test.utils import (
CMAPListener,
FunctionCallRecorder,
assertRaisesExactly,
connected,
delay,
get_pool,
gevent_monkey_patched,
is_greenthread_patched,
lazy_client_trial,
one,
remove_all_users,
rs_client,
rs_or_single_client,
rs_or_single_client_noauth,
@ -109,6 +111,7 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor, CursorType
from pymongo.synchronous.database import Database
from pymongo.synchronous.helpers import next
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import (
Connection,
@ -118,18 +121,20 @@ from pymongo.synchronous.topology import _ErrorContext
from pymongo.topology_description import TopologyDescription
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class ClientUnitTest(unittest.TestCase):
class ClientUnitTest(UnitTest):
"""MongoClient tests that don't require a server."""
client: MongoClient
@classmethod
def setUpClass(cls):
def _setup_class(cls):
cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.client.close()
@pytest.fixture(autouse=True)
@ -254,7 +259,8 @@ class ClientUnitTest(unittest.TestCase):
def test_get_default_database(self):
c = rs_or_single_client(
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
connect=False,
)
self.assertEqual(Database(c, "foo"), c.get_default_database())
# Test that default doesn't override the URI value.
@ -269,39 +275,49 @@ class ClientUnitTest(unittest.TestCase):
self.assertEqual(write_concern, db.write_concern)
c = rs_or_single_client(
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
"mongodb://%s:%d/" % (client_context.host, client_context.port),
connect=False,
)
self.assertEqual(Database(c, "foo"), c.get_default_database("foo"))
def test_get_default_database_error(self):
# URI with no database.
c = rs_or_single_client(
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
"mongodb://%s:%d/" % (client_context.host, client_context.port),
connect=False,
)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
uri = "mongodb://%s:%d/foo?authSource=src" % (
client_context.host,
client_context.port,
)
c = rs_or_single_client(uri, connect=False)
self.assertEqual(Database(c, "foo"), c.get_default_database())
def test_get_database_default(self):
c = rs_or_single_client(
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
connect=False,
)
self.assertEqual(Database(c, "foo"), c.get_database())
def test_get_database_default_error(self):
# URI with no database.
c = rs_or_single_client(
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
"mongodb://%s:%d/" % (client_context.host, client_context.port),
connect=False,
)
self.assertRaises(ConfigurationError, c.get_database)
def test_get_database_default_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
uri = "mongodb://%s:%d/foo?authSource=src" % (
client_context.host,
client_context.port,
)
c = rs_or_single_client(uri, connect=False)
self.assertEqual(Database(c, "foo"), c.get_database())
@ -634,7 +650,7 @@ class TestClient(IntegrationTest):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.conns))
@ -645,7 +661,7 @@ class TestClient(IntegrationTest):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper removes idle socket and replaces it with a new one
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn:
pass
# When the reaper runs at the same time as the get_socket, two
@ -659,7 +675,7 @@ class TestClient(IntegrationTest):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper respects maxPoolSize when adding new connections.
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn:
pass
# When the reaper runs at the same time as the get_socket,
@ -673,7 +689,7 @@ class TestClient(IntegrationTest):
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper has removed idle socket and NOT replaced it
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn_one:
pass
# Assert that the pool does not close connections prematurely.
@ -690,12 +706,12 @@ class TestClient(IntegrationTest):
def test_min_pool_size(self):
with client_knobs(kill_cursor_frequency=0.1):
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
self.assertEqual(0, len(server._pool.conns))
# Assert that pool started up at minPoolSize
client = rs_or_single_client(minPoolSize=10)
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
@ -714,7 +730,7 @@ class TestClient(IntegrationTest):
# Use high frequency to test _get_socket_no_auth.
with client_knobs(kill_cursor_frequency=99999999):
client = rs_or_single_client(maxIdleTimeMS=500)
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.conns))
@ -728,7 +744,7 @@ class TestClient(IntegrationTest):
# Test that connections are reused if maxIdleTimeMS is not set.
client = rs_or_single_client()
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
with server._pool.checkout() as conn:
pass
self.assertEqual(1, len(server._pool.conns))
@ -793,12 +809,14 @@ class TestClient(IntegrationTest):
bad_host = "somedomainthatdoesntexist.org"
c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
with self.assertRaises(ConnectionFailure):
c.pymongo_test.test.find_one()
def test_init_disconnected_with_auth(self):
uri = "mongodb://user:pass@somedomainthatdoesntexist"
c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
with self.assertRaises(ConnectionFailure):
c.pymongo_test.test.find_one()
def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
@ -816,7 +834,8 @@ class TestClient(IntegrationTest):
self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False))
# Same seeds but out of order still compares equal:
self.assertEqual(
MongoClient(["a", "b", "c"], connect=False), MongoClient(["c", "a", "b"], connect=False)
MongoClient(["a", "b", "c"], connect=False),
MongoClient(["c", "a", "b"], connect=False),
)
def test_hashable(self):
@ -830,9 +849,10 @@ class TestClient(IntegrationTest):
def test_host_w_port(self):
with self.assertRaises(ValueError):
host = client_context.host
connected(
MongoClient(
f"{client_context.host}:1234567",
f"{host}:1234567",
connectTimeoutMS=1,
serverSelectionTimeoutMS=10,
)
@ -883,10 +903,10 @@ class TestClient(IntegrationTest):
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
def test_list_databases(self):
cmd_docs = self.client.admin.command("listDatabases")["databases"]
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
cursor = self.client.list_databases()
self.assertIsInstance(cursor, CommandCursor)
helper_docs = list(cursor)
helper_docs = cursor.to_list()
self.assertTrue(len(helper_docs) > 0)
self.assertEqual(len(helper_docs), len(cmd_docs))
# PYTHON-3529 Some fields may change between calls, just compare names.
@ -900,7 +920,7 @@ class TestClient(IntegrationTest):
self.client.pymongo_test.test.insert_one({})
cursor = self.client.list_databases(filter={"name": "admin"})
docs = list(cursor)
docs = cursor.to_list()
self.assertEqual(1, len(docs))
self.assertEqual(docs[0]["name"], "admin")
@ -911,7 +931,7 @@ class TestClient(IntegrationTest):
def test_list_database_names(self):
self.client.pymongo_test.test.insert_one({"dummy": "object"})
self.client.pymongo_test_mike.test.insert_one({"dummy": "object"})
cmd_docs = self.client.admin.command("listDatabases")["databases"]
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
cmd_names = [doc["name"] for doc in cmd_docs]
db_names = self.client.list_database_names()
@ -920,8 +940,10 @@ class TestClient(IntegrationTest):
self.assertEqual(db_names, cmd_names)
def test_drop_database(self):
self.assertRaises(TypeError, self.client.drop_database, 5)
self.assertRaises(TypeError, self.client.drop_database, None)
with self.assertRaises(TypeError):
self.client.drop_database(5) # type: ignore[arg-type]
with self.assertRaises(TypeError):
self.client.drop_database(None) # type: ignore[arg-type]
self.client.pymongo_test.test.insert_one({"dummy": "object"})
self.client.pymongo_test2.test.insert_one({"dummy": "object"})
@ -944,7 +966,8 @@ class TestClient(IntegrationTest):
test_client = rs_or_single_client()
coll = test_client.pymongo_test.bar
test_client.close()
self.assertRaises(InvalidOperation, coll.count_documents, {})
with self.assertRaises(InvalidOperation):
coll.count_documents({})
def test_close_kills_cursors(self):
if sys.platform.startswith("java"):
@ -961,7 +984,7 @@ class TestClient(IntegrationTest):
coll.insert_many([{"i": i} for i in range(docs_inserted)])
# Open a cursor and leave it open on the server.
cursor = coll.find().batch_size(10)
cursor = (coll.find()).batch_size(10)
self.assertTrue(bool(next(cursor)))
self.assertLess(cursor.retrieved, docs_inserted)
@ -992,7 +1015,8 @@ class TestClient(IntegrationTest):
self.assertTrue(client._kill_cursors_executor._stopped)
# Reusing the closed client should raise an InvalidOperation error.
self.assertRaises(InvalidOperation, client.admin.command, "ping")
with self.assertRaises(InvalidOperation):
client.admin.command("ping")
# Thread is still stopped.
self.assertTrue(client._kill_cursors_executor._stopped)
@ -1065,8 +1089,10 @@ class TestClient(IntegrationTest):
)
# Auth with lazy connection.
rs_or_single_client_noauth(
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
(
rs_or_single_client_noauth(
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
)
).pymongo_test.test.find_one()
# Wrong password.
@ -1074,7 +1100,8 @@ class TestClient(IntegrationTest):
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False
)
self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one)
with self.assertRaises(OperationFailure):
bad_client.pymongo_test.test.find_one()
@client_context.require_auth
def test_username_and_password(self):
@ -1093,13 +1120,14 @@ class TestClient(IntegrationTest):
c.server_info()
with self.assertRaises(OperationFailure):
rs_or_single_client_noauth(username="ad min", password="foo").server_info()
(rs_or_single_client_noauth(username="ad min", password="foo")).server_info()
@client_context.require_auth
@client_context.require_no_fips
def test_lazy_auth_raises_operation_failure(self):
host = client_context.host
lazy_client = rs_or_single_client_noauth(
f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False
f"mongodb://user:wrong@{host}/pymongo_test", connect=False
)
assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one)
@ -1125,11 +1153,10 @@ class TestClient(IntegrationTest):
self.assertTrue(mongodb_socket in repr(client))
# Confirm it fails with a missing socket.
self.assertRaises(
ConnectionFailure,
connected,
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
)
with self.assertRaises(ConnectionFailure):
connected(
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
)
def test_document_class(self):
c = self.client
@ -1154,27 +1181,30 @@ class TestClient(IntegrationTest):
maxIdleTimeMS=10500,
serverSelectionTimeoutMS=10500,
)
self.assertEqual(10.5, get_pool(client).opts.connect_timeout)
self.assertEqual(10.5, get_pool(client).opts.socket_timeout)
self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds)
self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout)
self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout)
self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds)
self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds)
self.assertEqual(10.5, client.options.server_selection_timeout)
def test_socket_timeout_ms_validation(self):
c = rs_or_single_client(socketTimeoutMS=10 * 1000)
self.assertEqual(10, get_pool(c).opts.socket_timeout)
self.assertEqual(10, (get_pool(c)).opts.socket_timeout)
c = connected(rs_or_single_client(socketTimeoutMS=None))
self.assertEqual(None, get_pool(c).opts.socket_timeout)
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
c = connected(rs_or_single_client(socketTimeoutMS=0))
self.assertEqual(None, get_pool(c).opts.socket_timeout)
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1)
with self.assertRaises(ValueError):
rs_or_single_client(socketTimeoutMS=-1)
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10)
with self.assertRaises(ValueError):
rs_or_single_client(socketTimeoutMS=1e10)
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS="foo")
with self.assertRaises(ValueError):
rs_or_single_client(socketTimeoutMS="foo")
def test_socket_timeout(self):
no_timeout = self.client
@ -1189,11 +1219,12 @@ class TestClient(IntegrationTest):
where_func = delay(timeout_sec + 1)
def get_x(db):
doc = next(db.test.find().where(where_func))
doc = next((db.test.find()).where(where_func))
return doc["x"]
self.assertEqual(1, get_x(no_timeout.pymongo_test))
self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test)
with self.assertRaises(NetworkTimeout):
get_x(timeout.pymongo_test)
def test_server_selection_timeout(self):
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
@ -1224,7 +1255,7 @@ class TestClient(IntegrationTest):
def test_waitQueueTimeoutMS(self):
client = rs_or_single_client(waitQueueTimeoutMS=2000)
self.assertEqual(get_pool(client).opts.wait_queue_timeout, 2)
self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2)
def test_socketKeepAlive(self):
pool = get_pool(self.client)
@ -1244,11 +1275,11 @@ class TestClient(IntegrationTest):
now = datetime.datetime.now(tz=datetime.timezone.utc)
aware.pymongo_test.test.insert_one({"x": now})
self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo)
self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo)
self.assertEqual(None, (naive.pymongo_test.test.find_one())["x"].tzinfo)
self.assertEqual(utc, (aware.pymongo_test.test.find_one())["x"].tzinfo)
self.assertEqual(
aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
naive.pymongo_test.test.find_one()["x"],
(aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None),
(naive.pymongo_test.test.find_one())["x"],
)
@client_context.require_ipv6
@ -1282,18 +1313,21 @@ class TestClient(IntegrationTest):
# The socket used for the previous commands has been returned to the
# pool
self.assertEqual(1, len(get_pool(client).conns))
self.assertEqual(1, len((get_pool(client)).conns))
with contextlib.closing(client):
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
with self.assertRaises(InvalidOperation):
client.pymongo_test.test.find_one()
client = rs_or_single_client()
with client as client:
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
with self.assertRaises(InvalidOperation):
client.pymongo_test.test.find_one()
# contextlib async support was added in Python 3.10
if _IS_SYNC or sys.version_info >= (3, 10):
with contextlib.closing(client):
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
with self.assertRaises(InvalidOperation):
client.pymongo_test.test.find_one()
client = rs_or_single_client()
with client as client:
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
with self.assertRaises(InvalidOperation):
client.pymongo_test.test.find_one()
@client_context.require_sync
def test_interrupt_signal(self):
if sys.platform.startswith("java"):
# We can't figure out how to raise an exception on a thread that's
@ -1344,7 +1378,7 @@ class TestClient(IntegrationTest):
raised = False
try:
# Will be interrupted by a KeyboardInterrupt.
next(db.foo.find({"$where": where}))
next(db.foo.find({"$where": where})) # type: ignore[call-overload]
except KeyboardInterrupt:
raised = True
@ -1355,7 +1389,7 @@ class TestClient(IntegrationTest):
# Raises AssertionError due to PYTHON-294 -- Mongo's response to
# the previous find() is still waiting to be read on the socket,
# so the request id's don't match.
self.assertEqual({"_id": 1}, next(db.foo.find()))
self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload]
finally:
if old_signal_handler:
signal.signal(signal.SIGALRM, old_signal_handler)
@ -1374,7 +1408,8 @@ class TestClient(IntegrationTest):
old_conn = next(iter(pool.conns))
client.pymongo_test.test.drop()
client.pymongo_test.test.insert_one({"_id": "foo"})
self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"})
with self.assertRaises(OperationFailure):
client.pymongo_test.test.insert_one({"_id": "foo"})
self.assertEqual(socket_count, len(pool.conns))
new_con = next(iter(pool.conns))
@ -1392,23 +1427,29 @@ class TestClient(IntegrationTest):
client = rs_or_single_client(connect=False, w=0)
self.addCleanup(client.close)
client.test_lazy_connect_w0.test.insert_one({})
wait_until(
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document"
)
def predicate():
return client.test_lazy_connect_w0.test.count_documents({}) == 1
wait_until(predicate, "find one document")
client = rs_or_single_client(connect=False, w=0)
self.addCleanup(client.close)
client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
wait_until(
lambda: client.test_lazy_connect_w0.test.find_one().get("x") == 1, "update one document"
)
def predicate():
return (client.test_lazy_connect_w0.test.find_one()).get("x") == 1
wait_until(predicate, "update one document")
client = rs_or_single_client(connect=False, w=0)
self.addCleanup(client.close)
client.test_lazy_connect_w0.test.delete_one({})
wait_until(
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document"
)
def predicate():
return client.test_lazy_connect_w0.test.count_documents({}) == 0
wait_until(predicate, "delete one document")
@client_context.require_no_mongos
def test_exhaust_network_error(self):
@ -1445,12 +1486,13 @@ class TestClient(IntegrationTest):
# Cause a network error on the actual socket.
pool = get_pool(c)
socket_info = one(pool.conns)
socket_info.conn.close()
conn = one(pool.conns)
conn.conn.close()
# Connection.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect.
self.assertRaises(AutoReconnect, c.test.collection.find_one)
with self.assertRaises(AutoReconnect):
c.test.collection.find_one()
# No semaphore leak, the pool is allowed to make a new socket.
c.test.collection.find_one()
@ -1638,13 +1680,19 @@ class TestClient(IntegrationTest):
def stop(self):
self.running = False
def run(self):
def _run(self):
while self.running:
exc = AutoReconnect("mock pool error")
ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None)
client._topology.handle_error(pool.address, ctx)
time.sleep(0.001)
def run(self):
if _IS_SYNC:
self._run()
else:
asyncio.run(self._run())
t = ResetPoolThread(pool)
t.start()
@ -1755,7 +1803,7 @@ class TestClient(IntegrationTest):
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
):
assert client.address is not None
expected = "{}:{}: ".format(*client.address)
expected = "{}:{}: ".format(*(client.address))
with self.assertRaisesRegex(AutoReconnect, expected):
client.pymongo_test.test.find_one({})
@ -1814,6 +1862,7 @@ class TestClient(IntegrationTest):
"loadBalanced clients do not run SDAM",
)
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
@client_context.require_sync
def test_sigstop_sigcont(self):
test_dir = os.path.dirname(os.path.realpath(__file__))
script = os.path.join(test_dir, "sigstop_sigcont.py")
@ -1961,7 +2010,8 @@ class TestExhaustCursor(IntegrationTest):
SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST
)
self.assertRaises(OperationFailure, cursor.next)
with self.assertRaises(OperationFailure):
cursor.next()
self.assertFalse(conn.closed)
# The socket was checked in and the semaphore was decremented.
@ -1998,7 +2048,8 @@ class TestExhaustCursor(IntegrationTest):
return message._OpReply.unpack(msg)
conn.receive_message = receive_message
self.assertRaises(OperationFailure, list, cursor)
with self.assertRaises(OperationFailure):
cursor.to_list()
# Unpatch the instance.
del conn.receive_message
@ -2019,7 +2070,8 @@ class TestExhaustCursor(IntegrationTest):
conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
self.assertRaises(ConnectionFailure, cursor.next)
with self.assertRaises(ConnectionFailure):
cursor.next()
self.assertTrue(conn.closed)
# The socket was closed and the semaphore was decremented.
@ -2046,7 +2098,8 @@ class TestExhaustCursor(IntegrationTest):
conn.conn.close()
# A getmore fails.
self.assertRaises(ConnectionFailure, list, cursor)
with self.assertRaises(ConnectionFailure):
cursor.to_list()
self.assertTrue(conn.closed)
wait_until(
@ -2057,6 +2110,7 @@ class TestExhaustCursor(IntegrationTest):
self.assertNotIn(conn, pool.conns)
self.assertEqual(0, pool.requests)
@client_context.require_sync
def test_gevent_task(self):
if not gevent_monkey_patched():
raise SkipTest("Must be running monkey patched by gevent")
@ -2070,6 +2124,7 @@ class TestExhaustCursor(IntegrationTest):
task.kill()
self.assertTrue(task.dead)
@client_context.require_sync
def test_gevent_timeout(self):
if not gevent_monkey_patched():
raise SkipTest("Must be running monkey patched by gevent")
@ -2101,6 +2156,7 @@ class TestExhaustCursor(IntegrationTest):
self.assertIsNone(tt.get())
self.assertIsNone(ct.get())
@client_context.require_sync
def test_gevent_timeout_when_creating_connection(self):
if not gevent_monkey_patched():
raise SkipTest("Must be running monkey patched by gevent")
@ -2145,6 +2201,7 @@ class TestClientLazyConnect(IntegrationTest):
def _get_client(self):
return rs_or_single_client(connect=False)
@client_context.require_sync
def test_insert_one(self):
def reset(collection):
collection.drop()
@ -2157,6 +2214,7 @@ class TestClientLazyConnect(IntegrationTest):
lazy_client_trial(reset, insert_one, test, self._get_client)
@client_context.require_sync
def test_update_one(self):
def reset(collection):
collection.drop()
@ -2171,6 +2229,7 @@ class TestClientLazyConnect(IntegrationTest):
lazy_client_trial(reset, update_one, test, self._get_client)
@client_context.require_sync
def test_delete_one(self):
def reset(collection):
collection.drop()
@ -2184,6 +2243,7 @@ class TestClientLazyConnect(IntegrationTest):
lazy_client_trial(reset, delete_one, test, self._get_client)
@client_context.require_sync
def test_find_one(self):
results: list = []
@ -2203,7 +2263,7 @@ class TestClientLazyConnect(IntegrationTest):
class TestMongoClientFailover(MockClientTest):
def test_discover_primary(self):
c = MockClient(
c = MockClient.get_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3"],
mongoses=[],
@ -2219,13 +2279,17 @@ class TestMongoClientFailover(MockClientTest):
# Fail over.
c.kill_host("a:1")
c.mock_primary = "b:2"
wait_until(lambda: c.address == ("b", 2), "wait for server address to be updated")
def predicate():
return (c.address) == ("b", 2)
wait_until(predicate, "wait for server address to be updated")
# a:1 not longer in nodes.
self.assertLess(len(c.nodes), 3)
def test_reconnect(self):
# Verify the node list isn't forgotten during a network failure.
c = MockClient(
c = MockClient.get_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3"],
mongoses=[],
@ -2245,12 +2309,13 @@ class TestMongoClientFailover(MockClientTest):
# MongoClient discovers it's alone. The first attempt raises either
# ServerSelectionTimeoutError or AutoReconnect (from
# MockPool.get_socket).
self.assertRaises(AutoReconnect, c.db.collection.find_one)
# AsyncMockPool.get_socket).
with self.assertRaises(AutoReconnect):
c.db.collection.find_one()
# But it can reconnect.
c.revive_host("a:1")
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
self.assertEqual(c.address, ("a", 1))
def _test_network_error(self, operation_callback):
@ -2273,7 +2338,7 @@ class TestMongoClientFailover(MockClientTest):
# Set host-specific information so we can test whether it is reset.
c.set_wire_version_range("a:1", 2, 6)
c.set_wire_version_range("b:2", 2, 7)
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
wait_until(lambda: len(c.nodes) == 2, "connect")
c.kill_host("a:1")
@ -2281,17 +2346,18 @@ class TestMongoClientFailover(MockClientTest):
# MongoClient is disconnected from the primary. This raises either
# ServerSelectionTimeoutError or AutoReconnect (from
# MockPool.get_socket).
self.assertRaises(AutoReconnect, operation_callback, c)
with self.assertRaises(AutoReconnect):
operation_callback(c)
# The primary's description is reset.
server_a = c._get_topology().get_server_by_address(("a", 1))
server_a = (c._get_topology()).get_server_by_address(("a", 1))
sd_a = server_a.description
self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type)
self.assertEqual(0, sd_a.min_wire_version)
self.assertEqual(0, sd_a.max_wire_version)
# ...but not the secondary's.
server_b = c._get_topology().get_server_by_address(("b", 2))
server_b = (c._get_topology()).get_server_by_address(("b", 2))
sd_b = server_b.description
self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type)
self.assertEqual(2, sd_b.min_wire_version)
@ -2332,7 +2398,7 @@ class TestClientPool(MockClientTest):
@client_context.require_connection
def test_rs_client_does_not_maintain_pool_to_arbiters(self):
listener = CMAPListener()
c = MockClient(
c = MockClient.get_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3", "d:4"],
mongoses=[],
@ -2363,7 +2429,7 @@ class TestClientPool(MockClientTest):
@client_context.require_connection
def test_direct_client_maintains_pool_to_arbiter(self):
listener = CMAPListener()
c = MockClient(
c = MockClient.get_mock_client(
standalones=[],
members=["a:1", "b:2", "c:3"],
mongoses=[],

View File

@ -20,8 +20,8 @@ import uuid
sys.path[0:0] = [""]
from test import IntegrationTest, client_context, unittest
from test.utils import connected, rs_or_single_client, single_client
from test import IntegrationTest, client_context, connected, unittest
from test.utils import rs_or_single_client, single_client
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
from bson.codec_options import CodecOptions

View File

@ -20,13 +20,12 @@ import sys
sys.path[0:0] = [""]
from test import IntegrationTest, unittest
from test import IntegrationTest, drop_collections, unittest
from test.utils import (
SpecTestCreator,
camel_to_snake,
camel_to_snake_args,
camel_to_upper_camel,
drop_collections,
)
from pymongo import WriteConcern, operations

View File

@ -22,9 +22,9 @@ from pymongo.operations import _Op
sys.path[0:0] = [""]
from test import MockClientTest, client_context, unittest
from test import MockClientTest, client_context, connected, unittest
from test.pymongo_mocks import MockClient
from test.utils import connected, wait_until
from test.utils import wait_until
from pymongo.errors import AutoReconnect, InvalidOperation
from pymongo.server_selectors import writable_server_selector

View File

@ -22,10 +22,9 @@ from functools import partial
sys.path[0:0] = [""]
from test import IntegrationTest, unittest
from test import IntegrationTest, connected, unittest
from test.utils import (
ServerAndTopologyEventListener,
connected,
single_client,
wait_until,
)

View File

@ -26,10 +26,9 @@ from pymongo.operations import _Op
sys.path[0:0] = [""]
from test import IntegrationTest, SkipTest, client_context, unittest
from test import IntegrationTest, SkipTest, client_context, connected, unittest
from test.utils import (
OvertCommandListener,
connected,
one,
rs_client,
single_client,

View File

@ -21,13 +21,19 @@ import sys
sys.path[0:0] = [""]
from test import HAVE_IPADDRESS, IntegrationTest, SkipTest, client_context, unittest
from test import (
HAVE_IPADDRESS,
IntegrationTest,
SkipTest,
client_context,
connected,
remove_all_users,
unittest,
)
from test.utils import (
EventListener,
cat_files,
connected,
ignore_deprecations,
remove_all_users,
)
from urllib.parse import quote_plus

View File

@ -621,7 +621,10 @@ async def _async_mongo_client(host, port, authenticate=True, directConnection=No
):
client_options["username"] = db_user
client_options["password"] = db_pwd
return AsyncMongoClient(uri, port, **client_options)
client = AsyncMongoClient(uri, port, **client_options)
if client._options.connect:
await client.aconnect()
return client
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
@ -843,16 +846,6 @@ def server_started_with_auth(client):
return "--auth" in argv or "--keyFile" in argv
def drop_collections(db):
# Drop all non-system collections in this database.
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
db.drop_collection(coll)
def remove_all_users(db):
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
def joinall(threads):
"""Join threads with a 5-minute timeout, assert joins succeeded"""
for t in threads:
@ -860,17 +853,6 @@ def joinall(threads):
assert not t.is_alive(), "Thread %s hung" % t
def connected(client):
"""Convenience to wait for a newly-constructed client to connect."""
with warnings.catch_warnings():
# Ignore warning that ping is always routed to primary even
# if client's read preference isn't PRIMARY.
warnings.simplefilter("ignore", UserWarning)
client.admin.command("ping") # Force connection.
return client
def wait_until(predicate, success_description, timeout=10):
"""Wait up to 10 seconds (by default) for predicate to be true.
@ -957,6 +939,20 @@ def assertRaisesExactly(cls, fn, *args, **kwargs):
raise AssertionError("%s not raised" % cls)
async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs):
"""
Unlike the standard assertRaises, this checks that a function raises a
specific class of exception, and not a subclass. E.g., check that
MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect.
"""
try:
await fn(*args, **kwargs)
except Exception as e:
assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}"
else:
raise AssertionError("%s not raised" % cls)
@contextlib.contextmanager
def _ignore_deprecations():
with warnings.catch_warnings():

View File

@ -72,11 +72,20 @@ replacements = {
"addAsyncCleanup": "addCleanup",
"async_setup_class": "setup_class",
"IsolatedAsyncioTestCase": "TestCase",
"AsyncUnitTest": "UnitTest",
"AsyncMockClient": "MockClient",
"async_get_pool": "get_pool",
"async_is_mongos": "is_mongos",
"async_rs_or_single_client": "rs_or_single_client",
"async_rs_or_single_client_noauth": "rs_or_single_client_noauth",
"async_rs_client": "rs_client",
"async_single_client": "single_client",
"async_from_client": "from_client",
"aclosing": "closing",
"asyncAssertRaisesExactly": "assertRaisesExactly",
"get_async_mock_client": "get_mock_client",
"aconnect": "_connect",
"aclose": "close",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -131,6 +140,8 @@ sync_gridfs_files = [
converted_tests = [
"__init__.py",
"conftest.py",
"pymongo_mocks.py",
"test_client.py",
"test_collection.py",
"test_database.py",
]