PYTHON-4533 - Convert test/test_client.py to async (#1730)
This commit is contained in:
parent
554ce7d984
commit
d0193eb045
@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
self.client_ref = None
|
||||
self.key_vault_coll = None
|
||||
if self.mongocryptd_client:
|
||||
await self.mongocryptd_client.close()
|
||||
await self.mongocryptd_client.aclose()
|
||||
self.mongocryptd_client = None
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class _Encrypter:
|
||||
self._closed = True
|
||||
await self._auto_encrypter.close()
|
||||
if self._internal_client:
|
||||
await self._internal_client.close()
|
||||
await self._internal_client.aclose()
|
||||
self._internal_client = None
|
||||
|
||||
|
||||
|
||||
@ -861,6 +861,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# This will be used later if we fork.
|
||||
AsyncMongoClient._clients[self._topology._topology_id] = self
|
||||
|
||||
async def aconnect(self) -> None:
|
||||
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
|
||||
await self._get_topology()
|
||||
|
||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||
self._topology = Topology(self._topology_settings)
|
||||
# Seed the topology with the old one's pid so we can detect clients
|
||||
@ -1354,13 +1358,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
await self.aclose()
|
||||
|
||||
# See PYTHON-3084.
|
||||
__iter__ = None
|
||||
|
||||
def __next__(self) -> NoReturn:
|
||||
raise TypeError("'MongoClient' object is not iterable")
|
||||
raise TypeError("'AsyncMongoClient' object is not iterable")
|
||||
|
||||
next = __next__
|
||||
|
||||
@ -1490,7 +1494,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# command.
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
async def aclose(self) -> None:
|
||||
"""Cleanup client resources and disconnect from MongoDB.
|
||||
|
||||
End all server sessions created by this client by sending one or more
|
||||
|
||||
@ -860,6 +860,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# This will be used later if we fork.
|
||||
MongoClient._clients[self._topology._topology_id] = self
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
|
||||
self._get_topology()
|
||||
|
||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||
self._topology = Topology(self._topology_settings)
|
||||
# Seed the topology with the old one's pid so we can detect clients
|
||||
|
||||
107
test/__init__.py
107
test/__init__.py
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -39,8 +40,6 @@ from test.helpers import (
|
||||
TEST_SERVERLESS,
|
||||
TLS_OPTIONS,
|
||||
SystemCertsPatcher,
|
||||
_all_users,
|
||||
_create_user,
|
||||
client_knobs,
|
||||
db_pwd,
|
||||
db_user,
|
||||
@ -62,9 +61,9 @@ try:
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
@ -812,6 +811,12 @@ class ClientContext:
|
||||
func=func,
|
||||
)
|
||||
|
||||
def require_sync(self, func):
|
||||
"""Run a test only if using the synchronous API."""
|
||||
return self._require(
|
||||
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||
)
|
||||
|
||||
def mongos_seeds(self):
|
||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||
|
||||
@ -919,6 +924,32 @@ class PyMongoTestCase(unittest.TestCase):
|
||||
self.assertEqual(proc.exitcode, 0)
|
||||
|
||||
|
||||
class UnitTest(PyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls._setup_class()
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls._tearDown_class()
|
||||
|
||||
|
||||
class IntegrationTest(PyMongoTestCase):
|
||||
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
||||
|
||||
@ -933,6 +964,13 @@ class IntegrationTest(PyMongoTestCase):
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def _setup_class(cls):
|
||||
@ -947,6 +985,10 @@ class IntegrationTest(PyMongoTestCase):
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
for c in collections:
|
||||
@ -959,7 +1001,7 @@ class IntegrationTest(PyMongoTestCase):
|
||||
self.addCleanup(patcher.disable)
|
||||
|
||||
|
||||
class MockClientTest(unittest.TestCase):
|
||||
class MockClientTest(UnitTest):
|
||||
"""Base class for TestCases that use MockClient.
|
||||
|
||||
This class is *not* an IntegrationTest: if properly written, MockClient
|
||||
@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
@client_context.require_no_load_balancer
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@client_context.require_no_load_balancer
|
||||
def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
@ -1051,3 +1111,38 @@ def print_running_clients():
|
||||
processed.add(obj._topology_id)
|
||||
except ReferenceError:
|
||||
pass
|
||||
|
||||
|
||||
def _all_users(db):
|
||||
return {u["user"] for u in (db.command("usersInfo")).get("users", [])}
|
||||
|
||||
|
||||
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||
cmd = SON([("createUser", user)])
|
||||
# X509 doesn't use a password
|
||||
if pwd:
|
||||
cmd["pwd"] = pwd
|
||||
cmd["roles"] = roles or ["root"]
|
||||
cmd.update(**kwargs)
|
||||
return authdb.command(cmd)
|
||||
|
||||
|
||||
def connected(client):
|
||||
"""Convenience to wait for a newly-constructed client to connect."""
|
||||
with warnings.catch_warnings():
|
||||
# Ignore warning that ping is always routed to primary even
|
||||
# if client's read preference isn't PRIMARY.
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
client.admin.command("ping") # Force connection.
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def drop_collections(db: Database):
|
||||
# Drop all non-system collections in this database.
|
||||
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
||||
db.drop_collection(coll)
|
||||
|
||||
|
||||
def remove_all_users(db: Database):
|
||||
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import gc
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -39,8 +40,6 @@ from test.helpers import (
|
||||
TEST_SERVERLESS,
|
||||
TLS_OPTIONS,
|
||||
SystemCertsPatcher,
|
||||
_all_users,
|
||||
_create_user,
|
||||
client_knobs,
|
||||
db_pwd,
|
||||
db_user,
|
||||
@ -62,9 +61,9 @@ try:
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
@ -184,7 +183,7 @@ class AsyncClientContext:
|
||||
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
|
||||
return None
|
||||
finally:
|
||||
await client.close()
|
||||
await client.aclose()
|
||||
|
||||
async def _init_client(self):
|
||||
self.client = await self._connect(host, port)
|
||||
@ -229,7 +228,7 @@ class AsyncClientContext:
|
||||
if not self.serverless and not IS_SRV:
|
||||
# See if db_user already exists.
|
||||
if not await self._check_user_provided():
|
||||
_create_user(self.client.admin, db_user, db_pwd)
|
||||
await _create_user(self.client.admin, db_user, db_pwd)
|
||||
|
||||
self.client = await self._connect(
|
||||
host,
|
||||
@ -304,7 +303,7 @@ class AsyncClientContext:
|
||||
params = self.cmd_line["parsed"].get("setParameter", {})
|
||||
if params.get("enableTestCommands") == "1":
|
||||
self.test_commands_enabled = True
|
||||
self.has_ipv6 = self._server_started_with_ipv6()
|
||||
self.has_ipv6 = await self._server_started_with_ipv6()
|
||||
|
||||
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
|
||||
if self.is_mongos:
|
||||
@ -390,7 +389,7 @@ class AsyncClientContext:
|
||||
)
|
||||
|
||||
try:
|
||||
return db_user in _all_users(client.admin)
|
||||
return db_user in await _all_users(client.admin)
|
||||
except pymongo.errors.OperationFailure as e:
|
||||
assert e.details is not None
|
||||
msg = e.details.get("errmsg", "")
|
||||
@ -400,7 +399,7 @@ class AsyncClientContext:
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
await client.close()
|
||||
await client.aclose()
|
||||
|
||||
def _server_started_with_auth(self):
|
||||
# MongoDB >= 2.0
|
||||
@ -482,9 +481,9 @@ class AsyncClientContext:
|
||||
return decorate
|
||||
return make_wrapper(func)
|
||||
|
||||
def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
||||
async def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
||||
kwargs["writeConcern"] = {"w": self.w}
|
||||
return _create_user(self.client[dbname], user, pwd, roles, **kwargs)
|
||||
return await _create_user(self.client[dbname], user, pwd, roles, **kwargs)
|
||||
|
||||
async def drop_user(self, dbname, user):
|
||||
await self.client[dbname].command("dropUser", user, writeConcern={"w": self.w})
|
||||
@ -814,6 +813,12 @@ class AsyncClientContext:
|
||||
func=func,
|
||||
)
|
||||
|
||||
def require_sync(self, func):
|
||||
"""Run a test only if using the synchronous API."""
|
||||
return self._require(
|
||||
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||
)
|
||||
|
||||
def mongos_seeds(self):
|
||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||
|
||||
@ -921,6 +926,32 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(proc.exitcode, 0)
|
||||
|
||||
|
||||
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await cls._setup_class()
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls._tearDown_class()
|
||||
|
||||
|
||||
class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
||||
|
||||
@ -935,6 +966,13 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
@ -949,6 +987,10 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
else:
|
||||
cls.credentials = {}
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
async def cleanup_colls(self, *collections):
|
||||
"""Cleanup collections faster than drop_collection."""
|
||||
for c in collections:
|
||||
@ -961,7 +1003,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
self.addCleanup(patcher.disable)
|
||||
|
||||
|
||||
class AsyncMockClientTest(unittest.TestCase):
|
||||
class AsyncMockClientTest(AsyncUnitTest):
|
||||
"""Base class for TestCases that use MockClient.
|
||||
|
||||
This class is *not* an IntegrationTest: if properly written, MockClient
|
||||
@ -974,8 +1016,26 @@ class AsyncMockClientTest(unittest.TestCase):
|
||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||
# with loadBalanced=True.
|
||||
@classmethod
|
||||
@async_client_context.require_no_load_balancer
|
||||
def setUpClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._setup_class()
|
||||
else:
|
||||
asyncio.run(cls._setup_class())
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if _IS_SYNC:
|
||||
cls._tearDown_class()
|
||||
else:
|
||||
asyncio.run(cls._tearDown_class())
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_no_load_balancer
|
||||
async def _setup_class(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
@ -1015,7 +1075,7 @@ async def async_teardown():
|
||||
await c.drop_database("pymongo_test2")
|
||||
await c.drop_database("pymongo_test_mike")
|
||||
await c.drop_database("pymongo_test_bernie")
|
||||
await c.close()
|
||||
await c.aclose()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
@ -1053,3 +1113,38 @@ def print_running_clients():
|
||||
processed.add(obj._topology_id)
|
||||
except ReferenceError:
|
||||
pass
|
||||
|
||||
|
||||
async def _all_users(db):
|
||||
return {u["user"] for u in (await db.command("usersInfo")).get("users", [])}
|
||||
|
||||
|
||||
async def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||
cmd = SON([("createUser", user)])
|
||||
# X509 doesn't use a password
|
||||
if pwd:
|
||||
cmd["pwd"] = pwd
|
||||
cmd["roles"] = roles or ["root"]
|
||||
cmd.update(**kwargs)
|
||||
return await authdb.command(cmd)
|
||||
|
||||
|
||||
async def connected(client):
|
||||
"""Convenience to wait for a newly-constructed client to connect."""
|
||||
with warnings.catch_warnings():
|
||||
# Ignore warning that ping is always routed to primary even
|
||||
# if client's read preference isn't PRIMARY.
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
await client.admin.command("ping") # Force connection.
|
||||
|
||||
return client
|
||||
|
||||
|
||||
async def drop_collections(db: AsyncDatabase):
|
||||
# Drop all non-system collections in this database.
|
||||
for coll in await db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
||||
await db.drop_collection(coll)
|
||||
|
||||
|
||||
async def remove_all_users(db: AsyncDatabase):
|
||||
await db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": async_client_context.w})
|
||||
|
||||
252
test/asynchronous/pymongo_mocks.py
Normal file
252
test/asynchronous/pymongo_mocks.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright 2013-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tools for mocking parts of PyMongo to test other parts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import weakref
|
||||
from functools import partial
|
||||
from test import client_context
|
||||
from test.asynchronous import async_client_context
|
||||
|
||||
from pymongo import AsyncMongoClient, common
|
||||
from pymongo.asynchronous.monitor import Monitor
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.errors import AutoReconnect, NetworkTimeout
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.server_description import ServerDescription
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class MockPool(Pool):
|
||||
def __init__(self, client, pair, *args, **kwargs):
|
||||
# MockPool gets a 'client' arg, regular pools don't. Weakref it to
|
||||
# avoid cycle with __del__, causing ResourceWarnings in Python 3.3.
|
||||
self.client = weakref.proxy(client)
|
||||
self.mock_host, self.mock_port = pair
|
||||
|
||||
# Actually connect to the default server.
|
||||
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def checkout(self, handler=None):
|
||||
client = self.client
|
||||
host_and_port = f"{self.mock_host}:{self.mock_port}"
|
||||
if host_and_port in client.mock_down_hosts:
|
||||
raise AutoReconnect("mock error")
|
||||
|
||||
assert host_and_port in (
|
||||
client.mock_standalones + client.mock_members + client.mock_mongoses
|
||||
), "bad host: %s" % host_and_port
|
||||
|
||||
async with Pool.checkout(self, handler) as conn:
|
||||
conn.mock_host = self.mock_host
|
||||
conn.mock_port = self.mock_port
|
||||
yield conn
|
||||
|
||||
|
||||
class DummyMonitor:
|
||||
def __init__(self, server_description, topology, pool, topology_settings):
|
||||
self._server_description = server_description
|
||||
self.opened = False
|
||||
|
||||
def cancel_check(self):
|
||||
pass
|
||||
|
||||
def join(self):
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
self.opened = True
|
||||
|
||||
def request_check(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.opened = False
|
||||
|
||||
|
||||
class AsyncMockMonitor(Monitor):
|
||||
def __init__(self, client, server_description, topology, pool, topology_settings):
|
||||
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
|
||||
# to avoid cycles.
|
||||
self.client = weakref.proxy(client)
|
||||
Monitor.__init__(self, server_description, topology, pool, topology_settings)
|
||||
|
||||
async def _check_once(self):
|
||||
client = self.client
|
||||
address = self._server_description.address
|
||||
response, rtt = client.mock_hello("%s:%d" % address) # type: ignore[str-format]
|
||||
return ServerDescription(address, Hello(response), rtt)
|
||||
|
||||
|
||||
class AsyncMockClient(AsyncMongoClient):
|
||||
def __init__(
|
||||
self,
|
||||
standalones,
|
||||
members,
|
||||
mongoses,
|
||||
hello_hosts=None,
|
||||
arbiters=None,
|
||||
down_hosts=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""An AsyncMongoClient connected to the default server, with a mock topology.
|
||||
|
||||
standalones, members, mongoses, arbiters, and down_hosts determine the
|
||||
configuration of the topology. They are formatted like ['a:1', 'b:2'].
|
||||
hello_hosts provides an alternative host list for the server's
|
||||
mocked hello response; see test_connect_with_internal_ips.
|
||||
"""
|
||||
self.mock_standalones = standalones[:]
|
||||
self.mock_members = members[:]
|
||||
|
||||
if self.mock_members:
|
||||
self.mock_primary = self.mock_members[0]
|
||||
else:
|
||||
self.mock_primary = None
|
||||
|
||||
# Hosts that should be considered an arbiter.
|
||||
self.mock_arbiters = arbiters[:] if arbiters else []
|
||||
|
||||
if hello_hosts is not None:
|
||||
self.mock_hello_hosts = hello_hosts
|
||||
else:
|
||||
self.mock_hello_hosts = members[:]
|
||||
|
||||
self.mock_mongoses = mongoses[:]
|
||||
|
||||
# Hosts that should raise socket errors.
|
||||
self.mock_down_hosts = down_hosts[:] if down_hosts else []
|
||||
|
||||
# Hostname -> (min wire version, max wire version)
|
||||
self.mock_wire_versions = {}
|
||||
|
||||
# Hostname -> max write batch size
|
||||
self.mock_max_write_batch_sizes = {}
|
||||
|
||||
# Hostname -> round trip time
|
||||
self.mock_rtts = {}
|
||||
|
||||
kwargs["_pool_class"] = partial(MockPool, self)
|
||||
kwargs["_monitor_class"] = partial(AsyncMockMonitor, self)
|
||||
|
||||
client_options = async_client_context.default_client_options.copy()
|
||||
client_options.update(kwargs)
|
||||
|
||||
super().__init__(*args, **client_options)
|
||||
|
||||
@classmethod
|
||||
async def get_async_mock_client(
|
||||
cls,
|
||||
standalones,
|
||||
members,
|
||||
mongoses,
|
||||
hello_hosts=None,
|
||||
arbiters=None,
|
||||
down_hosts=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
c = AsyncMockClient(
|
||||
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
|
||||
)
|
||||
|
||||
await c.aconnect()
|
||||
return c
|
||||
|
||||
def kill_host(self, host):
|
||||
"""Host is like 'a:1'."""
|
||||
self.mock_down_hosts.append(host)
|
||||
|
||||
def revive_host(self, host):
|
||||
"""Host is like 'a:1'."""
|
||||
self.mock_down_hosts.remove(host)
|
||||
|
||||
def set_wire_version_range(self, host, min_version, max_version):
|
||||
self.mock_wire_versions[host] = (min_version, max_version)
|
||||
|
||||
def set_max_write_batch_size(self, host, size):
|
||||
self.mock_max_write_batch_sizes[host] = size
|
||||
|
||||
def mock_hello(self, host):
|
||||
"""Return mock hello response (a dict) and round trip time."""
|
||||
if host in self.mock_wire_versions:
|
||||
min_wire_version, max_wire_version = self.mock_wire_versions[host]
|
||||
else:
|
||||
min_wire_version = common.MIN_SUPPORTED_WIRE_VERSION
|
||||
max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION
|
||||
|
||||
max_write_batch_size = self.mock_max_write_batch_sizes.get(
|
||||
host, common.MAX_WRITE_BATCH_SIZE
|
||||
)
|
||||
|
||||
rtt = self.mock_rtts.get(host, 0)
|
||||
|
||||
# host is like 'a:1'.
|
||||
if host in self.mock_down_hosts:
|
||||
raise NetworkTimeout("mock timeout")
|
||||
|
||||
elif host in self.mock_standalones:
|
||||
response = {
|
||||
"ok": 1,
|
||||
HelloCompat.LEGACY_CMD: True,
|
||||
"minWireVersion": min_wire_version,
|
||||
"maxWireVersion": max_wire_version,
|
||||
"maxWriteBatchSize": max_write_batch_size,
|
||||
}
|
||||
elif host in self.mock_members:
|
||||
primary = host == self.mock_primary
|
||||
|
||||
# Simulate a replica set member.
|
||||
response = {
|
||||
"ok": 1,
|
||||
HelloCompat.LEGACY_CMD: primary,
|
||||
"secondary": not primary,
|
||||
"setName": "rs",
|
||||
"hosts": self.mock_hello_hosts,
|
||||
"minWireVersion": min_wire_version,
|
||||
"maxWireVersion": max_wire_version,
|
||||
"maxWriteBatchSize": max_write_batch_size,
|
||||
}
|
||||
|
||||
if self.mock_primary:
|
||||
response["primary"] = self.mock_primary
|
||||
|
||||
if host in self.mock_arbiters:
|
||||
response["arbiterOnly"] = True
|
||||
response["secondary"] = False
|
||||
elif host in self.mock_mongoses:
|
||||
response = {
|
||||
"ok": 1,
|
||||
HelloCompat.LEGACY_CMD: True,
|
||||
"minWireVersion": min_wire_version,
|
||||
"maxWireVersion": max_wire_version,
|
||||
"msg": "isdbgrid",
|
||||
"maxWriteBatchSize": max_write_batch_size,
|
||||
}
|
||||
else:
|
||||
# In test_internal_ips(), we try to connect to a host listed
|
||||
# in hello['hosts'] but not publicly accessible.
|
||||
raise AutoReconnect("Unknown host: %s" % host)
|
||||
|
||||
return response, rtt
|
||||
|
||||
def _process_periodic_tasks(self):
|
||||
# Avoid the background thread causing races, e.g. a surprising
|
||||
# reconnect while we're trying to test a disconnected client.
|
||||
pass
|
||||
2500
test/asynchronous/test_client.py
Normal file
2500
test/asynchronous/test_client.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1825,7 +1825,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await self.db.test.insert_many([{"i": i} for i in range(150)])
|
||||
|
||||
client = await async_rs_or_single_client(maxPoolSize=1)
|
||||
self.addAsyncCleanup(client.close)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
pool = await async_get_pool(client)
|
||||
|
||||
# Make sure the socket is returned after exhaustion.
|
||||
|
||||
@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
async def test_check_exists(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.close)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
db = client[self.db.name]
|
||||
await db.drop_collection("unique")
|
||||
await db.create_collection("unique", check_exists=True)
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared constants and helper method for pymongo, bson, and gridfs test suites."""
|
||||
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
@ -20,14 +20,15 @@ import weakref
|
||||
from functools import partial
|
||||
from test import client_context
|
||||
|
||||
from pymongo import common
|
||||
from pymongo import MongoClient, common
|
||||
from pymongo.errors import AutoReconnect, NetworkTimeout
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.monitor import Monitor
|
||||
from pymongo.synchronous.pool import Pool
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class MockPool(Pool):
|
||||
def __init__(self, client, pair, *args, **kwargs):
|
||||
@ -77,7 +78,7 @@ class DummyMonitor:
|
||||
self.opened = False
|
||||
|
||||
|
||||
class MockMonitor(Monitor):
|
||||
class SyncMockMonitor(Monitor):
|
||||
def __init__(self, client, server_description, topology, pool, topology_settings):
|
||||
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
|
||||
# to avoid cycles.
|
||||
@ -141,13 +142,32 @@ class MockClient(MongoClient):
|
||||
self.mock_rtts = {}
|
||||
|
||||
kwargs["_pool_class"] = partial(MockPool, self)
|
||||
kwargs["_monitor_class"] = partial(MockMonitor, self)
|
||||
kwargs["_monitor_class"] = partial(SyncMockMonitor, self)
|
||||
|
||||
client_options = client_context.default_client_options.copy()
|
||||
client_options.update(kwargs)
|
||||
|
||||
super().__init__(*args, **client_options)
|
||||
|
||||
@classmethod
|
||||
def get_mock_client(
|
||||
cls,
|
||||
standalones,
|
||||
members,
|
||||
mongoses,
|
||||
hello_hosts=None,
|
||||
arbiters=None,
|
||||
down_hosts=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
c = MockClient(
|
||||
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
|
||||
)
|
||||
|
||||
c._connect()
|
||||
return c
|
||||
|
||||
def kill_host(self, host):
|
||||
"""Host is like 'a:1'."""
|
||||
self.mock_down_hosts.append(host)
|
||||
|
||||
@ -23,9 +23,8 @@ from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test import IntegrationTest, client_context, remove_all_users, unittest
|
||||
from test.utils import (
|
||||
remove_all_users,
|
||||
rs_or_single_client_noauth,
|
||||
single_client,
|
||||
wait_until,
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import _thread as thread
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import datetime
|
||||
@ -45,10 +46,13 @@ from test import (
|
||||
IntegrationTest,
|
||||
MockClientTest,
|
||||
SkipTest,
|
||||
UnitTest,
|
||||
client_context,
|
||||
client_knobs,
|
||||
connected,
|
||||
db_pwd,
|
||||
db_user,
|
||||
remove_all_users,
|
||||
unittest,
|
||||
)
|
||||
from test.pymongo_mocks import MockClient
|
||||
@ -57,14 +61,12 @@ from test.utils import (
|
||||
CMAPListener,
|
||||
FunctionCallRecorder,
|
||||
assertRaisesExactly,
|
||||
connected,
|
||||
delay,
|
||||
get_pool,
|
||||
gevent_monkey_patched,
|
||||
is_greenthread_patched,
|
||||
lazy_client_trial,
|
||||
one,
|
||||
remove_all_users,
|
||||
rs_client,
|
||||
rs_or_single_client,
|
||||
rs_or_single_client_noauth,
|
||||
@ -109,6 +111,7 @@ from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.cursor import Cursor, CursorType
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import (
|
||||
Connection,
|
||||
@ -118,18 +121,20 @@ from pymongo.synchronous.topology import _ErrorContext
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
class ClientUnitTest(unittest.TestCase):
|
||||
|
||||
class ClientUnitTest(UnitTest):
|
||||
"""MongoClient tests that don't require a server."""
|
||||
|
||||
client: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def _setup_class(cls):
|
||||
cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.client.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -254,7 +259,8 @@ class ClientUnitTest(unittest.TestCase):
|
||||
|
||||
def test_get_default_database(self):
|
||||
c = rs_or_single_client(
|
||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
|
||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
|
||||
connect=False,
|
||||
)
|
||||
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
||||
# Test that default doesn't override the URI value.
|
||||
@ -269,39 +275,49 @@ class ClientUnitTest(unittest.TestCase):
|
||||
self.assertEqual(write_concern, db.write_concern)
|
||||
|
||||
c = rs_or_single_client(
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||
connect=False,
|
||||
)
|
||||
self.assertEqual(Database(c, "foo"), c.get_default_database("foo"))
|
||||
|
||||
def test_get_default_database_error(self):
|
||||
# URI with no database.
|
||||
c = rs_or_single_client(
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||
connect=False,
|
||||
)
|
||||
self.assertRaises(ConfigurationError, c.get_default_database)
|
||||
|
||||
def test_get_default_database_with_authsource(self):
|
||||
# Ensure we distinguish database name from authSource.
|
||||
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
|
||||
uri = "mongodb://%s:%d/foo?authSource=src" % (
|
||||
client_context.host,
|
||||
client_context.port,
|
||||
)
|
||||
c = rs_or_single_client(uri, connect=False)
|
||||
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
||||
|
||||
def test_get_database_default(self):
|
||||
c = rs_or_single_client(
|
||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
|
||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
|
||||
connect=False,
|
||||
)
|
||||
self.assertEqual(Database(c, "foo"), c.get_database())
|
||||
|
||||
def test_get_database_default_error(self):
|
||||
# URI with no database.
|
||||
c = rs_or_single_client(
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
||||
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||
connect=False,
|
||||
)
|
||||
self.assertRaises(ConfigurationError, c.get_database)
|
||||
|
||||
def test_get_database_default_with_authsource(self):
|
||||
# Ensure we distinguish database name from authSource.
|
||||
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
|
||||
uri = "mongodb://%s:%d/foo?authSource=src" % (
|
||||
client_context.host,
|
||||
client_context.port,
|
||||
)
|
||||
c = rs_or_single_client(uri, connect=False)
|
||||
self.assertEqual(Database(c, "foo"), c.get_database())
|
||||
|
||||
@ -634,7 +650,7 @@ class TestClient(IntegrationTest):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
@ -645,7 +661,7 @@ class TestClient(IntegrationTest):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper removes idle socket and replaces it with a new one
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket, two
|
||||
@ -659,7 +675,7 @@ class TestClient(IntegrationTest):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper respects maxPoolSize when adding new connections.
|
||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
# When the reaper runs at the same time as the get_socket,
|
||||
@ -673,7 +689,7 @@ class TestClient(IntegrationTest):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
# Assert reaper has removed idle socket and NOT replaced it
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn_one:
|
||||
pass
|
||||
# Assert that the pool does not close connections prematurely.
|
||||
@ -690,12 +706,12 @@ class TestClient(IntegrationTest):
|
||||
def test_min_pool_size(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
self.assertEqual(0, len(server._pool.conns))
|
||||
|
||||
# Assert that pool started up at minPoolSize
|
||||
client = rs_or_single_client(minPoolSize=10)
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"pool initialized with 10 connections",
|
||||
@ -714,7 +730,7 @@ class TestClient(IntegrationTest):
|
||||
# Use high frequency to test _get_socket_no_auth.
|
||||
with client_knobs(kill_cursor_frequency=99999999):
|
||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
@ -728,7 +744,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
# Test that connections are reused if maxIdleTimeMS is not set.
|
||||
client = rs_or_single_client()
|
||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
||||
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||
with server._pool.checkout() as conn:
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
@ -793,12 +809,14 @@ class TestClient(IntegrationTest):
|
||||
|
||||
bad_host = "somedomainthatdoesntexist.org"
|
||||
c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
||||
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
c.pymongo_test.test.find_one()
|
||||
|
||||
def test_init_disconnected_with_auth(self):
|
||||
uri = "mongodb://user:pass@somedomainthatdoesntexist"
|
||||
c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
||||
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
c.pymongo_test.test.find_one()
|
||||
|
||||
def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
@ -816,7 +834,8 @@ class TestClient(IntegrationTest):
|
||||
self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False))
|
||||
# Same seeds but out of order still compares equal:
|
||||
self.assertEqual(
|
||||
MongoClient(["a", "b", "c"], connect=False), MongoClient(["c", "a", "b"], connect=False)
|
||||
MongoClient(["a", "b", "c"], connect=False),
|
||||
MongoClient(["c", "a", "b"], connect=False),
|
||||
)
|
||||
|
||||
def test_hashable(self):
|
||||
@ -830,9 +849,10 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_host_w_port(self):
|
||||
with self.assertRaises(ValueError):
|
||||
host = client_context.host
|
||||
connected(
|
||||
MongoClient(
|
||||
f"{client_context.host}:1234567",
|
||||
f"{host}:1234567",
|
||||
connectTimeoutMS=1,
|
||||
serverSelectionTimeoutMS=10,
|
||||
)
|
||||
@ -883,10 +903,10 @@ class TestClient(IntegrationTest):
|
||||
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
||||
|
||||
def test_list_databases(self):
|
||||
cmd_docs = self.client.admin.command("listDatabases")["databases"]
|
||||
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
|
||||
cursor = self.client.list_databases()
|
||||
self.assertIsInstance(cursor, CommandCursor)
|
||||
helper_docs = list(cursor)
|
||||
helper_docs = cursor.to_list()
|
||||
self.assertTrue(len(helper_docs) > 0)
|
||||
self.assertEqual(len(helper_docs), len(cmd_docs))
|
||||
# PYTHON-3529 Some fields may change between calls, just compare names.
|
||||
@ -900,7 +920,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
self.client.pymongo_test.test.insert_one({})
|
||||
cursor = self.client.list_databases(filter={"name": "admin"})
|
||||
docs = list(cursor)
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(docs[0]["name"], "admin")
|
||||
|
||||
@ -911,7 +931,7 @@ class TestClient(IntegrationTest):
|
||||
def test_list_database_names(self):
|
||||
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||
self.client.pymongo_test_mike.test.insert_one({"dummy": "object"})
|
||||
cmd_docs = self.client.admin.command("listDatabases")["databases"]
|
||||
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
|
||||
cmd_names = [doc["name"] for doc in cmd_docs]
|
||||
|
||||
db_names = self.client.list_database_names()
|
||||
@ -920,8 +940,10 @@ class TestClient(IntegrationTest):
|
||||
self.assertEqual(db_names, cmd_names)
|
||||
|
||||
def test_drop_database(self):
|
||||
self.assertRaises(TypeError, self.client.drop_database, 5)
|
||||
self.assertRaises(TypeError, self.client.drop_database, None)
|
||||
with self.assertRaises(TypeError):
|
||||
self.client.drop_database(5) # type: ignore[arg-type]
|
||||
with self.assertRaises(TypeError):
|
||||
self.client.drop_database(None) # type: ignore[arg-type]
|
||||
|
||||
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||
self.client.pymongo_test2.test.insert_one({"dummy": "object"})
|
||||
@ -944,7 +966,8 @@ class TestClient(IntegrationTest):
|
||||
test_client = rs_or_single_client()
|
||||
coll = test_client.pymongo_test.bar
|
||||
test_client.close()
|
||||
self.assertRaises(InvalidOperation, coll.count_documents, {})
|
||||
with self.assertRaises(InvalidOperation):
|
||||
coll.count_documents({})
|
||||
|
||||
def test_close_kills_cursors(self):
|
||||
if sys.platform.startswith("java"):
|
||||
@ -961,7 +984,7 @@ class TestClient(IntegrationTest):
|
||||
coll.insert_many([{"i": i} for i in range(docs_inserted)])
|
||||
|
||||
# Open a cursor and leave it open on the server.
|
||||
cursor = coll.find().batch_size(10)
|
||||
cursor = (coll.find()).batch_size(10)
|
||||
self.assertTrue(bool(next(cursor)))
|
||||
self.assertLess(cursor.retrieved, docs_inserted)
|
||||
|
||||
@ -992,7 +1015,8 @@ class TestClient(IntegrationTest):
|
||||
self.assertTrue(client._kill_cursors_executor._stopped)
|
||||
|
||||
# Reusing the closed client should raise an InvalidOperation error.
|
||||
self.assertRaises(InvalidOperation, client.admin.command, "ping")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
client.admin.command("ping")
|
||||
# Thread is still stopped.
|
||||
self.assertTrue(client._kill_cursors_executor._stopped)
|
||||
|
||||
@ -1065,8 +1089,10 @@ class TestClient(IntegrationTest):
|
||||
)
|
||||
|
||||
# Auth with lazy connection.
|
||||
rs_or_single_client_noauth(
|
||||
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
|
||||
(
|
||||
rs_or_single_client_noauth(
|
||||
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
|
||||
)
|
||||
).pymongo_test.test.find_one()
|
||||
|
||||
# Wrong password.
|
||||
@ -1074,7 +1100,8 @@ class TestClient(IntegrationTest):
|
||||
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False
|
||||
)
|
||||
|
||||
self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one)
|
||||
with self.assertRaises(OperationFailure):
|
||||
bad_client.pymongo_test.test.find_one()
|
||||
|
||||
@client_context.require_auth
|
||||
def test_username_and_password(self):
|
||||
@ -1093,13 +1120,14 @@ class TestClient(IntegrationTest):
|
||||
c.server_info()
|
||||
|
||||
with self.assertRaises(OperationFailure):
|
||||
rs_or_single_client_noauth(username="ad min", password="foo").server_info()
|
||||
(rs_or_single_client_noauth(username="ad min", password="foo")).server_info()
|
||||
|
||||
@client_context.require_auth
|
||||
@client_context.require_no_fips
|
||||
def test_lazy_auth_raises_operation_failure(self):
|
||||
host = client_context.host
|
||||
lazy_client = rs_or_single_client_noauth(
|
||||
f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False
|
||||
f"mongodb://user:wrong@{host}/pymongo_test", connect=False
|
||||
)
|
||||
|
||||
assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one)
|
||||
@ -1125,11 +1153,10 @@ class TestClient(IntegrationTest):
|
||||
self.assertTrue(mongodb_socket in repr(client))
|
||||
|
||||
# Confirm it fails with a missing socket.
|
||||
self.assertRaises(
|
||||
ConnectionFailure,
|
||||
connected,
|
||||
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
|
||||
)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
connected(
|
||||
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
|
||||
)
|
||||
|
||||
def test_document_class(self):
|
||||
c = self.client
|
||||
@ -1154,27 +1181,30 @@ class TestClient(IntegrationTest):
|
||||
maxIdleTimeMS=10500,
|
||||
serverSelectionTimeoutMS=10500,
|
||||
)
|
||||
self.assertEqual(10.5, get_pool(client).opts.connect_timeout)
|
||||
self.assertEqual(10.5, get_pool(client).opts.socket_timeout)
|
||||
self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds)
|
||||
self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout)
|
||||
self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout)
|
||||
self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds)
|
||||
self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds)
|
||||
self.assertEqual(10.5, client.options.server_selection_timeout)
|
||||
|
||||
def test_socket_timeout_ms_validation(self):
|
||||
c = rs_or_single_client(socketTimeoutMS=10 * 1000)
|
||||
self.assertEqual(10, get_pool(c).opts.socket_timeout)
|
||||
self.assertEqual(10, (get_pool(c)).opts.socket_timeout)
|
||||
|
||||
c = connected(rs_or_single_client(socketTimeoutMS=None))
|
||||
self.assertEqual(None, get_pool(c).opts.socket_timeout)
|
||||
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
|
||||
|
||||
c = connected(rs_or_single_client(socketTimeoutMS=0))
|
||||
self.assertEqual(None, get_pool(c).opts.socket_timeout)
|
||||
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
|
||||
|
||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
rs_or_single_client(socketTimeoutMS=-1)
|
||||
|
||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10)
|
||||
with self.assertRaises(ValueError):
|
||||
rs_or_single_client(socketTimeoutMS=1e10)
|
||||
|
||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS="foo")
|
||||
with self.assertRaises(ValueError):
|
||||
rs_or_single_client(socketTimeoutMS="foo")
|
||||
|
||||
def test_socket_timeout(self):
|
||||
no_timeout = self.client
|
||||
@ -1189,11 +1219,12 @@ class TestClient(IntegrationTest):
|
||||
where_func = delay(timeout_sec + 1)
|
||||
|
||||
def get_x(db):
|
||||
doc = next(db.test.find().where(where_func))
|
||||
doc = next((db.test.find()).where(where_func))
|
||||
return doc["x"]
|
||||
|
||||
self.assertEqual(1, get_x(no_timeout.pymongo_test))
|
||||
self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test)
|
||||
with self.assertRaises(NetworkTimeout):
|
||||
get_x(timeout.pymongo_test)
|
||||
|
||||
def test_server_selection_timeout(self):
|
||||
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||
@ -1224,7 +1255,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
def test_waitQueueTimeoutMS(self):
|
||||
client = rs_or_single_client(waitQueueTimeoutMS=2000)
|
||||
self.assertEqual(get_pool(client).opts.wait_queue_timeout, 2)
|
||||
self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2)
|
||||
|
||||
def test_socketKeepAlive(self):
|
||||
pool = get_pool(self.client)
|
||||
@ -1244,11 +1275,11 @@ class TestClient(IntegrationTest):
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
aware.pymongo_test.test.insert_one({"x": now})
|
||||
|
||||
self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo)
|
||||
self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo)
|
||||
self.assertEqual(None, (naive.pymongo_test.test.find_one())["x"].tzinfo)
|
||||
self.assertEqual(utc, (aware.pymongo_test.test.find_one())["x"].tzinfo)
|
||||
self.assertEqual(
|
||||
aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
|
||||
naive.pymongo_test.test.find_one()["x"],
|
||||
(aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None),
|
||||
(naive.pymongo_test.test.find_one())["x"],
|
||||
)
|
||||
|
||||
@client_context.require_ipv6
|
||||
@ -1282,18 +1313,21 @@ class TestClient(IntegrationTest):
|
||||
|
||||
# The socket used for the previous commands has been returned to the
|
||||
# pool
|
||||
self.assertEqual(1, len(get_pool(client).conns))
|
||||
self.assertEqual(1, len((get_pool(client)).conns))
|
||||
|
||||
with contextlib.closing(client):
|
||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
client.pymongo_test.test.find_one()
|
||||
client = rs_or_single_client()
|
||||
with client as client:
|
||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
client.pymongo_test.test.find_one()
|
||||
# contextlib async support was added in Python 3.10
|
||||
if _IS_SYNC or sys.version_info >= (3, 10):
|
||||
with contextlib.closing(client):
|
||||
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
client.pymongo_test.test.find_one()
|
||||
client = rs_or_single_client()
|
||||
with client as client:
|
||||
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
client.pymongo_test.test.find_one()
|
||||
|
||||
@client_context.require_sync
|
||||
def test_interrupt_signal(self):
|
||||
if sys.platform.startswith("java"):
|
||||
# We can't figure out how to raise an exception on a thread that's
|
||||
@ -1344,7 +1378,7 @@ class TestClient(IntegrationTest):
|
||||
raised = False
|
||||
try:
|
||||
# Will be interrupted by a KeyboardInterrupt.
|
||||
next(db.foo.find({"$where": where}))
|
||||
next(db.foo.find({"$where": where})) # type: ignore[call-overload]
|
||||
except KeyboardInterrupt:
|
||||
raised = True
|
||||
|
||||
@ -1355,7 +1389,7 @@ class TestClient(IntegrationTest):
|
||||
# Raises AssertionError due to PYTHON-294 -- Mongo's response to
|
||||
# the previous find() is still waiting to be read on the socket,
|
||||
# so the request id's don't match.
|
||||
self.assertEqual({"_id": 1}, next(db.foo.find()))
|
||||
self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload]
|
||||
finally:
|
||||
if old_signal_handler:
|
||||
signal.signal(signal.SIGALRM, old_signal_handler)
|
||||
@ -1374,7 +1408,8 @@ class TestClient(IntegrationTest):
|
||||
old_conn = next(iter(pool.conns))
|
||||
client.pymongo_test.test.drop()
|
||||
client.pymongo_test.test.insert_one({"_id": "foo"})
|
||||
self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"})
|
||||
with self.assertRaises(OperationFailure):
|
||||
client.pymongo_test.test.insert_one({"_id": "foo"})
|
||||
|
||||
self.assertEqual(socket_count, len(pool.conns))
|
||||
new_con = next(iter(pool.conns))
|
||||
@ -1392,23 +1427,29 @@ class TestClient(IntegrationTest):
|
||||
client = rs_or_single_client(connect=False, w=0)
|
||||
self.addCleanup(client.close)
|
||||
client.test_lazy_connect_w0.test.insert_one({})
|
||||
wait_until(
|
||||
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document"
|
||||
)
|
||||
|
||||
def predicate():
|
||||
return client.test_lazy_connect_w0.test.count_documents({}) == 1
|
||||
|
||||
wait_until(predicate, "find one document")
|
||||
|
||||
client = rs_or_single_client(connect=False, w=0)
|
||||
self.addCleanup(client.close)
|
||||
client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
|
||||
wait_until(
|
||||
lambda: client.test_lazy_connect_w0.test.find_one().get("x") == 1, "update one document"
|
||||
)
|
||||
|
||||
def predicate():
|
||||
return (client.test_lazy_connect_w0.test.find_one()).get("x") == 1
|
||||
|
||||
wait_until(predicate, "update one document")
|
||||
|
||||
client = rs_or_single_client(connect=False, w=0)
|
||||
self.addCleanup(client.close)
|
||||
client.test_lazy_connect_w0.test.delete_one({})
|
||||
wait_until(
|
||||
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document"
|
||||
)
|
||||
|
||||
def predicate():
|
||||
return client.test_lazy_connect_w0.test.count_documents({}) == 0
|
||||
|
||||
wait_until(predicate, "delete one document")
|
||||
|
||||
@client_context.require_no_mongos
|
||||
def test_exhaust_network_error(self):
|
||||
@ -1445,12 +1486,13 @@ class TestClient(IntegrationTest):
|
||||
|
||||
# Cause a network error on the actual socket.
|
||||
pool = get_pool(c)
|
||||
socket_info = one(pool.conns)
|
||||
socket_info.conn.close()
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
|
||||
# Connection.authenticate logs, but gets a socket.error. Should be
|
||||
# reraised as AutoReconnect.
|
||||
self.assertRaises(AutoReconnect, c.test.collection.find_one)
|
||||
with self.assertRaises(AutoReconnect):
|
||||
c.test.collection.find_one()
|
||||
|
||||
# No semaphore leak, the pool is allowed to make a new socket.
|
||||
c.test.collection.find_one()
|
||||
@ -1638,13 +1680,19 @@ class TestClient(IntegrationTest):
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
def run(self):
|
||||
def _run(self):
|
||||
while self.running:
|
||||
exc = AutoReconnect("mock pool error")
|
||||
ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None)
|
||||
client._topology.handle_error(pool.address, ctx)
|
||||
time.sleep(0.001)
|
||||
|
||||
def run(self):
|
||||
if _IS_SYNC:
|
||||
self._run()
|
||||
else:
|
||||
asyncio.run(self._run())
|
||||
|
||||
t = ResetPoolThread(pool)
|
||||
t.start()
|
||||
|
||||
@ -1755,7 +1803,7 @@ class TestClient(IntegrationTest):
|
||||
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
|
||||
):
|
||||
assert client.address is not None
|
||||
expected = "{}:{}: ".format(*client.address)
|
||||
expected = "{}:{}: ".format(*(client.address))
|
||||
with self.assertRaisesRegex(AutoReconnect, expected):
|
||||
client.pymongo_test.test.find_one({})
|
||||
|
||||
@ -1814,6 +1862,7 @@ class TestClient(IntegrationTest):
|
||||
"loadBalanced clients do not run SDAM",
|
||||
)
|
||||
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
|
||||
@client_context.require_sync
|
||||
def test_sigstop_sigcont(self):
|
||||
test_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
script = os.path.join(test_dir, "sigstop_sigcont.py")
|
||||
@ -1961,7 +2010,8 @@ class TestExhaustCursor(IntegrationTest):
|
||||
SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST
|
||||
)
|
||||
|
||||
self.assertRaises(OperationFailure, cursor.next)
|
||||
with self.assertRaises(OperationFailure):
|
||||
cursor.next()
|
||||
self.assertFalse(conn.closed)
|
||||
|
||||
# The socket was checked in and the semaphore was decremented.
|
||||
@ -1998,7 +2048,8 @@ class TestExhaustCursor(IntegrationTest):
|
||||
return message._OpReply.unpack(msg)
|
||||
|
||||
conn.receive_message = receive_message
|
||||
self.assertRaises(OperationFailure, list, cursor)
|
||||
with self.assertRaises(OperationFailure):
|
||||
cursor.to_list()
|
||||
# Unpatch the instance.
|
||||
del conn.receive_message
|
||||
|
||||
@ -2019,7 +2070,8 @@ class TestExhaustCursor(IntegrationTest):
|
||||
conn.conn.close()
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
self.assertRaises(ConnectionFailure, cursor.next)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
cursor.next()
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
# The socket was closed and the semaphore was decremented.
|
||||
@ -2046,7 +2098,8 @@ class TestExhaustCursor(IntegrationTest):
|
||||
conn.conn.close()
|
||||
|
||||
# A getmore fails.
|
||||
self.assertRaises(ConnectionFailure, list, cursor)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
cursor.to_list()
|
||||
self.assertTrue(conn.closed)
|
||||
|
||||
wait_until(
|
||||
@ -2057,6 +2110,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
self.assertNotIn(conn, pool.conns)
|
||||
self.assertEqual(0, pool.requests)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_gevent_task(self):
|
||||
if not gevent_monkey_patched():
|
||||
raise SkipTest("Must be running monkey patched by gevent")
|
||||
@ -2070,6 +2124,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
task.kill()
|
||||
self.assertTrue(task.dead)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_gevent_timeout(self):
|
||||
if not gevent_monkey_patched():
|
||||
raise SkipTest("Must be running monkey patched by gevent")
|
||||
@ -2101,6 +2156,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
self.assertIsNone(tt.get())
|
||||
self.assertIsNone(ct.get())
|
||||
|
||||
@client_context.require_sync
|
||||
def test_gevent_timeout_when_creating_connection(self):
|
||||
if not gevent_monkey_patched():
|
||||
raise SkipTest("Must be running monkey patched by gevent")
|
||||
@ -2145,6 +2201,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
def _get_client(self):
|
||||
return rs_or_single_client(connect=False)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_insert_one(self):
|
||||
def reset(collection):
|
||||
collection.drop()
|
||||
@ -2157,6 +2214,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
|
||||
lazy_client_trial(reset, insert_one, test, self._get_client)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_update_one(self):
|
||||
def reset(collection):
|
||||
collection.drop()
|
||||
@ -2171,6 +2229,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
|
||||
lazy_client_trial(reset, update_one, test, self._get_client)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_delete_one(self):
|
||||
def reset(collection):
|
||||
collection.drop()
|
||||
@ -2184,6 +2243,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
|
||||
lazy_client_trial(reset, delete_one, test, self._get_client)
|
||||
|
||||
@client_context.require_sync
|
||||
def test_find_one(self):
|
||||
results: list = []
|
||||
|
||||
@ -2203,7 +2263,7 @@ class TestClientLazyConnect(IntegrationTest):
|
||||
|
||||
class TestMongoClientFailover(MockClientTest):
|
||||
def test_discover_primary(self):
|
||||
c = MockClient(
|
||||
c = MockClient.get_mock_client(
|
||||
standalones=[],
|
||||
members=["a:1", "b:2", "c:3"],
|
||||
mongoses=[],
|
||||
@ -2219,13 +2279,17 @@ class TestMongoClientFailover(MockClientTest):
|
||||
# Fail over.
|
||||
c.kill_host("a:1")
|
||||
c.mock_primary = "b:2"
|
||||
wait_until(lambda: c.address == ("b", 2), "wait for server address to be updated")
|
||||
|
||||
def predicate():
|
||||
return (c.address) == ("b", 2)
|
||||
|
||||
wait_until(predicate, "wait for server address to be updated")
|
||||
# a:1 not longer in nodes.
|
||||
self.assertLess(len(c.nodes), 3)
|
||||
|
||||
def test_reconnect(self):
|
||||
# Verify the node list isn't forgotten during a network failure.
|
||||
c = MockClient(
|
||||
c = MockClient.get_mock_client(
|
||||
standalones=[],
|
||||
members=["a:1", "b:2", "c:3"],
|
||||
mongoses=[],
|
||||
@ -2245,12 +2309,13 @@ class TestMongoClientFailover(MockClientTest):
|
||||
|
||||
# MongoClient discovers it's alone. The first attempt raises either
|
||||
# ServerSelectionTimeoutError or AutoReconnect (from
|
||||
# MockPool.get_socket).
|
||||
self.assertRaises(AutoReconnect, c.db.collection.find_one)
|
||||
# AsyncMockPool.get_socket).
|
||||
with self.assertRaises(AutoReconnect):
|
||||
c.db.collection.find_one()
|
||||
|
||||
# But it can reconnect.
|
||||
c.revive_host("a:1")
|
||||
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
|
||||
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
self.assertEqual(c.address, ("a", 1))
|
||||
|
||||
def _test_network_error(self, operation_callback):
|
||||
@ -2273,7 +2338,7 @@ class TestMongoClientFailover(MockClientTest):
|
||||
# Set host-specific information so we can test whether it is reset.
|
||||
c.set_wire_version_range("a:1", 2, 6)
|
||||
c.set_wire_version_range("b:2", 2, 7)
|
||||
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
|
||||
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||
wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||
|
||||
c.kill_host("a:1")
|
||||
@ -2281,17 +2346,18 @@ class TestMongoClientFailover(MockClientTest):
|
||||
# MongoClient is disconnected from the primary. This raises either
|
||||
# ServerSelectionTimeoutError or AutoReconnect (from
|
||||
# MockPool.get_socket).
|
||||
self.assertRaises(AutoReconnect, operation_callback, c)
|
||||
with self.assertRaises(AutoReconnect):
|
||||
operation_callback(c)
|
||||
|
||||
# The primary's description is reset.
|
||||
server_a = c._get_topology().get_server_by_address(("a", 1))
|
||||
server_a = (c._get_topology()).get_server_by_address(("a", 1))
|
||||
sd_a = server_a.description
|
||||
self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type)
|
||||
self.assertEqual(0, sd_a.min_wire_version)
|
||||
self.assertEqual(0, sd_a.max_wire_version)
|
||||
|
||||
# ...but not the secondary's.
|
||||
server_b = c._get_topology().get_server_by_address(("b", 2))
|
||||
server_b = (c._get_topology()).get_server_by_address(("b", 2))
|
||||
sd_b = server_b.description
|
||||
self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type)
|
||||
self.assertEqual(2, sd_b.min_wire_version)
|
||||
@ -2332,7 +2398,7 @@ class TestClientPool(MockClientTest):
|
||||
@client_context.require_connection
|
||||
def test_rs_client_does_not_maintain_pool_to_arbiters(self):
|
||||
listener = CMAPListener()
|
||||
c = MockClient(
|
||||
c = MockClient.get_mock_client(
|
||||
standalones=[],
|
||||
members=["a:1", "b:2", "c:3", "d:4"],
|
||||
mongoses=[],
|
||||
@ -2363,7 +2429,7 @@ class TestClientPool(MockClientTest):
|
||||
@client_context.require_connection
|
||||
def test_direct_client_maintains_pool_to_arbiter(self):
|
||||
listener = CMAPListener()
|
||||
c = MockClient(
|
||||
c = MockClient.get_mock_client(
|
||||
standalones=[],
|
||||
members=["a:1", "b:2", "c:3"],
|
||||
mongoses=[],
|
||||
|
||||
@ -20,8 +20,8 @@ import uuid
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import connected, rs_or_single_client, single_client
|
||||
from test import IntegrationTest, client_context, connected, unittest
|
||||
from test.utils import rs_or_single_client, single_client
|
||||
|
||||
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
|
||||
@ -20,13 +20,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, unittest
|
||||
from test import IntegrationTest, drop_collections, unittest
|
||||
from test.utils import (
|
||||
SpecTestCreator,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
camel_to_upper_camel,
|
||||
drop_collections,
|
||||
)
|
||||
|
||||
from pymongo import WriteConcern, operations
|
||||
|
||||
@ -22,9 +22,9 @@ from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import MockClientTest, client_context, unittest
|
||||
from test import MockClientTest, client_context, connected, unittest
|
||||
from test.pymongo_mocks import MockClient
|
||||
from test.utils import connected, wait_until
|
||||
from test.utils import wait_until
|
||||
|
||||
from pymongo.errors import AutoReconnect, InvalidOperation
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
@ -22,10 +22,9 @@ from functools import partial
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, unittest
|
||||
from test import IntegrationTest, connected, unittest
|
||||
from test.utils import (
|
||||
ServerAndTopologyEventListener,
|
||||
connected,
|
||||
single_client,
|
||||
wait_until,
|
||||
)
|
||||
|
||||
@ -26,10 +26,9 @@ from pymongo.operations import _Op
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, SkipTest, client_context, unittest
|
||||
from test import IntegrationTest, SkipTest, client_context, connected, unittest
|
||||
from test.utils import (
|
||||
OvertCommandListener,
|
||||
connected,
|
||||
one,
|
||||
rs_client,
|
||||
single_client,
|
||||
|
||||
@ -21,13 +21,19 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import HAVE_IPADDRESS, IntegrationTest, SkipTest, client_context, unittest
|
||||
from test import (
|
||||
HAVE_IPADDRESS,
|
||||
IntegrationTest,
|
||||
SkipTest,
|
||||
client_context,
|
||||
connected,
|
||||
remove_all_users,
|
||||
unittest,
|
||||
)
|
||||
from test.utils import (
|
||||
EventListener,
|
||||
cat_files,
|
||||
connected,
|
||||
ignore_deprecations,
|
||||
remove_all_users,
|
||||
)
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
|
||||
@ -621,7 +621,10 @@ async def _async_mongo_client(host, port, authenticate=True, directConnection=No
|
||||
):
|
||||
client_options["username"] = db_user
|
||||
client_options["password"] = db_pwd
|
||||
return AsyncMongoClient(uri, port, **client_options)
|
||||
client = AsyncMongoClient(uri, port, **client_options)
|
||||
if client._options.connect:
|
||||
await client.aconnect()
|
||||
return client
|
||||
|
||||
|
||||
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||
@ -843,16 +846,6 @@ def server_started_with_auth(client):
|
||||
return "--auth" in argv or "--keyFile" in argv
|
||||
|
||||
|
||||
def drop_collections(db):
|
||||
# Drop all non-system collections in this database.
|
||||
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
||||
db.drop_collection(coll)
|
||||
|
||||
|
||||
def remove_all_users(db):
|
||||
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
|
||||
|
||||
|
||||
def joinall(threads):
|
||||
"""Join threads with a 5-minute timeout, assert joins succeeded"""
|
||||
for t in threads:
|
||||
@ -860,17 +853,6 @@ def joinall(threads):
|
||||
assert not t.is_alive(), "Thread %s hung" % t
|
||||
|
||||
|
||||
def connected(client):
|
||||
"""Convenience to wait for a newly-constructed client to connect."""
|
||||
with warnings.catch_warnings():
|
||||
# Ignore warning that ping is always routed to primary even
|
||||
# if client's read preference isn't PRIMARY.
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
client.admin.command("ping") # Force connection.
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def wait_until(predicate, success_description, timeout=10):
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
@ -957,6 +939,20 @@ def assertRaisesExactly(cls, fn, *args, **kwargs):
|
||||
raise AssertionError("%s not raised" % cls)
|
||||
|
||||
|
||||
async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs):
|
||||
"""
|
||||
Unlike the standard assertRaises, this checks that a function raises a
|
||||
specific class of exception, and not a subclass. E.g., check that
|
||||
MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect.
|
||||
"""
|
||||
try:
|
||||
await fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}"
|
||||
else:
|
||||
raise AssertionError("%s not raised" % cls)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _ignore_deprecations():
|
||||
with warnings.catch_warnings():
|
||||
|
||||
@ -72,11 +72,20 @@ replacements = {
|
||||
"addAsyncCleanup": "addCleanup",
|
||||
"async_setup_class": "setup_class",
|
||||
"IsolatedAsyncioTestCase": "TestCase",
|
||||
"AsyncUnitTest": "UnitTest",
|
||||
"AsyncMockClient": "MockClient",
|
||||
"async_get_pool": "get_pool",
|
||||
"async_is_mongos": "is_mongos",
|
||||
"async_rs_or_single_client": "rs_or_single_client",
|
||||
"async_rs_or_single_client_noauth": "rs_or_single_client_noauth",
|
||||
"async_rs_client": "rs_client",
|
||||
"async_single_client": "single_client",
|
||||
"async_from_client": "from_client",
|
||||
"aclosing": "closing",
|
||||
"asyncAssertRaisesExactly": "assertRaisesExactly",
|
||||
"get_async_mock_client": "get_mock_client",
|
||||
"aconnect": "_connect",
|
||||
"aclose": "close",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -131,6 +140,8 @@ sync_gridfs_files = [
|
||||
converted_tests = [
|
||||
"__init__.py",
|
||||
"conftest.py",
|
||||
"pymongo_mocks.py",
|
||||
"test_client.py",
|
||||
"test_collection.py",
|
||||
"test_database.py",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user