PYTHON-4533 - Convert test/test_client.py to async (#1730)
This commit is contained in:
parent
554ce7d984
commit
d0193eb045
@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
|||||||
self.client_ref = None
|
self.client_ref = None
|
||||||
self.key_vault_coll = None
|
self.key_vault_coll = None
|
||||||
if self.mongocryptd_client:
|
if self.mongocryptd_client:
|
||||||
await self.mongocryptd_client.close()
|
await self.mongocryptd_client.aclose()
|
||||||
self.mongocryptd_client = None
|
self.mongocryptd_client = None
|
||||||
|
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class _Encrypter:
|
|||||||
self._closed = True
|
self._closed = True
|
||||||
await self._auto_encrypter.close()
|
await self._auto_encrypter.close()
|
||||||
if self._internal_client:
|
if self._internal_client:
|
||||||
await self._internal_client.close()
|
await self._internal_client.aclose()
|
||||||
self._internal_client = None
|
self._internal_client = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -861,6 +861,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
# This will be used later if we fork.
|
# This will be used later if we fork.
|
||||||
AsyncMongoClient._clients[self._topology._topology_id] = self
|
AsyncMongoClient._clients[self._topology._topology_id] = self
|
||||||
|
|
||||||
|
async def aconnect(self) -> None:
|
||||||
|
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
|
||||||
|
await self._get_topology()
|
||||||
|
|
||||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||||
self._topology = Topology(self._topology_settings)
|
self._topology = Topology(self._topology_settings)
|
||||||
# Seed the topology with the old one's pid so we can detect clients
|
# Seed the topology with the old one's pid so we can detect clients
|
||||||
@ -1354,13 +1358,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
await self.close()
|
await self.aclose()
|
||||||
|
|
||||||
# See PYTHON-3084.
|
# See PYTHON-3084.
|
||||||
__iter__ = None
|
__iter__ = None
|
||||||
|
|
||||||
def __next__(self) -> NoReturn:
|
def __next__(self) -> NoReturn:
|
||||||
raise TypeError("'MongoClient' object is not iterable")
|
raise TypeError("'AsyncMongoClient' object is not iterable")
|
||||||
|
|
||||||
next = __next__
|
next = __next__
|
||||||
|
|
||||||
@ -1490,7 +1494,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
# command.
|
# command.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def aclose(self) -> None:
|
||||||
"""Cleanup client resources and disconnect from MongoDB.
|
"""Cleanup client resources and disconnect from MongoDB.
|
||||||
|
|
||||||
End all server sessions created by this client by sending one or more
|
End all server sessions created by this client by sending one or more
|
||||||
|
|||||||
@ -860,6 +860,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
# This will be used later if we fork.
|
# This will be used later if we fork.
|
||||||
MongoClient._clients[self._topology._topology_id] = self
|
MongoClient._clients[self._topology._topology_id] = self
|
||||||
|
|
||||||
|
def _connect(self) -> None:
|
||||||
|
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
|
||||||
|
self._get_topology()
|
||||||
|
|
||||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||||
self._topology = Topology(self._topology_settings)
|
self._topology = Topology(self._topology_settings)
|
||||||
# Seed the topology with the old one's pid so we can detect clients
|
# Seed the topology with the old one's pid so we can detect clients
|
||||||
|
|||||||
107
test/__init__.py
107
test/__init__.py
@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
@ -39,8 +40,6 @@ from test.helpers import (
|
|||||||
TEST_SERVERLESS,
|
TEST_SERVERLESS,
|
||||||
TLS_OPTIONS,
|
TLS_OPTIONS,
|
||||||
SystemCertsPatcher,
|
SystemCertsPatcher,
|
||||||
_all_users,
|
|
||||||
_create_user,
|
|
||||||
client_knobs,
|
client_knobs,
|
||||||
db_pwd,
|
db_pwd,
|
||||||
db_user,
|
db_user,
|
||||||
@ -62,9 +61,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAVE_IPADDRESS = False
|
HAVE_IPADDRESS = False
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import partial, wraps
|
||||||
from test.version import Version
|
from test.version import Version
|
||||||
from typing import Any, Callable, Dict, Generator
|
from typing import Any, Callable, Dict, Generator, overload
|
||||||
from unittest import SkipTest
|
from unittest import SkipTest
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
@ -812,6 +811,12 @@ class ClientContext:
|
|||||||
func=func,
|
func=func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def require_sync(self, func):
|
||||||
|
"""Run a test only if using the synchronous API."""
|
||||||
|
return self._require(
|
||||||
|
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||||
|
)
|
||||||
|
|
||||||
def mongos_seeds(self):
|
def mongos_seeds(self):
|
||||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||||
|
|
||||||
@ -919,6 +924,32 @@ class PyMongoTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(proc.exitcode, 0)
|
self.assertEqual(proc.exitcode, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class UnitTest(PyMongoTestCase):
|
||||||
|
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._setup_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _setup_class(cls):
|
||||||
|
cls._setup_class()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tearDown_class(cls):
|
||||||
|
cls._tearDown_class()
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTest(PyMongoTestCase):
|
class IntegrationTest(PyMongoTestCase):
|
||||||
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
||||||
|
|
||||||
@ -933,6 +964,13 @@ class IntegrationTest(PyMongoTestCase):
|
|||||||
else:
|
else:
|
||||||
asyncio.run(cls._setup_class())
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def _setup_class(cls):
|
def _setup_class(cls):
|
||||||
@ -947,6 +985,10 @@ class IntegrationTest(PyMongoTestCase):
|
|||||||
else:
|
else:
|
||||||
cls.credentials = {}
|
cls.credentials = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tearDown_class(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
def cleanup_colls(self, *collections):
|
def cleanup_colls(self, *collections):
|
||||||
"""Cleanup collections faster than drop_collection."""
|
"""Cleanup collections faster than drop_collection."""
|
||||||
for c in collections:
|
for c in collections:
|
||||||
@ -959,7 +1001,7 @@ class IntegrationTest(PyMongoTestCase):
|
|||||||
self.addCleanup(patcher.disable)
|
self.addCleanup(patcher.disable)
|
||||||
|
|
||||||
|
|
||||||
class MockClientTest(unittest.TestCase):
|
class MockClientTest(UnitTest):
|
||||||
"""Base class for TestCases that use MockClient.
|
"""Base class for TestCases that use MockClient.
|
||||||
|
|
||||||
This class is *not* an IntegrationTest: if properly written, MockClient
|
This class is *not* an IntegrationTest: if properly written, MockClient
|
||||||
@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
|
|||||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||||
# with loadBalanced=True.
|
# with loadBalanced=True.
|
||||||
@classmethod
|
@classmethod
|
||||||
@client_context.require_no_load_balancer
|
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._setup_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@client_context.require_no_load_balancer
|
||||||
|
def _setup_class(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tearDown_class(cls):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1051,3 +1111,38 @@ def print_running_clients():
|
|||||||
processed.add(obj._topology_id)
|
processed.add(obj._topology_id)
|
||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _all_users(db):
|
||||||
|
return {u["user"] for u in (db.command("usersInfo")).get("users", [])}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||||
|
cmd = SON([("createUser", user)])
|
||||||
|
# X509 doesn't use a password
|
||||||
|
if pwd:
|
||||||
|
cmd["pwd"] = pwd
|
||||||
|
cmd["roles"] = roles or ["root"]
|
||||||
|
cmd.update(**kwargs)
|
||||||
|
return authdb.command(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def connected(client):
|
||||||
|
"""Convenience to wait for a newly-constructed client to connect."""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# Ignore warning that ping is always routed to primary even
|
||||||
|
# if client's read preference isn't PRIMARY.
|
||||||
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
client.admin.command("ping") # Force connection.
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def drop_collections(db: Database):
|
||||||
|
# Drop all non-system collections in this database.
|
||||||
|
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
||||||
|
db.drop_collection(coll)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_all_users(db: Database):
|
||||||
|
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
@ -39,8 +40,6 @@ from test.helpers import (
|
|||||||
TEST_SERVERLESS,
|
TEST_SERVERLESS,
|
||||||
TLS_OPTIONS,
|
TLS_OPTIONS,
|
||||||
SystemCertsPatcher,
|
SystemCertsPatcher,
|
||||||
_all_users,
|
|
||||||
_create_user,
|
|
||||||
client_knobs,
|
client_knobs,
|
||||||
db_pwd,
|
db_pwd,
|
||||||
db_user,
|
db_user,
|
||||||
@ -62,9 +61,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAVE_IPADDRESS = False
|
HAVE_IPADDRESS = False
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from functools import wraps
|
from functools import partial, wraps
|
||||||
from test.version import Version
|
from test.version import Version
|
||||||
from typing import Any, Callable, Dict, Generator
|
from typing import Any, Callable, Dict, Generator, overload
|
||||||
from unittest import SkipTest
|
from unittest import SkipTest
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
@ -184,7 +183,7 @@ class AsyncClientContext:
|
|||||||
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
|
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.aclose()
|
||||||
|
|
||||||
async def _init_client(self):
|
async def _init_client(self):
|
||||||
self.client = await self._connect(host, port)
|
self.client = await self._connect(host, port)
|
||||||
@ -229,7 +228,7 @@ class AsyncClientContext:
|
|||||||
if not self.serverless and not IS_SRV:
|
if not self.serverless and not IS_SRV:
|
||||||
# See if db_user already exists.
|
# See if db_user already exists.
|
||||||
if not await self._check_user_provided():
|
if not await self._check_user_provided():
|
||||||
_create_user(self.client.admin, db_user, db_pwd)
|
await _create_user(self.client.admin, db_user, db_pwd)
|
||||||
|
|
||||||
self.client = await self._connect(
|
self.client = await self._connect(
|
||||||
host,
|
host,
|
||||||
@ -304,7 +303,7 @@ class AsyncClientContext:
|
|||||||
params = self.cmd_line["parsed"].get("setParameter", {})
|
params = self.cmd_line["parsed"].get("setParameter", {})
|
||||||
if params.get("enableTestCommands") == "1":
|
if params.get("enableTestCommands") == "1":
|
||||||
self.test_commands_enabled = True
|
self.test_commands_enabled = True
|
||||||
self.has_ipv6 = self._server_started_with_ipv6()
|
self.has_ipv6 = await self._server_started_with_ipv6()
|
||||||
|
|
||||||
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
|
self.is_mongos = (await self.hello).get("msg") == "isdbgrid"
|
||||||
if self.is_mongos:
|
if self.is_mongos:
|
||||||
@ -390,7 +389,7 @@ class AsyncClientContext:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return db_user in _all_users(client.admin)
|
return db_user in await _all_users(client.admin)
|
||||||
except pymongo.errors.OperationFailure as e:
|
except pymongo.errors.OperationFailure as e:
|
||||||
assert e.details is not None
|
assert e.details is not None
|
||||||
msg = e.details.get("errmsg", "")
|
msg = e.details.get("errmsg", "")
|
||||||
@ -400,7 +399,7 @@ class AsyncClientContext:
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.aclose()
|
||||||
|
|
||||||
def _server_started_with_auth(self):
|
def _server_started_with_auth(self):
|
||||||
# MongoDB >= 2.0
|
# MongoDB >= 2.0
|
||||||
@ -482,9 +481,9 @@ class AsyncClientContext:
|
|||||||
return decorate
|
return decorate
|
||||||
return make_wrapper(func)
|
return make_wrapper(func)
|
||||||
|
|
||||||
def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
async def create_user(self, dbname, user, pwd=None, roles=None, **kwargs):
|
||||||
kwargs["writeConcern"] = {"w": self.w}
|
kwargs["writeConcern"] = {"w": self.w}
|
||||||
return _create_user(self.client[dbname], user, pwd, roles, **kwargs)
|
return await _create_user(self.client[dbname], user, pwd, roles, **kwargs)
|
||||||
|
|
||||||
async def drop_user(self, dbname, user):
|
async def drop_user(self, dbname, user):
|
||||||
await self.client[dbname].command("dropUser", user, writeConcern={"w": self.w})
|
await self.client[dbname].command("dropUser", user, writeConcern={"w": self.w})
|
||||||
@ -814,6 +813,12 @@ class AsyncClientContext:
|
|||||||
func=func,
|
func=func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def require_sync(self, func):
|
||||||
|
"""Run a test only if using the synchronous API."""
|
||||||
|
return self._require(
|
||||||
|
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||||
|
)
|
||||||
|
|
||||||
def mongos_seeds(self):
|
def mongos_seeds(self):
|
||||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||||
|
|
||||||
@ -921,6 +926,32 @@ class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(proc.exitcode, 0)
|
self.assertEqual(proc.exitcode, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||||
|
"""Async base class for TestCases that don't require a connection to MongoDB."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._setup_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _setup_class(cls):
|
||||||
|
await cls._setup_class()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _tearDown_class(cls):
|
||||||
|
await cls._tearDown_class()
|
||||||
|
|
||||||
|
|
||||||
class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||||
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
"""Async base class for TestCases that need a connection to MongoDB to pass."""
|
||||||
|
|
||||||
@ -935,6 +966,13 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
|||||||
else:
|
else:
|
||||||
asyncio.run(cls._setup_class())
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_client_context.require_connection
|
@async_client_context.require_connection
|
||||||
async def _setup_class(cls):
|
async def _setup_class(cls):
|
||||||
@ -949,6 +987,10 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
|||||||
else:
|
else:
|
||||||
cls.credentials = {}
|
cls.credentials = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _tearDown_class(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
async def cleanup_colls(self, *collections):
|
async def cleanup_colls(self, *collections):
|
||||||
"""Cleanup collections faster than drop_collection."""
|
"""Cleanup collections faster than drop_collection."""
|
||||||
for c in collections:
|
for c in collections:
|
||||||
@ -961,7 +1003,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
|||||||
self.addCleanup(patcher.disable)
|
self.addCleanup(patcher.disable)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMockClientTest(unittest.TestCase):
|
class AsyncMockClientTest(AsyncUnitTest):
|
||||||
"""Base class for TestCases that use MockClient.
|
"""Base class for TestCases that use MockClient.
|
||||||
|
|
||||||
This class is *not* an IntegrationTest: if properly written, MockClient
|
This class is *not* an IntegrationTest: if properly written, MockClient
|
||||||
@ -974,8 +1016,26 @@ class AsyncMockClientTest(unittest.TestCase):
|
|||||||
# multiple seed addresses, or wait for heartbeat events are incompatible
|
# multiple seed addresses, or wait for heartbeat events are incompatible
|
||||||
# with loadBalanced=True.
|
# with loadBalanced=True.
|
||||||
@classmethod
|
@classmethod
|
||||||
@async_client_context.require_no_load_balancer
|
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._setup_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._setup_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if _IS_SYNC:
|
||||||
|
cls._tearDown_class()
|
||||||
|
else:
|
||||||
|
asyncio.run(cls._tearDown_class())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@async_client_context.require_no_load_balancer
|
||||||
|
async def _setup_class(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _tearDown_class(cls):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1015,7 +1075,7 @@ async def async_teardown():
|
|||||||
await c.drop_database("pymongo_test2")
|
await c.drop_database("pymongo_test2")
|
||||||
await c.drop_database("pymongo_test_mike")
|
await c.drop_database("pymongo_test_mike")
|
||||||
await c.drop_database("pymongo_test_bernie")
|
await c.drop_database("pymongo_test_bernie")
|
||||||
await c.close()
|
await c.aclose()
|
||||||
|
|
||||||
print_running_clients()
|
print_running_clients()
|
||||||
|
|
||||||
@ -1053,3 +1113,38 @@ def print_running_clients():
|
|||||||
processed.add(obj._topology_id)
|
processed.add(obj._topology_id)
|
||||||
except ReferenceError:
|
except ReferenceError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _all_users(db):
|
||||||
|
return {u["user"] for u in (await db.command("usersInfo")).get("users", [])}
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
|
||||||
|
cmd = SON([("createUser", user)])
|
||||||
|
# X509 doesn't use a password
|
||||||
|
if pwd:
|
||||||
|
cmd["pwd"] = pwd
|
||||||
|
cmd["roles"] = roles or ["root"]
|
||||||
|
cmd.update(**kwargs)
|
||||||
|
return await authdb.command(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
async def connected(client):
|
||||||
|
"""Convenience to wait for a newly-constructed client to connect."""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# Ignore warning that ping is always routed to primary even
|
||||||
|
# if client's read preference isn't PRIMARY.
|
||||||
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
await client.admin.command("ping") # Force connection.
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
async def drop_collections(db: AsyncDatabase):
|
||||||
|
# Drop all non-system collections in this database.
|
||||||
|
for coll in await db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
||||||
|
await db.drop_collection(coll)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_all_users(db: AsyncDatabase):
|
||||||
|
await db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": async_client_context.w})
|
||||||
|
|||||||
252
test/asynchronous/pymongo_mocks.py
Normal file
252
test/asynchronous/pymongo_mocks.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# Copyright 2013-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tools for mocking parts of PyMongo to test other parts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import weakref
|
||||||
|
from functools import partial
|
||||||
|
from test import client_context
|
||||||
|
from test.asynchronous import async_client_context
|
||||||
|
|
||||||
|
from pymongo import AsyncMongoClient, common
|
||||||
|
from pymongo.asynchronous.monitor import Monitor
|
||||||
|
from pymongo.asynchronous.pool import Pool
|
||||||
|
from pymongo.errors import AutoReconnect, NetworkTimeout
|
||||||
|
from pymongo.hello import Hello, HelloCompat
|
||||||
|
from pymongo.server_description import ServerDescription
|
||||||
|
|
||||||
|
_IS_SYNC = False
|
||||||
|
|
||||||
|
|
||||||
|
class MockPool(Pool):
|
||||||
|
def __init__(self, client, pair, *args, **kwargs):
|
||||||
|
# MockPool gets a 'client' arg, regular pools don't. Weakref it to
|
||||||
|
# avoid cycle with __del__, causing ResourceWarnings in Python 3.3.
|
||||||
|
self.client = weakref.proxy(client)
|
||||||
|
self.mock_host, self.mock_port = pair
|
||||||
|
|
||||||
|
# Actually connect to the default server.
|
||||||
|
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def checkout(self, handler=None):
|
||||||
|
client = self.client
|
||||||
|
host_and_port = f"{self.mock_host}:{self.mock_port}"
|
||||||
|
if host_and_port in client.mock_down_hosts:
|
||||||
|
raise AutoReconnect("mock error")
|
||||||
|
|
||||||
|
assert host_and_port in (
|
||||||
|
client.mock_standalones + client.mock_members + client.mock_mongoses
|
||||||
|
), "bad host: %s" % host_and_port
|
||||||
|
|
||||||
|
async with Pool.checkout(self, handler) as conn:
|
||||||
|
conn.mock_host = self.mock_host
|
||||||
|
conn.mock_port = self.mock_port
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMonitor:
|
||||||
|
def __init__(self, server_description, topology, pool, topology_settings):
|
||||||
|
self._server_description = server_description
|
||||||
|
self.opened = False
|
||||||
|
|
||||||
|
def cancel_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def join(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
self.opened = True
|
||||||
|
|
||||||
|
def request_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.opened = False
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMockMonitor(Monitor):
|
||||||
|
def __init__(self, client, server_description, topology, pool, topology_settings):
|
||||||
|
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
|
||||||
|
# to avoid cycles.
|
||||||
|
self.client = weakref.proxy(client)
|
||||||
|
Monitor.__init__(self, server_description, topology, pool, topology_settings)
|
||||||
|
|
||||||
|
async def _check_once(self):
|
||||||
|
client = self.client
|
||||||
|
address = self._server_description.address
|
||||||
|
response, rtt = client.mock_hello("%s:%d" % address) # type: ignore[str-format]
|
||||||
|
return ServerDescription(address, Hello(response), rtt)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMockClient(AsyncMongoClient):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
standalones,
|
||||||
|
members,
|
||||||
|
mongoses,
|
||||||
|
hello_hosts=None,
|
||||||
|
arbiters=None,
|
||||||
|
down_hosts=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""An AsyncMongoClient connected to the default server, with a mock topology.
|
||||||
|
|
||||||
|
standalones, members, mongoses, arbiters, and down_hosts determine the
|
||||||
|
configuration of the topology. They are formatted like ['a:1', 'b:2'].
|
||||||
|
hello_hosts provides an alternative host list for the server's
|
||||||
|
mocked hello response; see test_connect_with_internal_ips.
|
||||||
|
"""
|
||||||
|
self.mock_standalones = standalones[:]
|
||||||
|
self.mock_members = members[:]
|
||||||
|
|
||||||
|
if self.mock_members:
|
||||||
|
self.mock_primary = self.mock_members[0]
|
||||||
|
else:
|
||||||
|
self.mock_primary = None
|
||||||
|
|
||||||
|
# Hosts that should be considered an arbiter.
|
||||||
|
self.mock_arbiters = arbiters[:] if arbiters else []
|
||||||
|
|
||||||
|
if hello_hosts is not None:
|
||||||
|
self.mock_hello_hosts = hello_hosts
|
||||||
|
else:
|
||||||
|
self.mock_hello_hosts = members[:]
|
||||||
|
|
||||||
|
self.mock_mongoses = mongoses[:]
|
||||||
|
|
||||||
|
# Hosts that should raise socket errors.
|
||||||
|
self.mock_down_hosts = down_hosts[:] if down_hosts else []
|
||||||
|
|
||||||
|
# Hostname -> (min wire version, max wire version)
|
||||||
|
self.mock_wire_versions = {}
|
||||||
|
|
||||||
|
# Hostname -> max write batch size
|
||||||
|
self.mock_max_write_batch_sizes = {}
|
||||||
|
|
||||||
|
# Hostname -> round trip time
|
||||||
|
self.mock_rtts = {}
|
||||||
|
|
||||||
|
kwargs["_pool_class"] = partial(MockPool, self)
|
||||||
|
kwargs["_monitor_class"] = partial(AsyncMockMonitor, self)
|
||||||
|
|
||||||
|
client_options = async_client_context.default_client_options.copy()
|
||||||
|
client_options.update(kwargs)
|
||||||
|
|
||||||
|
super().__init__(*args, **client_options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_async_mock_client(
|
||||||
|
cls,
|
||||||
|
standalones,
|
||||||
|
members,
|
||||||
|
mongoses,
|
||||||
|
hello_hosts=None,
|
||||||
|
arbiters=None,
|
||||||
|
down_hosts=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
c = AsyncMockClient(
|
||||||
|
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
await c.aconnect()
|
||||||
|
return c
|
||||||
|
|
||||||
|
def kill_host(self, host):
|
||||||
|
"""Host is like 'a:1'."""
|
||||||
|
self.mock_down_hosts.append(host)
|
||||||
|
|
||||||
|
def revive_host(self, host):
|
||||||
|
"""Host is like 'a:1'."""
|
||||||
|
self.mock_down_hosts.remove(host)
|
||||||
|
|
||||||
|
def set_wire_version_range(self, host, min_version, max_version):
|
||||||
|
self.mock_wire_versions[host] = (min_version, max_version)
|
||||||
|
|
||||||
|
def set_max_write_batch_size(self, host, size):
|
||||||
|
self.mock_max_write_batch_sizes[host] = size
|
||||||
|
|
||||||
|
def mock_hello(self, host):
|
||||||
|
"""Return mock hello response (a dict) and round trip time."""
|
||||||
|
if host in self.mock_wire_versions:
|
||||||
|
min_wire_version, max_wire_version = self.mock_wire_versions[host]
|
||||||
|
else:
|
||||||
|
min_wire_version = common.MIN_SUPPORTED_WIRE_VERSION
|
||||||
|
max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION
|
||||||
|
|
||||||
|
max_write_batch_size = self.mock_max_write_batch_sizes.get(
|
||||||
|
host, common.MAX_WRITE_BATCH_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
rtt = self.mock_rtts.get(host, 0)
|
||||||
|
|
||||||
|
# host is like 'a:1'.
|
||||||
|
if host in self.mock_down_hosts:
|
||||||
|
raise NetworkTimeout("mock timeout")
|
||||||
|
|
||||||
|
elif host in self.mock_standalones:
|
||||||
|
response = {
|
||||||
|
"ok": 1,
|
||||||
|
HelloCompat.LEGACY_CMD: True,
|
||||||
|
"minWireVersion": min_wire_version,
|
||||||
|
"maxWireVersion": max_wire_version,
|
||||||
|
"maxWriteBatchSize": max_write_batch_size,
|
||||||
|
}
|
||||||
|
elif host in self.mock_members:
|
||||||
|
primary = host == self.mock_primary
|
||||||
|
|
||||||
|
# Simulate a replica set member.
|
||||||
|
response = {
|
||||||
|
"ok": 1,
|
||||||
|
HelloCompat.LEGACY_CMD: primary,
|
||||||
|
"secondary": not primary,
|
||||||
|
"setName": "rs",
|
||||||
|
"hosts": self.mock_hello_hosts,
|
||||||
|
"minWireVersion": min_wire_version,
|
||||||
|
"maxWireVersion": max_wire_version,
|
||||||
|
"maxWriteBatchSize": max_write_batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.mock_primary:
|
||||||
|
response["primary"] = self.mock_primary
|
||||||
|
|
||||||
|
if host in self.mock_arbiters:
|
||||||
|
response["arbiterOnly"] = True
|
||||||
|
response["secondary"] = False
|
||||||
|
elif host in self.mock_mongoses:
|
||||||
|
response = {
|
||||||
|
"ok": 1,
|
||||||
|
HelloCompat.LEGACY_CMD: True,
|
||||||
|
"minWireVersion": min_wire_version,
|
||||||
|
"maxWireVersion": max_wire_version,
|
||||||
|
"msg": "isdbgrid",
|
||||||
|
"maxWriteBatchSize": max_write_batch_size,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# In test_internal_ips(), we try to connect to a host listed
|
||||||
|
# in hello['hosts'] but not publicly accessible.
|
||||||
|
raise AutoReconnect("Unknown host: %s" % host)
|
||||||
|
|
||||||
|
return response, rtt
|
||||||
|
|
||||||
|
def _process_periodic_tasks(self):
|
||||||
|
# Avoid the background thread causing races, e.g. a surprising
|
||||||
|
# reconnect while we're trying to test a disconnected client.
|
||||||
|
pass
|
||||||
2500
test/asynchronous/test_client.py
Normal file
2500
test/asynchronous/test_client.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1825,7 +1825,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
|||||||
await self.db.test.insert_many([{"i": i} for i in range(150)])
|
await self.db.test.insert_many([{"i": i} for i in range(150)])
|
||||||
|
|
||||||
client = await async_rs_or_single_client(maxPoolSize=1)
|
client = await async_rs_or_single_client(maxPoolSize=1)
|
||||||
self.addAsyncCleanup(client.close)
|
self.addAsyncCleanup(client.aclose)
|
||||||
pool = await async_get_pool(client)
|
pool = await async_get_pool(client)
|
||||||
|
|
||||||
# Make sure the socket is returned after exhaustion.
|
# Make sure the socket is returned after exhaustion.
|
||||||
|
|||||||
@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest):
|
|||||||
async def test_check_exists(self):
|
async def test_check_exists(self):
|
||||||
listener = OvertCommandListener()
|
listener = OvertCommandListener()
|
||||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||||
self.addAsyncCleanup(client.close)
|
self.addAsyncCleanup(client.aclose)
|
||||||
db = client[self.db.name]
|
db = client[self.db.name]
|
||||||
await db.drop_collection("unique")
|
await db.drop_collection("unique")
|
||||||
await db.create_collection("unique", check_exists=True)
|
await db.create_collection("unique", check_exists=True)
|
||||||
|
|||||||
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Shared constants and helper method for pymongo, bson, and gridfs test suites."""
|
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|||||||
@ -20,14 +20,15 @@ import weakref
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from test import client_context
|
from test import client_context
|
||||||
|
|
||||||
from pymongo import common
|
from pymongo import MongoClient, common
|
||||||
from pymongo.errors import AutoReconnect, NetworkTimeout
|
from pymongo.errors import AutoReconnect, NetworkTimeout
|
||||||
from pymongo.hello import Hello, HelloCompat
|
from pymongo.hello import Hello, HelloCompat
|
||||||
from pymongo.server_description import ServerDescription
|
from pymongo.server_description import ServerDescription
|
||||||
from pymongo.synchronous.mongo_client import MongoClient
|
|
||||||
from pymongo.synchronous.monitor import Monitor
|
from pymongo.synchronous.monitor import Monitor
|
||||||
from pymongo.synchronous.pool import Pool
|
from pymongo.synchronous.pool import Pool
|
||||||
|
|
||||||
|
_IS_SYNC = True
|
||||||
|
|
||||||
|
|
||||||
class MockPool(Pool):
|
class MockPool(Pool):
|
||||||
def __init__(self, client, pair, *args, **kwargs):
|
def __init__(self, client, pair, *args, **kwargs):
|
||||||
@ -77,7 +78,7 @@ class DummyMonitor:
|
|||||||
self.opened = False
|
self.opened = False
|
||||||
|
|
||||||
|
|
||||||
class MockMonitor(Monitor):
|
class SyncMockMonitor(Monitor):
|
||||||
def __init__(self, client, server_description, topology, pool, topology_settings):
|
def __init__(self, client, server_description, topology, pool, topology_settings):
|
||||||
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
|
# MockMonitor gets a 'client' arg, regular monitors don't. Weakref it
|
||||||
# to avoid cycles.
|
# to avoid cycles.
|
||||||
@ -141,13 +142,32 @@ class MockClient(MongoClient):
|
|||||||
self.mock_rtts = {}
|
self.mock_rtts = {}
|
||||||
|
|
||||||
kwargs["_pool_class"] = partial(MockPool, self)
|
kwargs["_pool_class"] = partial(MockPool, self)
|
||||||
kwargs["_monitor_class"] = partial(MockMonitor, self)
|
kwargs["_monitor_class"] = partial(SyncMockMonitor, self)
|
||||||
|
|
||||||
client_options = client_context.default_client_options.copy()
|
client_options = client_context.default_client_options.copy()
|
||||||
client_options.update(kwargs)
|
client_options.update(kwargs)
|
||||||
|
|
||||||
super().__init__(*args, **client_options)
|
super().__init__(*args, **client_options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mock_client(
|
||||||
|
cls,
|
||||||
|
standalones,
|
||||||
|
members,
|
||||||
|
mongoses,
|
||||||
|
hello_hosts=None,
|
||||||
|
arbiters=None,
|
||||||
|
down_hosts=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
c = MockClient(
|
||||||
|
standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
c._connect()
|
||||||
|
return c
|
||||||
|
|
||||||
def kill_host(self, host):
|
def kill_host(self, host):
|
||||||
"""Host is like 'a:1'."""
|
"""Host is like 'a:1'."""
|
||||||
self.mock_down_hosts.append(host)
|
self.mock_down_hosts.append(host)
|
||||||
|
|||||||
@ -23,9 +23,8 @@ from pymongo.synchronous.mongo_client import MongoClient
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, client_context, unittest
|
from test import IntegrationTest, client_context, remove_all_users, unittest
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
remove_all_users,
|
|
||||||
rs_or_single_client_noauth,
|
rs_or_single_client_noauth,
|
||||||
single_client,
|
single_client,
|
||||||
wait_until,
|
wait_until,
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import _thread as thread
|
import _thread as thread
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
@ -45,10 +46,13 @@ from test import (
|
|||||||
IntegrationTest,
|
IntegrationTest,
|
||||||
MockClientTest,
|
MockClientTest,
|
||||||
SkipTest,
|
SkipTest,
|
||||||
|
UnitTest,
|
||||||
client_context,
|
client_context,
|
||||||
client_knobs,
|
client_knobs,
|
||||||
|
connected,
|
||||||
db_pwd,
|
db_pwd,
|
||||||
db_user,
|
db_user,
|
||||||
|
remove_all_users,
|
||||||
unittest,
|
unittest,
|
||||||
)
|
)
|
||||||
from test.pymongo_mocks import MockClient
|
from test.pymongo_mocks import MockClient
|
||||||
@ -57,14 +61,12 @@ from test.utils import (
|
|||||||
CMAPListener,
|
CMAPListener,
|
||||||
FunctionCallRecorder,
|
FunctionCallRecorder,
|
||||||
assertRaisesExactly,
|
assertRaisesExactly,
|
||||||
connected,
|
|
||||||
delay,
|
delay,
|
||||||
get_pool,
|
get_pool,
|
||||||
gevent_monkey_patched,
|
gevent_monkey_patched,
|
||||||
is_greenthread_patched,
|
is_greenthread_patched,
|
||||||
lazy_client_trial,
|
lazy_client_trial,
|
||||||
one,
|
one,
|
||||||
remove_all_users,
|
|
||||||
rs_client,
|
rs_client,
|
||||||
rs_or_single_client,
|
rs_or_single_client,
|
||||||
rs_or_single_client_noauth,
|
rs_or_single_client_noauth,
|
||||||
@ -109,6 +111,7 @@ from pymongo.server_type import SERVER_TYPE
|
|||||||
from pymongo.synchronous.command_cursor import CommandCursor
|
from pymongo.synchronous.command_cursor import CommandCursor
|
||||||
from pymongo.synchronous.cursor import Cursor, CursorType
|
from pymongo.synchronous.cursor import Cursor, CursorType
|
||||||
from pymongo.synchronous.database import Database
|
from pymongo.synchronous.database import Database
|
||||||
|
from pymongo.synchronous.helpers import next
|
||||||
from pymongo.synchronous.mongo_client import MongoClient
|
from pymongo.synchronous.mongo_client import MongoClient
|
||||||
from pymongo.synchronous.pool import (
|
from pymongo.synchronous.pool import (
|
||||||
Connection,
|
Connection,
|
||||||
@ -118,18 +121,20 @@ from pymongo.synchronous.topology import _ErrorContext
|
|||||||
from pymongo.topology_description import TopologyDescription
|
from pymongo.topology_description import TopologyDescription
|
||||||
from pymongo.write_concern import WriteConcern
|
from pymongo.write_concern import WriteConcern
|
||||||
|
|
||||||
|
_IS_SYNC = True
|
||||||
|
|
||||||
class ClientUnitTest(unittest.TestCase):
|
|
||||||
|
class ClientUnitTest(UnitTest):
|
||||||
"""MongoClient tests that don't require a server."""
|
"""MongoClient tests that don't require a server."""
|
||||||
|
|
||||||
client: MongoClient
|
client: MongoClient
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def _setup_class(cls):
|
||||||
cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def _tearDown_class(cls):
|
||||||
cls.client.close()
|
cls.client.close()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -254,7 +259,8 @@ class ClientUnitTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_default_database(self):
|
def test_get_default_database(self):
|
||||||
c = rs_or_single_client(
|
c = rs_or_single_client(
|
||||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
|
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
|
||||||
|
connect=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
||||||
# Test that default doesn't override the URI value.
|
# Test that default doesn't override the URI value.
|
||||||
@ -269,39 +275,49 @@ class ClientUnitTest(unittest.TestCase):
|
|||||||
self.assertEqual(write_concern, db.write_concern)
|
self.assertEqual(write_concern, db.write_concern)
|
||||||
|
|
||||||
c = rs_or_single_client(
|
c = rs_or_single_client(
|
||||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||||
|
connect=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(Database(c, "foo"), c.get_default_database("foo"))
|
self.assertEqual(Database(c, "foo"), c.get_default_database("foo"))
|
||||||
|
|
||||||
def test_get_default_database_error(self):
|
def test_get_default_database_error(self):
|
||||||
# URI with no database.
|
# URI with no database.
|
||||||
c = rs_or_single_client(
|
c = rs_or_single_client(
|
||||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||||
|
connect=False,
|
||||||
)
|
)
|
||||||
self.assertRaises(ConfigurationError, c.get_default_database)
|
self.assertRaises(ConfigurationError, c.get_default_database)
|
||||||
|
|
||||||
def test_get_default_database_with_authsource(self):
|
def test_get_default_database_with_authsource(self):
|
||||||
# Ensure we distinguish database name from authSource.
|
# Ensure we distinguish database name from authSource.
|
||||||
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
|
uri = "mongodb://%s:%d/foo?authSource=src" % (
|
||||||
|
client_context.host,
|
||||||
|
client_context.port,
|
||||||
|
)
|
||||||
c = rs_or_single_client(uri, connect=False)
|
c = rs_or_single_client(uri, connect=False)
|
||||||
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
self.assertEqual(Database(c, "foo"), c.get_default_database())
|
||||||
|
|
||||||
def test_get_database_default(self):
|
def test_get_database_default(self):
|
||||||
c = rs_or_single_client(
|
c = rs_or_single_client(
|
||||||
"mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False
|
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
|
||||||
|
connect=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(Database(c, "foo"), c.get_database())
|
self.assertEqual(Database(c, "foo"), c.get_database())
|
||||||
|
|
||||||
def test_get_database_default_error(self):
|
def test_get_database_default_error(self):
|
||||||
# URI with no database.
|
# URI with no database.
|
||||||
c = rs_or_single_client(
|
c = rs_or_single_client(
|
||||||
"mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False
|
"mongodb://%s:%d/" % (client_context.host, client_context.port),
|
||||||
|
connect=False,
|
||||||
)
|
)
|
||||||
self.assertRaises(ConfigurationError, c.get_database)
|
self.assertRaises(ConfigurationError, c.get_database)
|
||||||
|
|
||||||
def test_get_database_default_with_authsource(self):
|
def test_get_database_default_with_authsource(self):
|
||||||
# Ensure we distinguish database name from authSource.
|
# Ensure we distinguish database name from authSource.
|
||||||
uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port)
|
uri = "mongodb://%s:%d/foo?authSource=src" % (
|
||||||
|
client_context.host,
|
||||||
|
client_context.port,
|
||||||
|
)
|
||||||
c = rs_or_single_client(uri, connect=False)
|
c = rs_or_single_client(uri, connect=False)
|
||||||
self.assertEqual(Database(c, "foo"), c.get_database())
|
self.assertEqual(Database(c, "foo"), c.get_database())
|
||||||
|
|
||||||
@ -634,7 +650,7 @@ class TestClient(IntegrationTest):
|
|||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
|
# Assert reaper doesn't remove connections when maxIdleTimeMS not set
|
||||||
client = rs_or_single_client()
|
client = rs_or_single_client()
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn:
|
with server._pool.checkout() as conn:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(1, len(server._pool.conns))
|
self.assertEqual(1, len(server._pool.conns))
|
||||||
@ -645,7 +661,7 @@ class TestClient(IntegrationTest):
|
|||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
# Assert reaper removes idle socket and replaces it with a new one
|
# Assert reaper removes idle socket and replaces it with a new one
|
||||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
|
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1)
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn:
|
with server._pool.checkout() as conn:
|
||||||
pass
|
pass
|
||||||
# When the reaper runs at the same time as the get_socket, two
|
# When the reaper runs at the same time as the get_socket, two
|
||||||
@ -659,7 +675,7 @@ class TestClient(IntegrationTest):
|
|||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
# Assert reaper respects maxPoolSize when adding new connections.
|
# Assert reaper respects maxPoolSize when adding new connections.
|
||||||
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
|
client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1)
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn:
|
with server._pool.checkout() as conn:
|
||||||
pass
|
pass
|
||||||
# When the reaper runs at the same time as the get_socket,
|
# When the reaper runs at the same time as the get_socket,
|
||||||
@ -673,7 +689,7 @@ class TestClient(IntegrationTest):
|
|||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
# Assert reaper has removed idle socket and NOT replaced it
|
# Assert reaper has removed idle socket and NOT replaced it
|
||||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn_one:
|
with server._pool.checkout() as conn_one:
|
||||||
pass
|
pass
|
||||||
# Assert that the pool does not close connections prematurely.
|
# Assert that the pool does not close connections prematurely.
|
||||||
@ -690,12 +706,12 @@ class TestClient(IntegrationTest):
|
|||||||
def test_min_pool_size(self):
|
def test_min_pool_size(self):
|
||||||
with client_knobs(kill_cursor_frequency=0.1):
|
with client_knobs(kill_cursor_frequency=0.1):
|
||||||
client = rs_or_single_client()
|
client = rs_or_single_client()
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
self.assertEqual(0, len(server._pool.conns))
|
self.assertEqual(0, len(server._pool.conns))
|
||||||
|
|
||||||
# Assert that pool started up at minPoolSize
|
# Assert that pool started up at minPoolSize
|
||||||
client = rs_or_single_client(minPoolSize=10)
|
client = rs_or_single_client(minPoolSize=10)
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
wait_until(
|
wait_until(
|
||||||
lambda: len(server._pool.conns) == 10,
|
lambda: len(server._pool.conns) == 10,
|
||||||
"pool initialized with 10 connections",
|
"pool initialized with 10 connections",
|
||||||
@ -714,7 +730,7 @@ class TestClient(IntegrationTest):
|
|||||||
# Use high frequency to test _get_socket_no_auth.
|
# Use high frequency to test _get_socket_no_auth.
|
||||||
with client_knobs(kill_cursor_frequency=99999999):
|
with client_knobs(kill_cursor_frequency=99999999):
|
||||||
client = rs_or_single_client(maxIdleTimeMS=500)
|
client = rs_or_single_client(maxIdleTimeMS=500)
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn:
|
with server._pool.checkout() as conn:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(1, len(server._pool.conns))
|
self.assertEqual(1, len(server._pool.conns))
|
||||||
@ -728,7 +744,7 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
# Test that connections are reused if maxIdleTimeMS is not set.
|
# Test that connections are reused if maxIdleTimeMS is not set.
|
||||||
client = rs_or_single_client()
|
client = rs_or_single_client()
|
||||||
server = client._get_topology().select_server(readable_server_selector, _Op.TEST)
|
server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST)
|
||||||
with server._pool.checkout() as conn:
|
with server._pool.checkout() as conn:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(1, len(server._pool.conns))
|
self.assertEqual(1, len(server._pool.conns))
|
||||||
@ -793,12 +809,14 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
bad_host = "somedomainthatdoesntexist.org"
|
bad_host = "somedomainthatdoesntexist.org"
|
||||||
c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
||||||
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
|
with self.assertRaises(ConnectionFailure):
|
||||||
|
c.pymongo_test.test.find_one()
|
||||||
|
|
||||||
def test_init_disconnected_with_auth(self):
|
def test_init_disconnected_with_auth(self):
|
||||||
uri = "mongodb://user:pass@somedomainthatdoesntexist"
|
uri = "mongodb://user:pass@somedomainthatdoesntexist"
|
||||||
c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10)
|
||||||
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
|
with self.assertRaises(ConnectionFailure):
|
||||||
|
c.pymongo_test.test.find_one()
|
||||||
|
|
||||||
def test_equality(self):
|
def test_equality(self):
|
||||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||||
@ -816,7 +834,8 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False))
|
self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False))
|
||||||
# Same seeds but out of order still compares equal:
|
# Same seeds but out of order still compares equal:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
MongoClient(["a", "b", "c"], connect=False), MongoClient(["c", "a", "b"], connect=False)
|
MongoClient(["a", "b", "c"], connect=False),
|
||||||
|
MongoClient(["c", "a", "b"], connect=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hashable(self):
|
def test_hashable(self):
|
||||||
@ -830,9 +849,10 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
def test_host_w_port(self):
|
def test_host_w_port(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
host = client_context.host
|
||||||
connected(
|
connected(
|
||||||
MongoClient(
|
MongoClient(
|
||||||
f"{client_context.host}:1234567",
|
f"{host}:1234567",
|
||||||
connectTimeoutMS=1,
|
connectTimeoutMS=1,
|
||||||
serverSelectionTimeoutMS=10,
|
serverSelectionTimeoutMS=10,
|
||||||
)
|
)
|
||||||
@ -883,10 +903,10 @@ class TestClient(IntegrationTest):
|
|||||||
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
||||||
|
|
||||||
def test_list_databases(self):
|
def test_list_databases(self):
|
||||||
cmd_docs = self.client.admin.command("listDatabases")["databases"]
|
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
|
||||||
cursor = self.client.list_databases()
|
cursor = self.client.list_databases()
|
||||||
self.assertIsInstance(cursor, CommandCursor)
|
self.assertIsInstance(cursor, CommandCursor)
|
||||||
helper_docs = list(cursor)
|
helper_docs = cursor.to_list()
|
||||||
self.assertTrue(len(helper_docs) > 0)
|
self.assertTrue(len(helper_docs) > 0)
|
||||||
self.assertEqual(len(helper_docs), len(cmd_docs))
|
self.assertEqual(len(helper_docs), len(cmd_docs))
|
||||||
# PYTHON-3529 Some fields may change between calls, just compare names.
|
# PYTHON-3529 Some fields may change between calls, just compare names.
|
||||||
@ -900,7 +920,7 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
self.client.pymongo_test.test.insert_one({})
|
self.client.pymongo_test.test.insert_one({})
|
||||||
cursor = self.client.list_databases(filter={"name": "admin"})
|
cursor = self.client.list_databases(filter={"name": "admin"})
|
||||||
docs = list(cursor)
|
docs = cursor.to_list()
|
||||||
self.assertEqual(1, len(docs))
|
self.assertEqual(1, len(docs))
|
||||||
self.assertEqual(docs[0]["name"], "admin")
|
self.assertEqual(docs[0]["name"], "admin")
|
||||||
|
|
||||||
@ -911,7 +931,7 @@ class TestClient(IntegrationTest):
|
|||||||
def test_list_database_names(self):
|
def test_list_database_names(self):
|
||||||
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||||
self.client.pymongo_test_mike.test.insert_one({"dummy": "object"})
|
self.client.pymongo_test_mike.test.insert_one({"dummy": "object"})
|
||||||
cmd_docs = self.client.admin.command("listDatabases")["databases"]
|
cmd_docs = (self.client.admin.command("listDatabases"))["databases"]
|
||||||
cmd_names = [doc["name"] for doc in cmd_docs]
|
cmd_names = [doc["name"] for doc in cmd_docs]
|
||||||
|
|
||||||
db_names = self.client.list_database_names()
|
db_names = self.client.list_database_names()
|
||||||
@ -920,8 +940,10 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertEqual(db_names, cmd_names)
|
self.assertEqual(db_names, cmd_names)
|
||||||
|
|
||||||
def test_drop_database(self):
|
def test_drop_database(self):
|
||||||
self.assertRaises(TypeError, self.client.drop_database, 5)
|
with self.assertRaises(TypeError):
|
||||||
self.assertRaises(TypeError, self.client.drop_database, None)
|
self.client.drop_database(5) # type: ignore[arg-type]
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.client.drop_database(None) # type: ignore[arg-type]
|
||||||
|
|
||||||
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
self.client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||||
self.client.pymongo_test2.test.insert_one({"dummy": "object"})
|
self.client.pymongo_test2.test.insert_one({"dummy": "object"})
|
||||||
@ -944,7 +966,8 @@ class TestClient(IntegrationTest):
|
|||||||
test_client = rs_or_single_client()
|
test_client = rs_or_single_client()
|
||||||
coll = test_client.pymongo_test.bar
|
coll = test_client.pymongo_test.bar
|
||||||
test_client.close()
|
test_client.close()
|
||||||
self.assertRaises(InvalidOperation, coll.count_documents, {})
|
with self.assertRaises(InvalidOperation):
|
||||||
|
coll.count_documents({})
|
||||||
|
|
||||||
def test_close_kills_cursors(self):
|
def test_close_kills_cursors(self):
|
||||||
if sys.platform.startswith("java"):
|
if sys.platform.startswith("java"):
|
||||||
@ -961,7 +984,7 @@ class TestClient(IntegrationTest):
|
|||||||
coll.insert_many([{"i": i} for i in range(docs_inserted)])
|
coll.insert_many([{"i": i} for i in range(docs_inserted)])
|
||||||
|
|
||||||
# Open a cursor and leave it open on the server.
|
# Open a cursor and leave it open on the server.
|
||||||
cursor = coll.find().batch_size(10)
|
cursor = (coll.find()).batch_size(10)
|
||||||
self.assertTrue(bool(next(cursor)))
|
self.assertTrue(bool(next(cursor)))
|
||||||
self.assertLess(cursor.retrieved, docs_inserted)
|
self.assertLess(cursor.retrieved, docs_inserted)
|
||||||
|
|
||||||
@ -992,7 +1015,8 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertTrue(client._kill_cursors_executor._stopped)
|
self.assertTrue(client._kill_cursors_executor._stopped)
|
||||||
|
|
||||||
# Reusing the closed client should raise an InvalidOperation error.
|
# Reusing the closed client should raise an InvalidOperation error.
|
||||||
self.assertRaises(InvalidOperation, client.admin.command, "ping")
|
with self.assertRaises(InvalidOperation):
|
||||||
|
client.admin.command("ping")
|
||||||
# Thread is still stopped.
|
# Thread is still stopped.
|
||||||
self.assertTrue(client._kill_cursors_executor._stopped)
|
self.assertTrue(client._kill_cursors_executor._stopped)
|
||||||
|
|
||||||
@ -1065,8 +1089,10 @@ class TestClient(IntegrationTest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Auth with lazy connection.
|
# Auth with lazy connection.
|
||||||
rs_or_single_client_noauth(
|
(
|
||||||
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
|
rs_or_single_client_noauth(
|
||||||
|
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False
|
||||||
|
)
|
||||||
).pymongo_test.test.find_one()
|
).pymongo_test.test.find_one()
|
||||||
|
|
||||||
# Wrong password.
|
# Wrong password.
|
||||||
@ -1074,7 +1100,8 @@ class TestClient(IntegrationTest):
|
|||||||
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False
|
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one)
|
with self.assertRaises(OperationFailure):
|
||||||
|
bad_client.pymongo_test.test.find_one()
|
||||||
|
|
||||||
@client_context.require_auth
|
@client_context.require_auth
|
||||||
def test_username_and_password(self):
|
def test_username_and_password(self):
|
||||||
@ -1093,13 +1120,14 @@ class TestClient(IntegrationTest):
|
|||||||
c.server_info()
|
c.server_info()
|
||||||
|
|
||||||
with self.assertRaises(OperationFailure):
|
with self.assertRaises(OperationFailure):
|
||||||
rs_or_single_client_noauth(username="ad min", password="foo").server_info()
|
(rs_or_single_client_noauth(username="ad min", password="foo")).server_info()
|
||||||
|
|
||||||
@client_context.require_auth
|
@client_context.require_auth
|
||||||
@client_context.require_no_fips
|
@client_context.require_no_fips
|
||||||
def test_lazy_auth_raises_operation_failure(self):
|
def test_lazy_auth_raises_operation_failure(self):
|
||||||
|
host = client_context.host
|
||||||
lazy_client = rs_or_single_client_noauth(
|
lazy_client = rs_or_single_client_noauth(
|
||||||
f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False
|
f"mongodb://user:wrong@{host}/pymongo_test", connect=False
|
||||||
)
|
)
|
||||||
|
|
||||||
assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one)
|
assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one)
|
||||||
@ -1125,11 +1153,10 @@ class TestClient(IntegrationTest):
|
|||||||
self.assertTrue(mongodb_socket in repr(client))
|
self.assertTrue(mongodb_socket in repr(client))
|
||||||
|
|
||||||
# Confirm it fails with a missing socket.
|
# Confirm it fails with a missing socket.
|
||||||
self.assertRaises(
|
with self.assertRaises(ConnectionFailure):
|
||||||
ConnectionFailure,
|
connected(
|
||||||
connected,
|
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
|
||||||
MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def test_document_class(self):
|
def test_document_class(self):
|
||||||
c = self.client
|
c = self.client
|
||||||
@ -1154,27 +1181,30 @@ class TestClient(IntegrationTest):
|
|||||||
maxIdleTimeMS=10500,
|
maxIdleTimeMS=10500,
|
||||||
serverSelectionTimeoutMS=10500,
|
serverSelectionTimeoutMS=10500,
|
||||||
)
|
)
|
||||||
self.assertEqual(10.5, get_pool(client).opts.connect_timeout)
|
self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout)
|
||||||
self.assertEqual(10.5, get_pool(client).opts.socket_timeout)
|
self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout)
|
||||||
self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds)
|
self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds)
|
||||||
self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds)
|
self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds)
|
||||||
self.assertEqual(10.5, client.options.server_selection_timeout)
|
self.assertEqual(10.5, client.options.server_selection_timeout)
|
||||||
|
|
||||||
def test_socket_timeout_ms_validation(self):
|
def test_socket_timeout_ms_validation(self):
|
||||||
c = rs_or_single_client(socketTimeoutMS=10 * 1000)
|
c = rs_or_single_client(socketTimeoutMS=10 * 1000)
|
||||||
self.assertEqual(10, get_pool(c).opts.socket_timeout)
|
self.assertEqual(10, (get_pool(c)).opts.socket_timeout)
|
||||||
|
|
||||||
c = connected(rs_or_single_client(socketTimeoutMS=None))
|
c = connected(rs_or_single_client(socketTimeoutMS=None))
|
||||||
self.assertEqual(None, get_pool(c).opts.socket_timeout)
|
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
|
||||||
|
|
||||||
c = connected(rs_or_single_client(socketTimeoutMS=0))
|
c = connected(rs_or_single_client(socketTimeoutMS=0))
|
||||||
self.assertEqual(None, get_pool(c).opts.socket_timeout)
|
self.assertEqual(None, (get_pool(c)).opts.socket_timeout)
|
||||||
|
|
||||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1)
|
with self.assertRaises(ValueError):
|
||||||
|
rs_or_single_client(socketTimeoutMS=-1)
|
||||||
|
|
||||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10)
|
with self.assertRaises(ValueError):
|
||||||
|
rs_or_single_client(socketTimeoutMS=1e10)
|
||||||
|
|
||||||
self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS="foo")
|
with self.assertRaises(ValueError):
|
||||||
|
rs_or_single_client(socketTimeoutMS="foo")
|
||||||
|
|
||||||
def test_socket_timeout(self):
|
def test_socket_timeout(self):
|
||||||
no_timeout = self.client
|
no_timeout = self.client
|
||||||
@ -1189,11 +1219,12 @@ class TestClient(IntegrationTest):
|
|||||||
where_func = delay(timeout_sec + 1)
|
where_func = delay(timeout_sec + 1)
|
||||||
|
|
||||||
def get_x(db):
|
def get_x(db):
|
||||||
doc = next(db.test.find().where(where_func))
|
doc = next((db.test.find()).where(where_func))
|
||||||
return doc["x"]
|
return doc["x"]
|
||||||
|
|
||||||
self.assertEqual(1, get_x(no_timeout.pymongo_test))
|
self.assertEqual(1, get_x(no_timeout.pymongo_test))
|
||||||
self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test)
|
with self.assertRaises(NetworkTimeout):
|
||||||
|
get_x(timeout.pymongo_test)
|
||||||
|
|
||||||
def test_server_selection_timeout(self):
|
def test_server_selection_timeout(self):
|
||||||
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
client = MongoClient(serverSelectionTimeoutMS=100, connect=False)
|
||||||
@ -1224,7 +1255,7 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
def test_waitQueueTimeoutMS(self):
|
def test_waitQueueTimeoutMS(self):
|
||||||
client = rs_or_single_client(waitQueueTimeoutMS=2000)
|
client = rs_or_single_client(waitQueueTimeoutMS=2000)
|
||||||
self.assertEqual(get_pool(client).opts.wait_queue_timeout, 2)
|
self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2)
|
||||||
|
|
||||||
def test_socketKeepAlive(self):
|
def test_socketKeepAlive(self):
|
||||||
pool = get_pool(self.client)
|
pool = get_pool(self.client)
|
||||||
@ -1244,11 +1275,11 @@ class TestClient(IntegrationTest):
|
|||||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||||
aware.pymongo_test.test.insert_one({"x": now})
|
aware.pymongo_test.test.insert_one({"x": now})
|
||||||
|
|
||||||
self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo)
|
self.assertEqual(None, (naive.pymongo_test.test.find_one())["x"].tzinfo)
|
||||||
self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo)
|
self.assertEqual(utc, (aware.pymongo_test.test.find_one())["x"].tzinfo)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None),
|
(aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None),
|
||||||
naive.pymongo_test.test.find_one()["x"],
|
(naive.pymongo_test.test.find_one())["x"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@client_context.require_ipv6
|
@client_context.require_ipv6
|
||||||
@ -1282,18 +1313,21 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
# The socket used for the previous commands has been returned to the
|
# The socket used for the previous commands has been returned to the
|
||||||
# pool
|
# pool
|
||||||
self.assertEqual(1, len(get_pool(client).conns))
|
self.assertEqual(1, len((get_pool(client)).conns))
|
||||||
|
|
||||||
with contextlib.closing(client):
|
# contextlib async support was added in Python 3.10
|
||||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
if _IS_SYNC or sys.version_info >= (3, 10):
|
||||||
with self.assertRaises(InvalidOperation):
|
with contextlib.closing(client):
|
||||||
client.pymongo_test.test.find_one()
|
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
|
||||||
client = rs_or_single_client()
|
with self.assertRaises(InvalidOperation):
|
||||||
with client as client:
|
client.pymongo_test.test.find_one()
|
||||||
self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"])
|
client = rs_or_single_client()
|
||||||
with self.assertRaises(InvalidOperation):
|
with client as client:
|
||||||
client.pymongo_test.test.find_one()
|
self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"])
|
||||||
|
with self.assertRaises(InvalidOperation):
|
||||||
|
client.pymongo_test.test.find_one()
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_interrupt_signal(self):
|
def test_interrupt_signal(self):
|
||||||
if sys.platform.startswith("java"):
|
if sys.platform.startswith("java"):
|
||||||
# We can't figure out how to raise an exception on a thread that's
|
# We can't figure out how to raise an exception on a thread that's
|
||||||
@ -1344,7 +1378,7 @@ class TestClient(IntegrationTest):
|
|||||||
raised = False
|
raised = False
|
||||||
try:
|
try:
|
||||||
# Will be interrupted by a KeyboardInterrupt.
|
# Will be interrupted by a KeyboardInterrupt.
|
||||||
next(db.foo.find({"$where": where}))
|
next(db.foo.find({"$where": where})) # type: ignore[call-overload]
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raised = True
|
raised = True
|
||||||
|
|
||||||
@ -1355,7 +1389,7 @@ class TestClient(IntegrationTest):
|
|||||||
# Raises AssertionError due to PYTHON-294 -- Mongo's response to
|
# Raises AssertionError due to PYTHON-294 -- Mongo's response to
|
||||||
# the previous find() is still waiting to be read on the socket,
|
# the previous find() is still waiting to be read on the socket,
|
||||||
# so the request id's don't match.
|
# so the request id's don't match.
|
||||||
self.assertEqual({"_id": 1}, next(db.foo.find()))
|
self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload]
|
||||||
finally:
|
finally:
|
||||||
if old_signal_handler:
|
if old_signal_handler:
|
||||||
signal.signal(signal.SIGALRM, old_signal_handler)
|
signal.signal(signal.SIGALRM, old_signal_handler)
|
||||||
@ -1374,7 +1408,8 @@ class TestClient(IntegrationTest):
|
|||||||
old_conn = next(iter(pool.conns))
|
old_conn = next(iter(pool.conns))
|
||||||
client.pymongo_test.test.drop()
|
client.pymongo_test.test.drop()
|
||||||
client.pymongo_test.test.insert_one({"_id": "foo"})
|
client.pymongo_test.test.insert_one({"_id": "foo"})
|
||||||
self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"})
|
with self.assertRaises(OperationFailure):
|
||||||
|
client.pymongo_test.test.insert_one({"_id": "foo"})
|
||||||
|
|
||||||
self.assertEqual(socket_count, len(pool.conns))
|
self.assertEqual(socket_count, len(pool.conns))
|
||||||
new_con = next(iter(pool.conns))
|
new_con = next(iter(pool.conns))
|
||||||
@ -1392,23 +1427,29 @@ class TestClient(IntegrationTest):
|
|||||||
client = rs_or_single_client(connect=False, w=0)
|
client = rs_or_single_client(connect=False, w=0)
|
||||||
self.addCleanup(client.close)
|
self.addCleanup(client.close)
|
||||||
client.test_lazy_connect_w0.test.insert_one({})
|
client.test_lazy_connect_w0.test.insert_one({})
|
||||||
wait_until(
|
|
||||||
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document"
|
def predicate():
|
||||||
)
|
return client.test_lazy_connect_w0.test.count_documents({}) == 1
|
||||||
|
|
||||||
|
wait_until(predicate, "find one document")
|
||||||
|
|
||||||
client = rs_or_single_client(connect=False, w=0)
|
client = rs_or_single_client(connect=False, w=0)
|
||||||
self.addCleanup(client.close)
|
self.addCleanup(client.close)
|
||||||
client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
|
client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
|
||||||
wait_until(
|
|
||||||
lambda: client.test_lazy_connect_w0.test.find_one().get("x") == 1, "update one document"
|
def predicate():
|
||||||
)
|
return (client.test_lazy_connect_w0.test.find_one()).get("x") == 1
|
||||||
|
|
||||||
|
wait_until(predicate, "update one document")
|
||||||
|
|
||||||
client = rs_or_single_client(connect=False, w=0)
|
client = rs_or_single_client(connect=False, w=0)
|
||||||
self.addCleanup(client.close)
|
self.addCleanup(client.close)
|
||||||
client.test_lazy_connect_w0.test.delete_one({})
|
client.test_lazy_connect_w0.test.delete_one({})
|
||||||
wait_until(
|
|
||||||
lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document"
|
def predicate():
|
||||||
)
|
return client.test_lazy_connect_w0.test.count_documents({}) == 0
|
||||||
|
|
||||||
|
wait_until(predicate, "delete one document")
|
||||||
|
|
||||||
@client_context.require_no_mongos
|
@client_context.require_no_mongos
|
||||||
def test_exhaust_network_error(self):
|
def test_exhaust_network_error(self):
|
||||||
@ -1445,12 +1486,13 @@ class TestClient(IntegrationTest):
|
|||||||
|
|
||||||
# Cause a network error on the actual socket.
|
# Cause a network error on the actual socket.
|
||||||
pool = get_pool(c)
|
pool = get_pool(c)
|
||||||
socket_info = one(pool.conns)
|
conn = one(pool.conns)
|
||||||
socket_info.conn.close()
|
conn.conn.close()
|
||||||
|
|
||||||
# Connection.authenticate logs, but gets a socket.error. Should be
|
# Connection.authenticate logs, but gets a socket.error. Should be
|
||||||
# reraised as AutoReconnect.
|
# reraised as AutoReconnect.
|
||||||
self.assertRaises(AutoReconnect, c.test.collection.find_one)
|
with self.assertRaises(AutoReconnect):
|
||||||
|
c.test.collection.find_one()
|
||||||
|
|
||||||
# No semaphore leak, the pool is allowed to make a new socket.
|
# No semaphore leak, the pool is allowed to make a new socket.
|
||||||
c.test.collection.find_one()
|
c.test.collection.find_one()
|
||||||
@ -1638,13 +1680,19 @@ class TestClient(IntegrationTest):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def run(self):
|
def _run(self):
|
||||||
while self.running:
|
while self.running:
|
||||||
exc = AutoReconnect("mock pool error")
|
exc = AutoReconnect("mock pool error")
|
||||||
ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None)
|
ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None)
|
||||||
client._topology.handle_error(pool.address, ctx)
|
client._topology.handle_error(pool.address, ctx)
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if _IS_SYNC:
|
||||||
|
self._run()
|
||||||
|
else:
|
||||||
|
asyncio.run(self._run())
|
||||||
|
|
||||||
t = ResetPoolThread(pool)
|
t = ResetPoolThread(pool)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
@ -1755,7 +1803,7 @@ class TestClient(IntegrationTest):
|
|||||||
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
|
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
|
||||||
):
|
):
|
||||||
assert client.address is not None
|
assert client.address is not None
|
||||||
expected = "{}:{}: ".format(*client.address)
|
expected = "{}:{}: ".format(*(client.address))
|
||||||
with self.assertRaisesRegex(AutoReconnect, expected):
|
with self.assertRaisesRegex(AutoReconnect, expected):
|
||||||
client.pymongo_test.test.find_one({})
|
client.pymongo_test.test.find_one({})
|
||||||
|
|
||||||
@ -1814,6 +1862,7 @@ class TestClient(IntegrationTest):
|
|||||||
"loadBalanced clients do not run SDAM",
|
"loadBalanced clients do not run SDAM",
|
||||||
)
|
)
|
||||||
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
|
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
|
||||||
|
@client_context.require_sync
|
||||||
def test_sigstop_sigcont(self):
|
def test_sigstop_sigcont(self):
|
||||||
test_dir = os.path.dirname(os.path.realpath(__file__))
|
test_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
script = os.path.join(test_dir, "sigstop_sigcont.py")
|
script = os.path.join(test_dir, "sigstop_sigcont.py")
|
||||||
@ -1961,7 +2010,8 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST
|
SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertRaises(OperationFailure, cursor.next)
|
with self.assertRaises(OperationFailure):
|
||||||
|
cursor.next()
|
||||||
self.assertFalse(conn.closed)
|
self.assertFalse(conn.closed)
|
||||||
|
|
||||||
# The socket was checked in and the semaphore was decremented.
|
# The socket was checked in and the semaphore was decremented.
|
||||||
@ -1998,7 +2048,8 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
return message._OpReply.unpack(msg)
|
return message._OpReply.unpack(msg)
|
||||||
|
|
||||||
conn.receive_message = receive_message
|
conn.receive_message = receive_message
|
||||||
self.assertRaises(OperationFailure, list, cursor)
|
with self.assertRaises(OperationFailure):
|
||||||
|
cursor.to_list()
|
||||||
# Unpatch the instance.
|
# Unpatch the instance.
|
||||||
del conn.receive_message
|
del conn.receive_message
|
||||||
|
|
||||||
@ -2019,7 +2070,8 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
conn.conn.close()
|
conn.conn.close()
|
||||||
|
|
||||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||||
self.assertRaises(ConnectionFailure, cursor.next)
|
with self.assertRaises(ConnectionFailure):
|
||||||
|
cursor.next()
|
||||||
self.assertTrue(conn.closed)
|
self.assertTrue(conn.closed)
|
||||||
|
|
||||||
# The socket was closed and the semaphore was decremented.
|
# The socket was closed and the semaphore was decremented.
|
||||||
@ -2046,7 +2098,8 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
conn.conn.close()
|
conn.conn.close()
|
||||||
|
|
||||||
# A getmore fails.
|
# A getmore fails.
|
||||||
self.assertRaises(ConnectionFailure, list, cursor)
|
with self.assertRaises(ConnectionFailure):
|
||||||
|
cursor.to_list()
|
||||||
self.assertTrue(conn.closed)
|
self.assertTrue(conn.closed)
|
||||||
|
|
||||||
wait_until(
|
wait_until(
|
||||||
@ -2057,6 +2110,7 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
self.assertNotIn(conn, pool.conns)
|
self.assertNotIn(conn, pool.conns)
|
||||||
self.assertEqual(0, pool.requests)
|
self.assertEqual(0, pool.requests)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_gevent_task(self):
|
def test_gevent_task(self):
|
||||||
if not gevent_monkey_patched():
|
if not gevent_monkey_patched():
|
||||||
raise SkipTest("Must be running monkey patched by gevent")
|
raise SkipTest("Must be running monkey patched by gevent")
|
||||||
@ -2070,6 +2124,7 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
task.kill()
|
task.kill()
|
||||||
self.assertTrue(task.dead)
|
self.assertTrue(task.dead)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_gevent_timeout(self):
|
def test_gevent_timeout(self):
|
||||||
if not gevent_monkey_patched():
|
if not gevent_monkey_patched():
|
||||||
raise SkipTest("Must be running monkey patched by gevent")
|
raise SkipTest("Must be running monkey patched by gevent")
|
||||||
@ -2101,6 +2156,7 @@ class TestExhaustCursor(IntegrationTest):
|
|||||||
self.assertIsNone(tt.get())
|
self.assertIsNone(tt.get())
|
||||||
self.assertIsNone(ct.get())
|
self.assertIsNone(ct.get())
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_gevent_timeout_when_creating_connection(self):
|
def test_gevent_timeout_when_creating_connection(self):
|
||||||
if not gevent_monkey_patched():
|
if not gevent_monkey_patched():
|
||||||
raise SkipTest("Must be running monkey patched by gevent")
|
raise SkipTest("Must be running monkey patched by gevent")
|
||||||
@ -2145,6 +2201,7 @@ class TestClientLazyConnect(IntegrationTest):
|
|||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
return rs_or_single_client(connect=False)
|
return rs_or_single_client(connect=False)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_insert_one(self):
|
def test_insert_one(self):
|
||||||
def reset(collection):
|
def reset(collection):
|
||||||
collection.drop()
|
collection.drop()
|
||||||
@ -2157,6 +2214,7 @@ class TestClientLazyConnect(IntegrationTest):
|
|||||||
|
|
||||||
lazy_client_trial(reset, insert_one, test, self._get_client)
|
lazy_client_trial(reset, insert_one, test, self._get_client)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_update_one(self):
|
def test_update_one(self):
|
||||||
def reset(collection):
|
def reset(collection):
|
||||||
collection.drop()
|
collection.drop()
|
||||||
@ -2171,6 +2229,7 @@ class TestClientLazyConnect(IntegrationTest):
|
|||||||
|
|
||||||
lazy_client_trial(reset, update_one, test, self._get_client)
|
lazy_client_trial(reset, update_one, test, self._get_client)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_delete_one(self):
|
def test_delete_one(self):
|
||||||
def reset(collection):
|
def reset(collection):
|
||||||
collection.drop()
|
collection.drop()
|
||||||
@ -2184,6 +2243,7 @@ class TestClientLazyConnect(IntegrationTest):
|
|||||||
|
|
||||||
lazy_client_trial(reset, delete_one, test, self._get_client)
|
lazy_client_trial(reset, delete_one, test, self._get_client)
|
||||||
|
|
||||||
|
@client_context.require_sync
|
||||||
def test_find_one(self):
|
def test_find_one(self):
|
||||||
results: list = []
|
results: list = []
|
||||||
|
|
||||||
@ -2203,7 +2263,7 @@ class TestClientLazyConnect(IntegrationTest):
|
|||||||
|
|
||||||
class TestMongoClientFailover(MockClientTest):
|
class TestMongoClientFailover(MockClientTest):
|
||||||
def test_discover_primary(self):
|
def test_discover_primary(self):
|
||||||
c = MockClient(
|
c = MockClient.get_mock_client(
|
||||||
standalones=[],
|
standalones=[],
|
||||||
members=["a:1", "b:2", "c:3"],
|
members=["a:1", "b:2", "c:3"],
|
||||||
mongoses=[],
|
mongoses=[],
|
||||||
@ -2219,13 +2279,17 @@ class TestMongoClientFailover(MockClientTest):
|
|||||||
# Fail over.
|
# Fail over.
|
||||||
c.kill_host("a:1")
|
c.kill_host("a:1")
|
||||||
c.mock_primary = "b:2"
|
c.mock_primary = "b:2"
|
||||||
wait_until(lambda: c.address == ("b", 2), "wait for server address to be updated")
|
|
||||||
|
def predicate():
|
||||||
|
return (c.address) == ("b", 2)
|
||||||
|
|
||||||
|
wait_until(predicate, "wait for server address to be updated")
|
||||||
# a:1 not longer in nodes.
|
# a:1 not longer in nodes.
|
||||||
self.assertLess(len(c.nodes), 3)
|
self.assertLess(len(c.nodes), 3)
|
||||||
|
|
||||||
def test_reconnect(self):
|
def test_reconnect(self):
|
||||||
# Verify the node list isn't forgotten during a network failure.
|
# Verify the node list isn't forgotten during a network failure.
|
||||||
c = MockClient(
|
c = MockClient.get_mock_client(
|
||||||
standalones=[],
|
standalones=[],
|
||||||
members=["a:1", "b:2", "c:3"],
|
members=["a:1", "b:2", "c:3"],
|
||||||
mongoses=[],
|
mongoses=[],
|
||||||
@ -2245,12 +2309,13 @@ class TestMongoClientFailover(MockClientTest):
|
|||||||
|
|
||||||
# MongoClient discovers it's alone. The first attempt raises either
|
# MongoClient discovers it's alone. The first attempt raises either
|
||||||
# ServerSelectionTimeoutError or AutoReconnect (from
|
# ServerSelectionTimeoutError or AutoReconnect (from
|
||||||
# MockPool.get_socket).
|
# AsyncMockPool.get_socket).
|
||||||
self.assertRaises(AutoReconnect, c.db.collection.find_one)
|
with self.assertRaises(AutoReconnect):
|
||||||
|
c.db.collection.find_one()
|
||||||
|
|
||||||
# But it can reconnect.
|
# But it can reconnect.
|
||||||
c.revive_host("a:1")
|
c.revive_host("a:1")
|
||||||
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
|
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||||
self.assertEqual(c.address, ("a", 1))
|
self.assertEqual(c.address, ("a", 1))
|
||||||
|
|
||||||
def _test_network_error(self, operation_callback):
|
def _test_network_error(self, operation_callback):
|
||||||
@ -2273,7 +2338,7 @@ class TestMongoClientFailover(MockClientTest):
|
|||||||
# Set host-specific information so we can test whether it is reset.
|
# Set host-specific information so we can test whether it is reset.
|
||||||
c.set_wire_version_range("a:1", 2, 6)
|
c.set_wire_version_range("a:1", 2, 6)
|
||||||
c.set_wire_version_range("b:2", 2, 7)
|
c.set_wire_version_range("b:2", 2, 7)
|
||||||
c._get_topology().select_servers(writable_server_selector, _Op.TEST)
|
(c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
|
||||||
wait_until(lambda: len(c.nodes) == 2, "connect")
|
wait_until(lambda: len(c.nodes) == 2, "connect")
|
||||||
|
|
||||||
c.kill_host("a:1")
|
c.kill_host("a:1")
|
||||||
@ -2281,17 +2346,18 @@ class TestMongoClientFailover(MockClientTest):
|
|||||||
# MongoClient is disconnected from the primary. This raises either
|
# MongoClient is disconnected from the primary. This raises either
|
||||||
# ServerSelectionTimeoutError or AutoReconnect (from
|
# ServerSelectionTimeoutError or AutoReconnect (from
|
||||||
# MockPool.get_socket).
|
# MockPool.get_socket).
|
||||||
self.assertRaises(AutoReconnect, operation_callback, c)
|
with self.assertRaises(AutoReconnect):
|
||||||
|
operation_callback(c)
|
||||||
|
|
||||||
# The primary's description is reset.
|
# The primary's description is reset.
|
||||||
server_a = c._get_topology().get_server_by_address(("a", 1))
|
server_a = (c._get_topology()).get_server_by_address(("a", 1))
|
||||||
sd_a = server_a.description
|
sd_a = server_a.description
|
||||||
self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type)
|
self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type)
|
||||||
self.assertEqual(0, sd_a.min_wire_version)
|
self.assertEqual(0, sd_a.min_wire_version)
|
||||||
self.assertEqual(0, sd_a.max_wire_version)
|
self.assertEqual(0, sd_a.max_wire_version)
|
||||||
|
|
||||||
# ...but not the secondary's.
|
# ...but not the secondary's.
|
||||||
server_b = c._get_topology().get_server_by_address(("b", 2))
|
server_b = (c._get_topology()).get_server_by_address(("b", 2))
|
||||||
sd_b = server_b.description
|
sd_b = server_b.description
|
||||||
self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type)
|
self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type)
|
||||||
self.assertEqual(2, sd_b.min_wire_version)
|
self.assertEqual(2, sd_b.min_wire_version)
|
||||||
@ -2332,7 +2398,7 @@ class TestClientPool(MockClientTest):
|
|||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def test_rs_client_does_not_maintain_pool_to_arbiters(self):
|
def test_rs_client_does_not_maintain_pool_to_arbiters(self):
|
||||||
listener = CMAPListener()
|
listener = CMAPListener()
|
||||||
c = MockClient(
|
c = MockClient.get_mock_client(
|
||||||
standalones=[],
|
standalones=[],
|
||||||
members=["a:1", "b:2", "c:3", "d:4"],
|
members=["a:1", "b:2", "c:3", "d:4"],
|
||||||
mongoses=[],
|
mongoses=[],
|
||||||
@ -2363,7 +2429,7 @@ class TestClientPool(MockClientTest):
|
|||||||
@client_context.require_connection
|
@client_context.require_connection
|
||||||
def test_direct_client_maintains_pool_to_arbiter(self):
|
def test_direct_client_maintains_pool_to_arbiter(self):
|
||||||
listener = CMAPListener()
|
listener = CMAPListener()
|
||||||
c = MockClient(
|
c = MockClient.get_mock_client(
|
||||||
standalones=[],
|
standalones=[],
|
||||||
members=["a:1", "b:2", "c:3"],
|
members=["a:1", "b:2", "c:3"],
|
||||||
mongoses=[],
|
mongoses=[],
|
||||||
|
|||||||
@ -20,8 +20,8 @@ import uuid
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, client_context, unittest
|
from test import IntegrationTest, client_context, connected, unittest
|
||||||
from test.utils import connected, rs_or_single_client, single_client
|
from test.utils import rs_or_single_client, single_client
|
||||||
|
|
||||||
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
|
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
|
||||||
from bson.codec_options import CodecOptions
|
from bson.codec_options import CodecOptions
|
||||||
|
|||||||
@ -20,13 +20,12 @@ import sys
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, unittest
|
from test import IntegrationTest, drop_collections, unittest
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
SpecTestCreator,
|
SpecTestCreator,
|
||||||
camel_to_snake,
|
camel_to_snake,
|
||||||
camel_to_snake_args,
|
camel_to_snake_args,
|
||||||
camel_to_upper_camel,
|
camel_to_upper_camel,
|
||||||
drop_collections,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pymongo import WriteConcern, operations
|
from pymongo import WriteConcern, operations
|
||||||
|
|||||||
@ -22,9 +22,9 @@ from pymongo.operations import _Op
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import MockClientTest, client_context, unittest
|
from test import MockClientTest, client_context, connected, unittest
|
||||||
from test.pymongo_mocks import MockClient
|
from test.pymongo_mocks import MockClient
|
||||||
from test.utils import connected, wait_until
|
from test.utils import wait_until
|
||||||
|
|
||||||
from pymongo.errors import AutoReconnect, InvalidOperation
|
from pymongo.errors import AutoReconnect, InvalidOperation
|
||||||
from pymongo.server_selectors import writable_server_selector
|
from pymongo.server_selectors import writable_server_selector
|
||||||
|
|||||||
@ -22,10 +22,9 @@ from functools import partial
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, unittest
|
from test import IntegrationTest, connected, unittest
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
ServerAndTopologyEventListener,
|
ServerAndTopologyEventListener,
|
||||||
connected,
|
|
||||||
single_client,
|
single_client,
|
||||||
wait_until,
|
wait_until,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -26,10 +26,9 @@ from pymongo.operations import _Op
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import IntegrationTest, SkipTest, client_context, unittest
|
from test import IntegrationTest, SkipTest, client_context, connected, unittest
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
OvertCommandListener,
|
OvertCommandListener,
|
||||||
connected,
|
|
||||||
one,
|
one,
|
||||||
rs_client,
|
rs_client,
|
||||||
single_client,
|
single_client,
|
||||||
|
|||||||
@ -21,13 +21,19 @@ import sys
|
|||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
from test import HAVE_IPADDRESS, IntegrationTest, SkipTest, client_context, unittest
|
from test import (
|
||||||
|
HAVE_IPADDRESS,
|
||||||
|
IntegrationTest,
|
||||||
|
SkipTest,
|
||||||
|
client_context,
|
||||||
|
connected,
|
||||||
|
remove_all_users,
|
||||||
|
unittest,
|
||||||
|
)
|
||||||
from test.utils import (
|
from test.utils import (
|
||||||
EventListener,
|
EventListener,
|
||||||
cat_files,
|
cat_files,
|
||||||
connected,
|
|
||||||
ignore_deprecations,
|
ignore_deprecations,
|
||||||
remove_all_users,
|
|
||||||
)
|
)
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
|||||||
@ -621,7 +621,10 @@ async def _async_mongo_client(host, port, authenticate=True, directConnection=No
|
|||||||
):
|
):
|
||||||
client_options["username"] = db_user
|
client_options["username"] = db_user
|
||||||
client_options["password"] = db_pwd
|
client_options["password"] = db_pwd
|
||||||
return AsyncMongoClient(uri, port, **client_options)
|
client = AsyncMongoClient(uri, port, **client_options)
|
||||||
|
if client._options.connect:
|
||||||
|
await client.aconnect()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
|
||||||
@ -843,16 +846,6 @@ def server_started_with_auth(client):
|
|||||||
return "--auth" in argv or "--keyFile" in argv
|
return "--auth" in argv or "--keyFile" in argv
|
||||||
|
|
||||||
|
|
||||||
def drop_collections(db):
|
|
||||||
# Drop all non-system collections in this database.
|
|
||||||
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
|
|
||||||
db.drop_collection(coll)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_all_users(db):
|
|
||||||
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})
|
|
||||||
|
|
||||||
|
|
||||||
def joinall(threads):
|
def joinall(threads):
|
||||||
"""Join threads with a 5-minute timeout, assert joins succeeded"""
|
"""Join threads with a 5-minute timeout, assert joins succeeded"""
|
||||||
for t in threads:
|
for t in threads:
|
||||||
@ -860,17 +853,6 @@ def joinall(threads):
|
|||||||
assert not t.is_alive(), "Thread %s hung" % t
|
assert not t.is_alive(), "Thread %s hung" % t
|
||||||
|
|
||||||
|
|
||||||
def connected(client):
|
|
||||||
"""Convenience to wait for a newly-constructed client to connect."""
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
# Ignore warning that ping is always routed to primary even
|
|
||||||
# if client's read preference isn't PRIMARY.
|
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
|
||||||
client.admin.command("ping") # Force connection.
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def wait_until(predicate, success_description, timeout=10):
|
def wait_until(predicate, success_description, timeout=10):
|
||||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||||
|
|
||||||
@ -957,6 +939,20 @@ def assertRaisesExactly(cls, fn, *args, **kwargs):
|
|||||||
raise AssertionError("%s not raised" % cls)
|
raise AssertionError("%s not raised" % cls)
|
||||||
|
|
||||||
|
|
||||||
|
async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Unlike the standard assertRaises, this checks that a function raises a
|
||||||
|
specific class of exception, and not a subclass. E.g., check that
|
||||||
|
MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await fn(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}"
|
||||||
|
else:
|
||||||
|
raise AssertionError("%s not raised" % cls)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _ignore_deprecations():
|
def _ignore_deprecations():
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
|||||||
@ -72,11 +72,20 @@ replacements = {
|
|||||||
"addAsyncCleanup": "addCleanup",
|
"addAsyncCleanup": "addCleanup",
|
||||||
"async_setup_class": "setup_class",
|
"async_setup_class": "setup_class",
|
||||||
"IsolatedAsyncioTestCase": "TestCase",
|
"IsolatedAsyncioTestCase": "TestCase",
|
||||||
|
"AsyncUnitTest": "UnitTest",
|
||||||
|
"AsyncMockClient": "MockClient",
|
||||||
"async_get_pool": "get_pool",
|
"async_get_pool": "get_pool",
|
||||||
"async_is_mongos": "is_mongos",
|
"async_is_mongos": "is_mongos",
|
||||||
"async_rs_or_single_client": "rs_or_single_client",
|
"async_rs_or_single_client": "rs_or_single_client",
|
||||||
|
"async_rs_or_single_client_noauth": "rs_or_single_client_noauth",
|
||||||
|
"async_rs_client": "rs_client",
|
||||||
"async_single_client": "single_client",
|
"async_single_client": "single_client",
|
||||||
"async_from_client": "from_client",
|
"async_from_client": "from_client",
|
||||||
|
"aclosing": "closing",
|
||||||
|
"asyncAssertRaisesExactly": "assertRaisesExactly",
|
||||||
|
"get_async_mock_client": "get_mock_client",
|
||||||
|
"aconnect": "_connect",
|
||||||
|
"aclose": "close",
|
||||||
}
|
}
|
||||||
|
|
||||||
docstring_replacements: dict[tuple[str, str], str] = {
|
docstring_replacements: dict[tuple[str, str], str] = {
|
||||||
@ -131,6 +140,8 @@ sync_gridfs_files = [
|
|||||||
converted_tests = [
|
converted_tests = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"conftest.py",
|
"conftest.py",
|
||||||
|
"pymongo_mocks.py",
|
||||||
|
"test_client.py",
|
||||||
"test_collection.py",
|
"test_collection.py",
|
||||||
"test_database.py",
|
"test_database.py",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user