diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 1f54717a1..5e482b3d1 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -258,9 +258,6 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS - if [ -z "$TEST_ARGS" ]; then # TODO: remove this in PYTHON-4528 - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/synchronous/ $TEST_ARGS - fi python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 test/asynchronous/ $TEST_ARGS else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index cbac42f54..0715bcd2a 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -206,5 +206,4 @@ jobs: which python pip install -e ".[test]" pytest -v - pytest -v test/synchronous/ pytest -v test/asynchronous/ diff --git a/mypy_test.ini b/mypy_test.ini index 08e9b301a..9fdc664e3 100644 --- a/mypy_test.ini +++ b/mypy_test.ini @@ -7,7 +7,7 @@ exclude = (?x)( | ^test/conftest.py$ ) -[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.synchronous.*] +[mypy-pymongo.synchronous.*,gridfs.synchronous.*,test.*] warn_unused_ignores = false disable_error_code = unused-coroutine diff --git a/test/__init__.py b/test/__init__.py index 19218f01a..ef0a32882 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test suite for pymongo, bson, and gridfs.""" +"""Synchronous test suite for pymongo, bson, and gridfs.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -28,6 +29,31 @@ import time import traceback import unittest import warnings +from asyncio import iscoroutinefunction +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + _all_users, + _create_user, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) try: import ipaddress @@ -38,210 +64,21 @@ except ImportError: from contextlib import contextmanager from functools import wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator from unittest import SkipTest from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON -from pymongo import common, message from pymongo.common import partition_node from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.uri_parser import parse_uri -if HAVE_SSL: - import ssl - -# Enable debug output for uncollectable objects. PyPy does not have set_debug. -if hasattr(gc, "set_debug"): - gc.set_debug( - gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) - ) - -# The host and port of a single mongod or mongos, or the seed host -# for a replica set. -host = os.environ.get("DB_IP", "localhost") -port = int(os.environ.get("DB_PORT", 27017)) -IS_SRV = "mongodb+srv" in host - -db_user = os.environ.get("DB_USER", "user") -db_pwd = os.environ.get("DB_PASSWORD", "password") - -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") -CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) -CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) - -TLS_OPTIONS: Dict = {"tls": True} -if CLIENT_PEM: - TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM -if CA_PEM: - TLS_OPTIONS["tlsCAFile"] = CA_PEM - -COMPRESSORS = os.environ.get("COMPRESSORS") -MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") -TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) -TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) -SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") -MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") - -if TEST_LOADBALANCER: - res = parse_uri(SINGLE_MONGOS_LB_URI or "") - host, port = res["nodelist"][0] - db_user = res["username"] or db_user - db_pwd = res["password"] or db_pwd -elif TEST_SERVERLESS: - TEST_LOADBALANCER = True - res = parse_uri(SINGLE_MONGOS_LB_URI or "") - host, port = res["nodelist"][0] - db_user = res["username"] or db_user - db_pwd = res["password"] or db_pwd - TLS_OPTIONS = {"tls": True} - # Spec says serverless tests must be run with compression. - COMPRESSORS = COMPRESSORS or "zlib" - - -# Shared KMS data. -LOCAL_MASTER_KEY = base64.b64decode( - b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" - b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" -) -AWS_CREDS = { - "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), - "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), -} -AWS_CREDS_2 = { - "accessKeyId": os.environ.get("FLE_AWS_KEY2", ""), - "secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""), -} -AZURE_CREDS = { - "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), - "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), - "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), -} -GCP_CREDS = { - "email": os.environ.get("FLE_GCP_EMAIL", ""), - "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), -} -KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} - -# Ensure Evergreen metadata doesn't result in truncation -os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000") - - -def is_server_resolvable(): - """Returns True if 'server' is resolvable.""" - socket_timeout = socket.getdefaulttimeout() - socket.setdefaulttimeout(1) - try: - try: - socket.gethostbyname("server") - return True - except OSError: - return False - finally: - socket.setdefaulttimeout(socket_timeout) - - -def _create_user(authdb, user, pwd=None, roles=None, **kwargs): - cmd = SON([("createUser", user)]) - # X509 doesn't use a password - if pwd: - cmd["pwd"] = pwd - cmd["roles"] = roles or ["root"] - cmd.update(**kwargs) - return authdb.command(cmd) - - -class client_knobs: - def __init__( - self, - heartbeat_frequency=None, - min_heartbeat_interval=None, - kill_cursor_frequency=None, - events_queue_frequency=None, - ): - self.heartbeat_frequency = heartbeat_frequency - self.min_heartbeat_interval = min_heartbeat_interval - self.kill_cursor_frequency = kill_cursor_frequency - self.events_queue_frequency = events_queue_frequency - - self.old_heartbeat_frequency = None - self.old_min_heartbeat_interval = None - self.old_kill_cursor_frequency = None - self.old_events_queue_frequency = None - self._enabled = False - self._stack = None - - def enable(self): - self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY - self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL - self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY - self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY - - if self.heartbeat_frequency is not None: - common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency - - if self.min_heartbeat_interval is not None: - common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval - - if self.kill_cursor_frequency is not None: - common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency - - if self.events_queue_frequency is not None: - common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency - self._enabled = True - # Store the allocation traceback to catch non-disabled client_knobs. - self._stack = "".join(traceback.format_stack()) - - def __enter__(self): - self.enable() - - @no_type_check - def disable(self): - common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency - common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval - common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency - common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency - self._enabled = False - - def __exit__(self, exc_type, exc_val, exc_tb): - self.disable() - - def __call__(self, func): - def make_wrapper(f): - @wraps(f) - def wrap(*args, **kwargs): - with self: - return f(*args, **kwargs) - - return wrap - - return make_wrapper(func) - - def __del__(self): - if self._enabled: - msg = ( - "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " - "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " - "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( - common.HEARTBEAT_FREQUENCY, - common.MIN_HEARTBEAT_INTERVAL, - common.KILL_CURSOR_FREQUENCY, - common.EVENTS_QUEUE_FREQUENCY, - self._stack, - ) - ) - self.disable() - raise Exception(msg) - - -def _all_users(db): - return {u["user"] for u in db.command("usersInfo").get("users", [])} +_IS_SYNC = True class ClientContext: @@ -314,7 +151,8 @@ class ClientContext: auth_part = "" if client_context.auth_enabled: auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@" - return f"mongodb://{auth_part}{self.pair}/?{opts_part}" + pair = self.pair + return f"mongodb://{auth_part}{pair}/?{opts_part}" @property def hello(self): @@ -468,7 +306,7 @@ class ClientContext: self.test_commands_enabled = True self.has_ipv6 = self._server_started_with_ipv6() - self.is_mongos = self.hello.get("msg") == "isdbgrid" + self.is_mongos = (self.hello).get("msg") == "isdbgrid" if self.is_mongos: address = self.client.address self.mongoses.append(address) @@ -604,15 +442,33 @@ class ClientContext: def _require(self, condition, msg, func=None): def make_wrapper(f): + if iscoroutinefunction(f): + wraps_async = True + else: + wraps_async = False + @wraps(f) def wrap(*args, **kwargs): self.init() # Always raise SkipTest if we can't connect to MongoDB if not self.connected: - raise SkipTest(f"Cannot connect to MongoDB on {self.pair}") - if condition(): - return f(*args, **kwargs) - raise SkipTest(msg) + pair = self.pair + raise SkipTest(f"Cannot connect to MongoDB on {pair}") + if iscoroutinefunction(condition) and condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) + elif condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) + if "self.pair" in msg: + new_msg = msg.replace("self.pair", self.pair) + else: + new_msg = msg + raise SkipTest(new_msg) return wrap @@ -635,7 +491,7 @@ class ClientContext: """Run a test only if we can connect to MongoDB.""" return self._require( lambda: True, # _require checks if we're connected - f"Cannot connect to MongoDB on {self.pair}", + "Cannot connect to MongoDB on self.pair", func=func, ) @@ -643,7 +499,7 @@ class ClientContext: """Run a test only if we are connected to Atlas Data Lake.""" return self._require( lambda: self.is_data_lake, - f"Not connected to Atlas Data Lake on {self.pair}", + "Not connected to Atlas Data Lake on self.pair", func=func, ) @@ -816,7 +672,7 @@ class ClientContext: if "sharded" in topologies and self.is_mongos: return True if "sharded-replicaset" in topologies and self.is_mongos: - shards = list(client_context.client.config.shards.find()) + shards = (client_context.client.config.shards.find()).to_list() for shard in shards: # For a 3-member RS-backed sharded cluster, shard['host'] # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' @@ -969,45 +825,17 @@ class ClientContext: @property def max_bson_size(self): - return self.hello["maxBsonObjectSize"] + return (self.hello)["maxBsonObjectSize"] @property def max_write_batch_size(self): - return self.hello["maxWriteBatchSize"] + return (self.hello)["maxWriteBatchSize"] # Reusable client context client_context = ClientContext() -def sanitize_cmd(cmd): - cp = cmd.copy() - cp.pop("$clusterTime", None) - cp.pop("$db", None) - cp.pop("$readPreference", None) - cp.pop("lsid", None) - if MONGODB_API_VERSION: - # Stable API parameters - cp.pop("apiVersion", None) - # OP_MSG encoding may move the payload type one field to the - # end of the command. Do the same here. - name = next(iter(cp)) - try: - identifier = message._FIELD_MAP[name] - docs = cp.pop(identifier) - cp[identifier] = docs - except KeyError: - pass - return cp - - -def sanitize_reply(reply): - cp = reply.copy() - cp.pop("$clusterTime", None) - cp.pop("operationTime", None) - return cp - - class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1085,40 +913,23 @@ class PyMongoTestCase(unittest.TestCase): self.assertEqual(proc.exitcode, 0) -def print_thread_tracebacks() -> None: - """Print all Python thread tracebacks.""" - for thread_id, frame in sys._current_frames().items(): - sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n") - traceback.print_stack(frame, file=sys.stderr) - - -def print_thread_stacks(pid: int) -> None: - """Print all C-level thread stacks for a given process id.""" - if sys.platform == "darwin": - cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"'] - else: - cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"'] - - try: - res = subprocess.run( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" - ) - except Exception as exc: - sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}") - else: - sys.stderr.write(res.stdout) - - class IntegrationTest(PyMongoTestCase): - """Base class for TestCases that need a connection to MongoDB to pass.""" + """Async base class for TestCases that need a connection to MongoDB to pass.""" client: MongoClient[dict] db: Database credentials: Dict[str, str] @classmethod - @client_context.require_connection def setUpClass(cls): + if _IS_SYNC: + cls._setup_class() + else: + asyncio.run(cls._setup_class()) + + @classmethod + @client_context.require_connection + def _setup_class(cls): if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): @@ -1171,10 +982,6 @@ class MockClientTest(unittest.TestCase): super().tearDown() -# Global knobs to speed up the test suite. -global_knobs = client_knobs(events_queue_frequency=0.05) - - def setup(): client_context.init() warnings.resetwarnings() @@ -1182,56 +989,6 @@ def setup(): global_knobs.enable() -def _get_executors(topology): - executors = [] - for server in topology._servers.values(): - # Some MockMonitor do not have an _executor. - if hasattr(server._monitor, "_executor"): - executors.append(server._monitor._executor) - if hasattr(server._monitor, "_rtt_monitor"): - executors.append(server._monitor._rtt_monitor._executor) - executors.append(topology._Topology__events_executor) - if topology._srv_monitor: - executors.append(topology._srv_monitor._executor) - - return [e for e in executors if e is not None] - - -def print_running_topology(topology): - running = [e for e in _get_executors(topology) if not e._stopped] - if running: - print( - "WARNING: found Topology with running threads:\n" - f" Threads: {running}\n" - f" Topology: {topology}\n" - f" Creation traceback:\n{topology._settings._stack}" - ) - - -def print_running_clients(): - from pymongo.synchronous.topology import Topology - - processed = set() - # Avoid false positives on the main test client. - # XXX: Can be removed after PYTHON-1634 or PYTHON-1896. - c = client_context.client - if c: - processed.add(c._topology._topology_id) - # Call collect to manually cleanup any would-be gc'd clients to avoid - # false positives. - gc.collect() - for obj in gc.get_objects(): - try: - if isinstance(obj, Topology): - # Avoid printing the same Topology multiple times. - if obj._topology_id in processed: - continue - print_running_topology(obj) - processed.add(obj._topology_id) - except ReferenceError: - pass - - def teardown(): global_knobs.disable() garbage = [] @@ -1266,31 +1023,25 @@ def test_cases(suite): yield from test_cases(suite_or_case) -# Helper method to workaround https://bugs.python.org/issue21724 -def clear_warning_registry(): - """Clear the __warningregistry__ for all modules.""" - for _, module in list(sys.modules.items()): - if hasattr(module, "__warningregistry__"): - module.__warningregistry__ = {} # type:ignore[attr-defined] +def print_running_clients(): + from pymongo.synchronous.topology import Topology - -class SystemCertsPatcher: - def __init__(self, ca_certs): - if ( - ssl.OPENSSL_VERSION.lower().startswith("libressl") - and sys.platform == "darwin" - and not _ssl.IS_PYOPENSSL - ): - raise SkipTest( - "LibreSSL on OSX doesn't support setting CA certificates " - "using SSL_CERT_FILE environment variable." - ) - self.original_certs = os.environ.get("SSL_CERT_FILE") - # Tell OpenSSL where CA certificates live. - os.environ["SSL_CERT_FILE"] = ca_certs - - def disable(self): - if self.original_certs is None: - os.environ.pop("SSL_CERT_FILE") - else: - os.environ["SSL_CERT_FILE"] = self.original_certs + processed = set() + # Avoid false positives on the main test client. + # XXX: Can be removed after PYTHON-1634 or PYTHON-1896. + c = client_context.client + if c: + processed.add(c._topology._topology_id) + # Call collect to manually cleanup any would-be gc'd clients to avoid + # false positives. + gc.collect() + for obj in gc.get_objects(): + try: + if isinstance(obj, Topology): + # Avoid printing the same Topology multiple times. + if obj._topology_id in processed: + continue + print_running_topology(obj) + processed.add(obj._topology_id) + except ReferenceError: + pass diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0a74366ae..8e68f6dae 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -30,7 +30,7 @@ import traceback import unittest import warnings from asyncio import iscoroutinefunction -from test import ( +from test.helpers import ( COMPRESSORS, IS_SRV, MONGODB_API_VERSION, @@ -41,13 +41,14 @@ from test import ( SystemCertsPatcher, _all_users, _create_user, + client_knobs, db_pwd, db_user, global_knobs, host, is_server_resolvable, port, - print_running_clients, + print_running_topology, print_thread_stacks, print_thread_tracebacks, sanitize_cmd, @@ -113,6 +114,7 @@ class AsyncClientContext: self.is_data_lake = False self.load_balancer = TEST_LOADBALANCER self.serverless = TEST_SERVERLESS + self._fips_enabled = None if self.load_balancer or self.serverless: self.default_client_options["loadBalanced"] = True if COMPRESSORS: @@ -189,8 +191,7 @@ class AsyncClientContext: if self.client is not None: # Return early when connected to dataLake as mongohoused does not # support the getCmdLineOpts command and is tested without TLS. - build_info: Any = await self.client.admin.command("buildInfo") - if "dataLake" in build_info: + if os.environ.get("TEST_DATA_LAKE"): self.is_data_lake = True self.auth_enabled = True self.client = await self._connect(host, port, username=db_user, password=db_pwd) @@ -363,6 +364,17 @@ class AsyncClientContext: # Raised if self.server_status is None. return None + @property + def fips_enabled(self): + if self._fips_enabled is not None: + return self._fips_enabled + try: + subprocess.check_call(["fips-mode-setup", "--is-enabled"]) + self._fips_enabled = True + except (subprocess.SubprocessError, FileNotFoundError): + self._fips_enabled = False + return self._fips_enabled + def check_auth_type(self, auth_type): auth_mechs = self.server_parameters.get("authenticationMechanisms", []) return auth_type in auth_mechs @@ -528,6 +540,12 @@ class AsyncClientContext: lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) + def require_no_fips(self, func): + """Run a test only if the host does not have FIPS enabled.""" + return self._require( + lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func + ) + def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -937,6 +955,35 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): self.addCleanup(patcher.disable) +class AsyncMockClientTest(unittest.TestCase): + """Base class for TestCases that use MockClient. + + This class is *not* an IntegrationTest: if properly written, MockClient + tests do not require a running server. + + The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests. + """ + + # MockClients tests that use replicaSet, directConnection=True, pass + # multiple seed addresses, or wait for heartbeat events are incompatible + # with loadBalanced=True. + @classmethod + @async_client_context.require_no_load_balancer + def setUpClass(cls): + pass + + def setUp(self): + super().setUp() + + self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) + + self.client_knobs.enable() + + def tearDown(self): + self.client_knobs.disable() + super().tearDown() + + async def async_setup(): await async_client_context.init() warnings.resetwarnings() @@ -976,3 +1023,27 @@ def test_cases(suite): else: # unittest.TestSuite yield from test_cases(suite_or_case) + + +def print_running_clients(): + from pymongo.asynchronous.topology import Topology + + processed = set() + # Avoid false positives on the main test client. + # XXX: Can be removed after PYTHON-1634 or PYTHON-1896. + c = async_client_context.client + if c: + processed.add(c._topology._topology_id) + # Call collect to manually cleanup any would-be gc'd clients to avoid + # false positives. + gc.collect() + for obj in gc.get_objects(): + try: + if isinstance(obj, Topology): + # Avoid printing the same Topology multiple times. + if obj._topology_id in processed: + continue + print_running_topology(obj) + processed.add(obj._topology_id) + except ReferenceError: + pass diff --git a/test/conftest.py b/test/conftest.py index b65c64186..58f04ea7c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,6 +4,8 @@ from test import setup, teardown import pytest +_IS_SYNC = True + @pytest.fixture(scope="session", autouse=True) def test_setup_and_teardown(): diff --git a/test/helpers.py b/test/helpers.py new file mode 100644 index 000000000..af644b9cf --- /dev/null +++ b/test/helpers.py @@ -0,0 +1,367 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared constants and helper method for pymongo, bson, and gridfs test suites.""" +from __future__ import annotations + +import base64 +import gc +import multiprocessing +import os +import signal +import socket +import subprocess +import sys +import threading +import time +import traceback +import unittest +import warnings + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from contextlib import contextmanager +from functools import wraps +from test.version import Version +from typing import Any, Callable, Dict, Generator, no_type_check +from unittest import SkipTest +from urllib.parse import quote_plus + +import pymongo +import pymongo.errors +from bson.son import SON +from pymongo import common, message +from pymongo.common import partition_node +from pymongo.hello import HelloCompat +from pymongo.server_api import ServerApi +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +from pymongo.synchronous.database import Database +from pymongo.synchronous.mongo_client import MongoClient +from pymongo.uri_parser import parse_uri + +if HAVE_SSL: + import ssl + +# Enable debug output for uncollectable objects. PyPy does not have set_debug. +if hasattr(gc, "set_debug"): + gc.set_debug( + gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) + ) + +# The host and port of a single mongod or mongos, or the seed host +# for a replica set. +host = os.environ.get("DB_IP", "localhost") +port = int(os.environ.get("DB_PORT", 27017)) +IS_SRV = "mongodb+srv" in host + +db_user = os.environ.get("DB_USER", "user") +db_pwd = os.environ.get("DB_PASSWORD", "password") + +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) +CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) + +TLS_OPTIONS: Dict = {"tls": True} +if CLIENT_PEM: + TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM +if CA_PEM: + TLS_OPTIONS["tlsCAFile"] = CA_PEM + +COMPRESSORS = os.environ.get("COMPRESSORS") +MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") +TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) +TEST_SERVERLESS = bool(os.environ.get("TEST_SERVERLESS")) +SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") +MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") + +if TEST_LOADBALANCER: + res = parse_uri(SINGLE_MONGOS_LB_URI or "") + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd +elif TEST_SERVERLESS: + TEST_LOADBALANCER = True + res = parse_uri(SINGLE_MONGOS_LB_URI or "") + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd + TLS_OPTIONS = {"tls": True} + # Spec says serverless tests must be run with compression. + COMPRESSORS = COMPRESSORS or "zlib" + + +# Shared KMS data. +LOCAL_MASTER_KEY = base64.b64decode( + b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" + b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" +) +AWS_CREDS = { + "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), + "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), +} +AWS_CREDS_2 = { + "accessKeyId": os.environ.get("FLE_AWS_KEY2", ""), + "secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""), +} +AZURE_CREDS = { + "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), + "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), + "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), +} +GCP_CREDS = { + "email": os.environ.get("FLE_GCP_EMAIL", ""), + "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), +} +KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} + +# Ensure Evergreen metadata doesn't result in truncation +os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000") + + +def is_server_resolvable(): + """Returns True if 'server' is resolvable.""" + socket_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(1) + try: + try: + socket.gethostbyname("server") + return True + except OSError: + return False + finally: + socket.setdefaulttimeout(socket_timeout) + + +def _create_user(authdb, user, pwd=None, roles=None, **kwargs): + cmd = SON([("createUser", user)]) + # X509 doesn't use a password + if pwd: + cmd["pwd"] = pwd + cmd["roles"] = roles or ["root"] + cmd.update(**kwargs) + return authdb.command(cmd) + + +class client_knobs: + def __init__( + self, + heartbeat_frequency=None, + min_heartbeat_interval=None, + kill_cursor_frequency=None, + events_queue_frequency=None, + ): + self.heartbeat_frequency = heartbeat_frequency + self.min_heartbeat_interval = min_heartbeat_interval + self.kill_cursor_frequency = kill_cursor_frequency + self.events_queue_frequency = events_queue_frequency + + self.old_heartbeat_frequency = None + self.old_min_heartbeat_interval = None + self.old_kill_cursor_frequency = None + self.old_events_queue_frequency = None + self._enabled = False + self._stack = None + + def enable(self): + self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY + self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL + self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY + self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY + + if self.heartbeat_frequency is not None: + common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency + + if self.min_heartbeat_interval is not None: + common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval + + if self.kill_cursor_frequency is not None: + common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency + + if self.events_queue_frequency is not None: + common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency + self._enabled = True + # Store the allocation traceback to catch non-disabled client_knobs. + self._stack = "".join(traceback.format_stack()) + + def __enter__(self): + self.enable() + + @no_type_check + def disable(self): + common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency + common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval + common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency + common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency + self._enabled = False + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + def __call__(self, func): + def make_wrapper(f): + @wraps(f) + def wrap(*args, **kwargs): + with self: + return f(*args, **kwargs) + + return wrap + + return make_wrapper(func) + + def __del__(self): + if self._enabled: + msg = ( + "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " + "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " + "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( + common.HEARTBEAT_FREQUENCY, + common.MIN_HEARTBEAT_INTERVAL, + common.KILL_CURSOR_FREQUENCY, + common.EVENTS_QUEUE_FREQUENCY, + self._stack, + ) + ) + self.disable() + raise Exception(msg) + + +def _all_users(db): + return {u["user"] for u in db.command("usersInfo").get("users", [])} + + +def sanitize_cmd(cmd): + cp = cmd.copy() + cp.pop("$clusterTime", None) + cp.pop("$db", None) + cp.pop("$readPreference", None) + cp.pop("lsid", None) + if MONGODB_API_VERSION: + # Stable API parameters + cp.pop("apiVersion", None) + # OP_MSG encoding may move the payload type one field to the + # end of the command. Do the same here. + name = next(iter(cp)) + try: + identifier = message._FIELD_MAP[name] + docs = cp.pop(identifier) + cp[identifier] = docs + except KeyError: + pass + return cp + + +def sanitize_reply(reply): + cp = reply.copy() + cp.pop("$clusterTime", None) + cp.pop("operationTime", None) + return cp + + +def print_thread_tracebacks() -> None: + """Print all Python thread tracebacks.""" + for thread_id, frame in sys._current_frames().items(): + sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n") + traceback.print_stack(frame, file=sys.stderr) + + +def print_thread_stacks(pid: int) -> None: + """Print all C-level thread stacks for a given process id.""" + if sys.platform == "darwin": + cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"'] + else: + cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"'] + + try: + res = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" + ) + except Exception as exc: + sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}") + else: + sys.stderr.write(res.stdout) + + +# Global knobs to speed up the test suite. +global_knobs = client_knobs(events_queue_frequency=0.05) + + +def _get_executors(topology): + executors = [] + for server in topology._servers.values(): + # Some MockMonitor do not have an _executor. + if hasattr(server._monitor, "_executor"): + executors.append(server._monitor._executor) + if hasattr(server._monitor, "_rtt_monitor"): + executors.append(server._monitor._rtt_monitor._executor) + executors.append(topology._Topology__events_executor) + if topology._srv_monitor: + executors.append(topology._srv_monitor._executor) + + return [e for e in executors if e is not None] + + +def print_running_topology(topology): + running = [e for e in _get_executors(topology) if not e._stopped] + if running: + print( + "WARNING: found Topology with running threads:\n" + f" Threads: {running}\n" + f" Topology: {topology}\n" + f" Creation traceback:\n{topology._settings._stack}" + ) + + +def test_cases(suite): + """Iterator over all TestCases within a TestSuite.""" + for suite_or_case in suite._tests: + if isinstance(suite_or_case, unittest.TestCase): + # unittest.TestCase + yield suite_or_case + else: + # unittest.TestSuite + yield from test_cases(suite_or_case) + + +# Helper method to workaround https://bugs.python.org/issue21724 +def clear_warning_registry(): + """Clear the __warningregistry__ for all modules.""" + for _, module in list(sys.modules.items()): + if hasattr(module, "__warningregistry__"): + module.__warningregistry__ = {} # type:ignore[attr-defined] + + +class SystemCertsPatcher: + def __init__(self, ca_certs): + if ( + ssl.OPENSSL_VERSION.lower().startswith("libressl") + and sys.platform == "darwin" + and not _ssl.IS_PYOPENSSL + ): + raise SkipTest( + "LibreSSL on OSX doesn't support setting CA certificates " + "using SSL_CERT_FILE environment variable." + ) + self.original_certs = os.environ.get("SSL_CERT_FILE") + # Tell OpenSSL where CA certificates live. + os.environ["SSL_CERT_FILE"] = ca_certs + + def disable(self): + if self.original_certs is None: + os.environ.pop("SSL_CERT_FILE") + else: + os.environ["SSL_CERT_FILE"] = self.original_certs diff --git a/test/synchronous/__init__.py b/test/synchronous/__init__.py deleted file mode 100644 index 9176b22d1..000000000 --- a/test/synchronous/__init__.py +++ /dev/null @@ -1,976 +0,0 @@ -# Copyright 2010-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Asynchronous test suite for pymongo, bson, and gridfs.""" -from __future__ import annotations - -import asyncio -import base64 -import gc -import multiprocessing -import os -import signal -import socket -import subprocess -import sys -import threading -import time -import traceback -import unittest -import warnings -from asyncio import iscoroutinefunction -from test import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - _all_users, - _create_user, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_clients, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) - -try: - import ipaddress - - HAVE_IPADDRESS = True -except ImportError: - HAVE_IPADDRESS = False -from contextlib import contextmanager -from functools import wraps -from test.version import Version -from typing import Any, Callable, Dict, Generator -from unittest import SkipTest -from urllib.parse import quote_plus - -import pymongo -import pymongo.errors -from bson.son import SON -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.database import Database -from pymongo.synchronous.mongo_client import MongoClient - -_IS_SYNC = True - - -class ClientContext: - client: MongoClient - - MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI - - def __init__(self): - """Create a client and grab essential information from the server.""" - self.connection_attempts = [] - self.connected = False - self.w = None - self.nodes = set() - self.replica_set_name = None - self.cmd_line = None - self.server_status = None - self.version = Version(-1) # Needs to be comparable with Version - self.auth_enabled = False - self.test_commands_enabled = False - self.server_parameters = {} - self._hello = None - self.is_mongos = False - self.mongoses = [] - self.is_rs = False - self.has_ipv6 = False - self.tls = False - self.tlsCertificateKeyFile = False - self.server_is_resolvable = is_server_resolvable() - self.default_client_options: Dict = {} - self.sessions_enabled = False - self.client = None # type: ignore - self.conn_lock = threading.Lock() - self.is_data_lake = False - self.load_balancer = TEST_LOADBALANCER - self.serverless = TEST_SERVERLESS - if self.load_balancer or self.serverless: - self.default_client_options["loadBalanced"] = True - if COMPRESSORS: - self.default_client_options["compressors"] = COMPRESSORS - if MONGODB_API_VERSION: - server_api = ServerApi(MONGODB_API_VERSION) - self.default_client_options["server_api"] = server_api - - @property - def client_options(self): - """Return the MongoClient options for creating a duplicate client.""" - opts = client_context.default_client_options.copy() - opts["host"] = host - opts["port"] = port - if client_context.auth_enabled: - opts["username"] = db_user - opts["password"] = db_pwd - if self.replica_set_name: - opts["replicaSet"] = self.replica_set_name - return opts - - @property - def uri(self): - """Return the MongoClient URI for creating a duplicate client.""" - opts = client_context.default_client_options.copy() - opts.pop("server_api", None) # Cannot be set from the URI - opts_parts = [] - for opt, val in opts.items(): - strval = str(val) - if isinstance(val, bool): - strval = strval.lower() - opts_parts.append(f"{opt}={quote_plus(strval)}") - opts_part = "&".join(opts_parts) - auth_part = "" - if client_context.auth_enabled: - auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@" - pair = self.pair - return f"mongodb://{auth_part}{pair}/?{opts_part}" - - @property - def hello(self): - if not self._hello: - if self.serverless or self.load_balancer: - self._hello = self.client.admin.command(HelloCompat.CMD) - else: - self._hello = self.client.admin.command(HelloCompat.LEGACY_CMD) - return self._hello - - def _connect(self, host, port, **kwargs): - kwargs.update(self.default_client_options) - client: MongoClient = pymongo.MongoClient( - host, port, serverSelectionTimeoutMS=5000, **kwargs - ) - try: - try: - client.admin.command("ping") # Can we connect? - except pymongo.errors.OperationFailure as exc: - # SERVER-32063 - self.connection_attempts.append( - f"connected client {client!r}, but legacy hello failed: {exc}" - ) - else: - self.connection_attempts.append(f"successfully connected client {client!r}") - # If connected, then return client with default timeout - return pymongo.MongoClient(host, port, **kwargs) - except pymongo.errors.ConnectionFailure as exc: - self.connection_attempts.append(f"failed to connect client {client!r}: {exc}") - return None - finally: - client.close() - - def _init_client(self): - self.client = self._connect(host, port) - if self.client is not None: - # Return early when connected to dataLake as mongohoused does not - # support the getCmdLineOpts command and is tested without TLS. - build_info: Any = self.client.admin.command("buildInfo") - if "dataLake" in build_info: - self.is_data_lake = True - self.auth_enabled = True - self.client = self._connect(host, port, username=db_user, password=db_pwd) - self.connected = True - return - - if HAVE_SSL and not self.client: - # Is MongoDB configured for SSL? - self.client = self._connect(host, port, **TLS_OPTIONS) - if self.client: - self.tls = True - self.default_client_options.update(TLS_OPTIONS) - self.tlsCertificateKeyFile = True - - if self.client: - self.connected = True - - if self.serverless: - self.auth_enabled = True - else: - try: - self.cmd_line = self.client.admin.command("getCmdLineOpts") - except pymongo.errors.OperationFailure as e: - assert e.details is not None - msg = e.details.get("errmsg", "") - if e.code == 13 or "unauthorized" in msg or "login" in msg: - # Unauthorized. - self.auth_enabled = True - else: - raise - else: - self.auth_enabled = self._server_started_with_auth() - - if self.auth_enabled: - if not self.serverless and not IS_SRV: - # See if db_user already exists. - if not self._check_user_provided(): - _create_user(self.client.admin, db_user, db_pwd) - - self.client = self._connect( - host, - port, - username=db_user, - password=db_pwd, - replicaSet=self.replica_set_name, - **self.default_client_options, - ) - - # May not have this if OperationFailure was raised earlier. - self.cmd_line = self.client.admin.command("getCmdLineOpts") - - if self.serverless: - self.server_status = {} - else: - self.server_status = self.client.admin.command("serverStatus") - if self.storage_engine == "mmapv1": - # MMAPv1 does not support retryWrites=True. - self.default_client_options["retryWrites"] = False - - hello = self.hello - self.sessions_enabled = "logicalSessionTimeoutMinutes" in hello - - if "setName" in hello: - self.replica_set_name = str(hello["setName"]) - self.is_rs = True - if self.auth_enabled: - # It doesn't matter which member we use as the seed here. - self.client = pymongo.MongoClient( - host, - port, - username=db_user, - password=db_pwd, - replicaSet=self.replica_set_name, - **self.default_client_options, - ) - else: - self.client = pymongo.MongoClient( - host, port, replicaSet=self.replica_set_name, **self.default_client_options - ) - - # Get the authoritative hello result from the primary. - self._hello = None - hello = self.hello - nodes = [partition_node(node.lower()) for node in hello.get("hosts", [])] - nodes.extend([partition_node(node.lower()) for node in hello.get("passives", [])]) - nodes.extend([partition_node(node.lower()) for node in hello.get("arbiters", [])]) - self.nodes = set(nodes) - else: - self.nodes = {(host, port)} - self.w = len(hello.get("hosts", [])) or 1 - self.version = Version.from_client(self.client) - - if self.serverless: - self.server_parameters = { - "requireApiVersion": False, - "enableTestCommands": True, - } - self.test_commands_enabled = True - self.has_ipv6 = False - else: - self.server_parameters = self.client.admin.command("getParameter", "*") - assert self.cmd_line is not None - if self.server_parameters["enableTestCommands"]: - self.test_commands_enabled = True - elif "parsed" in self.cmd_line: - params = self.cmd_line["parsed"].get("setParameter", []) - if "enableTestCommands=1" in params: - self.test_commands_enabled = True - else: - params = self.cmd_line["parsed"].get("setParameter", {}) - if params.get("enableTestCommands") == "1": - self.test_commands_enabled = True - self.has_ipv6 = self._server_started_with_ipv6() - - self.is_mongos = (self.hello).get("msg") == "isdbgrid" - if self.is_mongos: - address = self.client.address - self.mongoses.append(address) - if not self.serverless: - # Check for another mongos on the next port. - assert address is not None - next_address = address[0], address[1] + 1 - mongos_client = self._connect(*next_address, **self.default_client_options) - if mongos_client: - hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) - if hello.get("msg") == "isdbgrid": - self.mongoses.append(next_address) - - def init(self): - with self.conn_lock: - if not self.client and not self.connection_attempts: - self._init_client() - - def connection_attempt_info(self): - return "\n".join(self.connection_attempts) - - @property - def host(self): - if self.is_rs and not IS_SRV: - primary = self.client.primary - return str(primary[0]) if primary is not None else host - return host - - @property - def port(self): - if self.is_rs and not IS_SRV: - primary = self.client.primary - return primary[1] if primary is not None else port - return port - - @property - def pair(self): - return "%s:%d" % (self.host, self.port) - - @property - def has_secondaries(self): - if not self.client: - return False - return bool(len(self.client.secondaries)) - - @property - def storage_engine(self): - try: - return self.server_status.get("storageEngine", {}).get( # type:ignore[union-attr] - "name" - ) - except AttributeError: - # Raised if self.server_status is None. - return None - - def check_auth_type(self, auth_type): - auth_mechs = self.server_parameters.get("authenticationMechanisms", []) - return auth_type in auth_mechs - - def _check_user_provided(self): - """Return True if db_user/db_password is already an admin user.""" - client: MongoClient = pymongo.MongoClient( - host, - port, - username=db_user, - password=db_pwd, - **self.default_client_options, - ) - - try: - return db_user in _all_users(client.admin) - except pymongo.errors.OperationFailure as e: - assert e.details is not None - msg = e.details.get("errmsg", "") - if e.code == 18 or "auth fails" in msg: - # Auth failed. - return False - else: - raise - finally: - client.close() - - def _server_started_with_auth(self): - # MongoDB >= 2.0 - assert self.cmd_line is not None - if "parsed" in self.cmd_line: - parsed = self.cmd_line["parsed"] - # MongoDB >= 2.6 - if "security" in parsed: - security = parsed["security"] - # >= rc3 - if "authorization" in security: - return security["authorization"] == "enabled" - # < rc3 - return security.get("auth", False) or bool(security.get("keyFile")) - return parsed.get("auth", False) or bool(parsed.get("keyFile")) - # Legacy - argv = self.cmd_line["argv"] - return "--auth" in argv or "--keyFile" in argv - - def _server_started_with_ipv6(self): - if not socket.has_ipv6: - return False - - assert self.cmd_line is not None - if "parsed" in self.cmd_line: - if not self.cmd_line["parsed"].get("net", {}).get("ipv6"): - return False - else: - if "--ipv6" not in self.cmd_line["argv"]: - return False - - # The server was started with --ipv6. Is there an IPv6 route to it? - try: - for info in socket.getaddrinfo(self.host, self.port): - if info[0] == socket.AF_INET6: - return True - except OSError: - pass - - return False - - def _require(self, condition, msg, func=None): - def make_wrapper(f): - if iscoroutinefunction(f): - wraps_async = True - else: - wraps_async = False - - @wraps(f) - def wrap(*args, **kwargs): - self.init() - # Always raise SkipTest if we can't connect to MongoDB - if not self.connected: - pair = self.pair - raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) - elif condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) - if "self.pair" in msg: - new_msg = msg.replace("self.pair", self.pair) - else: - new_msg = msg - raise SkipTest(new_msg) - - return wrap - - if func is None: - - def decorate(f): - return make_wrapper(f) - - return decorate - return make_wrapper(func) - - def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): - kwargs["writeConcern"] = {"w": self.w} - return _create_user(self.client[dbname], user, pwd, roles, **kwargs) - - def drop_user(self, dbname, user): - self.client[dbname].command("dropUser", user, writeConcern={"w": self.w}) - - def require_connection(self, func): - """Run a test only if we can connect to MongoDB.""" - return self._require( - lambda: True, # _require checks if we're connected - "Cannot connect to MongoDB on self.pair", - func=func, - ) - - def require_data_lake(self, func): - """Run a test only if we are connected to Atlas Data Lake.""" - return self._require( - lambda: self.is_data_lake, - "Not connected to Atlas Data Lake on self.pair", - func=func, - ) - - def require_no_mmap(self, func): - """Run a test only if the server is not using the MMAPv1 storage - engine. Only works for standalone and replica sets; tests are - run regardless of storage engine on sharded clusters. - """ - - def is_not_mmap(): - if self.is_mongos: - return True - return self.storage_engine != "mmapv1" - - return self._require(is_not_mmap, "Storage engine must not be MMAPv1", func=func) - - def require_version_min(self, *ver): - """Run a test only if the server version is at least ``version``.""" - other_version = Version(*ver) - return self._require( - lambda: self.version >= other_version, - "Server version must be at least %s" % str(other_version), - ) - - def require_version_max(self, *ver): - """Run a test only if the server version is at most ``version``.""" - other_version = Version(*ver) - return self._require( - lambda: self.version <= other_version, - "Server version must be at most %s" % str(other_version), - ) - - def require_auth(self, func): - """Run a test only if the server is running with auth enabled.""" - return self._require( - lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func - ) - - def require_no_auth(self, func): - """Run a test only if the server is running without auth enabled.""" - return self._require( - lambda: not self.auth_enabled, - "Authentication must not be enabled on the server", - func=func, - ) - - def require_replica_set(self, func): - """Run a test only if the client is connected to a replica set.""" - return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) - - def require_secondaries_count(self, count): - """Run a test only if the client is connected to a replica set that has - `count` secondaries. - """ - - def sec_count(): - return 0 if not self.client else len(self.client.secondaries) - - return self._require(lambda: sec_count() >= count, "Not enough secondaries available") - - @property - def supports_secondary_read_pref(self): - if self.has_secondaries: - return True - if self.is_mongos: - shard = self.client.config.shards.find_one()["host"] # type:ignore[index] - num_members = shard.count(",") + 1 - return num_members > 1 - return False - - def require_secondary_read_pref(self): - """Run a test only if the client is connected to a cluster that - supports secondary read preference - """ - return self._require( - lambda: self.supports_secondary_read_pref, - "This cluster does not support secondary read preference", - ) - - def require_no_replica_set(self, func): - """Run a test if the client is *not* connected to a replica set.""" - return self._require( - lambda: not self.is_rs, "Connected to a replica set, not a standalone mongod", func=func - ) - - def require_ipv6(self, func): - """Run a test only if the client can connect to a server via IPv6.""" - return self._require(lambda: self.has_ipv6, "No IPv6", func=func) - - def require_no_mongos(self, func): - """Run a test only if the client is not connected to a mongos.""" - return self._require( - lambda: not self.is_mongos, "Must be connected to a mongod, not a mongos", func=func - ) - - def require_mongos(self, func): - """Run a test only if the client is connected to a mongos.""" - return self._require(lambda: self.is_mongos, "Must be connected to a mongos", func=func) - - def require_multiple_mongoses(self, func): - """Run a test only if the client is connected to a sharded cluster - that has 2 mongos nodes. - """ - return self._require( - lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func - ) - - def require_standalone(self, func): - """Run a test only if the client is connected to a standalone.""" - return self._require( - lambda: not (self.is_mongos or self.is_rs), - "Must be connected to a standalone", - func=func, - ) - - def require_no_standalone(self, func): - """Run a test only if the client is not connected to a standalone.""" - return self._require( - lambda: self.is_mongos or self.is_rs, - "Must be connected to a replica set or mongos", - func=func, - ) - - def require_load_balancer(self, func): - """Run a test only if the client is connected to a load balancer.""" - return self._require( - lambda: self.load_balancer, "Must be connected to a load balancer", func=func - ) - - def require_no_load_balancer(self, func): - """Run a test only if the client is not connected to a load balancer.""" - return self._require( - lambda: not self.load_balancer, "Must not be connected to a load balancer", func=func - ) - - def require_no_serverless(self, func): - """Run a test only if the client is not connected to serverless.""" - return self._require( - lambda: not self.serverless, "Must not be connected to serverless", func=func - ) - - def require_change_streams(self, func): - """Run a test only if the server supports change streams.""" - return self.require_no_mmap(self.require_no_standalone(self.require_no_serverless(func))) - - def is_topology_type(self, topologies): - unknown = set(topologies) - { - "single", - "replicaset", - "sharded", - "sharded-replicaset", - "load-balanced", - } - if unknown: - raise AssertionError(f"Unknown topologies: {unknown!r}") - if self.load_balancer: - if "load-balanced" in topologies: - return True - return False - if "single" in topologies and not (self.is_mongos or self.is_rs): - return True - if "replicaset" in topologies and self.is_rs: - return True - if "sharded" in topologies and self.is_mongos: - return True - if "sharded-replicaset" in topologies and self.is_mongos: - shards = (client_context.client.config.shards.find()).to_list() - for shard in shards: - # For a 3-member RS-backed sharded cluster, shard['host'] - # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' - # Otherwise it will be 'ip1:port1' - host_spec = shard["host"] - if not len(host_spec.split("/")) > 1: - return False - return True - return False - - def require_cluster_type(self, topologies=None): - """Run a test only if the client is connected to a cluster that - conforms to one of the specified topologies. Acceptable topologies - are 'single', 'replicaset', and 'sharded'. - """ - topologies = topologies or [] - - def _is_valid_topology(): - return self.is_topology_type(topologies) - - return self._require(_is_valid_topology, "Cluster type not in %s" % (topologies)) - - def require_test_commands(self, func): - """Run a test only if the server has test commands enabled.""" - return self._require( - lambda: self.test_commands_enabled, "Test commands must be enabled", func=func - ) - - def require_failCommand_fail_point(self, func): - """Run a test only if the server supports the failCommand fail - point. - """ - return self._require( - lambda: self.supports_failCommand_fail_point, - "failCommand fail point must be supported", - func=func, - ) - - def require_failCommand_appName(self, func): - """Run a test only if the server supports the failCommand appName.""" - # SERVER-47195 - return self._require( - lambda: (self.test_commands_enabled and self.version >= (4, 4, -1)), - "failCommand appName must be supported", - func=func, - ) - - def require_failCommand_blockConnection(self, func): - """Run a test only if the server supports failCommand blockConnection.""" - return self._require( - lambda: ( - self.test_commands_enabled - and ( - (not self.is_mongos and self.version >= (4, 2, 9)) - or (self.is_mongos and self.version >= (4, 4)) - ) - ), - "failCommand blockConnection is not supported", - func=func, - ) - - def require_tls(self, func): - """Run a test only if the client can connect over TLS.""" - return self._require(lambda: self.tls, "Must be able to connect via TLS", func=func) - - def require_no_tls(self, func): - """Run a test only if the client can connect over TLS.""" - return self._require(lambda: not self.tls, "Must be able to connect without TLS", func=func) - - def require_tlsCertificateKeyFile(self, func): - """Run a test only if the client can connect with tlsCertificateKeyFile.""" - return self._require( - lambda: self.tlsCertificateKeyFile, - "Must be able to connect with tlsCertificateKeyFile", - func=func, - ) - - def require_server_resolvable(self, func): - """Run a test only if the hostname 'server' is resolvable.""" - return self._require( - lambda: self.server_is_resolvable, - "No hosts entry for 'server'. Cannot validate hostname in the certificate", - func=func, - ) - - def require_sessions(self, func): - """Run a test only if the deployment supports sessions.""" - return self._require(lambda: self.sessions_enabled, "Sessions not supported", func=func) - - def supports_retryable_writes(self): - if self.storage_engine == "mmapv1": - return False - if not self.sessions_enabled: - return False - return self.is_mongos or self.is_rs - - def require_retryable_writes(self, func): - """Run a test only if the deployment supports retryable writes.""" - return self._require( - self.supports_retryable_writes, - "This server does not support retryable writes", - func=func, - ) - - def supports_transactions(self): - if self.storage_engine == "mmapv1": - return False - - if self.version.at_least(4, 1, 8): - return self.is_mongos or self.is_rs - - if self.version.at_least(4, 0): - return self.is_rs - - return False - - def require_transactions(self, func): - """Run a test only if the deployment might support transactions. - - *Might* because this does not test the storage engine or FCV. - """ - return self._require( - self.supports_transactions, "Transactions are not supported", func=func - ) - - def require_no_api_version(self, func): - """Skip this test when testing with requireApiVersion.""" - return self._require( - lambda: not MONGODB_API_VERSION, - "This test does not work with requireApiVersion", - func=func, - ) - - def mongos_seeds(self): - return ",".join("{}:{}".format(*address) for address in self.mongoses) - - @property - def supports_failCommand_fail_point(self): - """Does the server support the failCommand fail point?""" - if self.is_mongos: - return self.version.at_least(4, 1, 5) and self.test_commands_enabled - else: - return self.version.at_least(4, 0) and self.test_commands_enabled - - @property - def requires_hint_with_min_max_queries(self): - """Does the server require a hint with min/max queries.""" - # Changed in SERVER-39567. - return self.version.at_least(4, 1, 10) - - @property - def max_bson_size(self): - return (self.hello)["maxBsonObjectSize"] - - @property - def max_write_batch_size(self): - return (self.hello)["maxWriteBatchSize"] - - -# Reusable client context -client_context = ClientContext() - - -class PyMongoTestCase(unittest.TestCase): - def assertEqualCommand(self, expected, actual, msg=None): - self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) - - def assertEqualReply(self, expected, actual, msg=None): - self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg) - - @contextmanager - def fail_point(self, command_args): - cmd_on = SON([("configureFailPoint", "failCommand")]) - cmd_on.update(command_args) - client_context.client.admin.command(cmd_on) - try: - yield - finally: - client_context.client.admin.command( - "configureFailPoint", cmd_on["configureFailPoint"], mode="off" - ) - - @contextmanager - def fork( - self, target: Callable, timeout: float = 60 - ) -> Generator[multiprocessing.Process, None, None]: - """Helper for tests that use os.fork() - - Use in a with statement: - - with self.fork(target=lambda: print('in child')) as proc: - self.assertTrue(proc.pid) # Child process was started - """ - - def _print_threads(*args: object) -> None: - if _print_threads.called: # type:ignore[attr-defined] - return - _print_threads.called = True # type:ignore[attr-defined] - print_thread_tracebacks() - - _print_threads.called = False # type:ignore[attr-defined] - - def _target() -> None: - signal.signal(signal.SIGUSR1, _print_threads) - try: - target() - except Exception as exc: - sys.stderr.write(f"Child process failed with: {exc}\n") - _print_threads() - # Sleep for a while to let the parent attach via GDB. - time.sleep(2 * timeout) - raise - - ctx = multiprocessing.get_context("fork") - proc = ctx.Process(target=_target) - proc.start() - try: - yield proc # type: ignore - finally: - proc.join(timeout) - pid = proc.pid - assert pid - if proc.exitcode is None: - # gdb to get C-level tracebacks - print_thread_stacks(pid) - # If it failed, SIGUSR1 to get thread tracebacks. - os.kill(pid, signal.SIGUSR1) - proc.join(5) - if proc.exitcode is None: - # SIGINT to get main thread traceback in case SIGUSR1 didn't work. - os.kill(pid, signal.SIGINT) - proc.join(5) - if proc.exitcode is None: - # SIGKILL in case SIGINT didn't work. - proc.kill() - proc.join(1) - self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") - self.assertEqual(proc.exitcode, 0) - - -class IntegrationTest(PyMongoTestCase): - """Async base class for TestCases that need a connection to MongoDB to pass.""" - - client: MongoClient[dict] - db: Database - credentials: Dict[str, str] - - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - @client_context.require_connection - def _setup_class(cls): - if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): - raise SkipTest("this test does not support load balancers") - if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): - raise SkipTest("this test does not support serverless") - cls.client = client_context.client - cls.db = cls.client.pymongo_test - if client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} - else: - cls.credentials = {} - - def cleanup_colls(self, *collections): - """Cleanup collections faster than drop_collection.""" - for c in collections: - c = self.client[c.database.name][c.name] - c.delete_many({}) - c.drop_indexes() - - def patch_system_certs(self, ca_certs): - patcher = SystemCertsPatcher(ca_certs) - self.addCleanup(patcher.disable) - - -def setup(): - client_context.init() - warnings.resetwarnings() - warnings.simplefilter("always") - global_knobs.enable() - - -def teardown(): - global_knobs.disable() - garbage = [] - for g in gc.garbage: - garbage.append(f"GARBAGE: {g!r}") - garbage.append(f" gc.get_referents: {gc.get_referents(g)!r}") - garbage.append(f" gc.get_referrers: {gc.get_referrers(g)!r}") - if garbage: - raise AssertionError("\n".join(garbage)) - c = client_context.client - if c: - if not client_context.is_data_lake: - c.drop_database("pymongo-pooling-tests") - c.drop_database("pymongo_test") - c.drop_database("pymongo_test1") - c.drop_database("pymongo_test2") - c.drop_database("pymongo_test_mike") - c.drop_database("pymongo_test_bernie") - c.close() - - print_running_clients() - - -def test_cases(suite): - """Iterator over all TestCases within a TestSuite.""" - for suite_or_case in suite._tests: - if isinstance(suite_or_case, unittest.TestCase): - # unittest.TestCase - yield suite_or_case - else: - # unittest.TestSuite - yield from test_cases(suite_or_case) diff --git a/test/synchronous/conftest.py b/test/synchronous/conftest.py deleted file mode 100644 index 58f04ea7c..000000000 --- a/test/synchronous/conftest.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from test import setup, teardown - -import pytest - -_IS_SYNC = True - - -@pytest.fixture(scope="session", autouse=True) -def test_setup_and_teardown(): - setup() - yield - teardown() diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py deleted file mode 100644 index 7d105acb6..000000000 --- a/test/synchronous/test_collection.py +++ /dev/null @@ -1,2236 +0,0 @@ -# Copyright 2009-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test the collection module.""" -from __future__ import annotations - -import asyncio -import contextlib -import re -import sys -from codecs import utf_8_decode -from collections import defaultdict -from typing import Any, Iterable, no_type_check - -from pymongo.synchronous.database import Database - -sys.path[0:0] = [""] - -from test import ( # TODO: fix sync imports in PYTHON-4528 - IntegrationTest, - client_context, - unittest, -) -from test.utils import ( - IMPOSSIBLE_WRITE_CONCERN, - EventListener, - get_pool, - is_mongos, - rs_or_single_client, - single_client, - wait_until, -) - -from bson import encode -from bson.codec_options import CodecOptions -from bson.objectid import ObjectId -from bson.raw_bson import RawBSONDocument -from bson.regex import Regex -from bson.son import SON -from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT -from pymongo.bulk_shared import BulkWriteError -from pymongo.cursor_shared import CursorType -from pymongo.errors import ( - ConfigurationError, - DocumentTooLarge, - DuplicateKeyError, - ExecutionTimeout, - InvalidDocument, - InvalidName, - InvalidOperation, - OperationFailure, - WriteConcernError, -) -from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command -from pymongo.operations import * -from pymongo.read_concern import DEFAULT_READ_CONCERN -from pymongo.read_preferences import ReadPreference -from pymongo.results import ( - DeleteResult, - InsertManyResult, - InsertOneResult, - UpdateResult, -) -from pymongo.synchronous.collection import Collection, ReturnDocument -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.helpers import next -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.write_concern import WriteConcern - -_IS_SYNC = True - - -class TestCollectionNoConnect(unittest.TestCase): - """Test Collection features on a client that does not connect.""" - - db: Database - - @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test - - def test_collection(self): - self.assertRaises(TypeError, Collection, self.db, 5) - - def make_col(base, name): - return base[name] - - self.assertRaises(InvalidName, make_col, self.db, "") - self.assertRaises(InvalidName, make_col, self.db, "te$t") - self.assertRaises(InvalidName, make_col, self.db, ".test") - self.assertRaises(InvalidName, make_col, self.db, "test.") - self.assertRaises(InvalidName, make_col, self.db, "tes..t") - self.assertRaises(InvalidName, make_col, self.db.test, "") - self.assertRaises(InvalidName, make_col, self.db.test, "te$t") - self.assertRaises(InvalidName, make_col, self.db.test, ".test") - self.assertRaises(InvalidName, make_col, self.db.test, "test.") - self.assertRaises(InvalidName, make_col, self.db.test, "tes..t") - self.assertRaises(InvalidName, make_col, self.db.test, "tes\x00t") - - def test_getattr(self): - coll = self.db.test - self.assertTrue(isinstance(coll["_does_not_exist"], Collection)) - - with self.assertRaises(AttributeError) as context: - coll._does_not_exist - - # Message should be: - # "AttributeError: Collection has no attribute '_does_not_exist'. To - # access the test._does_not_exist collection, use - # database['test._does_not_exist']." - self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) - - coll2 = coll.with_options(write_concern=WriteConcern(w=0)) - self.assertEqual(coll2.write_concern, WriteConcern(w=0)) - self.assertNotEqual(coll.write_concern, coll2.write_concern) - coll3 = coll2.subcoll - self.assertEqual(coll2.write_concern, coll3.write_concern) - coll4 = coll2["subcoll"] - self.assertEqual(coll2.write_concern, coll4.write_concern) - - def test_iteration(self): - coll = self.db.coll - if "PyPy" in sys.version and sys.version_info < (3, 8, 15): - msg = "'NoneType' object is not callable" - else: - if _IS_SYNC: - msg = "'Collection' object is not iterable" - else: - msg = "'Collection' object is not iterable" - # Iteration fails - with self.assertRaisesRegex(TypeError, msg): - for _ in coll: # type: ignore[misc] # error: "None" not callable [misc] - break - # Non-string indices will start failing in PyMongo 5. - self.assertEqual(coll[0].name, "coll.0") - self.assertEqual(coll[{}].name, "coll.{}") - # next fails - with self.assertRaisesRegex(TypeError, msg): - _ = next(coll) - # .next() fails - with self.assertRaisesRegex(TypeError, msg): - _ = coll.next() - # Do not implement typing.Iterable. - self.assertNotIsInstance(coll, Iterable) - - -class TestCollection(IntegrationTest): - w: int - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - def async_tearDownClass(cls): - cls.db.drop_collection("test_large_limit") - - def setUp(self): - self.db.test.drop() - - def tearDown(self): - self.db.test.drop() - - @contextlib.contextmanager - def write_concern_collection(self): - if client_context.is_rs: - with self.assertRaises(WriteConcernError): - # Unsatisfiable write concern. - yield Collection( - self.db, - "test", - write_concern=WriteConcern(w=len(client_context.nodes) + 1), - ) - else: - yield self.db.test - - def test_equality(self): - self.assertTrue(isinstance(self.db.test, Collection)) - self.assertEqual(self.db.test, self.db["test"]) - self.assertEqual(self.db.test, Collection(self.db, "test")) - self.assertEqual(self.db.test.mike, self.db["test.mike"]) - self.assertEqual(self.db.test["mike"], self.db["test.mike"]) - - def test_hashable(self): - self.assertIn(self.db.test.mike, {self.db["test.mike"]}) - - def test_create(self): - # No Exception. - db = client_context.client.pymongo_test - db.create_test_no_wc.drop() - - def lambda_test(): - return "create_test_no_wc" not in db.list_collection_names() - - def lambda_test_2(): - return "create_test_no_wc" in db.list_collection_names() - - wait_until( - lambda_test, - "drop create_test_no_wc collection", - ) - db.create_collection("create_test_no_wc") - wait_until( - lambda_test_2, - "create create_test_no_wc collection", - ) - # SERVER-33317 - if not client_context.is_mongos or not client_context.version.at_least(3, 7, 0): - with self.assertRaises(OperationFailure): - db.create_collection("create-test-wc", write_concern=IMPOSSIBLE_WRITE_CONCERN) - - def test_drop_nonexistent_collection(self): - self.db.drop_collection("test") - self.assertFalse("test" in self.db.list_collection_names()) - - # No exception - self.db.drop_collection("test") - - def test_create_indexes(self): - db = self.db - - with self.assertRaises(TypeError): - db.test.create_indexes("foo") # type: ignore[arg-type] - with self.assertRaises(TypeError): - db.test.create_indexes(["foo"]) # type: ignore[list-item] - self.assertRaises(TypeError, IndexModel, 5) - self.assertRaises(ValueError, IndexModel, []) - - db.test.drop_indexes() - db.test.insert_one({}) - self.assertEqual(len(db.test.index_information()), 1) - - db.test.create_indexes([IndexModel("hello")]) - db.test.create_indexes([IndexModel([("hello", DESCENDING), ("world", ASCENDING)])]) - - # Tuple instead of list. - db.test.create_indexes([IndexModel((("world", ASCENDING),))]) - - self.assertEqual(len(db.test.index_information()), 4) - - db.test.drop_indexes() - names = db.test.create_indexes( - [IndexModel([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world")] - ) - self.assertEqual(names, ["hello_world"]) - - db.test.drop_indexes() - self.assertEqual(len(db.test.index_information()), 1) - db.test.create_indexes([IndexModel("hello")]) - self.assertTrue("hello_1" in db.test.index_information()) - - db.test.drop_indexes() - self.assertEqual(len(db.test.index_information()), 1) - names = db.test.create_indexes( - [IndexModel([("hello", DESCENDING), ("world", ASCENDING)]), IndexModel("hello")] - ) - info = db.test.index_information() - for name in names: - self.assertTrue(name in info) - - db.test.drop() - db.test.insert_one({"a": 1}) - db.test.insert_one({"a": 1}) - with self.assertRaises(DuplicateKeyError): - db.test.create_indexes([IndexModel("a", unique=True)]) - - with self.write_concern_collection() as coll: - coll.create_indexes([IndexModel("hello")]) - - @client_context.require_version_max(4, 3, -1) - def test_create_indexes_commitQuorum_requires_44(self): - db = self.db - with self.assertRaisesRegex( - ConfigurationError, - r"Must be connected to MongoDB 4\.4\+ to use the commitQuorum option for createIndexes", - ): - db.coll.create_indexes([IndexModel("a")], commitQuorum="majority") - - @client_context.require_no_standalone - @client_context.require_version_min(4, 4, -1) - def test_create_indexes_commitQuorum(self): - self.db.coll.create_indexes([IndexModel("a")], commitQuorum="majority") - - def test_create_index(self): - db = self.db - - with self.assertRaises(TypeError): - db.test.create_index(5) # type: ignore[arg-type] - with self.assertRaises(ValueError): - db.test.create_index([]) - - db.test.drop_indexes() - db.test.insert_one({}) - self.assertEqual(len(db.test.index_information()), 1) - - db.test.create_index("hello") - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)]) - - # Tuple instead of list. - db.test.create_index((("world", ASCENDING),)) - - self.assertEqual(len(db.test.index_information()), 4) - - db.test.drop_indexes() - ix = db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world") - self.assertEqual(ix, "hello_world") - - db.test.drop_indexes() - self.assertEqual(len(db.test.index_information()), 1) - db.test.create_index("hello") - self.assertTrue("hello_1" in db.test.index_information()) - - db.test.drop_indexes() - self.assertEqual(len(db.test.index_information()), 1) - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)]) - self.assertTrue("hello_-1_world_1" in db.test.index_information()) - - db.test.drop_indexes() - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], name=None) - self.assertTrue("hello_-1_world_1" in db.test.index_information()) - - db.test.drop() - db.test.insert_one({"a": 1}) - db.test.insert_one({"a": 1}) - with self.assertRaises(DuplicateKeyError): - db.test.create_index("a", unique=True) - - with self.write_concern_collection() as coll: - coll.create_index([("hello", DESCENDING)]) - - db.test.create_index(["hello", "world"]) - db.test.create_index(["hello", ("world", DESCENDING)]) - db.test.create_index({"hello": 1}.items()) # type:ignore[arg-type] - - def test_drop_index(self): - db = self.db - db.test.drop_indexes() - db.test.create_index("hello") - name = db.test.create_index("goodbye") - - self.assertEqual(len(db.test.index_information()), 3) - self.assertEqual(name, "goodbye_1") - db.test.drop_index(name) - - # Drop it again. - with self.assertRaises(OperationFailure): - db.test.drop_index(name) - self.assertEqual(len(db.test.index_information()), 2) - self.assertTrue("hello_1" in db.test.index_information()) - - db.test.drop_indexes() - db.test.create_index("hello") - name = db.test.create_index("goodbye") - - self.assertEqual(len(db.test.index_information()), 3) - self.assertEqual(name, "goodbye_1") - db.test.drop_index([("goodbye", ASCENDING)]) - self.assertEqual(len(db.test.index_information()), 2) - self.assertTrue("hello_1" in db.test.index_information()) - - with self.write_concern_collection() as coll: - coll.drop_index("hello_1") - - @client_context.require_no_mongos - @client_context.require_test_commands - def test_index_management_max_time_ms(self): - coll = self.db.test - self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") - try: - with self.assertRaises(ExecutionTimeout): - coll.create_index("foo", maxTimeMS=1) - with self.assertRaises(ExecutionTimeout): - coll.create_indexes([IndexModel("foo")], maxTimeMS=1) - with self.assertRaises(ExecutionTimeout): - coll.drop_index("foo", maxTimeMS=1) - with self.assertRaises(ExecutionTimeout): - coll.drop_indexes(maxTimeMS=1) - finally: - self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") - - def test_list_indexes(self): - db = self.db - db.test.drop() - db.test.insert_one({}) # create collection - - def map_indexes(indexes): - return {index["name"]: index for index in indexes} - - indexes = (db.test.list_indexes()).to_list() - self.assertEqual(len(indexes), 1) - self.assertTrue("_id_" in map_indexes(indexes)) - - db.test.create_index("hello") - indexes = (db.test.list_indexes()).to_list() - self.assertEqual(len(indexes), 2) - self.assertEqual(map_indexes(indexes)["hello_1"]["key"], SON([("hello", ASCENDING)])) - - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) - indexes = (db.test.list_indexes()).to_list() - self.assertEqual(len(indexes), 3) - index_map = map_indexes(indexes) - self.assertEqual( - index_map["hello_-1_world_1"]["key"], SON([("hello", DESCENDING), ("world", ASCENDING)]) - ) - self.assertEqual(True, index_map["hello_-1_world_1"]["unique"]) - - # List indexes on a collection that does not exist. - indexes = (db.does_not_exist.list_indexes()).to_list() - self.assertEqual(len(indexes), 0) - - # List indexes on a database that does not exist. - indexes = (db.does_not_exist.list_indexes()).to_list() - self.assertEqual(len(indexes), 0) - - def test_index_info(self): - db = self.db - db.test.drop() - db.test.insert_one({}) # create collection - self.assertEqual(len(db.test.index_information()), 1) - self.assertTrue("_id_" in db.test.index_information()) - - db.test.create_index("hello") - self.assertEqual(len(db.test.index_information()), 2) - self.assertEqual((db.test.index_information())["hello_1"]["key"], [("hello", ASCENDING)]) - - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) - self.assertEqual((db.test.index_information())["hello_1"]["key"], [("hello", ASCENDING)]) - self.assertEqual(len(db.test.index_information()), 3) - self.assertEqual( - [("hello", DESCENDING), ("world", ASCENDING)], - (db.test.index_information())["hello_-1_world_1"]["key"], - ) - self.assertEqual(True, (db.test.index_information())["hello_-1_world_1"]["unique"]) - - def test_index_geo2d(self): - db = self.db - db.test.drop_indexes() - self.assertEqual("loc_2d", db.test.create_index([("loc", GEO2D)])) - index_info = (db.test.index_information())["loc_2d"] - self.assertEqual([("loc", "2d")], index_info["key"]) - - # geoSearch was deprecated in 4.4 and removed in 5.0 - @client_context.require_version_max(4, 5) - @client_context.require_no_mongos - def test_index_haystack(self): - db = self.db - db.test.drop() - _id = ( - db.test.insert_one({"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}) - ).inserted_id - db.test.insert_one({"pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant"}) - db.test.insert_one({"pos": {"long": 59.1, "lat": 87.2}, "type": "office"}) - db.test.create_index([("pos", "geoHaystack"), ("type", ASCENDING)], bucketSize=1) - - results = ( - db.command( - SON( - [ - ("geoSearch", "test"), - ("near", [33, 33]), - ("maxDistance", 6), - ("search", {"type": "restaurant"}), - ("limit", 30), - ] - ) - ) - )["results"] - - self.assertEqual(2, len(results)) - self.assertEqual( - {"_id": _id, "pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}, results[0] - ) - - @client_context.require_no_mongos - def test_index_text(self): - db = self.db - db.test.drop_indexes() - self.assertEqual("t_text", db.test.create_index([("t", TEXT)])) - index_info = (db.test.index_information())["t_text"] - self.assertTrue("weights" in index_info) - - db.test.insert_many( - [{"t": "spam eggs and spam"}, {"t": "spam"}, {"t": "egg sausage and bacon"}] - ) - - # MongoDB 2.6 text search. Create 'score' field in projection. - cursor = db.test.find({"$text": {"$search": "spam"}}, {"score": {"$meta": "textScore"}}) - - # Sort by 'score' field. - cursor.sort([("score", {"$meta": "textScore"})]) - results = cursor.to_list() - self.assertTrue(results[0]["score"] >= results[1]["score"]) - - db.test.drop_indexes() - - def test_index_2dsphere(self): - db = self.db - db.test.drop_indexes() - self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) - - for dummy, info in (db.test.index_information()).items(): - field, idx_type = info["key"][0] - if field == "geo" and idx_type == "2dsphere": - break - else: - self.fail("2dsphere index not found.") - - poly = {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} - query = {"geo": {"$within": {"$geometry": poly}}} - - # This query will error without a 2dsphere index. - db.test.find(query) - db.test.drop_indexes() - - def test_index_hashed(self): - db = self.db - db.test.drop_indexes() - self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) - - for dummy, info in (db.test.index_information()).items(): - field, idx_type = info["key"][0] - if field == "a" and idx_type == "hashed": - break - else: - self.fail("hashed index not found.") - - db.test.drop_indexes() - - def test_index_sparse(self): - db = self.db - db.test.drop_indexes() - db.test.create_index([("key", ASCENDING)], sparse=True) - self.assertTrue((db.test.index_information())["key_1"]["sparse"]) - - def test_index_background(self): - db = self.db - db.test.drop_indexes() - db.test.create_index([("keya", ASCENDING)]) - db.test.create_index([("keyb", ASCENDING)], background=False) - db.test.create_index([("keyc", ASCENDING)], background=True) - self.assertFalse("background" in (db.test.index_information())["keya_1"]) - self.assertFalse((db.test.index_information())["keyb_1"]["background"]) - self.assertTrue((db.test.index_information())["keyc_1"]["background"]) - - def _drop_dups_setup(self, db): - db.drop_collection("test") - db.test.insert_one({"i": 1}) - db.test.insert_one({"i": 2}) - db.test.insert_one({"i": 2}) # duplicate - db.test.insert_one({"i": 3}) - - def test_index_dont_drop_dups(self): - # Try *not* dropping duplicates - db = self.db - self._drop_dups_setup(db) - - # There's a duplicate - def _test_create(): - db.test.create_index([("i", ASCENDING)], unique=True, dropDups=False) - - with self.assertRaises(DuplicateKeyError): - _test_create() - - # Duplicate wasn't dropped - self.assertEqual(4, db.test.count_documents({})) - - # Index wasn't created, only the default index on _id - self.assertEqual(1, len(db.test.index_information())) - - # Get the plan dynamically because the explain format will change. - def get_plan_stage(self, root, stage): - if root.get("stage") == stage: - return root - elif "inputStage" in root: - return self.get_plan_stage(root["inputStage"], stage) - elif "inputStages" in root: - for i in root["inputStages"]: - stage = self.get_plan_stage(i, stage) - if stage: - return stage - elif "queryPlan" in root: - # queryPlan (and slotBasedPlan) are new in 5.0. - return self.get_plan_stage(root["queryPlan"], stage) - elif "shards" in root: - for i in root["shards"]: - stage = self.get_plan_stage(i["winningPlan"], stage) - if stage: - return stage - return {} - - def test_index_filter(self): - db = self.db - db.drop_collection("test") - - # Test bad filter spec on create. - with self.assertRaises(OperationFailure): - db.test.create_index("x", partialFilterExpression=5) - with self.assertRaises(OperationFailure): - db.test.create_index("x", partialFilterExpression={"x": {"$asdasd": 3}}) - with self.assertRaises(OperationFailure): - db.test.create_index("x", partialFilterExpression={"$and": 5}) - - self.assertEqual( - "x_1", - db.test.create_index([("x", ASCENDING)], partialFilterExpression={"a": {"$lte": 1.5}}), - ) - db.test.insert_one({"x": 5, "a": 2}) - db.test.insert_one({"x": 6, "a": 1}) - - # Operations that use the partial index. - explain = (db.test.find({"x": 6, "a": 1})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") - self.assertEqual("x_1", stage.get("indexName")) - self.assertTrue(stage.get("isPartial")) - - explain = (db.test.find({"x": {"$gt": 1}, "a": 1})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") - self.assertEqual("x_1", stage.get("indexName")) - self.assertTrue(stage.get("isPartial")) - - explain = (db.test.find({"x": 6, "a": {"$lte": 1}})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") - self.assertEqual("x_1", stage.get("indexName")) - self.assertTrue(stage.get("isPartial")) - - # Operations that do not use the partial index. - explain = (db.test.find({"x": 6, "a": {"$lte": 1.6}})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") - self.assertNotEqual({}, stage) - explain = (db.test.find({"x": 6})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") - self.assertNotEqual({}, stage) - - # Test drop_indexes. - db.test.drop_index("x_1") - explain = (db.test.find({"x": 6, "a": 1})).explain() - stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") - self.assertNotEqual({}, stage) - - def test_field_selection(self): - db = self.db - db.drop_collection("test") - - doc = {"a": 1, "b": 5, "c": {"d": 5, "e": 10}} - db.test.insert_one(doc) - - # Test field inclusion - doc = next(db.test.find({}, ["_id"])) - self.assertEqual(list(doc), ["_id"]) - doc = next(db.test.find({}, ["a"])) - l = list(doc) - l.sort() - self.assertEqual(l, ["_id", "a"]) - doc = next(db.test.find({}, ["b"])) - l = list(doc) - l.sort() - self.assertEqual(l, ["_id", "b"]) - doc = next(db.test.find({}, ["c"])) - l = list(doc) - l.sort() - self.assertEqual(l, ["_id", "c"]) - doc = next(db.test.find({}, ["a"])) - self.assertEqual(doc["a"], 1) - doc = next(db.test.find({}, ["b"])) - self.assertEqual(doc["b"], 5) - doc = next(db.test.find({}, ["c"])) - self.assertEqual(doc["c"], {"d": 5, "e": 10}) - - # Test inclusion of fields with dots - doc = next(db.test.find({}, ["c.d"])) - self.assertEqual(doc["c"], {"d": 5}) - doc = next(db.test.find({}, ["c.e"])) - self.assertEqual(doc["c"], {"e": 10}) - doc = next(db.test.find({}, ["b", "c.e"])) - self.assertEqual(doc["c"], {"e": 10}) - - doc = next(db.test.find({}, ["b", "c.e"])) - l = list(doc) - l.sort() - self.assertEqual(l, ["_id", "b", "c"]) - doc = next(db.test.find({}, ["b", "c.e"])) - self.assertEqual(doc["b"], 5) - - # Test field exclusion - doc = next(db.test.find({}, {"a": False, "b": 0})) - l = list(doc) - l.sort() - self.assertEqual(l, ["_id", "c"]) - - doc = next(db.test.find({}, {"_id": False})) - l = list(doc) - self.assertFalse("_id" in l) - - def test_options(self): - db = self.db - db.drop_collection("test") - db.create_collection("test", capped=True, size=4096) - result = db.test.options() - self.assertEqual(result, {"capped": True, "size": 4096}) - db.drop_collection("test") - - def test_insert_one(self): - db = self.db - db.test.drop() - - document: dict[str, Any] = {"_id": 1000} - result = db.test.insert_one(document) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertTrue(isinstance(result.inserted_id, int)) - self.assertEqual(document["_id"], result.inserted_id) - self.assertTrue(result.acknowledged) - self.assertIsNotNone(db.test.find_one({"_id": document["_id"]})) - self.assertEqual(1, db.test.count_documents({})) - - document = {"foo": "bar"} - result = db.test.insert_one(document) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertTrue(isinstance(result.inserted_id, ObjectId)) - self.assertEqual(document["_id"], result.inserted_id) - self.assertTrue(result.acknowledged) - self.assertIsNotNone(db.test.find_one({"_id": document["_id"]})) - self.assertEqual(2, db.test.count_documents({})) - - db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - result = db.test.insert_one(document) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertTrue(isinstance(result.inserted_id, ObjectId)) - self.assertEqual(document["_id"], result.inserted_id) - self.assertFalse(result.acknowledged) - # The insert failed duplicate key... - - def async_lambda(): - return db.test.count_documents({}) == 2 - - wait_until(async_lambda, "forcing duplicate key error") - - document = RawBSONDocument(encode({"_id": ObjectId(), "foo": "bar"})) - result = db.test.insert_one(document) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertEqual(result.inserted_id, None) - - def test_insert_many(self): - db = self.db - db.test.drop() - - docs: list = [{} for _ in range(5)] - result = db.test.insert_many(docs) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertTrue(isinstance(result.inserted_ids, list)) - self.assertEqual(5, len(result.inserted_ids)) - for doc in docs: - _id = doc["_id"] - self.assertTrue(isinstance(_id, ObjectId)) - self.assertTrue(_id in result.inserted_ids) - self.assertEqual(1, db.test.count_documents({"_id": _id})) - self.assertTrue(result.acknowledged) - - docs = [{"_id": i} for i in range(5)] - result = db.test.insert_many(docs) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertTrue(isinstance(result.inserted_ids, list)) - self.assertEqual(5, len(result.inserted_ids)) - for doc in docs: - _id = doc["_id"] - self.assertTrue(isinstance(_id, int)) - self.assertTrue(_id in result.inserted_ids) - self.assertEqual(1, db.test.count_documents({"_id": _id})) - self.assertTrue(result.acknowledged) - - docs = [RawBSONDocument(encode({"_id": i + 5})) for i in range(5)] - result = db.test.insert_many(docs) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertTrue(isinstance(result.inserted_ids, list)) - self.assertEqual([], result.inserted_ids) - - db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - docs: list = [{} for _ in range(5)] - result = db.test.insert_many(docs) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertFalse(result.acknowledged) - self.assertEqual(20, db.test.count_documents({})) - - def test_insert_many_generator(self): - coll = self.db.test - coll.delete_many({}) - - def gen(): - yield {"a": 1, "b": 1} - yield {"a": 1, "b": 2} - yield {"a": 2, "b": 3} - yield {"a": 3, "b": 5} - yield {"a": 5, "b": 8} - - result = coll.insert_many(gen()) - self.assertEqual(5, len(result.inserted_ids)) - - def test_insert_many_invalid(self): - db = self.db - - with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): - db.test.insert_many({}) - - with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): - db.test.insert_many([]) - - with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): - db.test.insert_many(1) # type: ignore[arg-type] - - with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): - db.test.insert_many(RawBSONDocument(encode({"_id": 2}))) - - def test_delete_one(self): - self.db.test.drop() - - self.db.test.insert_one({"x": 1}) - self.db.test.insert_one({"y": 1}) - self.db.test.insert_one({"z": 1}) - - result = self.db.test.delete_one({"x": 1}) - self.assertTrue(isinstance(result, DeleteResult)) - self.assertEqual(1, result.deleted_count) - self.assertTrue(result.acknowledged) - self.assertEqual(2, self.db.test.count_documents({})) - - result = self.db.test.delete_one({"y": 1}) - self.assertTrue(isinstance(result, DeleteResult)) - self.assertEqual(1, result.deleted_count) - self.assertTrue(result.acknowledged) - self.assertEqual(1, self.db.test.count_documents({})) - - db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - result = db.test.delete_one({"z": 1}) - self.assertTrue(isinstance(result, DeleteResult)) - self.assertRaises(InvalidOperation, lambda: result.deleted_count) - self.assertFalse(result.acknowledged) - - def lambda_async(): - return db.test.count_documents({}) == 0 - - wait_until(lambda_async, "delete 1 documents") - - def test_delete_many(self): - self.db.test.drop() - - self.db.test.insert_one({"x": 1}) - self.db.test.insert_one({"x": 1}) - self.db.test.insert_one({"y": 1}) - self.db.test.insert_one({"y": 1}) - - result = self.db.test.delete_many({"x": 1}) - self.assertTrue(isinstance(result, DeleteResult)) - self.assertEqual(2, result.deleted_count) - self.assertTrue(result.acknowledged) - self.assertEqual(0, self.db.test.count_documents({"x": 1})) - - db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - result = db.test.delete_many({"y": 1}) - self.assertTrue(isinstance(result, DeleteResult)) - self.assertRaises(InvalidOperation, lambda: result.deleted_count) - self.assertFalse(result.acknowledged) - - def lambda_async(): - return db.test.count_documents({}) == 0 - - wait_until(lambda_async, "delete 2 documents") - - def test_command_document_too_large(self): - large = "*" * (client_context.max_bson_size + _COMMAND_OVERHEAD) - coll = self.db.test - with self.assertRaises(DocumentTooLarge): - coll.insert_one({"data": large}) - # update_one and update_many are the same - with self.assertRaises(DocumentTooLarge): - coll.replace_one({}, {"data": large}) - with self.assertRaises(DocumentTooLarge): - coll.delete_one({"data": large}) - - def test_write_large_document(self): - max_size = client_context.max_bson_size - half_size = int(max_size / 2) - max_str = "x" * max_size - half_str = "x" * half_size - self.assertEqual(max_size, 16777216) - - with self.assertRaises(OperationFailure): - self.db.test.insert_one({"foo": max_str}) - with self.assertRaises(OperationFailure): - self.db.test.replace_one({}, {"foo": max_str}, upsert=True) - with self.assertRaises(OperationFailure): - self.db.test.insert_many([{"x": 1}, {"foo": max_str}]) - self.db.test.insert_many([{"foo": half_str}, {"foo": half_str}]) - - self.db.test.insert_one({"bar": "x"}) - # Use w=0 here to test legacy doc size checking in all server versions - unack_coll = self.db.test.with_options(write_concern=WriteConcern(w=0)) - with self.assertRaises(DocumentTooLarge): - unack_coll.replace_one({"bar": "x"}, {"bar": "x" * (max_size - 14)}) - self.db.test.replace_one({"bar": "x"}, {"bar": "x" * (max_size - 32)}) - - def test_insert_bypass_document_validation(self): - db = self.db - db.test.drop() - db.create_collection("test", validator={"a": {"$exists": True}}) - db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - - # Test insert_one - with self.assertRaises(OperationFailure): - db.test.insert_one({"_id": 1, "x": 100}) - result = db.test.insert_one({"_id": 1, "x": 100}, bypass_document_validation=True) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertEqual(1, result.inserted_id) - result = db.test.insert_one({"_id": 2, "a": 0}) - self.assertTrue(isinstance(result, InsertOneResult)) - self.assertEqual(2, result.inserted_id) - - db_w0.test.insert_one({"y": 1}, bypass_document_validation=True) - - def async_lambda(): - return db_w0.test.find_one({"y": 1}) - - wait_until(async_lambda, "find w:0 inserted document") - - # Test insert_many - docs = [{"_id": i, "x": 100 - i} for i in range(3, 100)] - with self.assertRaises(OperationFailure): - db.test.insert_many(docs) - result = db.test.insert_many(docs, bypass_document_validation=True) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertTrue(97, len(result.inserted_ids)) - for doc in docs: - _id = doc["_id"] - self.assertTrue(isinstance(_id, int)) - self.assertTrue(_id in result.inserted_ids) - self.assertEqual(1, db.test.count_documents({"x": doc["x"]})) - self.assertTrue(result.acknowledged) - docs = [{"_id": i, "a": 200 - i} for i in range(100, 200)] - result = db.test.insert_many(docs) - self.assertTrue(isinstance(result, InsertManyResult)) - self.assertTrue(97, len(result.inserted_ids)) - for doc in docs: - _id = doc["_id"] - self.assertTrue(isinstance(_id, int)) - self.assertTrue(_id in result.inserted_ids) - self.assertEqual(1, db.test.count_documents({"a": doc["a"]})) - self.assertTrue(result.acknowledged) - - with self.assertRaises(OperationFailure): - db_w0.test.insert_many( - [{"x": 1}, {"x": 2}], - bypass_document_validation=True, - ) - - def test_replace_bypass_document_validation(self): - db = self.db - db.test.drop() - db.create_collection("test", validator={"a": {"$exists": True}}) - db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - - # Test replace_one - db.test.insert_one({"a": 101}) - with self.assertRaises(OperationFailure): - db.test.replace_one({"a": 101}, {"y": 1}) - self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual(1, db.test.count_documents({"a": 101})) - db.test.replace_one({"a": 101}, {"y": 1}, bypass_document_validation=True) - self.assertEqual(0, db.test.count_documents({"a": 101})) - self.assertEqual(1, db.test.count_documents({"y": 1})) - db.test.replace_one({"y": 1}, {"a": 102}) - self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual(0, db.test.count_documents({"a": 101})) - self.assertEqual(1, db.test.count_documents({"a": 102})) - - db.test.insert_one({"y": 1}, bypass_document_validation=True) - with self.assertRaises(OperationFailure): - db.test.replace_one({"y": 1}, {"x": 101}) - self.assertEqual(0, db.test.count_documents({"x": 101})) - self.assertEqual(1, db.test.count_documents({"y": 1})) - db.test.replace_one({"y": 1}, {"x": 101}, bypass_document_validation=True) - self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual(1, db.test.count_documents({"x": 101})) - db.test.replace_one({"x": 101}, {"a": 103}, bypass_document_validation=False) - self.assertEqual(0, db.test.count_documents({"x": 101})) - self.assertEqual(1, db.test.count_documents({"a": 103})) - - db.test.insert_one({"y": 1}, bypass_document_validation=True) - db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - - wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") - - def test_update_bypass_document_validation(self): - db = self.db - db.test.drop() - db.test.insert_one({"z": 5}) - db.command(SON([("collMod", "test"), ("validator", {"z": {"$gte": 0}})])) - db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - - # Test update_one - with self.assertRaises(OperationFailure): - db.test.update_one({"z": 5}, {"$inc": {"z": -10}}) - self.assertEqual(0, db.test.count_documents({"z": -5})) - self.assertEqual(1, db.test.count_documents({"z": 5})) - db.test.update_one({"z": 5}, {"$inc": {"z": -10}}, bypass_document_validation=True) - self.assertEqual(0, db.test.count_documents({"z": 5})) - self.assertEqual(1, db.test.count_documents({"z": -5})) - db.test.update_one({"z": -5}, {"$inc": {"z": 6}}, bypass_document_validation=False) - self.assertEqual(1, db.test.count_documents({"z": 1})) - self.assertEqual(0, db.test.count_documents({"z": -5})) - - db.test.insert_one({"z": -10}, bypass_document_validation=True) - with self.assertRaises(OperationFailure): - db.test.update_one({"z": -10}, {"$inc": {"z": 1}}) - self.assertEqual(0, db.test.count_documents({"z": -9})) - self.assertEqual(1, db.test.count_documents({"z": -10})) - db.test.update_one({"z": -10}, {"$inc": {"z": 1}}, bypass_document_validation=True) - self.assertEqual(1, db.test.count_documents({"z": -9})) - self.assertEqual(0, db.test.count_documents({"z": -10})) - db.test.update_one({"z": -9}, {"$inc": {"z": 9}}, bypass_document_validation=False) - self.assertEqual(0, db.test.count_documents({"z": -9})) - self.assertEqual(1, db.test.count_documents({"z": 0})) - - db.test.insert_one({"y": 1, "x": 0}, bypass_document_validation=True) - db_w0.test.update_one({"y": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) - - def async_lambda(): - return db_w0.test.find_one({"y": 1, "x": 1}) - - wait_until(async_lambda, "find w:0 updated document") - - # Test update_many - db.test.insert_many([{"z": i} for i in range(3, 101)]) - db.test.insert_one({"y": 0}, bypass_document_validation=True) - with self.assertRaises(OperationFailure): - db.test.update_many({}, {"$inc": {"z": -100}}) - self.assertEqual(100, db.test.count_documents({"z": {"$gte": 0}})) - self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) - self.assertEqual(0, db.test.count_documents({"y": 0, "z": -100})) - db.test.update_many( - {"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True - ) - self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) - self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) - db.test.update_many( - {"z": {"$gt": -50}}, {"$inc": {"z": 100}}, bypass_document_validation=False - ) - self.assertEqual(50, db.test.count_documents({"z": {"$gt": 0}})) - self.assertEqual(50, db.test.count_documents({"z": {"$lt": 0}})) - - db.test.insert_many([{"z": -i} for i in range(50)], bypass_document_validation=True) - with self.assertRaises(OperationFailure): - db.test.update_many({}, {"$inc": {"z": 1}}) - self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) - self.assertEqual(50, db.test.count_documents({"z": {"$gt": 1}})) - db.test.update_many( - {"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True - ) - self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) - self.assertEqual(150, db.test.count_documents({"z": {"$lte": 0}})) - db.test.update_many( - {"z": {"$lte": 0}}, {"$inc": {"z": 100}}, bypass_document_validation=False - ) - self.assertEqual(150, db.test.count_documents({"z": {"$gte": 0}})) - self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) - - db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) - db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) - db_w0.test.update_many({"m": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) - - def async_lambda(): - return db_w0.test.count_documents({"m": 1, "x": 1}) == 2 - - wait_until(async_lambda, "find w:0 updated documents") - - def test_bypass_document_validation_bulk_write(self): - db = self.db - db.test.drop() - db.create_collection("test", validator={"a": {"$gte": 0}}) - db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) - - ops: list = [ - InsertOne({"a": -10}), - InsertOne({"a": -11}), - InsertOne({"a": -12}), - UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), - UpdateMany({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), - ReplaceOne({"a": {"$lte": -10}}, {"a": -1}), - ] - db.test.bulk_write(ops, bypass_document_validation=True) - - self.assertEqual(3, db.test.count_documents({})) - self.assertEqual(1, db.test.count_documents({"a": -11})) - self.assertEqual(1, db.test.count_documents({"a": -1})) - self.assertEqual(1, db.test.count_documents({"a": -9})) - - # Assert that the operations would fail without bypass_doc_val - for op in ops: - with self.assertRaises(BulkWriteError): - db.test.bulk_write([op]) - - with self.assertRaises(OperationFailure): - db_w0.test.bulk_write(ops, bypass_document_validation=True) - - def test_find_by_default_dct(self): - db = self.db - db.test.insert_one({"foo": "bar"}) - dct = defaultdict(dict, [("foo", "bar")]) # type: ignore[arg-type] - self.assertIsNotNone(db.test.find_one(dct)) - self.assertEqual(dct, defaultdict(dict, [("foo", "bar")])) - - def test_find_w_fields(self): - db = self.db - db.test.delete_many({}) - - db.test.insert_one({"x": 1, "mike": "awesome", "extra thing": "abcdefghijklmnopqrstuvwxyz"}) - self.assertEqual(1, db.test.count_documents({})) - doc = next(db.test.find({})) - self.assertTrue("x" in doc) - doc = next(db.test.find({})) - self.assertTrue("mike" in doc) - doc = next(db.test.find({})) - self.assertTrue("extra thing" in doc) - doc = next(db.test.find({}, ["x", "mike"])) - self.assertTrue("x" in doc) - doc = next(db.test.find({}, ["x", "mike"])) - self.assertTrue("mike" in doc) - doc = next(db.test.find({}, ["x", "mike"])) - self.assertFalse("extra thing" in doc) - doc = next(db.test.find({}, ["mike"])) - self.assertFalse("x" in doc) - doc = next(db.test.find({}, ["mike"])) - self.assertTrue("mike" in doc) - doc = next(db.test.find({}, ["mike"])) - self.assertFalse("extra thing" in doc) - - @no_type_check - def test_fields_specifier_as_dict(self): - db = self.db - db.test.delete_many({}) - - db.test.insert_one({"x": [1, 2, 3], "mike": "awesome"}) - - self.assertEqual([1, 2, 3], (db.test.find_one())["x"]) - self.assertEqual([2, 3], (db.test.find_one(projection={"x": {"$slice": -2}}))["x"]) - self.assertTrue("x" not in db.test.find_one(projection={"x": 0})) - self.assertTrue("mike" in db.test.find_one(projection={"x": 0})) - - def test_find_w_regex(self): - db = self.db - db.test.delete_many({}) - - db.test.insert_one({"x": "hello_world"}) - db.test.insert_one({"x": "hello_mike"}) - db.test.insert_one({"x": "hello_mikey"}) - db.test.insert_one({"x": "hello_test"}) - - self.assertEqual(len((db.test.find()).to_list()), 4) - self.assertEqual(len((db.test.find({"x": re.compile("^hello.*")})).to_list()), 4) - self.assertEqual(len((db.test.find({"x": re.compile("ello")})).to_list()), 4) - self.assertEqual(len((db.test.find({"x": re.compile("^hello$")})).to_list()), 0) - self.assertEqual(len((db.test.find({"x": re.compile("^hello_mi.*$")})).to_list()), 2) - - def test_id_can_be_anything(self): - db = self.db - - db.test.delete_many({}) - auto_id = {"hello": "world"} - db.test.insert_one(auto_id) - self.assertTrue(isinstance(auto_id["_id"], ObjectId)) - - numeric = {"_id": 240, "hello": "world"} - db.test.insert_one(numeric) - self.assertEqual(numeric["_id"], 240) - - obj = {"_id": numeric, "hello": "world"} - db.test.insert_one(obj) - self.assertEqual(obj["_id"], numeric) - - for x in db.test.find(): - self.assertEqual(x["hello"], "world") - self.assertTrue("_id" in x) - - def test_unique_index(self): - db = self.db - db.drop_collection("test") - db.test.create_index("hello") - - # No error. - db.test.insert_one({"hello": "world"}) - db.test.insert_one({"hello": "world"}) - - db.drop_collection("test") - db.test.create_index("hello", unique=True) - - with self.assertRaises(DuplicateKeyError): - db.test.insert_one({"hello": "world"}) - db.test.insert_one({"hello": "world"}) - - def test_duplicate_key_error(self): - db = self.db - db.drop_collection("test") - - db.test.create_index("x", unique=True) - - db.test.insert_one({"_id": 1, "x": 1}) - - with self.assertRaises(DuplicateKeyError) as context: - db.test.insert_one({"x": 1}) - - self.assertIsNotNone(context.exception.details) - - with self.assertRaises(DuplicateKeyError) as context: - db.test.insert_one({"x": 1}) - - self.assertIsNotNone(context.exception.details) - self.assertEqual(1, db.test.count_documents({})) - - def test_write_error_text_handling(self): - db = self.db - db.drop_collection("test") - - db.test.create_index("text", unique=True) - - # Test workaround for SERVER-24007 - data = ( - b"a\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" - ) - - text = utf_8_decode(data, None, True) - db.test.insert_one({"text": text}) - - # Should raise DuplicateKeyError, not InvalidBSON - with self.assertRaises(DuplicateKeyError): - db.test.insert_one({"text": text}) - - with self.assertRaises(DuplicateKeyError): - db.test.replace_one({"_id": ObjectId()}, {"text": text}, upsert=True) - - # Should raise BulkWriteError, not InvalidBSON - with self.assertRaises(BulkWriteError): - db.test.insert_many([{"text": text}]) - - def test_write_error_unicode(self): - coll = self.db.test - self.addCleanup(coll.drop) - - coll.create_index("a", unique=True) - coll.insert_one({"a": "unicode \U0001f40d"}) - with self.assertRaisesRegex(DuplicateKeyError, "E11000 duplicate key error") as ctx: - coll.insert_one({"a": "unicode \U0001f40d"}) - - # Once more for good measure. - self.assertIn("E11000 duplicate key error", str(ctx.exception)) - - def test_wtimeout(self): - # Ensure setting wtimeout doesn't disable write concern altogether. - # See SERVER-12596. - collection = self.db.test - collection.drop() - collection.insert_one({"_id": 1}) - - coll = collection.with_options(write_concern=WriteConcern(w=1, wtimeout=1000)) - with self.assertRaises(DuplicateKeyError): - coll.insert_one({"_id": 1}) - - coll = collection.with_options(write_concern=WriteConcern(wtimeout=1000)) - with self.assertRaises(DuplicateKeyError): - coll.insert_one({"_id": 1}) - - def test_error_code(self): - try: - self.db.test.update_many({}, {"$thismodifierdoesntexist": 1}) - except OperationFailure as exc: - self.assertTrue(exc.code in (9, 10147, 16840, 17009)) - # Just check that we set the error document. Fields - # vary by MongoDB version. - self.assertTrue(exc.details is not None) - else: - self.fail("OperationFailure was not raised") - - def test_index_on_subfield(self): - db = self.db - db.drop_collection("test") - - db.test.insert_one({"hello": {"a": 4, "b": 5}}) - db.test.insert_one({"hello": {"a": 7, "b": 2}}) - db.test.insert_one({"hello": {"a": 4, "b": 10}}) - - db.drop_collection("test") - db.test.create_index("hello.a", unique=True) - - db.test.insert_one({"hello": {"a": 4, "b": 5}}) - db.test.insert_one({"hello": {"a": 7, "b": 2}}) - with self.assertRaises(DuplicateKeyError): - db.test.insert_one({"hello": {"a": 4, "b": 10}}) - - def test_replace_one(self): - db = self.db - db.drop_collection("test") - - with self.assertRaises(ValueError): - db.test.replace_one({}, {"$set": {"x": 1}}) - - id1 = (db.test.insert_one({"x": 1})).inserted_id - result = db.test.replace_one({"x": 1}, {"y": 1}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(1, result.matched_count) - self.assertTrue(result.modified_count in (None, 1)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual(1, db.test.count_documents({"y": 1})) - self.assertEqual(0, db.test.count_documents({"x": 1})) - self.assertEqual((db.test.find_one(id1))["y"], 1) # type: ignore - - replacement = RawBSONDocument(encode({"_id": id1, "z": 1})) - result = db.test.replace_one({"y": 1}, replacement, True) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(1, result.matched_count) - self.assertTrue(result.modified_count in (None, 1)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual(1, db.test.count_documents({"z": 1})) - self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual((db.test.find_one(id1))["z"], 1) # type: ignore - - result = db.test.replace_one({"x": 2}, {"y": 2}, True) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(0, result.matched_count) - self.assertTrue(result.modified_count in (None, 0)) - self.assertTrue(isinstance(result.upserted_id, ObjectId)) - self.assertTrue(result.acknowledged) - self.assertEqual(1, db.test.count_documents({"y": 2})) - - db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - result = db.test.replace_one({"x": 0}, {"y": 0}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertRaises(InvalidOperation, lambda: result.matched_count) - self.assertRaises(InvalidOperation, lambda: result.modified_count) - self.assertRaises(InvalidOperation, lambda: result.upserted_id) - self.assertFalse(result.acknowledged) - - def test_update_one(self): - db = self.db - db.drop_collection("test") - - with self.assertRaises(ValueError): - db.test.update_one({}, {"x": 1}) - - id1 = (db.test.insert_one({"x": 5})).inserted_id - result = db.test.update_one({}, {"$inc": {"x": 1}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(1, result.matched_count) - self.assertTrue(result.modified_count in (None, 1)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual((db.test.find_one(id1))["x"], 6) # type: ignore - - id2 = (db.test.insert_one({"x": 1})).inserted_id - result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(1, result.matched_count) - self.assertTrue(result.modified_count in (None, 1)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual((db.test.find_one(id1))["x"], 7) # type: ignore - self.assertEqual((db.test.find_one(id2))["x"], 1) # type: ignore - - result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(0, result.matched_count) - self.assertTrue(result.modified_count in (None, 0)) - self.assertTrue(isinstance(result.upserted_id, ObjectId)) - self.assertTrue(result.acknowledged) - - db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertRaises(InvalidOperation, lambda: result.matched_count) - self.assertRaises(InvalidOperation, lambda: result.modified_count) - self.assertRaises(InvalidOperation, lambda: result.upserted_id) - self.assertFalse(result.acknowledged) - - def test_update_many(self): - db = self.db - db.drop_collection("test") - - with self.assertRaises(ValueError): - db.test.update_many({}, {"x": 1}) - - db.test.insert_one({"x": 4, "y": 3}) - db.test.insert_one({"x": 5, "y": 5}) - db.test.insert_one({"x": 4, "y": 4}) - - result = db.test.update_many({"x": 4}, {"$set": {"y": 5}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(2, result.matched_count) - self.assertTrue(result.modified_count in (None, 2)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual(3, db.test.count_documents({"y": 5})) - - result = db.test.update_many({"x": 5}, {"$set": {"y": 6}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(1, result.matched_count) - self.assertTrue(result.modified_count in (None, 1)) - self.assertIsNone(result.upserted_id) - self.assertTrue(result.acknowledged) - self.assertEqual(1, db.test.count_documents({"y": 6})) - - result = db.test.update_many({"x": 2}, {"$set": {"y": 1}}, True) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertEqual(0, result.matched_count) - self.assertTrue(result.modified_count in (None, 0)) - self.assertTrue(isinstance(result.upserted_id, ObjectId)) - self.assertTrue(result.acknowledged) - - db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - result = db.test.update_many({"x": 0}, {"$inc": {"x": 1}}) - self.assertTrue(isinstance(result, UpdateResult)) - self.assertRaises(InvalidOperation, lambda: result.matched_count) - self.assertRaises(InvalidOperation, lambda: result.modified_count) - self.assertRaises(InvalidOperation, lambda: result.upserted_id) - self.assertFalse(result.acknowledged) - - def test_update_check_keys(self): - self.db.drop_collection("test") - self.assertTrue(self.db.test.insert_one({"hello": "world"})) - - # Modify shouldn't check keys... - self.assertTrue( - self.db.test.update_one({"hello": "world"}, {"$set": {"foo.bar": "baz"}}, upsert=True) - ) - - # I know this seems like testing the server but I'd like to be notified - # by CI if the server's behavior changes here. - doc = SON([("$set", {"foo.bar": "bim"}), ("hello", "world")]) - with self.assertRaises(OperationFailure): - self.db.test.update_one({"hello": "world"}, doc, upsert=True) - - # This is going to cause keys to be checked and raise InvalidDocument. - # That's OK assuming the server's behavior in the previous assert - # doesn't change. If the behavior changes checking the first key for - # '$' in update won't be good enough anymore. - doc = SON([("hello", "world"), ("$set", {"foo.bar": "bim"})]) - with self.assertRaises(OperationFailure): - self.db.test.replace_one({"hello": "world"}, doc, upsert=True) - - # Replace with empty document - self.assertNotEqual(0, (self.db.test.replace_one({"hello": "world"}, {})).matched_count) - - def test_acknowledged_delete(self): - db = self.db - db.drop_collection("test") - db.test.insert_many([{"x": 1}, {"x": 1}]) - self.assertEqual(2, (db.test.delete_many({})).deleted_count) - self.assertEqual(0, (db.test.delete_many({})).deleted_count) - - @client_context.require_version_max(4, 9) - def test_manual_last_error(self): - coll = self.db.get_collection("test", write_concern=WriteConcern(w=0)) - coll.insert_one({"x": 1}) - self.db.command("getlasterror", w=1, wtimeout=1) - - def test_count_documents(self): - db = self.db - db.drop_collection("test") - self.addCleanup(db.drop_collection, "test") - - self.assertEqual(db.test.count_documents({}), 0) - db.wrong.insert_many([{}, {}]) - self.assertEqual(db.test.count_documents({}), 0) - db.test.insert_many([{}, {}]) - self.assertEqual(db.test.count_documents({}), 2) - db.test.insert_many([{"foo": "bar"}, {"foo": "baz"}]) - self.assertEqual(db.test.count_documents({"foo": "bar"}), 1) - self.assertEqual(db.test.count_documents({"foo": re.compile(r"ba.*")}), 2) - - def test_estimated_document_count(self): - db = self.db - db.drop_collection("test") - self.addCleanup(db.drop_collection, "test") - - self.assertEqual(db.test.estimated_document_count(), 0) - db.wrong.insert_many([{}, {}]) - self.assertEqual(db.test.estimated_document_count(), 0) - db.test.insert_many([{}, {}]) - self.assertEqual(db.test.estimated_document_count(), 2) - - def test_aggregate(self): - db = self.db - db.drop_collection("test") - db.test.insert_one({"foo": [1, 2]}) - - with self.assertRaises(TypeError): - db.test.aggregate("wow") # type: ignore[arg-type] - - pipeline = {"$project": {"_id": False, "foo": True}} - result = db.test.aggregate([pipeline]) - self.assertTrue(isinstance(result, CommandCursor)) - self.assertEqual([{"foo": [1, 2]}], result.to_list()) - - # Test write concern. - with self.write_concern_collection() as coll: - coll.aggregate([{"$out": "output-collection"}]) - - def test_aggregate_raw_bson(self): - db = self.db - db.drop_collection("test") - db.test.insert_one({"foo": [1, 2]}) - - with self.assertRaises(TypeError): - db.test.aggregate("wow") # type: ignore[arg-type] - - pipeline = {"$project": {"_id": False, "foo": True}} - coll = db.get_collection("test", codec_options=CodecOptions(document_class=RawBSONDocument)) - result = coll.aggregate([pipeline]) - self.assertTrue(isinstance(result, CommandCursor)) - first_result = next(result) - self.assertIsInstance(first_result, RawBSONDocument) - self.assertEqual([1, 2], list(first_result["foo"])) - - def test_aggregation_cursor_validation(self): - db = self.db - projection = {"$project": {"_id": "$_id"}} - cursor = db.test.aggregate([projection], cursor={}) - self.assertTrue(isinstance(cursor, CommandCursor)) - - def test_aggregation_cursor(self): - db = self.db - if client_context.has_secondaries: - # Test that getMore messages are sent to the right server. - db = self.client.get_database( - db.name, - read_preference=ReadPreference.SECONDARY, - write_concern=WriteConcern(w=self.w), - ) - - for collection_size in (10, 1000): - db.drop_collection("test") - db.test.insert_many([{"_id": i} for i in range(collection_size)]) - expected_sum = sum(range(collection_size)) - # Use batchSize to ensure multiple getMore messages - cursor = db.test.aggregate([{"$project": {"_id": "$_id"}}], batchSize=5) - - self.assertEqual(expected_sum, sum(doc["_id"] for doc in cursor.to_list())) - - # Test that batchSize is handled properly. - cursor = db.test.aggregate([], batchSize=5) - self.assertEqual(5, len(cursor._data)) - # Force a getMore - cursor._data.clear() - next(cursor) - # batchSize - 1 - self.assertEqual(4, len(cursor._data)) - # Exhaust the cursor. There shouldn't be any errors. - for _doc in cursor: - pass - - def test_aggregation_cursor_alive(self): - self.db.test.delete_many({}) - self.db.test.insert_many([{} for _ in range(3)]) - self.addCleanup(self.db.test.delete_many, {}) - cursor = self.db.test.aggregate(pipeline=[], cursor={"batchSize": 2}) - n = 0 - while True: - cursor.next() - n += 1 - if n == 3: - self.assertFalse(cursor.alive) - break - - self.assertTrue(cursor.alive) - - def test_invalid_session_parameter(self): - def try_invalid_session(): - with self.db.test.aggregate([], {}): # type:ignore - pass - - with self.assertRaisesRegex(ValueError, "must be a ClientSession"): - try_invalid_session() - - def test_large_limit(self): - db = self.db - db.drop_collection("test_large_limit") - db.test_large_limit.create_index([("x", 1)]) - my_str = "mongomongo" * 1000 - - db.test_large_limit.insert_many({"x": i, "y": my_str} for i in range(2000)) - - i = 0 - y = 0 - for doc in (db.test_large_limit.find(limit=1900)).sort([("x", 1)]): - i += 1 - y += doc["x"] - - self.assertEqual(1900, i) - self.assertEqual((1900 * 1899) / 2, y) - - def test_find_kwargs(self): - db = self.db - db.drop_collection("test") - db.test.insert_many({"x": i} for i in range(10)) - - self.assertEqual(10, db.test.count_documents({})) - - total = 0 - for x in db.test.find({}, skip=4, limit=2): - total += x["x"] - - self.assertEqual(9, total) - - def test_rename(self): - db = self.db - db.drop_collection("test") - db.drop_collection("foo") - - with self.assertRaises(TypeError): - db.test.rename(5) # type: ignore[arg-type] - with self.assertRaises(InvalidName): - db.test.rename("") - with self.assertRaises(InvalidName): - db.test.rename("te$t") - with self.assertRaises(InvalidName): - db.test.rename(".test") - with self.assertRaises(InvalidName): - db.test.rename("test.") - with self.assertRaises(InvalidName): - db.test.rename("tes..t") - - self.assertEqual(0, db.test.count_documents({})) - self.assertEqual(0, db.foo.count_documents({})) - - db.test.insert_many({"x": i} for i in range(10)) - - self.assertEqual(10, db.test.count_documents({})) - - db.test.rename("foo") - - self.assertEqual(0, db.test.count_documents({})) - self.assertEqual(10, db.foo.count_documents({})) - - x = 0 - for doc in db.foo.find(): - self.assertEqual(x, doc["x"]) - x += 1 - - db.test.insert_one({}) - with self.assertRaises(OperationFailure): - db.foo.rename("test") - db.foo.rename("test", dropTarget=True) - - with self.write_concern_collection() as coll: - coll.rename("foo") - - @no_type_check - def test_find_one(self): - db = self.db - db.drop_collection("test") - - _id = (db.test.insert_one({"hello": "world", "foo": "bar"})).inserted_id - - self.assertEqual("world", (db.test.find_one())["hello"]) - self.assertEqual(db.test.find_one(_id), db.test.find_one()) - self.assertEqual(db.test.find_one(None), db.test.find_one()) - self.assertEqual(db.test.find_one({}), db.test.find_one()) - self.assertEqual(db.test.find_one({"hello": "world"}), db.test.find_one()) - - self.assertTrue("hello" in db.test.find_one(projection=["hello"])) - self.assertTrue("hello" not in db.test.find_one(projection=["foo"])) - - self.assertTrue("hello" in db.test.find_one(projection=("hello",))) - self.assertTrue("hello" not in db.test.find_one(projection=("foo",))) - - self.assertTrue("hello" in db.test.find_one(projection={"hello"})) - self.assertTrue("hello" not in db.test.find_one(projection={"foo"})) - - self.assertTrue("hello" in db.test.find_one(projection=frozenset(["hello"]))) - self.assertTrue("hello" not in db.test.find_one(projection=frozenset(["foo"]))) - - self.assertEqual(["_id"], list(db.test.find_one(projection={"_id": True}))) - self.assertTrue("hello" in list(db.test.find_one(projection={}))) - self.assertTrue("hello" in list(db.test.find_one(projection=[]))) - - self.assertEqual(None, db.test.find_one({"hello": "foo"})) - self.assertEqual(None, db.test.find_one(ObjectId())) - - def test_find_one_non_objectid(self): - db = self.db - db.drop_collection("test") - - db.test.insert_one({"_id": 5}) - - self.assertTrue(db.test.find_one(5)) - self.assertFalse(db.test.find_one(6)) - - def test_find_one_with_find_args(self): - db = self.db - db.drop_collection("test") - - db.test.insert_many([{"x": i} for i in range(1, 4)]) - - self.assertEqual(1, (db.test.find_one())["x"]) - self.assertEqual(2, (db.test.find_one(skip=1, limit=2))["x"]) - - def test_find_with_sort(self): - db = self.db - db.drop_collection("test") - - db.test.insert_many([{"x": 2}, {"x": 1}, {"x": 3}]) - - self.assertEqual(2, (db.test.find_one())["x"]) - self.assertEqual(1, (db.test.find_one(sort=[("x", 1)]))["x"]) - self.assertEqual(3, (db.test.find_one(sort=[("x", -1)]))["x"]) - - def to_list(things): - return [thing["x"] for thing in things] - - self.assertEqual([2, 1, 3], to_list(db.test.find())) - self.assertEqual([1, 2, 3], to_list(db.test.find(sort=[("x", 1)]))) - self.assertEqual([3, 2, 1], to_list(db.test.find(sort=[("x", -1)]))) - - with self.assertRaises(TypeError): - db.test.find(sort=5) - with self.assertRaises(TypeError): - db.test.find(sort="hello") - with self.assertRaises(TypeError): - db.test.find(sort=["hello", 1]) - - # TODO doesn't actually test functionality, just that it doesn't blow up - def test_cursor_timeout(self): - (self.db.test.find(no_cursor_timeout=True)).to_list() - (self.db.test.find(no_cursor_timeout=False)).to_list() - - def test_exhaust(self): - if is_mongos(self.db.client): - with self.assertRaises(InvalidOperation): - self.db.test.find(cursor_type=CursorType.EXHAUST) - return - - # Limit is incompatible with exhaust. - with self.assertRaises(InvalidOperation): - self.db.test.find(cursor_type=CursorType.EXHAUST, limit=5) - cur = self.db.test.find(cursor_type=CursorType.EXHAUST) - with self.assertRaises(InvalidOperation): - cur.limit(5) - cur = self.db.test.find(limit=5) - with self.assertRaises(InvalidOperation): - cur.add_option(64) - cur = self.db.test.find() - cur.add_option(64) - with self.assertRaises(InvalidOperation): - cur.limit(5) - - self.db.drop_collection("test") - # Insert enough documents to require more than one batch - self.db.test.insert_many([{"i": i} for i in range(150)]) - - client = rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) - pool = get_pool(client) - - # Make sure the socket is returned after exhaustion. - cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST) - next(cur) - self.assertEqual(0, len(pool.conns)) - for _ in cur: - pass - self.assertEqual(1, len(pool.conns)) - - # Same as previous but don't call next() - for _ in client[self.db.name].test.find(cursor_type=CursorType.EXHAUST): - pass - self.assertEqual(1, len(pool.conns)) - - # If the Cursor instance is discarded before being completely iterated - # and the socket has pending data (more_to_come=True) we have to close - # and discard the socket. - cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST, batch_size=2) - if client_context.version.at_least(4, 2): - # On 4.2+ we use OP_MSG which only sets more_to_come=True after the - # first getMore. - for _ in range(3): - next(cur) - else: - next(cur) - self.assertEqual(0, len(pool.conns)) - # if sys.platform.startswith("java") or "PyPy" in sys.version: - # # Don't wait for GC or use gc.collect(), it's unreliable. - cur.close() - cur = None - # Wait until the background thread returns the socket. - wait_until(lambda: pool.active_sockets == 0, "return socket") - # The socket should be discarded. - self.assertEqual(0, len(pool.conns)) - - def test_distinct(self): - self.db.drop_collection("test") - - test = self.db.test - test.insert_many([{"a": 1}, {"a": 2}, {"a": 2}, {"a": 2}, {"a": 3}]) - - distinct = test.distinct("a") - distinct.sort() - - self.assertEqual([1, 2, 3], distinct) - - distinct = (test.find({"a": {"$gt": 1}})).distinct("a") - distinct.sort() - self.assertEqual([2, 3], distinct) - - distinct = test.distinct("a", {"a": {"$gt": 1}}) - distinct.sort() - self.assertEqual([2, 3], distinct) - - self.db.drop_collection("test") - - test.insert_one({"a": {"b": "a"}, "c": 12}) - test.insert_one({"a": {"b": "b"}, "c": 12}) - test.insert_one({"a": {"b": "c"}, "c": 12}) - test.insert_one({"a": {"b": "c"}, "c": 12}) - - distinct = test.distinct("a.b") - distinct.sort() - - self.assertEqual(["a", "b", "c"], distinct) - - def test_query_on_query_field(self): - self.db.drop_collection("test") - self.db.test.insert_one({"query": "foo"}) - self.db.test.insert_one({"bar": "foo"}) - - self.assertEqual(1, self.db.test.count_documents({"query": {"$ne": None}})) - self.assertEqual(1, len((self.db.test.find({"query": {"$ne": None}})).to_list())) - - def test_min_query(self): - self.db.drop_collection("test") - self.db.test.insert_many([{"x": 1}, {"x": 2}]) - self.db.test.create_index("x") - - cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}}, hint="x_1") - - docs = cursor.to_list() - self.assertEqual(1, len(docs)) - self.assertEqual(2, docs[0]["x"]) - - def test_numerous_inserts(self): - # Ensure we don't exceed server's maxWriteBatchSize size limit. - self.db.test.drop() - n_docs = client_context.max_write_batch_size + 100 - self.db.test.insert_many([{} for _ in range(n_docs)]) - self.assertEqual(n_docs, self.db.test.count_documents({})) - self.db.test.drop() - - def test_insert_many_large_batch(self): - # Tests legacy insert. - db = self.client.test_insert_large_batch - self.addCleanup(self.client.drop_database, "test_insert_large_batch") - max_bson_size = client_context.max_bson_size - # Write commands are limited to 16MB + 16k per batch - big_string = "x" * int(max_bson_size / 2) - - # Batch insert that requires 2 batches. - successful_insert = [ - {"x": big_string}, - {"x": big_string}, - {"x": big_string}, - {"x": big_string}, - ] - db.collection_0.insert_many(successful_insert) - self.assertEqual(4, db.collection_0.count_documents({})) - - db.collection_0.drop() - - # Test that inserts fail after first error. - insert_second_fails = [ - {"_id": "id0", "x": big_string}, - {"_id": "id0", "x": big_string}, - {"_id": "id1", "x": big_string}, - {"_id": "id2", "x": big_string}, - ] - - with self.assertRaises(BulkWriteError): - db.collection_1.insert_many(insert_second_fails) - - self.assertEqual(1, db.collection_1.count_documents({})) - - db.collection_1.drop() - - # 2 batches, 2nd insert fails, unacknowledged, ordered. - unack_coll = db.collection_2.with_options(write_concern=WriteConcern(w=0)) - unack_coll.insert_many(insert_second_fails) - - def async_lambda(): - return db.collection_2.count_documents({}) == 1 - - wait_until(async_lambda, "insert 1 document", timeout=60) - - db.collection_2.drop() - - # 2 batches, ids of docs 0 and 1 are dupes, ids of docs 2 and 3 are - # dupes. Acknowledged, unordered. - insert_two_failures = [ - {"_id": "id0", "x": big_string}, - {"_id": "id0", "x": big_string}, - {"_id": "id1", "x": big_string}, - {"_id": "id1", "x": big_string}, - ] - - with self.assertRaises(OperationFailure) as context: - db.collection_3.insert_many(insert_two_failures, ordered=False) - - self.assertIn("id1", str(context.exception)) - - # Only the first and third documents should be inserted. - self.assertEqual(2, db.collection_3.count_documents({})) - - db.collection_3.drop() - - # 2 batches, 2 errors, unacknowledged, unordered. - unack_coll = db.collection_4.with_options(write_concern=WriteConcern(w=0)) - unack_coll.insert_many(insert_two_failures, ordered=False) - - def async_lambda(): - return db.collection_4.count_documents({}) == 2 - - # Only the first and third documents are inserted. - wait_until(async_lambda, "insert 2 documents", timeout=60) - - db.collection_4.drop() - - def test_messages_with_unicode_collection_names(self): - db = self.db - - db["Employés"].insert_one({"x": 1}) - db["Employés"].replace_one({"x": 1}, {"x": 2}) - db["Employés"].delete_many({}) - db["Employés"].find_one() - (db["Employés"].find()).to_list() - - def test_drop_indexes_non_existent(self): - self.db.drop_collection("test") - self.db.test.drop_indexes() - - # This is really a bson test but easier to just reproduce it here... - # (Shame on me) - def test_bad_encode(self): - c = self.db.test - c.drop() - with self.assertRaises(InvalidDocument): - c.insert_one({"x": c}) - - class BadGetAttr(dict): - def __getattr__(self, name): - pass - - bad = BadGetAttr([("foo", "bar")]) - c.insert_one({"bad": bad}) - self.assertEqual("bar", (c.find_one())["bad"]["foo"]) # type: ignore - - def test_array_filters_validation(self): - # array_filters must be a list. - c = self.db.test - with self.assertRaises(TypeError): - c.update_one({}, {"$set": {"a": 1}}, array_filters={}) # type: ignore[arg-type] - with self.assertRaises(TypeError): - c.update_many({}, {"$set": {"a": 1}}, array_filters={}) # type: ignore[arg-type] - with self.assertRaises(TypeError): - update = {"$set": {"a": 1}} - c.find_one_and_update({}, update, array_filters={}) # type: ignore[arg-type] - - def test_array_filters_unacknowledged(self): - c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0)) - with self.assertRaises(ConfigurationError): - c_w0.update_one({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) - with self.assertRaises(ConfigurationError): - c_w0.update_many({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) - with self.assertRaises(ConfigurationError): - c_w0.find_one_and_update({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) - - def test_find_one_and(self): - c = self.db.test - c.drop() - c.insert_one({"_id": 1, "i": 1}) - - self.assertEqual({"_id": 1, "i": 1}, c.find_one_and_update({"_id": 1}, {"$inc": {"i": 1}})) - self.assertEqual( - {"_id": 1, "i": 3}, - c.find_one_and_update( - {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER - ), - ) - - self.assertEqual({"_id": 1, "i": 3}, c.find_one_and_delete({"_id": 1})) - self.assertEqual(None, c.find_one({"_id": 1})) - - self.assertEqual(None, c.find_one_and_update({"_id": 1}, {"$inc": {"i": 1}})) - self.assertEqual( - {"_id": 1, "i": 1}, - c.find_one_and_update( - {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER, upsert=True - ), - ) - self.assertEqual( - {"_id": 1, "i": 2}, - c.find_one_and_update( - {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER - ), - ) - - self.assertEqual( - {"_id": 1, "i": 3}, - c.find_one_and_replace( - {"_id": 1}, {"i": 3, "j": 1}, projection=["i"], return_document=ReturnDocument.AFTER - ), - ) - self.assertEqual( - {"i": 4}, - c.find_one_and_update( - {"_id": 1}, - {"$inc": {"i": 1}}, - projection={"i": 1, "_id": 0}, - return_document=ReturnDocument.AFTER, - ), - ) - - c.drop() - for j in range(5): - c.insert_one({"j": j, "i": 0}) - - sort = [("j", DESCENDING)] - self.assertEqual(4, (c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort))["j"]) - - def test_find_one_and_write_concern(self): - listener = EventListener() - db = (single_client(event_listeners=[listener]))[self.db.name] - # non-default WriteConcern. - c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) - # default WriteConcern. - c_default = db.get_collection("test", write_concern=WriteConcern()) - # Authenticate the client and throw out auth commands from the listener. - db.command("ping") - listener.reset() - c_w0.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) - self.assertEqual({"w": 0}, listener.started_events[0].command["writeConcern"]) - listener.reset() - - c_w0.find_one_and_replace({"_id": 1}, {"foo": "bar"}) - self.assertEqual({"w": 0}, listener.started_events[0].command["writeConcern"]) - listener.reset() - - c_w0.find_one_and_delete({"_id": 1}) - self.assertEqual({"w": 0}, listener.started_events[0].command["writeConcern"]) - listener.reset() - - # Test write concern errors. - if client_context.is_rs: - c_wc_error = db.get_collection( - "test", write_concern=WriteConcern(w=len(client_context.nodes) + 1) - ) - with self.assertRaises(WriteConcernError): - c_wc_error.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) - with self.assertRaises(WriteConcernError): - c_wc_error.find_one_and_replace( - {"w": 0}, listener.started_events[0].command["writeConcern"] - ) - with self.assertRaises(WriteConcernError): - c_wc_error.find_one_and_delete( - {"w": 0}, listener.started_events[0].command["writeConcern"] - ) - listener.reset() - - c_default.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) - self.assertNotIn("writeConcern", listener.started_events[0].command) - listener.reset() - - c_default.find_one_and_replace({"_id": 1}, {"foo": "bar"}) - self.assertNotIn("writeConcern", listener.started_events[0].command) - listener.reset() - - c_default.find_one_and_delete({"_id": 1}) - self.assertNotIn("writeConcern", listener.started_events[0].command) - listener.reset() - - def test_find_with_nested(self): - c = self.db.test - c.drop() - c.insert_many([{"i": i} for i in range(5)]) # [0, 1, 2, 3, 4] - self.assertEqual( - [2], - [ - i["i"] - for i in c.find( - { - "$and": [ - { - # This clause gives us [1,2,4] - "$or": [ - {"i": {"$lte": 2}}, - {"i": {"$gt": 3}}, - ], - }, - { - # This clause gives us [2,3] - "$or": [ - {"i": 2}, - {"i": 3}, - ] - }, - ] - } - ) - ], - ) - - self.assertEqual( - [0, 1, 2], - [ - i["i"] - for i in c.find( - { - "$or": [ - { - # This clause gives us [2] - "$and": [ - {"i": {"$gte": 2}}, - {"i": {"$lt": 3}}, - ], - }, - { - # This clause gives us [0,1] - "$and": [ - {"i": {"$gt": -100}}, - {"i": {"$lt": 2}}, - ] - }, - ] - } - ) - ], - ) - - def test_find_regex(self): - c = self.db.test - c.drop() - c.insert_one({"r": re.compile(".*")}) - - self.assertTrue(isinstance((c.find_one())["r"], Regex)) # type: ignore - for doc in c.find(): - self.assertTrue(isinstance(doc["r"], Regex)) - - def test_find_command_generation(self): - cmd = _gen_find_command( - "coll", - {"$query": {"foo": 1}, "$dumb": 2}, - None, - 0, - 0, - 0, - None, - DEFAULT_READ_CONCERN, - None, - None, - ) - self.assertEqual(cmd, {"find": "coll", "$dumb": 2, "filter": {"foo": 1}}) - - def test_bool(self): - with self.assertRaises(NotImplementedError): - bool(Collection(self.db, "test")) - - @client_context.require_version_min(5, 0, 0) - def test_helpers_with_let(self): - c = self.db.test - helpers = [ - (c.delete_many, ({}, {})), - (c.delete_one, ({}, {})), - (c.find, ({})), - (c.update_many, ({}, {"$inc": {"x": 3}})), - (c.update_one, ({}, {"$inc": {"x": 3}})), - (c.find_one_and_delete, ({}, {})), - (c.find_one_and_replace, ({}, {})), - (c.aggregate, ([],)), - ] - for let in [10, "str", [], False]: - for helper, args in helpers: - with self.assertRaisesRegex(TypeError, "let must be an instance of dict"): - helper(*args, let=let) # type: ignore - for helper, args in helpers: - helper(*args, let={}) # type: ignore - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_collection.py b/test/test_collection.py index 0de506e0f..7d105acb6 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -15,6 +15,7 @@ """Test the collection module.""" from __future__ import annotations +import asyncio import contextlib import re import sys @@ -26,8 +27,11 @@ from pymongo.synchronous.database import Database sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import ( # TODO: fix sync imports in PYTHON-4528 + IntegrationTest, + client_context, + unittest, +) from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, @@ -70,9 +74,12 @@ from pymongo.results import ( ) from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import next from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestCollectionNoConnect(unittest.TestCase): """Test Collection features on a client that does not connect.""" @@ -127,7 +134,10 @@ class TestCollectionNoConnect(unittest.TestCase): if "PyPy" in sys.version and sys.version_info < (3, 8, 15): msg = "'NoneType' object is not callable" else: - msg = "'Collection' object is not iterable" + if _IS_SYNC: + msg = "'Collection' object is not iterable" + else: + msg = "'Collection' object is not iterable" # Iteration fails with self.assertRaisesRegex(TypeError, msg): for _ in coll: # type: ignore[misc] # error: "None" not callable [misc] @@ -136,10 +146,10 @@ class TestCollectionNoConnect(unittest.TestCase): self.assertEqual(coll[0].name, "coll.0") self.assertEqual(coll[{}].name, "coll.{}") # next fails - with self.assertRaisesRegex(TypeError, "'Collection' object is not iterable"): + with self.assertRaisesRegex(TypeError, msg): _ = next(coll) # .next() fails - with self.assertRaisesRegex(TypeError, "'Collection' object is not iterable"): + with self.assertRaisesRegex(TypeError, msg): _ = coll.next() # Do not implement typing.Iterable. self.assertNotIsInstance(coll, Iterable) @@ -155,6 +165,13 @@ class TestCollection(IntegrationTest): @classmethod def tearDownClass(cls): + if _IS_SYNC: + cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] + else: + asyncio.run(cls.async_tearDownClass()) + + @classmethod + def async_tearDownClass(cls): cls.db.drop_collection("test_large_limit") def setUp(self): @@ -169,7 +186,9 @@ class TestCollection(IntegrationTest): with self.assertRaises(WriteConcernError): # Unsatisfiable write concern. yield Collection( - self.db, "test", write_concern=WriteConcern(w=len(client_context.nodes) + 1) + self.db, + "test", + write_concern=WriteConcern(w=len(client_context.nodes) + 1), ) else: yield self.db.test @@ -188,26 +207,22 @@ class TestCollection(IntegrationTest): # No Exception. db = client_context.client.pymongo_test db.create_test_no_wc.drop() + + def lambda_test(): + return "create_test_no_wc" not in db.list_collection_names() + + def lambda_test_2(): + return "create_test_no_wc" in db.list_collection_names() + wait_until( - lambda: "create_test_no_wc" not in db.list_collection_names(), + lambda_test, "drop create_test_no_wc collection", ) db.create_collection("create_test_no_wc") wait_until( - lambda: "create_test_no_wc" in db.list_collection_names(), + lambda_test_2, "create create_test_no_wc collection", ) - db.create_test_no_wc.drop() - with self.assertWarns( - DeprecationWarning, - msg="The `create` and `kwargs` arguments to Collection are deprecated and will be removed in PyMongo 5.0", - ): - Collection(db, name="create_test_no_wc", create=True) - wait_until( - lambda: "create_test_no_wc" in db.list_collection_names(), - "create create_test_no_wc collection", - ) - # SERVER-33317 if not client_context.is_mongos or not client_context.version.at_least(3, 7, 0): with self.assertRaises(OperationFailure): @@ -223,8 +238,10 @@ class TestCollection(IntegrationTest): def test_create_indexes(self): db = self.db - self.assertRaises(TypeError, db.test.create_indexes, "foo") - self.assertRaises(TypeError, db.test.create_indexes, ["foo"]) + with self.assertRaises(TypeError): + db.test.create_indexes("foo") # type: ignore[arg-type] + with self.assertRaises(TypeError): + db.test.create_indexes(["foo"]) # type: ignore[list-item] self.assertRaises(TypeError, IndexModel, 5) self.assertRaises(ValueError, IndexModel, []) @@ -263,7 +280,8 @@ class TestCollection(IntegrationTest): db.test.drop() db.test.insert_one({"a": 1}) db.test.insert_one({"a": 1}) - self.assertRaises(DuplicateKeyError, db.test.create_indexes, [IndexModel("a", unique=True)]) + with self.assertRaises(DuplicateKeyError): + db.test.create_indexes([IndexModel("a", unique=True)]) with self.write_concern_collection() as coll: coll.create_indexes([IndexModel("hello")]) @@ -285,8 +303,10 @@ class TestCollection(IntegrationTest): def test_create_index(self): db = self.db - self.assertRaises(TypeError, db.test.create_index, 5) - self.assertRaises(ValueError, db.test.create_index, []) + with self.assertRaises(TypeError): + db.test.create_index(5) # type: ignore[arg-type] + with self.assertRaises(ValueError): + db.test.create_index([]) db.test.drop_indexes() db.test.insert_one({}) @@ -321,7 +341,8 @@ class TestCollection(IntegrationTest): db.test.drop() db.test.insert_one({"a": 1}) db.test.insert_one({"a": 1}) - self.assertRaises(DuplicateKeyError, db.test.create_index, "a", unique=True) + with self.assertRaises(DuplicateKeyError): + db.test.create_index("a", unique=True) with self.write_concern_collection() as coll: coll.create_index([("hello", DESCENDING)]) @@ -365,12 +386,14 @@ class TestCollection(IntegrationTest): coll = self.db.test self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: - self.assertRaises(ExecutionTimeout, coll.create_index, "foo", maxTimeMS=1) - self.assertRaises( - ExecutionTimeout, coll.create_indexes, [IndexModel("foo")], maxTimeMS=1 - ) - self.assertRaises(ExecutionTimeout, coll.drop_index, "foo", maxTimeMS=1) - self.assertRaises(ExecutionTimeout, coll.drop_indexes, maxTimeMS=1) + with self.assertRaises(ExecutionTimeout): + coll.create_index("foo", maxTimeMS=1) + with self.assertRaises(ExecutionTimeout): + coll.create_indexes([IndexModel("foo")], maxTimeMS=1) + with self.assertRaises(ExecutionTimeout): + coll.drop_index("foo", maxTimeMS=1) + with self.assertRaises(ExecutionTimeout): + coll.drop_indexes(maxTimeMS=1) finally: self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") @@ -382,17 +405,17 @@ class TestCollection(IntegrationTest): def map_indexes(indexes): return {index["name"]: index for index in indexes} - indexes = list(db.test.list_indexes()) + indexes = (db.test.list_indexes()).to_list() self.assertEqual(len(indexes), 1) self.assertTrue("_id_" in map_indexes(indexes)) db.test.create_index("hello") - indexes = list(db.test.list_indexes()) + indexes = (db.test.list_indexes()).to_list() self.assertEqual(len(indexes), 2) self.assertEqual(map_indexes(indexes)["hello_1"]["key"], SON([("hello", ASCENDING)])) db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) - indexes = list(db.test.list_indexes()) + indexes = (db.test.list_indexes()).to_list() self.assertEqual(len(indexes), 3) index_map = map_indexes(indexes) self.assertEqual( @@ -401,11 +424,11 @@ class TestCollection(IntegrationTest): self.assertEqual(True, index_map["hello_-1_world_1"]["unique"]) # List indexes on a collection that does not exist. - indexes = list(db.does_not_exist.list_indexes()) + indexes = (db.does_not_exist.list_indexes()).to_list() self.assertEqual(len(indexes), 0) # List indexes on a database that does not exist. - indexes = list(self.client.db_does_not_exist.coll.list_indexes()) + indexes = (db.does_not_exist.list_indexes()).to_list() self.assertEqual(len(indexes), 0) def test_index_info(self): @@ -417,22 +440,22 @@ class TestCollection(IntegrationTest): db.test.create_index("hello") self.assertEqual(len(db.test.index_information()), 2) - self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) + self.assertEqual((db.test.index_information())["hello_1"]["key"], [("hello", ASCENDING)]) db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) - self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) + self.assertEqual((db.test.index_information())["hello_1"]["key"], [("hello", ASCENDING)]) self.assertEqual(len(db.test.index_information()), 3) self.assertEqual( [("hello", DESCENDING), ("world", ASCENDING)], - db.test.index_information()["hello_-1_world_1"]["key"], + (db.test.index_information())["hello_-1_world_1"]["key"], ) - self.assertEqual(True, db.test.index_information()["hello_-1_world_1"]["unique"]) + self.assertEqual(True, (db.test.index_information())["hello_-1_world_1"]["unique"]) def test_index_geo2d(self): db = self.db db.test.drop_indexes() self.assertEqual("loc_2d", db.test.create_index([("loc", GEO2D)])) - index_info = db.test.index_information()["loc_2d"] + index_info = (db.test.index_information())["loc_2d"] self.assertEqual([("loc", "2d")], index_info["key"]) # geoSearch was deprecated in 4.4 and removed in 5.0 @@ -441,22 +464,24 @@ class TestCollection(IntegrationTest): def test_index_haystack(self): db = self.db db.test.drop() - _id = db.test.insert_one( - {"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"} + _id = ( + db.test.insert_one({"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}) ).inserted_id db.test.insert_one({"pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant"}) db.test.insert_one({"pos": {"long": 59.1, "lat": 87.2}, "type": "office"}) db.test.create_index([("pos", "geoHaystack"), ("type", ASCENDING)], bucketSize=1) - results = db.command( - SON( - [ - ("geoSearch", "test"), - ("near", [33, 33]), - ("maxDistance", 6), - ("search", {"type": "restaurant"}), - ("limit", 30), - ] + results = ( + db.command( + SON( + [ + ("geoSearch", "test"), + ("near", [33, 33]), + ("maxDistance", 6), + ("search", {"type": "restaurant"}), + ("limit", 30), + ] + ) ) )["results"] @@ -470,7 +495,7 @@ class TestCollection(IntegrationTest): db = self.db db.test.drop_indexes() self.assertEqual("t_text", db.test.create_index([("t", TEXT)])) - index_info = db.test.index_information()["t_text"] + index_info = (db.test.index_information())["t_text"] self.assertTrue("weights" in index_info) db.test.insert_many( @@ -482,7 +507,7 @@ class TestCollection(IntegrationTest): # Sort by 'score' field. cursor.sort([("score", {"$meta": "textScore"})]) - results = list(cursor) + results = cursor.to_list() self.assertTrue(results[0]["score"] >= results[1]["score"]) db.test.drop_indexes() @@ -492,7 +517,7 @@ class TestCollection(IntegrationTest): db.test.drop_indexes() self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) - for dummy, info in db.test.index_information().items(): + for dummy, info in (db.test.index_information()).items(): field, idx_type = info["key"][0] if field == "geo" and idx_type == "2dsphere": break @@ -511,7 +536,7 @@ class TestCollection(IntegrationTest): db.test.drop_indexes() self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) - for dummy, info in db.test.index_information().items(): + for dummy, info in (db.test.index_information()).items(): field, idx_type = info["key"][0] if field == "a" and idx_type == "hashed": break @@ -524,7 +549,7 @@ class TestCollection(IntegrationTest): db = self.db db.test.drop_indexes() db.test.create_index([("key", ASCENDING)], sparse=True) - self.assertTrue(db.test.index_information()["key_1"]["sparse"]) + self.assertTrue((db.test.index_information())["key_1"]["sparse"]) def test_index_background(self): db = self.db @@ -532,9 +557,9 @@ class TestCollection(IntegrationTest): db.test.create_index([("keya", ASCENDING)]) db.test.create_index([("keyb", ASCENDING)], background=False) db.test.create_index([("keyc", ASCENDING)], background=True) - self.assertFalse("background" in db.test.index_information()["keya_1"]) - self.assertFalse(db.test.index_information()["keyb_1"]["background"]) - self.assertTrue(db.test.index_information()["keyc_1"]["background"]) + self.assertFalse("background" in (db.test.index_information())["keya_1"]) + self.assertFalse((db.test.index_information())["keyb_1"]["background"]) + self.assertTrue((db.test.index_information())["keyc_1"]["background"]) def _drop_dups_setup(self, db): db.drop_collection("test") @@ -549,10 +574,11 @@ class TestCollection(IntegrationTest): self._drop_dups_setup(db) # There's a duplicate - def test_create(): + def _test_create(): db.test.create_index([("i", ASCENDING)], unique=True, dropDups=False) - self.assertRaises(DuplicateKeyError, test_create) + with self.assertRaises(DuplicateKeyError): + _test_create() # Duplicate wasn't dropped self.assertEqual(4, db.test.count_documents({})) @@ -586,16 +612,12 @@ class TestCollection(IntegrationTest): db.drop_collection("test") # Test bad filter spec on create. - self.assertRaises(OperationFailure, db.test.create_index, "x", partialFilterExpression=5) - self.assertRaises( - OperationFailure, - db.test.create_index, - "x", - partialFilterExpression={"x": {"$asdasd": 3}}, - ) - self.assertRaises( - OperationFailure, db.test.create_index, "x", partialFilterExpression={"$and": 5} - ) + with self.assertRaises(OperationFailure): + db.test.create_index("x", partialFilterExpression=5) + with self.assertRaises(OperationFailure): + db.test.create_index("x", partialFilterExpression={"x": {"$asdasd": 3}}) + with self.assertRaises(OperationFailure): + db.test.create_index("x", partialFilterExpression={"$and": 5}) self.assertEqual( "x_1", @@ -605,32 +627,32 @@ class TestCollection(IntegrationTest): db.test.insert_one({"x": 6, "a": 1}) # Operations that use the partial index. - explain = db.test.find({"x": 6, "a": 1}).explain() + explain = (db.test.find({"x": 6, "a": 1})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") self.assertEqual("x_1", stage.get("indexName")) self.assertTrue(stage.get("isPartial")) - explain = db.test.find({"x": {"$gt": 1}, "a": 1}).explain() + explain = (db.test.find({"x": {"$gt": 1}, "a": 1})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") self.assertEqual("x_1", stage.get("indexName")) self.assertTrue(stage.get("isPartial")) - explain = db.test.find({"x": 6, "a": {"$lte": 1}}).explain() + explain = (db.test.find({"x": 6, "a": {"$lte": 1}})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") self.assertEqual("x_1", stage.get("indexName")) self.assertTrue(stage.get("isPartial")) # Operations that do not use the partial index. - explain = db.test.find({"x": 6, "a": {"$lte": 1.6}}).explain() + explain = (db.test.find({"x": 6, "a": {"$lte": 1.6}})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) - explain = db.test.find({"x": 6}).explain() + explain = (db.test.find({"x": 6})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) # Test drop_indexes. db.test.drop_index("x_1") - explain = db.test.find({"x": 6, "a": 1}).explain() + explain = (db.test.find({"x": 6, "a": 1})).explain() stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) @@ -725,7 +747,11 @@ class TestCollection(IntegrationTest): self.assertEqual(document["_id"], result.inserted_id) self.assertFalse(result.acknowledged) # The insert failed duplicate key... - wait_until(lambda: db.test.count_documents({}) == 2, "forcing duplicate key error") + + def async_lambda(): + return db.test.count_documents({}) == 2 + + wait_until(async_lambda, "forcing duplicate key error") document = RawBSONDocument(encode({"_id": ObjectId(), "foo": "bar"})) result = db.test.insert_one(document) @@ -826,7 +852,11 @@ class TestCollection(IntegrationTest): self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until(lambda: db.test.count_documents({}) == 0, "delete 1 documents") + + def lambda_async(): + return db.test.count_documents({}) == 0 + + wait_until(lambda_async, "delete 1 documents") def test_delete_many(self): self.db.test.drop() @@ -847,15 +877,22 @@ class TestCollection(IntegrationTest): self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until(lambda: db.test.count_documents({}) == 0, "delete 2 documents") + + def lambda_async(): + return db.test.count_documents({}) == 0 + + wait_until(lambda_async, "delete 2 documents") def test_command_document_too_large(self): large = "*" * (client_context.max_bson_size + _COMMAND_OVERHEAD) coll = self.db.test - self.assertRaises(DocumentTooLarge, coll.insert_one, {"data": large}) + with self.assertRaises(DocumentTooLarge): + coll.insert_one({"data": large}) # update_one and update_many are the same - self.assertRaises(DocumentTooLarge, coll.replace_one, {}, {"data": large}) - self.assertRaises(DocumentTooLarge, coll.delete_one, {"data": large}) + with self.assertRaises(DocumentTooLarge): + coll.replace_one({}, {"data": large}) + with self.assertRaises(DocumentTooLarge): + coll.delete_one({"data": large}) def test_write_large_document(self): max_size = client_context.max_bson_size @@ -864,19 +901,19 @@ class TestCollection(IntegrationTest): half_str = "x" * half_size self.assertEqual(max_size, 16777216) - self.assertRaises(OperationFailure, self.db.test.insert_one, {"foo": max_str}) - self.assertRaises( - OperationFailure, self.db.test.replace_one, {}, {"foo": max_str}, upsert=True - ) - self.assertRaises(OperationFailure, self.db.test.insert_many, [{"x": 1}, {"foo": max_str}]) + with self.assertRaises(OperationFailure): + self.db.test.insert_one({"foo": max_str}) + with self.assertRaises(OperationFailure): + self.db.test.replace_one({}, {"foo": max_str}, upsert=True) + with self.assertRaises(OperationFailure): + self.db.test.insert_many([{"x": 1}, {"foo": max_str}]) self.db.test.insert_many([{"foo": half_str}, {"foo": half_str}]) self.db.test.insert_one({"bar": "x"}) # Use w=0 here to test legacy doc size checking in all server versions unack_coll = self.db.test.with_options(write_concern=WriteConcern(w=0)) - self.assertRaises( - DocumentTooLarge, unack_coll.replace_one, {"bar": "x"}, {"bar": "x" * (max_size - 14)} - ) + with self.assertRaises(DocumentTooLarge): + unack_coll.replace_one({"bar": "x"}, {"bar": "x" * (max_size - 14)}) self.db.test.replace_one({"bar": "x"}, {"bar": "x" * (max_size - 32)}) def test_insert_bypass_document_validation(self): @@ -886,7 +923,8 @@ class TestCollection(IntegrationTest): db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) # Test insert_one - self.assertRaises(OperationFailure, db.test.insert_one, {"_id": 1, "x": 100}) + with self.assertRaises(OperationFailure): + db.test.insert_one({"_id": 1, "x": 100}) result = db.test.insert_one({"_id": 1, "x": 100}, bypass_document_validation=True) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(1, result.inserted_id) @@ -895,11 +933,16 @@ class TestCollection(IntegrationTest): self.assertEqual(2, result.inserted_id) db_w0.test.insert_one({"y": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"y": 1}), "find w:0 inserted document") + + def async_lambda(): + return db_w0.test.find_one({"y": 1}) + + wait_until(async_lambda, "find w:0 inserted document") # Test insert_many docs = [{"_id": i, "x": 100 - i} for i in range(3, 100)] - self.assertRaises(OperationFailure, db.test.insert_many, docs) + with self.assertRaises(OperationFailure): + db.test.insert_many(docs) result = db.test.insert_many(docs, bypass_document_validation=True) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(97, len(result.inserted_ids)) @@ -920,12 +963,11 @@ class TestCollection(IntegrationTest): self.assertEqual(1, db.test.count_documents({"a": doc["a"]})) self.assertTrue(result.acknowledged) - self.assertRaises( - OperationFailure, - db_w0.test.insert_many, - [{"x": 1}, {"x": 2}], - bypass_document_validation=True, - ) + with self.assertRaises(OperationFailure): + db_w0.test.insert_many( + [{"x": 1}, {"x": 2}], + bypass_document_validation=True, + ) def test_replace_bypass_document_validation(self): db = self.db @@ -935,7 +977,8 @@ class TestCollection(IntegrationTest): # Test replace_one db.test.insert_one({"a": 101}) - self.assertRaises(OperationFailure, db.test.replace_one, {"a": 101}, {"y": 1}) + with self.assertRaises(OperationFailure): + db.test.replace_one({"a": 101}, {"y": 1}) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(1, db.test.count_documents({"a": 101})) db.test.replace_one({"a": 101}, {"y": 1}, bypass_document_validation=True) @@ -947,7 +990,8 @@ class TestCollection(IntegrationTest): self.assertEqual(1, db.test.count_documents({"a": 102})) db.test.insert_one({"y": 1}, bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.replace_one, {"y": 1}, {"x": 101}) + with self.assertRaises(OperationFailure): + db.test.replace_one({"y": 1}, {"x": 101}) self.assertEqual(0, db.test.count_documents({"x": 101})) self.assertEqual(1, db.test.count_documents({"y": 1})) db.test.replace_one({"y": 1}, {"x": 101}, bypass_document_validation=True) @@ -959,6 +1003,7 @@ class TestCollection(IntegrationTest): db.test.insert_one({"y": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) + wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") def test_update_bypass_document_validation(self): @@ -969,7 +1014,8 @@ class TestCollection(IntegrationTest): db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) # Test update_one - self.assertRaises(OperationFailure, db.test.update_one, {"z": 5}, {"$inc": {"z": -10}}) + with self.assertRaises(OperationFailure): + db.test.update_one({"z": 5}, {"$inc": {"z": -10}}) self.assertEqual(0, db.test.count_documents({"z": -5})) self.assertEqual(1, db.test.count_documents({"z": 5})) db.test.update_one({"z": 5}, {"$inc": {"z": -10}}, bypass_document_validation=True) @@ -980,7 +1026,8 @@ class TestCollection(IntegrationTest): self.assertEqual(0, db.test.count_documents({"z": -5})) db.test.insert_one({"z": -10}, bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_one, {"z": -10}, {"$inc": {"z": 1}}) + with self.assertRaises(OperationFailure): + db.test.update_one({"z": -10}, {"$inc": {"z": 1}}) self.assertEqual(0, db.test.count_documents({"z": -9})) self.assertEqual(1, db.test.count_documents({"z": -10})) db.test.update_one({"z": -10}, {"$inc": {"z": 1}}, bypass_document_validation=True) @@ -992,12 +1039,17 @@ class TestCollection(IntegrationTest): db.test.insert_one({"y": 1, "x": 0}, bypass_document_validation=True) db_w0.test.update_one({"y": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"y": 1, "x": 1}), "find w:0 updated document") + + def async_lambda(): + return db_w0.test.find_one({"y": 1, "x": 1}) + + wait_until(async_lambda, "find w:0 updated document") # Test update_many db.test.insert_many([{"z": i} for i in range(3, 101)]) db.test.insert_one({"y": 0}, bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": -100}}) + with self.assertRaises(OperationFailure): + db.test.update_many({}, {"$inc": {"z": -100}}) self.assertEqual(100, db.test.count_documents({"z": {"$gte": 0}})) self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) self.assertEqual(0, db.test.count_documents({"y": 0, "z": -100})) @@ -1013,7 +1065,8 @@ class TestCollection(IntegrationTest): self.assertEqual(50, db.test.count_documents({"z": {"$lt": 0}})) db.test.insert_many([{"z": -i} for i in range(50)], bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": 1}}) + with self.assertRaises(OperationFailure): + db.test.update_many({}, {"$inc": {"z": 1}}) self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) self.assertEqual(50, db.test.count_documents({"z": {"$gt": 1}})) db.test.update_many( @@ -1030,9 +1083,11 @@ class TestCollection(IntegrationTest): db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) db_w0.test.update_many({"m": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) - wait_until( - lambda: db_w0.test.count_documents({"m": 1, "x": 1}) == 2, "find w:0 updated documents" - ) + + def async_lambda(): + return db_w0.test.count_documents({"m": 1, "x": 1}) == 2 + + wait_until(async_lambda, "find w:0 updated documents") def test_bypass_document_validation_bulk_write(self): db = self.db @@ -1057,11 +1112,11 @@ class TestCollection(IntegrationTest): # Assert that the operations would fail without bypass_doc_val for op in ops: - self.assertRaises(BulkWriteError, db.test.bulk_write, [op]) + with self.assertRaises(BulkWriteError): + db.test.bulk_write([op]) - self.assertRaises( - OperationFailure, db_w0.test.bulk_write, ops, bypass_document_validation=True - ) + with self.assertRaises(OperationFailure): + db_w0.test.bulk_write(ops, bypass_document_validation=True) def test_find_by_default_dct(self): db = self.db @@ -1102,8 +1157,8 @@ class TestCollection(IntegrationTest): db.test.insert_one({"x": [1, 2, 3], "mike": "awesome"}) - self.assertEqual([1, 2, 3], db.test.find_one()["x"]) - self.assertEqual([2, 3], db.test.find_one(projection={"x": {"$slice": -2}})["x"]) + self.assertEqual([1, 2, 3], (db.test.find_one())["x"]) + self.assertEqual([2, 3], (db.test.find_one(projection={"x": {"$slice": -2}}))["x"]) self.assertTrue("x" not in db.test.find_one(projection={"x": 0})) self.assertTrue("mike" in db.test.find_one(projection={"x": 0})) @@ -1116,11 +1171,11 @@ class TestCollection(IntegrationTest): db.test.insert_one({"x": "hello_mikey"}) db.test.insert_one({"x": "hello_test"}) - self.assertEqual(len(list(db.test.find())), 4) - self.assertEqual(len(list(db.test.find({"x": re.compile("^hello.*")}))), 4) - self.assertEqual(len(list(db.test.find({"x": re.compile("ello")}))), 4) - self.assertEqual(len(list(db.test.find({"x": re.compile("^hello$")}))), 0) - self.assertEqual(len(list(db.test.find({"x": re.compile("^hello_mi.*$")}))), 2) + self.assertEqual(len((db.test.find()).to_list()), 4) + self.assertEqual(len((db.test.find({"x": re.compile("^hello.*")})).to_list()), 4) + self.assertEqual(len((db.test.find({"x": re.compile("ello")})).to_list()), 4) + self.assertEqual(len((db.test.find({"x": re.compile("^hello$")})).to_list()), 0) + self.assertEqual(len((db.test.find({"x": re.compile("^hello_mi.*$")})).to_list()), 2) def test_id_can_be_anything(self): db = self.db @@ -1219,14 +1274,15 @@ class TestCollection(IntegrationTest): db.test.insert_one({"text": text}) # Should raise DuplicateKeyError, not InvalidBSON - self.assertRaises(DuplicateKeyError, db.test.insert_one, {"text": text}) + with self.assertRaises(DuplicateKeyError): + db.test.insert_one({"text": text}) - self.assertRaises( - DuplicateKeyError, db.test.replace_one, {"_id": ObjectId()}, {"text": text}, upsert=True - ) + with self.assertRaises(DuplicateKeyError): + db.test.replace_one({"_id": ObjectId()}, {"text": text}, upsert=True) # Should raise BulkWriteError, not InvalidBSON - self.assertRaises(BulkWriteError, db.test.insert_many, [{"text": text}]) + with self.assertRaises(BulkWriteError): + db.test.insert_many([{"text": text}]) def test_write_error_unicode(self): coll = self.db.test @@ -1248,10 +1304,12 @@ class TestCollection(IntegrationTest): collection.insert_one({"_id": 1}) coll = collection.with_options(write_concern=WriteConcern(w=1, wtimeout=1000)) - self.assertRaises(DuplicateKeyError, coll.insert_one, {"_id": 1}) + with self.assertRaises(DuplicateKeyError): + coll.insert_one({"_id": 1}) coll = collection.with_options(write_concern=WriteConcern(wtimeout=1000)) - self.assertRaises(DuplicateKeyError, coll.insert_one, {"_id": 1}) + with self.assertRaises(DuplicateKeyError): + coll.insert_one({"_id": 1}) def test_error_code(self): try: @@ -1277,15 +1335,17 @@ class TestCollection(IntegrationTest): db.test.insert_one({"hello": {"a": 4, "b": 5}}) db.test.insert_one({"hello": {"a": 7, "b": 2}}) - self.assertRaises(DuplicateKeyError, db.test.insert_one, {"hello": {"a": 4, "b": 10}}) + with self.assertRaises(DuplicateKeyError): + db.test.insert_one({"hello": {"a": 4, "b": 10}}) def test_replace_one(self): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, lambda: db.test.replace_one({}, {"$set": {"x": 1}})) + with self.assertRaises(ValueError): + db.test.replace_one({}, {"$set": {"x": 1}}) - id1 = db.test.insert_one({"x": 1}).inserted_id + id1 = (db.test.insert_one({"x": 1})).inserted_id result = db.test.replace_one({"x": 1}, {"y": 1}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) @@ -1294,7 +1354,7 @@ class TestCollection(IntegrationTest): self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 1})) self.assertEqual(0, db.test.count_documents({"x": 1})) - self.assertEqual(db.test.find_one(id1)["y"], 1) # type: ignore + self.assertEqual((db.test.find_one(id1))["y"], 1) # type: ignore replacement = RawBSONDocument(encode({"_id": id1, "z": 1})) result = db.test.replace_one({"y": 1}, replacement, True) @@ -1305,7 +1365,7 @@ class TestCollection(IntegrationTest): self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"z": 1})) self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual(db.test.find_one(id1)["z"], 1) # type: ignore + self.assertEqual((db.test.find_one(id1))["z"], 1) # type: ignore result = db.test.replace_one({"x": 2}, {"y": 2}, True) self.assertTrue(isinstance(result, UpdateResult)) @@ -1327,26 +1387,27 @@ class TestCollection(IntegrationTest): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, lambda: db.test.update_one({}, {"x": 1})) + with self.assertRaises(ValueError): + db.test.update_one({}, {"x": 1}) - id1 = db.test.insert_one({"x": 5}).inserted_id + id1 = (db.test.insert_one({"x": 5})).inserted_id result = db.test.update_one({}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) - self.assertEqual(db.test.find_one(id1)["x"], 6) # type: ignore + self.assertEqual((db.test.find_one(id1))["x"], 6) # type: ignore - id2 = db.test.insert_one({"x": 1}).inserted_id + id2 = (db.test.insert_one({"x": 1})).inserted_id result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) - self.assertEqual(db.test.find_one(id1)["x"], 7) # type: ignore - self.assertEqual(db.test.find_one(id2)["x"], 1) # type: ignore + self.assertEqual((db.test.find_one(id1))["x"], 7) # type: ignore + self.assertEqual((db.test.find_one(id2))["x"], 1) # type: ignore result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True) self.assertTrue(isinstance(result, UpdateResult)) @@ -1367,7 +1428,8 @@ class TestCollection(IntegrationTest): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, lambda: db.test.update_many({}, {"x": 1})) + with self.assertRaises(ValueError): + db.test.update_many({}, {"x": 1}) db.test.insert_one({"x": 4, "y": 3}) db.test.insert_one({"x": 5, "y": 5}) @@ -1416,28 +1478,26 @@ class TestCollection(IntegrationTest): # I know this seems like testing the server but I'd like to be notified # by CI if the server's behavior changes here. doc = SON([("$set", {"foo.bar": "bim"}), ("hello", "world")]) - self.assertRaises( - OperationFailure, self.db.test.update_one, {"hello": "world"}, doc, upsert=True - ) + with self.assertRaises(OperationFailure): + self.db.test.update_one({"hello": "world"}, doc, upsert=True) # This is going to cause keys to be checked and raise InvalidDocument. # That's OK assuming the server's behavior in the previous assert # doesn't change. If the behavior changes checking the first key for # '$' in update won't be good enough anymore. doc = SON([("hello", "world"), ("$set", {"foo.bar": "bim"})]) - self.assertRaises( - OperationFailure, self.db.test.replace_one, {"hello": "world"}, doc, upsert=True - ) + with self.assertRaises(OperationFailure): + self.db.test.replace_one({"hello": "world"}, doc, upsert=True) # Replace with empty document - self.assertNotEqual(0, self.db.test.replace_one({"hello": "world"}, {}).matched_count) + self.assertNotEqual(0, (self.db.test.replace_one({"hello": "world"}, {})).matched_count) def test_acknowledged_delete(self): db = self.db db.drop_collection("test") db.test.insert_many([{"x": 1}, {"x": 1}]) - self.assertEqual(2, db.test.delete_many({}).deleted_count) - self.assertEqual(0, db.test.delete_many({}).deleted_count) + self.assertEqual(2, (db.test.delete_many({})).deleted_count) + self.assertEqual(0, (db.test.delete_many({})).deleted_count) @client_context.require_version_max(4, 9) def test_manual_last_error(self): @@ -1475,12 +1535,13 @@ class TestCollection(IntegrationTest): db.drop_collection("test") db.test.insert_one({"foo": [1, 2]}) - self.assertRaises(TypeError, db.test.aggregate, "wow") + with self.assertRaises(TypeError): + db.test.aggregate("wow") # type: ignore[arg-type] pipeline = {"$project": {"_id": False, "foo": True}} result = db.test.aggregate([pipeline]) self.assertTrue(isinstance(result, CommandCursor)) - self.assertEqual([{"foo": [1, 2]}], list(result)) + self.assertEqual([{"foo": [1, 2]}], result.to_list()) # Test write concern. with self.write_concern_collection() as coll: @@ -1491,7 +1552,8 @@ class TestCollection(IntegrationTest): db.drop_collection("test") db.test.insert_one({"foo": [1, 2]}) - self.assertRaises(TypeError, db.test.aggregate, "wow") + with self.assertRaises(TypeError): + db.test.aggregate("wow") # type: ignore[arg-type] pipeline = {"$project": {"_id": False, "foo": True}} coll = db.get_collection("test", codec_options=CodecOptions(document_class=RawBSONDocument)) @@ -1524,7 +1586,7 @@ class TestCollection(IntegrationTest): # Use batchSize to ensure multiple getMore messages cursor = db.test.aggregate([{"$project": {"_id": "$_id"}}], batchSize=5) - self.assertEqual(expected_sum, sum(doc["_id"] for doc in cursor)) + self.assertEqual(expected_sum, sum(doc["_id"] for doc in cursor.to_list())) # Test that batchSize is handled properly. cursor = db.test.aggregate([], batchSize=5) @@ -1558,7 +1620,8 @@ class TestCollection(IntegrationTest): with self.db.test.aggregate([], {}): # type:ignore pass - self.assertRaisesRegex(ValueError, "must be a ClientSession", try_invalid_session) + with self.assertRaisesRegex(ValueError, "must be a ClientSession"): + try_invalid_session() def test_large_limit(self): db = self.db @@ -1570,7 +1633,7 @@ class TestCollection(IntegrationTest): i = 0 y = 0 - for doc in db.test_large_limit.find(limit=1900).sort([("x", 1)]): + for doc in (db.test_large_limit.find(limit=1900)).sort([("x", 1)]): i += 1 y += doc["x"] @@ -1595,12 +1658,18 @@ class TestCollection(IntegrationTest): db.drop_collection("test") db.drop_collection("foo") - self.assertRaises(TypeError, db.test.rename, 5) - self.assertRaises(InvalidName, db.test.rename, "") - self.assertRaises(InvalidName, db.test.rename, "te$t") - self.assertRaises(InvalidName, db.test.rename, ".test") - self.assertRaises(InvalidName, db.test.rename, "test.") - self.assertRaises(InvalidName, db.test.rename, "tes..t") + with self.assertRaises(TypeError): + db.test.rename(5) # type: ignore[arg-type] + with self.assertRaises(InvalidName): + db.test.rename("") + with self.assertRaises(InvalidName): + db.test.rename("te$t") + with self.assertRaises(InvalidName): + db.test.rename(".test") + with self.assertRaises(InvalidName): + db.test.rename("test.") + with self.assertRaises(InvalidName): + db.test.rename("tes..t") self.assertEqual(0, db.test.count_documents({})) self.assertEqual(0, db.foo.count_documents({})) @@ -1620,7 +1689,8 @@ class TestCollection(IntegrationTest): x += 1 db.test.insert_one({}) - self.assertRaises(OperationFailure, db.foo.rename, "test") + with self.assertRaises(OperationFailure): + db.foo.rename("test") db.foo.rename("test", dropTarget=True) with self.write_concern_collection() as coll: @@ -1631,9 +1701,9 @@ class TestCollection(IntegrationTest): db = self.db db.drop_collection("test") - _id = db.test.insert_one({"hello": "world", "foo": "bar"}).inserted_id + _id = (db.test.insert_one({"hello": "world", "foo": "bar"})).inserted_id - self.assertEqual("world", db.test.find_one()["hello"]) + self.assertEqual("world", (db.test.find_one())["hello"]) self.assertEqual(db.test.find_one(_id), db.test.find_one()) self.assertEqual(db.test.find_one(None), db.test.find_one()) self.assertEqual(db.test.find_one({}), db.test.find_one()) @@ -1673,8 +1743,8 @@ class TestCollection(IntegrationTest): db.test.insert_many([{"x": i} for i in range(1, 4)]) - self.assertEqual(1, db.test.find_one()["x"]) - self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) + self.assertEqual(1, (db.test.find_one())["x"]) + self.assertEqual(2, (db.test.find_one(skip=1, limit=2))["x"]) def test_find_with_sort(self): db = self.db @@ -1682,9 +1752,9 @@ class TestCollection(IntegrationTest): db.test.insert_many([{"x": 2}, {"x": 1}, {"x": 3}]) - self.assertEqual(2, db.test.find_one()["x"]) - self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) - self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) + self.assertEqual(2, (db.test.find_one())["x"]) + self.assertEqual(1, (db.test.find_one(sort=[("x", 1)]))["x"]) + self.assertEqual(3, (db.test.find_one(sort=[("x", -1)]))["x"]) def to_list(things): return [thing["x"] for thing in things] @@ -1693,31 +1763,37 @@ class TestCollection(IntegrationTest): self.assertEqual([1, 2, 3], to_list(db.test.find(sort=[("x", 1)]))) self.assertEqual([3, 2, 1], to_list(db.test.find(sort=[("x", -1)]))) - self.assertRaises(TypeError, db.test.find, sort=5) - self.assertRaises(TypeError, db.test.find, sort="hello") - self.assertRaises(TypeError, db.test.find, sort=["hello", 1]) + with self.assertRaises(TypeError): + db.test.find(sort=5) + with self.assertRaises(TypeError): + db.test.find(sort="hello") + with self.assertRaises(TypeError): + db.test.find(sort=["hello", 1]) # TODO doesn't actually test functionality, just that it doesn't blow up def test_cursor_timeout(self): - list(self.db.test.find(no_cursor_timeout=True)) - list(self.db.test.find(no_cursor_timeout=False)) + (self.db.test.find(no_cursor_timeout=True)).to_list() + (self.db.test.find(no_cursor_timeout=False)).to_list() def test_exhaust(self): if is_mongos(self.db.client): - self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST) + with self.assertRaises(InvalidOperation): + self.db.test.find(cursor_type=CursorType.EXHAUST) return # Limit is incompatible with exhaust. - self.assertRaises( - InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST, limit=5 - ) + with self.assertRaises(InvalidOperation): + self.db.test.find(cursor_type=CursorType.EXHAUST, limit=5) cur = self.db.test.find(cursor_type=CursorType.EXHAUST) - self.assertRaises(InvalidOperation, cur.limit, 5) + with self.assertRaises(InvalidOperation): + cur.limit(5) cur = self.db.test.find(limit=5) - self.assertRaises(InvalidOperation, cur.add_option, 64) + with self.assertRaises(InvalidOperation): + cur.add_option(64) cur = self.db.test.find() cur.add_option(64) - self.assertRaises(InvalidOperation, cur.limit, 5) + with self.assertRaises(InvalidOperation): + cur.limit(5) self.db.drop_collection("test") # Insert enough documents to require more than one batch @@ -1752,9 +1828,9 @@ class TestCollection(IntegrationTest): else: next(cur) self.assertEqual(0, len(pool.conns)) - if sys.platform.startswith("java") or "PyPy" in sys.version: - # Don't wait for GC or use gc.collect(), it's unreliable. - cur.close() + # if sys.platform.startswith("java") or "PyPy" in sys.version: + # # Don't wait for GC or use gc.collect(), it's unreliable. + cur.close() cur = None # Wait until the background thread returns the socket. wait_until(lambda: pool.active_sockets == 0, "return socket") @@ -1772,7 +1848,7 @@ class TestCollection(IntegrationTest): self.assertEqual([1, 2, 3], distinct) - distinct = test.find({"a": {"$gt": 1}}).distinct("a") + distinct = (test.find({"a": {"$gt": 1}})).distinct("a") distinct.sort() self.assertEqual([2, 3], distinct) @@ -1798,7 +1874,7 @@ class TestCollection(IntegrationTest): self.db.test.insert_one({"bar": "foo"}) self.assertEqual(1, self.db.test.count_documents({"query": {"$ne": None}})) - self.assertEqual(1, len(list(self.db.test.find({"query": {"$ne": None}})))) + self.assertEqual(1, len((self.db.test.find({"query": {"$ne": None}})).to_list())) def test_min_query(self): self.db.drop_collection("test") @@ -1807,7 +1883,7 @@ class TestCollection(IntegrationTest): cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}}, hint="x_1") - docs = list(cursor) + docs = cursor.to_list() self.assertEqual(1, len(docs)) self.assertEqual(2, docs[0]["x"]) @@ -1857,9 +1933,11 @@ class TestCollection(IntegrationTest): # 2 batches, 2nd insert fails, unacknowledged, ordered. unack_coll = db.collection_2.with_options(write_concern=WriteConcern(w=0)) unack_coll.insert_many(insert_second_fails) - wait_until( - lambda: db.collection_2.count_documents({}) == 1, "insert 1 document", timeout=60 - ) + + def async_lambda(): + return db.collection_2.count_documents({}) == 1 + + wait_until(async_lambda, "insert 1 document", timeout=60) db.collection_2.drop() @@ -1886,10 +1964,11 @@ class TestCollection(IntegrationTest): unack_coll = db.collection_4.with_options(write_concern=WriteConcern(w=0)) unack_coll.insert_many(insert_two_failures, ordered=False) + def async_lambda(): + return db.collection_4.count_documents({}) == 2 + # Only the first and third documents are inserted. - wait_until( - lambda: db.collection_4.count_documents({}) == 2, "insert 2 documents", timeout=60 - ) + wait_until(async_lambda, "insert 2 documents", timeout=60) db.collection_4.drop() @@ -1900,7 +1979,7 @@ class TestCollection(IntegrationTest): db["Employés"].replace_one({"x": 1}, {"x": 2}) db["Employés"].delete_many({}) db["Employés"].find_one() - list(db["Employés"].find()) + (db["Employés"].find()).to_list() def test_drop_indexes_non_existent(self): self.db.drop_collection("test") @@ -1911,7 +1990,8 @@ class TestCollection(IntegrationTest): def test_bad_encode(self): c = self.db.test c.drop() - self.assertRaises(InvalidDocument, c.insert_one, {"x": c}) + with self.assertRaises(InvalidDocument): + c.insert_one({"x": c}) class BadGetAttr(dict): def __getattr__(self, name): @@ -1919,7 +1999,7 @@ class TestCollection(IntegrationTest): bad = BadGetAttr([("foo", "bar")]) c.insert_one({"bad": bad}) - self.assertEqual("bar", c.find_one()["bad"]["foo"]) # type: ignore + self.assertEqual("bar", (c.find_one())["bad"]["foo"]) # type: ignore def test_array_filters_validation(self): # array_filters must be a list. @@ -1992,11 +2072,11 @@ class TestCollection(IntegrationTest): c.insert_one({"j": j, "i": 0}) sort = [("j", DESCENDING)] - self.assertEqual(4, c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort)["j"]) + self.assertEqual(4, (c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort))["j"]) def test_find_one_and_write_concern(self): listener = EventListener() - db = single_client(event_listeners=[listener])[self.db.name] + db = (single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. @@ -2021,24 +2101,16 @@ class TestCollection(IntegrationTest): c_wc_error = db.get_collection( "test", write_concern=WriteConcern(w=len(client_context.nodes) + 1) ) - self.assertRaises( - WriteConcernError, - c_wc_error.find_one_and_update, - {"_id": 1}, - {"$set": {"foo": "bar"}}, - ) - self.assertRaises( - WriteConcernError, - c_wc_error.find_one_and_replace, - {"w": 0}, - listener.started_events[0].command["writeConcern"], - ) - self.assertRaises( - WriteConcernError, - c_wc_error.find_one_and_delete, - {"w": 0}, - listener.started_events[0].command["writeConcern"], - ) + with self.assertRaises(WriteConcernError): + c_wc_error.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) + with self.assertRaises(WriteConcernError): + c_wc_error.find_one_and_replace( + {"w": 0}, listener.started_events[0].command["writeConcern"] + ) + with self.assertRaises(WriteConcernError): + c_wc_error.find_one_and_delete( + {"w": 0}, listener.started_events[0].command["writeConcern"] + ) listener.reset() c_default.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) @@ -2116,7 +2188,7 @@ class TestCollection(IntegrationTest): c.drop() c.insert_one({"r": re.compile(".*")}) - self.assertTrue(isinstance(c.find_one()["r"], Regex)) # type: ignore + self.assertTrue(isinstance((c.find_one())["r"], Regex)) # type: ignore for doc in c.find(): self.assertTrue(isinstance(doc["r"], Regex)) diff --git a/test/test_encryption.py b/test/test_encryption.py index 9306876d1..0423f748c 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -36,6 +36,12 @@ from pymongo.synchronous.collection import Collection sys.path[0:0] = [""] from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) +from test.helpers import ( AWS_CREDS, AZURE_CREDS, CA_PEM, @@ -43,10 +49,6 @@ from test import ( GCP_CREDS, KMIP_CREDS, LOCAL_MASTER_KEY, - IntegrationTest, - PyMongoTestCase, - client_context, - unittest, ) from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 3a8bf6275..29cde7e07 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -24,7 +24,8 @@ import warnings sys.path[0:0] = [""] -from test import clear_warning_registry, unittest +from test import unittest +from test.helpers import clear_warning_registry from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _have_snappy diff --git a/test/unified_format.py b/test/unified_format.py index 50190982c..cb97653b0 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -31,6 +31,11 @@ import traceback import types from collections import abc, defaultdict from test import ( + IntegrationTest, + client_context, + unittest, +) +from test.helpers import ( AWS_CREDS, AWS_CREDS_2, AZURE_CREDS, @@ -39,9 +44,6 @@ from test import ( GCP_CREDS, KMIP_CREDS, LOCAL_MASTER_KEY, - IntegrationTest, - client_context, - unittest, ) from test.utils import ( CMAPListener, diff --git a/tools/synchro.py b/tools/synchro.py index 1c555748f..2a874bd18 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -40,6 +40,7 @@ replacements = { "async_receive_message": "receive_message", "async_sendall": "sendall", "asynchronous": "synchronous", + "Asynchronous": "Synchronous", "anext": "next", "_ALock": "_Lock", "_ACondition": "_Condition", @@ -60,6 +61,7 @@ replacements = { "AsyncTestCollection": "TestCollection", "AsyncIntegrationTest": "IntegrationTest", "AsyncPyMongoTestCase": "PyMongoTestCase", + "AsyncMockClientTest": "MockClientTest", "async_client_context": "client_context", "async_setup": "setup", "asyncSetUp": "setUp", @@ -100,7 +102,7 @@ _test_base = "./test/asynchronous/" _pymongo_dest_base = "./pymongo/synchronous/" _gridfs_dest_base = "./gridfs/synchronous/" -_test_dest_base = "./test/synchronous/" +_test_dest_base = "./test/" async_files = [ @@ -125,8 +127,15 @@ sync_gridfs_files = [ if (Path(_gridfs_dest_base) / f).is_file() ] +# Add each asynchronized test here as part of the converting PR +converted_tests = [ + "__init__.py", + "conftest.py", + "test_collection.py", +] + sync_test_files = [ - _test_dest_base + f for f in listdir(_test_dest_base) if (Path(_test_dest_base) / f).is_file() + _test_dest_base + f for f in converted_tests if (Path(_test_dest_base) / f).is_file() ] @@ -245,7 +254,7 @@ def translate_docstrings(lines: list[str]) -> list[str]: if "An asynchronous" in lines[i]: lines[i] = lines[i].replace("An asynchronous", "A") lines[i] = lines[i].replace(k, replacements[k]) - if "Sync" in lines[i] and replacements[k] in lines[i]: + if "Sync" in lines[i] and "Synchronous" not in lines[i] and replacements[k] in lines[i]: lines[i] = lines[i].replace("Sync", "") for i in range(len(lines)): for k in docstring_replacements: # type: ignore[assignment] diff --git a/tools/synchro.sh b/tools/synchro.sh index f5e7ab68c..2887509fe 100644 --- a/tools/synchro.sh +++ b/tools/synchro.sh @@ -1,5 +1,5 @@ #!/bin/bash -eu python ./tools/synchro.py -python -m ruff check pymongo/synchronous/ gridfs/synchronous/ test/synchronous --fix --silent -python -m ruff format pymongo/synchronous/ gridfs/synchronous/ test/synchronous --silent +python -m ruff check pymongo/synchronous/ gridfs/synchronous/ test/ --fix --silent +python -m ruff format pymongo/synchronous/ gridfs/synchronous/ test/ --silent diff --git a/tox.ini b/tox.ini index a154bf424..3295f4014 100644 --- a/tox.ini +++ b/tox.ini @@ -66,7 +66,6 @@ extras = test commands = pytest -v --durations=5 --maxfail=10 {posargs} - pytest -v --durations=5 --maxfail=10 test/synchronous/ {posargs} [testenv:test-async] description = run base set of async unit tests with no extra functionality