diff --git a/.evergreen/generated_configs/tasks.yml b/.evergreen/generated_configs/tasks.yml index c666c6901..37b7a622d 100644 --- a/.evergreen/generated_configs/tasks.yml +++ b/.evergreen/generated_configs/tasks.yml @@ -227,6 +227,186 @@ tasks: - noauth - nossl - sync_async + - name: test-4.2-standalone-auth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - standalone + - auth + - ssl + - sync + - name: test-4.2-standalone-auth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - standalone + - auth + - ssl + - async + - name: test-4.2-standalone-auth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - standalone + - auth + - ssl + - sync_async + - name: test-4.2-standalone-noauth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - standalone + - noauth + - ssl + - sync + - name: test-4.2-standalone-noauth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - standalone + - noauth + - ssl + - async + - name: test-4.2-standalone-noauth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - standalone + - noauth + - ssl + - sync_async + - name: test-4.2-standalone-noauth-nossl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - standalone + - noauth + - nossl + - sync + - name: test-4.2-standalone-noauth-nossl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - standalone + - noauth + - nossl + - async + - name: test-4.2-standalone-noauth-nossl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: server + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - standalone + - noauth + - nossl + - sync_async - name: test-4.4-standalone-auth-ssl-sync commands: - func: bootstrap mongo-orchestration @@ -1667,6 +1847,186 @@ tasks: - noauth - nossl - sync_async + - name: test-4.2-replica_set-auth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - replica_set + - auth + - ssl + - sync + - name: test-4.2-replica_set-auth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - replica_set + - auth + - ssl + - async + - name: test-4.2-replica_set-auth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - replica_set + - auth + - ssl + - sync_async + - name: test-4.2-replica_set-noauth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - replica_set + - noauth + - ssl + - sync + - name: test-4.2-replica_set-noauth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - replica_set + - noauth + - ssl + - async + - name: test-4.2-replica_set-noauth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - replica_set + - noauth + - ssl + - sync_async + - name: test-4.2-replica_set-noauth-nossl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - replica_set + - noauth + - nossl + - sync + - name: test-4.2-replica_set-noauth-nossl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - replica_set + - noauth + - nossl + - async + - name: test-4.2-replica_set-noauth-nossl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: replica_set + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - replica_set + - noauth + - nossl + - sync_async - name: test-4.4-replica_set-auth-ssl-sync commands: - func: bootstrap mongo-orchestration @@ -3107,6 +3467,186 @@ tasks: - noauth - nossl - sync_async + - name: test-4.2-sharded_cluster-auth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - sync + - name: test-4.2-sharded_cluster-auth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - async + - name: test-4.2-sharded_cluster-auth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: auth + SSL: ssl + - func: run tests + vars: + AUTH: auth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - sharded_cluster + - auth + - ssl + - sync_async + - name: test-4.2-sharded_cluster-noauth-ssl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - sync + - name: test-4.2-sharded_cluster-noauth-ssl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - async + - name: test-4.2-sharded_cluster-noauth-ssl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: ssl + - func: run tests + vars: + AUTH: noauth + SSL: ssl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - sharded_cluster + - noauth + - ssl + - sync_async + - name: test-4.2-sharded_cluster-noauth-nossl-sync + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync + TEST_SUITES: default + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - sync + - name: test-4.2-sharded_cluster-noauth-nossl-async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: async + TEST_SUITES: default_async + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - async + - name: test-4.2-sharded_cluster-noauth-nossl-sync_async + commands: + - func: bootstrap mongo-orchestration + vars: + VERSION: "4.2" + TOPOLOGY: sharded_cluster + AUTH: noauth + SSL: nossl + - func: run tests + vars: + AUTH: noauth + SSL: nossl + SYNC: sync_async + TEST_SUITES: "" + tags: + - "4.2" + - sharded_cluster + - noauth + - nossl + - sync_async - name: test-4.4-sharded_cluster-auth-ssl-sync commands: - func: bootstrap mongo-orchestration diff --git a/.evergreen/generated_configs/variants.yml b/.evergreen/generated_configs/variants.yml index 79c9b22c9..20b89f7e6 100644 --- a/.evergreen/generated_configs/variants.yml +++ b/.evergreen/generated_configs/variants.yml @@ -817,10 +817,23 @@ buildvariants: PYTHON_BINARY: /opt/python/3.13/bin/python3 # Ocsp tests - - name: ocsp-rhel8-v4.4-python3.9 + - name: ocsp-rhel8-v4.2-python3.9 tasks: - name: .ocsp - display_name: OCSP RHEL8 v4.4 Python3.9 + display_name: OCSP RHEL8 v4.2 Python3.9 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.2" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + - name: ocsp-rhel8-v4.4-python3.10 + tasks: + - name: .ocsp + display_name: OCSP RHEL8 v4.4 Python3.10 run_on: - rhel87-small batchtime: 20160 @@ -829,11 +842,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: "4.4" - PYTHON_BINARY: /opt/python/3.9/bin/python3 - - name: ocsp-rhel8-v5.0-python3.10 + PYTHON_BINARY: /opt/python/3.10/bin/python3 + - name: ocsp-rhel8-v5.0-python3.11 tasks: - name: .ocsp - display_name: OCSP RHEL8 v5.0 Python3.10 + display_name: OCSP RHEL8 v5.0 Python3.11 run_on: - rhel87-small batchtime: 20160 @@ -842,11 +855,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: "5.0" - PYTHON_BINARY: /opt/python/3.10/bin/python3 - - name: ocsp-rhel8-v6.0-python3.11 + PYTHON_BINARY: /opt/python/3.11/bin/python3 + - name: ocsp-rhel8-v6.0-python3.12 tasks: - name: .ocsp - display_name: OCSP RHEL8 v6.0 Python3.11 + display_name: OCSP RHEL8 v6.0 Python3.12 run_on: - rhel87-small batchtime: 20160 @@ -855,11 +868,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: "6.0" - PYTHON_BINARY: /opt/python/3.11/bin/python3 - - name: ocsp-rhel8-v7.0-python3.12 + PYTHON_BINARY: /opt/python/3.12/bin/python3 + - name: ocsp-rhel8-v7.0-python3.13 tasks: - name: .ocsp - display_name: OCSP RHEL8 v7.0 Python3.12 + display_name: OCSP RHEL8 v7.0 Python3.13 run_on: - rhel87-small batchtime: 20160 @@ -868,11 +881,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: "7.0" - PYTHON_BINARY: /opt/python/3.12/bin/python3 - - name: ocsp-rhel8-v8.0-python3.13 + PYTHON_BINARY: /opt/python/3.13/bin/python3 + - name: ocsp-rhel8-v8.0-pypy3.10 tasks: - name: .ocsp - display_name: OCSP RHEL8 v8.0 Python3.13 + display_name: OCSP RHEL8 v8.0 PyPy3.10 run_on: - rhel87-small batchtime: 20160 @@ -881,11 +894,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: "8.0" - PYTHON_BINARY: /opt/python/3.13/bin/python3 - - name: ocsp-rhel8-rapid-pypy3.10 + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + - name: ocsp-rhel8-rapid-python3.9 tasks: - name: .ocsp - display_name: OCSP RHEL8 rapid PyPy3.10 + display_name: OCSP RHEL8 rapid Python3.9 run_on: - rhel87-small batchtime: 20160 @@ -894,11 +907,11 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: rapid - PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 - - name: ocsp-rhel8-latest-python3.9 + PYTHON_BINARY: /opt/python/3.9/bin/python3 + - name: ocsp-rhel8-latest-python3.10 tasks: - name: .ocsp - display_name: OCSP RHEL8 latest Python3.9 + display_name: OCSP RHEL8 latest Python3.10 run_on: - rhel87-small batchtime: 20160 @@ -907,7 +920,7 @@ buildvariants: SSL: ssl TOPOLOGY: server VERSION: latest - PYTHON_BINARY: /opt/python/3.9/bin/python3 + PYTHON_BINARY: /opt/python/3.10/bin/python3 - name: ocsp-win64-v4.4-python3.9 tasks: - name: .ocsp-rsa !.ocsp-staple @@ -1066,6 +1079,19 @@ buildvariants: PYTHON_BINARY: /opt/python/3.9/bin/python3 # Server tests + - name: test-rhel8-python3.9-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 Python3.9 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.9/bin/python3 + tags: [coverage_tag] - name: test-rhel8-python3.9-cov tasks: - name: .standalone .sync_async @@ -1078,6 +1104,19 @@ buildvariants: COVERAGE: coverage PYTHON_BINARY: /opt/python/3.9/bin/python3 tags: [coverage_tag] + - name: test-rhel8-python3.13-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 Python3.13 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/3.13/bin/python3 + tags: [coverage_tag] - name: test-rhel8-python3.13-cov tasks: - name: .standalone .sync_async @@ -1090,6 +1129,19 @@ buildvariants: COVERAGE: coverage PYTHON_BINARY: /opt/python/3.13/bin/python3 tags: [coverage_tag] + - name: test-rhel8-pypy3.10-cov-no-c + tasks: + - name: .standalone .sync_async + - name: .replica_set .sync_async + - name: .sharded_cluster .sync_async + display_name: "* Test RHEL8 PyPy3.10 cov No C" + run_on: + - rhel87-small + expansions: + COVERAGE: coverage + NO_EXT: "1" + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 + tags: [coverage_tag] - name: test-rhel8-pypy3.10-cov tasks: - name: .standalone .sync_async @@ -1338,6 +1390,7 @@ buildvariants: - name: storage-inmemory-rhel8-python3.9 tasks: - name: .standalone .noauth .nossl .4.0 .sync_async + - name: .standalone .noauth .nossl .4.2 .sync_async - name: .standalone .noauth .nossl .4.4 .sync_async - name: .standalone .noauth .nossl .5.0 .sync_async - name: .standalone .noauth .nossl .6.0 .sync_async diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py index e9624ab10..133783637 100644 --- a/.evergreen/scripts/generate_config.py +++ b/.evergreen/scripts/generate_config.py @@ -26,7 +26,7 @@ from shrub.v3.shrub_service import ShrubService # Globals ############## -ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] +ALL_VERSIONS = ["4.0", "4.2", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] PYPYS = ["pypy3.10"] ALL_PYTHONS = CPYTHONS + PYPYS @@ -279,8 +279,9 @@ def create_server_variants() -> list[BuildVariant]: host = DEFAULT_HOST # Prefix the display name with an asterisk so it is sorted first. base_display_name = "* Test" - for python in [*MIN_MAX_PYTHON, PYPYS[-1]]: + for python, c_ext in product([*MIN_MAX_PYTHON, PYPYS[-1]], C_EXTS): expansions = dict(COVERAGE="coverage") + handle_c_ext(c_ext, expansions) display_name = get_display_name(base_display_name, host, python=python, **expansions) variant = create_variant( [f".{t} .sync_async" for t in TOPOLOGIES], diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 536110fcf..f67077e57 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -178,7 +178,7 @@ documentation including narrative docs, and the [Sphinx docstring format](https: You can build the documentation locally by running: ```bash -just docs-build +just docs ``` When updating docs, it can be helpful to run the live docs server as: @@ -261,6 +261,11 @@ To prevent the `synchro` hook from accidentally overwriting code, it first check of a file is changing and not its async counterpart, and will fail. In the unlikely scenario that you want to override this behavior, first export `OVERRIDE_SYNCHRO_CHECK=1`. +Sometimes, the `synchro` hook will fail and introduce changes many previously unmodified files. This is due to static +Python errors, such as missing imports, incorrect syntax, or other fatal typos. To resolve these issues, +run `pre-commit run --all-files --hook-stage manual ruff` and fix all reported errors before running the `synchro` +hook again. + ## Converting a test to async The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a partially-converted asynchronous version of the same name to the `test/asynchronous` directory. diff --git a/doc/async-tutorial.rst b/doc/async-tutorial.rst index 2ccf011d8..7a3a98711 100644 --- a/doc/async-tutorial.rst +++ b/doc/async-tutorial.rst @@ -420,3 +420,10 @@ the collection: DuplicateKeyError: E11000 duplicate key error index: test_database.profiles.$user_id_1 dup key: { : 212 } .. seealso:: The MongoDB documentation on `indexes `_ + +Task Cancellation +----------------- +`Cancelling `_ an asyncio Task +that is running a PyMongo operation is treated as a fatal interrupt. Any connections, cursors, and transactions +involved in a cancelled Task will be safely closed and cleaned up as part of the cancellation. If those resources are +also used elsewhere, attempting to utilize them after the cancellation will result in an error. diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index d15713c51..baa88d480 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -231,7 +231,7 @@ class AsyncGridFS: try: doc = await anext(cursor) return AsyncGridOut(self._collection, file_document=doc, session=session) - except StopIteration: + except StopAsyncIteration: raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 719020c40..f405e9116 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -391,7 +391,8 @@ class AsyncChangeStream(Generic[_DocumentType]): if not _resumable(exc) and not exc.timeout: await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 4c5171a35..e9548b0ec 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -21,7 +21,7 @@ Causally Consistent Reads .. code-block:: python - with client.start_session(causal_consistency=True) as session: + async with client.start_session(causal_consistency=True) as session: collection = client.db.collection await collection.update_one({"_id": 1}, {"$set": {"x": 10}}, session=session) secondary_c = collection.with_options(read_preference=ReadPreference.SECONDARY) @@ -53,8 +53,8 @@ operation: orders = client.db.orders inventory = client.db.inventory - with client.start_session() as session: - async with session.start_transaction(): + async with client.start_session() as session: + async with await session.start_transaction(): await orders.insert_one({"sku": "abc123", "qty": 100}, session=session) await inventory.update_one( {"sku": "abc123", "qty": {"$gte": 100}}, @@ -62,7 +62,7 @@ operation: session=session, ) -Upon normal completion of ``async with session.start_transaction()`` block, the +Upon normal completion of ``async with await session.start_transaction()`` block, the transaction automatically calls :meth:`AsyncClientSession.commit_transaction`. If the block exits with an exception, the transaction automatically calls :meth:`AsyncClientSession.abort_transaction`. @@ -113,7 +113,7 @@ replica set secondaries. .. code-block:: python # Each read using this session reads data from the same point in time. - with client.start_session(snapshot=True) as session: + async with client.start_session(snapshot=True) as session: order = await orders.find_one({"sku": "abc123"}, session=session) inventory = await inventory.find_one({"sku": "abc123"}, session=session) @@ -619,7 +619,7 @@ class AsyncClientSession: await inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}}, {"$inc": {"qty": -100}}, session=session) - with client.start_session() as session: + async with client.start_session() as session: await session.with_transaction(callback) To pass arbitrary arguments to the ``callback``, wrap your callable @@ -628,7 +628,7 @@ class AsyncClientSession: async def callback(session, custom_arg, custom_kwarg=None): # Transaction operations... - with client.start_session() as session: + async with client.start_session() as session: await session.with_transaction( lambda s: callback(s, "custom_arg", custom_kwarg=1)) @@ -697,7 +697,8 @@ class AsyncClientSession: ) try: ret = await callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: await self.abort_transaction() if ( diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 9101197ce..1b25bf4ee 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1126,7 +1126,8 @@ class AsyncCursor(Generic[_DocumentType]): self._killed = True await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index f777104cf..9d3ea6719 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -127,8 +127,6 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptionError(exc) from exc @@ -766,8 +764,6 @@ class AsyncClientEncryption(Generic[_DocumentType]): await database.create_collection(name=name, **kwargs), encrypted_fields, ) - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 365fc6210..37be9a194 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -276,7 +276,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): :param type_registry: instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. - :param datetime_conversion: Specifies how UTC datetimes should be decoded + :param kwargs: **Additional optional parameters available as keyword arguments:** + + - `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded within BSON. Valid options include 'datetime_ms' to return as a DatetimeMS, 'datetime' to return as a datetime.datetime and raising a ValueError for out-of-range values, 'datetime_auto' to @@ -284,9 +286,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): out-of-range and 'datetime_clamp' to clamp to the minimum and maximum possible datetimes. Defaults to 'datetime'. See :ref:`handling-out-of-range-datetimes` for details. - - | **Other optional parameters can be passed as keyword arguments:** - - `directConnection` (optional): if ``True``, forces this client to connect directly to the specified MongoDB host as a standalone. If ``false``, the client connects to the entire replica set of @@ -2044,8 +2043,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_id, conn_mgr in pinned_cursors: try: await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2060,8 +2057,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_ids in address_to_cursor_ids.items(): try: await self._kill_cursors(cursor_ids, address, topology, session=None) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2076,8 +2071,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): try: await self._process_kill_cursors() await self._topology.update_pool() - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index abde7a905..15289af4d 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -262,8 +262,6 @@ class Monitor(MonitorBase): details = cast(Mapping[str, Any], exc.details) await self._topology.receive_cluster_time(details.get("$clusterTime")) raise - except asyncio.CancelledError: - raise except ReferenceError: raise except Exception as error: @@ -429,8 +427,6 @@ class SrvMonitor(MonitorBase): if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception - except asyncio.CancelledError: - raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -494,8 +490,6 @@ class _RttMonitor(MonitorBase): except ReferenceError: # Topology was garbage-collected. await self.close() - except asyncio.CancelledError: - raise except Exception: await self._pool.reset() diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index bf2f2b494..1da695c5c 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -559,7 +559,7 @@ class AsyncConnection: ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ class AsyncConnection: try: await async_sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ class AsyncConnection: """ try: return await receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -704,8 +706,6 @@ class AsyncConnection: # shutdown. try: self.conn.close() - except asyncio.CancelledError: - raise except Exception: # noqa: S110 pass @@ -1269,6 +1269,7 @@ class Pool: try: sock = await _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: self.active_contexts.discard(tmp_context) @@ -1308,6 +1309,7 @@ class Pool: handler.contribute_socket(conn, completed_handshake=False) await conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: async with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1369,6 +1371,7 @@ class Pool: async with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1515,6 +1518,7 @@ class Pool: async with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index f51a98872..5f54b243e 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -100,6 +100,7 @@ class AsyncPeriodicExecutor: if not await self._target(): self._stopped = True break + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self._stopped = True raise @@ -232,6 +233,7 @@ class PeriodicExecutor: if not self._target(): self._stopped = True break + # Catch KeyboardInterrupt, etc. and cleanup. except BaseException: with self._lock: self._stopped = True diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index a971ad08c..43aab39ee 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -389,7 +389,8 @@ class ChangeStream(Generic[_DocumentType]): if not _resumable(exc) and not exc.timeout: self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 298dd7b35..af7ff59b3 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -694,7 +694,8 @@ class ClientSession: self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: self.abort_transaction() if ( diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index cda093ee1..31c4604f8 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1124,7 +1124,8 @@ class Cursor(Generic[_DocumentType]): self._killed = True self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 59f38e191..7cbac1c50 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -15,7 +15,6 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations -import asyncio import contextlib import enum import socket @@ -127,8 +126,6 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptionError(exc) from exc @@ -760,8 +757,6 @@ class ClientEncryption(Generic[_DocumentType]): database.create_collection(name=name, **kwargs), encrypted_fields, ) - except asyncio.CancelledError: - raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 8cd08ab72..373deabd4 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -274,7 +274,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): :param type_registry: instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. - :param datetime_conversion: Specifies how UTC datetimes should be decoded + :param kwargs: **Additional optional parameters available as keyword arguments:** + + - `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded within BSON. Valid options include 'datetime_ms' to return as a DatetimeMS, 'datetime' to return as a datetime.datetime and raising a ValueError for out-of-range values, 'datetime_auto' to @@ -282,9 +284,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): out-of-range and 'datetime_clamp' to clamp to the minimum and maximum possible datetimes. Defaults to 'datetime'. See :ref:`handling-out-of-range-datetimes` for details. - - | **Other optional parameters can be passed as keyword arguments:** - - `directConnection` (optional): if ``True``, forces this client to connect directly to the specified MongoDB host as a standalone. If ``false``, the client connects to the entire replica set of @@ -2038,8 +2037,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_id, conn_mgr in pinned_cursors: try: self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2054,8 +2051,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors(cursor_ids, address, topology, session=None) - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2070,8 +2065,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): try: self._process_kill_cursors() self._topology.update_pool() - except asyncio.CancelledError: - raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 211635d8b..802ba4742 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -260,8 +260,6 @@ class Monitor(MonitorBase): details = cast(Mapping[str, Any], exc.details) self._topology.receive_cluster_time(details.get("$clusterTime")) raise - except asyncio.CancelledError: - raise except ReferenceError: raise except Exception as error: @@ -427,8 +425,6 @@ class SrvMonitor(MonitorBase): if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception - except asyncio.CancelledError: - raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -492,8 +488,6 @@ class _RttMonitor(MonitorBase): except ReferenceError: # Topology was garbage-collected. self.close() - except asyncio.CancelledError: - raise except Exception: self._pool.reset() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 05f930d48..978f0ae39 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -559,7 +559,7 @@ class Connection: ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ class Connection: try: sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ class Connection: """ try: return receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -702,8 +704,6 @@ class Connection: # shutdown. try: self.conn.close() - except asyncio.CancelledError: - raise except Exception: # noqa: S110 pass @@ -1263,6 +1263,7 @@ class Pool: try: sock = _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: self.active_contexts.discard(tmp_context) @@ -1302,6 +1303,7 @@ class Pool: handler.contribute_socket(conn, completed_handshake=False) conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1363,6 +1365,7 @@ class Pool: with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1509,6 +1512,7 @@ class Pool: with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc9..40beb3c0d 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -66,7 +66,7 @@ class DummyMonitor: def cancel_check(self): pass - def join(self): + async def join(self): pass def open(self): @@ -75,7 +75,7 @@ class DummyMonitor: def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py new file mode 100644 index 000000000..b73c7a808 --- /dev/null +++ b/test/asynchronous/test_async_cancellation.py @@ -0,0 +1,126 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that async cancellation performed by users clean up resources correctly.""" +from __future__ import annotations + +import asyncio +import sys +from test.utils import async_get_pool, delay, one + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected + + +class TestAsyncCancellation(AsyncIntegrationTest): + async def test_async_cancellation_closes_connection(self): + pool = await async_get_pool(self.client) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + conn = one(pool.conns) + + async def task(): + await self.client.db.test.find_one({"$where": delay(0.2)}) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(conn.closed) + + @async_client_context.require_transactions + async def test_async_cancellation_aborts_transaction(self): + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + session = self.client.start_session() + + async def callback(session): + await self.client.db.test.find_one({"$where": delay(0.2)}, session=session) + + async def task(): + await session.with_transaction(callback) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertFalse(session.in_transaction) + + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_cursor(self): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + cursor = self.client.db.test.find({}, batch_size=1) + await cursor.next() + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await cursor.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(cursor._killed) + + @async_client_context.require_change_streams + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_change_stream(self): + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + change_stream = await self.client.db.test.watch(batch_size=2) + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + await change_stream.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(change_stream._closed) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 7191a412c..86568b666 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -961,7 +961,6 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): @async_client_context.require_replica_set @async_client_context.require_secondaries_count(1) async def test_write_concern_failure_ordered(self): - self.skipTest("Skipping until PYTHON-4865 is resolved.") details = None # Ensure we don't raise on wnote. diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py new file mode 100644 index 000000000..a68b2a90c --- /dev/null +++ b/test/asynchronous/test_connection_monitoring.py @@ -0,0 +1,479 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execute Transactions Spec tests.""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask +from test.utils import ( + CMAPListener, + async_client_context, + async_get_pool, + async_get_pools, + async_wait_until, + camel_to_snake, +) + +from bson.objectid import ObjectId +from bson.son import SON +from pymongo.asynchronous.pool import PoolState, _PoolClosedError +from pymongo.errors import ( + ConnectionFailure, + OperationFailure, + PyMongoError, + WaitQueueTimeoutError, +) +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionClosedReason, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_preferences import ReadPreference +from pymongo.topology_description import updated_topology_description + +_IS_SYNC = False + +OBJECT_TYPES = { + # Event types. + "ConnectionCheckedIn": ConnectionCheckedInEvent, + "ConnectionCheckedOut": ConnectionCheckedOutEvent, + "ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent, + "ConnectionClosed": ConnectionClosedEvent, + "ConnectionCreated": ConnectionCreatedEvent, + "ConnectionReady": ConnectionReadyEvent, + "ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent, + "ConnectionPoolCreated": PoolCreatedEvent, + "ConnectionPoolReady": PoolReadyEvent, + "ConnectionPoolCleared": PoolClearedEvent, + "ConnectionPoolClosed": PoolClosedEvent, + # Error types. + "PoolClosedError": _PoolClosedError, + "WaitQueueTimeoutError": WaitQueueTimeoutError, +} + + +class AsyncTestCMAP(AsyncIntegrationTest): + # Location of JSON test specifications. + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") + + # Test operations: + + async def start(self, op): + """Run the 'start' thread operation.""" + target = op["target"] + thread = SpecRunnerTask(target) + await thread.start() + self.targets[target] = thread + + async def wait(self, op): + """Run the 'wait' operation.""" + await asyncio.sleep(op["ms"] / 1000.0) + + async def wait_for_thread(self, op): + """Run the 'waitForThread' operation.""" + target = op["target"] + thread = self.targets[target] + await thread.stop() + await thread.join() + if thread.exc: + raise thread.exc + self.assertFalse(thread.ops) + + async def wait_for_event(self, op): + """Run the 'waitForEvent' operation.""" + event = OBJECT_TYPES[op["event"]] + count = op["count"] + timeout = op.get("timeout", 10000) / 1000.0 + await async_wait_until( + lambda: self.listener.event_count(event) >= count, + f"find {count} {event} event(s)", + timeout=timeout, + ) + + async def check_out(self, op): + """Run the 'checkOut' operation.""" + label = op["label"] + async with self.pool.checkout() as conn: + # Call 'pin_cursor' so we can hold the socket. + conn.pin_cursor() + if label: + self.labels[label] = conn + else: + self.addAsyncCleanup(conn.close_conn, None) + + async def check_in(self, op): + """Run the 'checkIn' operation.""" + label = op["connection"] + conn = self.labels[label] + await self.pool.checkin(conn) + + async def ready(self, op): + """Run the 'ready' operation.""" + await self.pool.ready() + + async def clear(self, op): + """Run the 'clear' operation.""" + if "interruptInUseConnections" in op: + await self.pool.reset(interrupt_connections=op["interruptInUseConnections"]) + else: + await self.pool.reset() + + async def close(self, op): + """Run the 'close' operation.""" + await self.pool.close() + + async def run_operation(self, op): + """Run a single operation in a test.""" + op_name = camel_to_snake(op["name"]) + thread = op["thread"] + meth = getattr(self, op_name) + if thread: + await self.targets[thread].schedule(lambda: meth(op)) + else: + await meth(op) + + async def run_operations(self, ops): + """Run a test's operations.""" + for op in ops: + self._ops.append(op) + await self.run_operation(op) + + def check_object(self, actual, expected): + """Assert that the actual object matches the expected object.""" + self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]]) + for attr, expected_val in expected.items(): + if attr == "type": + continue + c2s = camel_to_snake(attr) + if c2s == "interrupt_in_use_connections": + c2s = "interrupt_connections" + actual_val = getattr(actual, c2s) + if expected_val == 42: + self.assertIsNotNone(actual_val) + else: + self.assertEqual(actual_val, expected_val) + + def check_event(self, actual, expected): + """Assert that the actual event matches the expected event.""" + self.check_object(actual, expected) + + def actual_events(self, ignore): + """Return all the non-ignored events.""" + ignore = tuple(OBJECT_TYPES[name] for name in ignore) + return [event for event in self.listener.events if not isinstance(event, ignore)] + + def check_events(self, events, ignore): + """Check the events of a test.""" + actual_events = self.actual_events(ignore) + for actual, expected in zip(actual_events, events): + self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}") + self.check_event(actual, expected) + + if len(events) > len(actual_events): + self.fail(f"missing events: {events[len(actual_events) :]!r}") + + def check_error(self, actual, expected): + message = expected.pop("message") + self.check_object(actual, expected) + self.assertIn(message, str(actual)) + + async def _set_fail_point(self, client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + async def set_fail_point(self, command_args): + if not async_client_context.supports_failCommand_fail_point: + self.skipTest("failCommand fail point must be supported") + await self._set_fail_point(self.client, command_args) + + async def run_scenario(self, scenario_def, test): + """Run a CMAP spec test.""" + self.logs: list = [] + self.assertEqual(scenario_def["version"], 1) + self.assertIn(scenario_def["style"], ["unit", "integration"]) + self.listener = CMAPListener() + self._ops: list = [] + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + opts = test["poolOptions"].copy() + opts["event_listeners"] = [self.listener] + opts["_monitor_class"] = DummyMonitor + opts["connect"] = False + # Support backgroundThreadIntervalMS, default to 50ms. + interval = opts.pop("backgroundThreadIntervalMS", 50) + if interval < 0: + kill_cursor_frequency = 99999999 + else: + kill_cursor_frequency = interval / 1000.0 + with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): + client = await self.async_single_client(**opts) + # Update the SD to a known type because the DummyMonitor will not. + # Note we cannot simply call topology.on_change because that would + # internally call pool.ready() which introduces unexpected + # PoolReadyEvents. Instead, update the initial state before + # opening the Topology. + td = async_client_context.client._topology.description + sd = td.server_descriptions()[ + (await async_client_context.host, await async_client_context.port) + ] + client._topology._description = updated_topology_description( + client._topology._description, sd + ) + # When backgroundThreadIntervalMS is negative we do not start the + # background thread to ensure it never runs. + if interval < 0: + await client._topology.open() + else: + await client._get_topology() + self.pool = list(client._topology._servers.values())[0].pool + + # Map of target names to Thread objects. + self.targets: dict = {} + # Map of label names to AsyncConnection objects + self.labels: dict = {} + + async def cleanup(): + for t in self.targets.values(): + await t.stop() + for t in self.targets.values(): + await t.join(5) + for conn in self.labels.values(): + conn.close_conn(None) + + self.addAsyncCleanup(cleanup) + + try: + if test["error"]: + with self.assertRaises(PyMongoError) as ctx: + await self.run_operations(test["operations"]) + self.check_error(ctx.exception, test["error"]) + else: + await self.run_operations(test["operations"]) + + self.check_events(test["events"], test["ignore"]) + except Exception: + # Print the events after a test failure. + print("\nFailed test: {!r}".format(test["description"])) + print("Operations:") + for op in self._ops: + print(op) + print("Threads:") + print(self.targets) + print("AsyncConnections:") + print(self.labels) + print("Events:") + for event in self.listener.events: + print(event) + print("Log:") + for log in self.logs: + print(log) + raise + + POOL_OPTIONS = { + "maxPoolSize": 50, + "minPoolSize": 1, + "maxIdleTimeMS": 10000, + "waitQueueTimeoutMS": 10000, + } + + # + # Prose tests. Numbers correspond to the prose test number in the spec. + # + async def test_1_client_connection_pool_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_2_all_client_pools_have_same_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + await client.admin.command("ping") + # Discover at least one secondary. + if await async_client_context.has_secondaries: + await client.admin.command("ping", read_preference=ReadPreference.SECONDARY) + pools = await async_get_pools(client) + pool_opts = pools[0].opts + + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + for pool in pools[1:]: + self.assertEqual(pool.opts, pool_opts) + + async def test_3_uri_connection_pool_options(self): + opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) + uri = f"mongodb://{await async_client_context.pair}/?{opts}" + client = await self.async_rs_or_single_client(uri) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_4_subscribe_to_events(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + self.assertEqual(listener.event_count(PoolCreatedEvent), 1) + + # Creates a new connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1) + self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) + self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) + + # Uses the existing connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) + + await client.close() + self.assertEqual(listener.event_count(PoolClosedEvent), 1) + self.assertEqual(listener.event_count(ConnectionClosedEvent), 1) + + async def test_5_check_out_fails_connection_error(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + pool = await async_get_pool(client) + + def mock_connect(*args, **kwargs): + raise ConnectionFailure("connect failed") + + pool.connect = mock_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + # Attempt to create a new connection. + with self.assertRaisesRegex(ConnectionFailure, "connect failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent) + self.assertIsInstance(listener.events[4], PoolClearedEvent) + + failed_event = listener.events[3] + self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + @async_client_context.require_no_fips + async def test_5_check_out_fails_auth_error(self): + listener = CMAPListener() + client = await self.async_single_client_noauth( + username="notauser", password="fail", event_listeners=[listener] + ) + + # Attempt to create a new connection. + with self.assertRaisesRegex(OperationFailure, "failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCreatedEvent) + # Error happens here. + self.assertIsInstance(listener.events[4], ConnectionClosedEvent) + self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent) + self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + # + # Extra non-spec tests + # + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + async def test_events_repr(self): + host = ("localhost", 27017) + self.assertRepr(ConnectionCheckedInEvent(host, 1)) + self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic())) + self.assertRepr( + ConnectionCheckOutFailedEvent( + host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic() + ) + ) + self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED)) + self.assertRepr(ConnectionCreatedEvent(host, 1)) + self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic())) + self.assertRepr(ConnectionCheckOutStartedEvent(host)) + self.assertRepr(PoolCreatedEvent(host, {})) + self.assertRepr(PoolClearedEvent(host)) + self.assertRepr(PoolClearedEvent(host, service_id=ObjectId())) + self.assertRepr(PoolClosedEvent(host)) + + async def test_close_leaves_pool_unpaused(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + await client.admin.command("ping") + pool = await async_get_pool(client) + await client.close() + self.assertEqual(1, listener.event_count(PoolClosedEvent)) + self.assertEqual(PoolState.CLOSED, pool.state) + # Checking out a connection should fail + with self.assertRaises(_PoolClosedError): + async with pool.checkout(): + pass + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def, test) + + return run_scenario + + +class CMAPSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + CMAP tests do not have a 'tests' field. The whole file represents + a single test case. + """ + return [scenario_def] + + +test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) +test_creator.create_tests() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 2b22bd8b7..335aa9d81 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -739,7 +739,7 @@ class AsyncTestSpec(AsyncSpecRunner): return errors -async def create_test(scenario_def, test, name): +def create_test(scenario_def, test, name): @async_client_context.require_test_commands async def run_scenario(self): await self.run_scenario(scenario_def, test) diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py new file mode 100644 index 000000000..b1c1e754f --- /dev/null +++ b/test/asynchronous/test_gridfs.py @@ -0,0 +1,602 @@ +# +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package.""" +from __future__ import annotations + +import asyncio +import datetime +import sys +import threading +import time +from io import BytesIO +from test.asynchronous.helpers import ConcurrentRunner +from unittest.mock import patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import async_joinall, one + +import gridfs +from bson.binary import Binary +from gridfs.asynchronous.grid_file import DEFAULT_CHUNK_SIZE, AsyncGridOutCursor +from gridfs.errors import CorruptGridFile, FileExists, NoFile +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, +) +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + + +class JustWrite(ConcurrentRunner): + def __init__(self, fs, n): + super().__init__() + self.fs = fs + self.n = n + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = self.fs.new_file(filename="test") + await file.write(b"hello") + await file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, fs, n, results): + super().__init__() + self.fs = fs + self.n = n + self.results = results + self.daemon = True + + async def run(self): + for _ in range(self.n): + file = await self.fs.get("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + + +class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase): + db: AsyncDatabase + + async def asyncSetUp(self): + await super().asyncSetUp() + self.db = AsyncMongoClient(connect=False).pymongo_test + + async def test_gridfs(self): + self.assertRaises(TypeError, gridfs.AsyncGridFS, "foo") + self.assertRaises(TypeError, gridfs.AsyncGridFS, self.db, 5) + + +class TestGridfs(AsyncIntegrationTest): + fs: gridfs.AsyncGridFS + alt: gridfs.AsyncGridFS + + async def asyncSetUp(self): + await super().asyncSetUp() + self.fs = gridfs.AsyncGridFS(self.db) + self.alt = gridfs.AsyncGridFS(self.db, "alt") + await self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) + + async def test_basic(self): + oid = await self.fs.put(b"hello world") + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + await self.fs.delete(oid) + with self.assertRaises(NoFile): + await self.fs.get(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.fs.get("foo") + oid = await self.fs.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.fs.get("foo")).read()) + + async def test_multi_chunk_delete(self): + await self.db.fs.drop() + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFS(self.db) + oid = await gfs.put(b"hello", chunkSize=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_list(self): + self.assertEqual([], await self.fs.list()) + await self.fs.put(b"hello world") + self.assertEqual([], await self.fs.list()) + + # PYTHON-598: in server versions before 2.5.x, creating an index on + # filename, uploadDate causes list() to include None. + await self.fs.get_last_version() + self.assertEqual([], await self.fs.list()) + + await self.fs.put(b"", filename="mike") + await self.fs.put(b"foo", filename="test") + await self.fs.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.fs.list())) + + async def test_empty_file(self): + oid = await self.fs.put(b"") + self.assertEqual(b"", await (await self.fs.get(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + raw = await self.db.fs.files.find_one() + assert raw is not None + self.assertEqual(0, raw["length"]) + self.assertEqual(oid, raw["_id"]) + self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) + self.assertEqual(255 * 1024, raw["chunkSize"]) + self.assertNotIn("md5", raw) + + async def test_corrupt_chunk(self): + files_id = await self.fs.put(b"foobar") + await self.db.fs.chunks.update_one( + {"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}} + ) + try: + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.read() + + out = await self.fs.get(files_id) + with self.assertRaises(CorruptGridFile): + await out.readline() + finally: + await self.fs.delete(files_id) + + async def test_put_ensures_index(self): + chunks = self.db.fs.chunks + files = self.db.fs.files + # Ensure the collections are removed. + await chunks.drop() + await files.drop() + await self.fs.put(b"junk") + + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in (await chunks.index_information()).values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + + async def test_alt_collection(self): + oid = await self.alt.put(b"hello world") + self.assertEqual(b"hello world", await (await self.alt.get(oid)).read()) + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + await self.alt.delete(oid) + with self.assertRaises(NoFile): + await self.alt.get(oid) + self.assertEqual(0, await self.db.alt.files.count_documents({})) + self.assertEqual(0, await self.db.alt.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.alt.get("foo") + oid = await self.alt.put(b"hello world", _id="foo") + self.assertEqual("foo", oid) + self.assertEqual(b"hello world", await (await self.alt.get("foo")).read()) + + await self.alt.put(b"", filename="mike") + await self.alt.put(b"foo", filename="test") + await self.alt.put(b"", filename="hello world") + + self.assertEqual({"mike", "test", "hello world"}, set(await self.alt.list())) + + async def test_threaded_reads(self): + await self.fs.put(b"hello", _id="test") + + tasks = [] + results: list = [] + for i in range(10): + tasks.append(JustRead(self.fs, 10, results)) + await tasks[i].start() + + await async_joinall(tasks) + + self.assertEqual(100 * [b"hello"], results) + + async def test_threaded_writes(self): + tasks = [] + for i in range(10): + tasks.append(JustWrite(self.fs, 10)) + await tasks[i].start() + + await async_joinall(tasks) + + f = await self.fs.get_last_version("test") + self.assertEqual(await f.read(), b"hello") + + # Should have created 100 versions of 'test' file + self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"})) + + async def test_get_last_version(self): + one = await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + two = self.fs.new_file(filename="test") + await two.write(b"bar") + await two.close() + await asyncio.sleep(0.01) + two = two._id + three = await self.fs.put(b"baz", filename="test") + + self.assertEqual(b"baz", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(three) + self.assertEqual(b"bar", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version("test")).read()) + await self.fs.delete(one) + with self.assertRaises(NoFile): + await self.fs.get_last_version("test") + + async def test_get_last_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author") + + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author")).read()) + await self.fs.delete(one) + + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author2") + + self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", await (await self.fs.get_last_version(filename="test")).read()) + + with self.assertRaises(NoFile): + await self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + await self.fs.get_last_version(filename="nottest", author="author1") + + await self.fs.delete(one) + await self.fs.delete(two) + + async def test_get_version(self): + await self.fs.put(b"foo", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"bar", filename="test") + await asyncio.sleep(0.01) + await self.fs.put(b"baz", filename="test") + await asyncio.sleep(0.01) + + self.assertEqual(b"foo", await (await self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", await (await self.fs.get_version("test", 2)).read()) + + self.assertEqual(b"baz", await (await self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", await (await self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", await (await self.fs.get_version("test", -3)).read()) + + with self.assertRaises(NoFile): + await self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + await self.fs.get_version("test", -4) + + async def test_get_version_with_metadata(self): + one = await self.fs.put(b"foo", filename="test", author="author1") + await asyncio.sleep(0.01) + two = await self.fs.put(b"bar", filename="test", author="author1") + await asyncio.sleep(0.01) + three = await self.fs.put(b"baz", filename="test", author="author2") + + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=-2)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=-1)).read(), + ) + self.assertEqual( + b"foo", + await (await self.fs.get_version(filename="test", author="author1", version=0)).read(), + ) + self.assertEqual( + b"bar", + await (await self.fs.get_version(filename="test", author="author1", version=1)).read(), + ) + self.assertEqual( + b"baz", + await (await self.fs.get_version(filename="test", author="author2", version=0)).read(), + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=-1)).read() + ) + self.assertEqual( + b"baz", await (await self.fs.get_version(filename="test", version=2)).read() + ) + + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + await self.fs.get_version(filename="test", author="author1", version=2) + + await self.fs.delete(one) + await self.fs.delete(two) + await self.fs.delete(three) + + async def test_put_filelike(self): + oid = await self.fs.put(BytesIO(b"hello world"), chunk_size=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + self.assertEqual(b"hello world", await (await self.fs.get(oid)).read()) + + async def test_file_exists(self): + oid = await self.fs.put(b"hello") + with self.assertRaises(FileExists): + await self.fs.put(b"world", _id=oid) + + one = self.fs.new_file(_id=123) + await one.write(b"some content") + await one.close() + + # Attempt to upload a file with more chunks to the same _id. + with patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): + two = self.fs.new_file(_id=123) + with self.assertRaises(FileExists): + await two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) + # Original file is still readable (no extra chunks were uploaded). + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + two = self.fs.new_file(_id=123) + await two.write(b"some content") + with self.assertRaises(FileExists): + await two.close() + # Original file is still readable. + self.assertEqual(await (await self.fs.get(123)).read(), b"some content") + + async def test_exists(self): + oid = await self.fs.put(b"hello") + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + + self.assertFalse(await self.fs.exists(filename="mike")) + self.assertFalse(await self.fs.exists("mike")) + + oid = await self.fs.put(b"hello", filename="mike", foo=12) + self.assertTrue(await self.fs.exists(oid)) + self.assertTrue(await self.fs.exists({"_id": oid})) + self.assertTrue(await self.fs.exists(_id=oid)) + self.assertTrue(await self.fs.exists(filename="mike")) + self.assertTrue(await self.fs.exists({"filename": "mike"})) + self.assertTrue(await self.fs.exists(foo=12)) + self.assertTrue(await self.fs.exists({"foo": 12})) + self.assertTrue(await self.fs.exists(foo={"$gt": 11})) + self.assertTrue(await self.fs.exists({"foo": {"$gt": 11}})) + + self.assertFalse(await self.fs.exists(foo=13)) + self.assertFalse(await self.fs.exists({"foo": 13})) + self.assertFalse(await self.fs.exists(foo={"$gt": 12})) + self.assertFalse(await self.fs.exists({"foo": {"$gt": 12}})) + + async def test_put_unicode(self): + with self.assertRaises(TypeError): + await self.fs.put("hello") + + oid = await self.fs.put("hello", encoding="utf-8") + self.assertEqual(b"hello", await (await self.fs.get(oid)).read()) + self.assertEqual("utf-8", (await self.fs.get(oid)).encoding) + + oid = await self.fs.put("aé", encoding="iso-8859-1") + self.assertEqual("aé".encode("iso-8859-1"), await (await self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (await self.fs.get(oid)).encoding) + + async def test_missing_length_iter(self): + # Test fix that guards against PHP-237 + await self.fs.put(b"", filename="empty") + doc = await self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None + doc.pop("length") + await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) + f = await self.fs.get_last_version(filename="empty") + + async def iterate_file(grid_file): + async for _chunk in grid_file: + pass + return True + + self.assertTrue(await iterate_file(f)) + + async def test_gridfs_lazy_connect(self): + client = await self.async_single_client( + "badhost", connect=False, serverSelectionTimeoutMS=10 + ) + db = client.db + gfs = gridfs.AsyncGridFS(db) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.list() + + fs = gridfs.AsyncGridFS(db) + f = fs.new_file() + with self.assertRaises(ServerSelectionTimeoutError): + await f.close() + + async def test_gridfs_find(self): + await self.fs.put(b"test2", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test2+", filename="two") + await asyncio.sleep(0.01) + await self.fs.put(b"test1", filename="one") + await asyncio.sleep(0.01) + await self.fs.put(b"test2++", filename="two") + files = self.db.fs.files + self.assertEqual(3, await files.count_documents({"filename": "two"})) + self.assertEqual(4, await files.count_documents({})) + cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + await cursor.rewind() + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + gout = await cursor.next() + self.assertEqual(b"test2+", await gout.read()) + with self.assertRaises(StopAsyncIteration): + await cursor.__anext__() + await cursor.rewind() + items = await cursor.to_list() + self.assertEqual(len(items), 2) + await cursor.rewind() + items = await cursor.to_list(1) + self.assertEqual(len(items), 1) + await cursor.close() + self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) + + async def test_delete_not_initialized(self): + # Creating a cursor with invalid arguments will not run __init__ + # but will still call __del__. + cursor = AsyncGridOutCursor.__new__(AsyncGridOutCursor) # Skip calling __init__ + with self.assertRaises(TypeError): + cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore + cursor.__del__() # no error + + async def test_gridfs_find_one(self): + self.assertEqual(None, await self.fs.find_one()) + + id1 = await self.fs.put(b"test1", filename="file1") + res = await self.fs.find_one() + assert res is not None + self.assertEqual(b"test1", await res.read()) + + id2 = await self.fs.put(b"test2", filename="file2", meta="data") + res1 = await self.fs.find_one(id1) + assert res1 is not None + self.assertEqual(b"test1", await res1.read()) + res2 = await self.fs.find_one(id2) + assert res2 is not None + self.assertEqual(b"test2", await res2.read()) + + res3 = await self.fs.find_one({"filename": "file1"}) + assert res3 is not None + self.assertEqual(b"test1", await res3.read()) + + res4 = await self.fs.find_one(id2) + assert res4 is not None + self.assertEqual("data", res4.meta) + + async def test_grid_in_non_int_chunksize(self): + # Lua, and perhaps other buggy AsyncGridFS clients, store size as a float. + data = b"data" + await self.fs.put(data, filename="f") + await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) + + self.assertEqual(data, await (await self.fs.get_version("f")).read()) + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + gridfs.AsyncGridFS((await self.async_rs_or_single_client(w=0)).pymongo_test) + + async def test_md5(self): + gin = self.fs.new_file() + await gin.write(b"no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.get(gin._id) + self.assertIsNone(gout.md5) + + _id = await self.fs.put(b"still no md5 sum") + gout = await self.fs.get(_id) + self.assertIsNone(gout.md5) + + +class TestGridfsReplicaSet(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + + @classmethod + @async_client_context.require_connection + async def asyncTearDownClass(cls): + await async_client_context.client.drop_database("gfsreplica") + + async def test_gridfs_replica_set(self): + rsc = await self.async_rs_client( + w=async_client_context.w, read_preference=ReadPreference.SECONDARY + ) + + fs = gridfs.AsyncGridFS(rsc.gfsreplica, "gfsreplicatest") + + gin = fs.new_file() + self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) + + oid = await fs.put(b"foo") + content = await (await fs.get(oid)).read() + self.assertEqual(b"foo", content) + + async def test_gridfs_secondary(self): + secondary_host, secondary_port = one(await self.client.secondaries) + secondary_connection = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) + + # Should detect it's connected to secondary and not attempt to + # create index + fs = gridfs.AsyncGridFS(secondary_connection.gfsreplica, "gfssecondarytest") + + # This won't detect secondary, raises error + with self.assertRaises(NotPrimaryError): + await fs.put(b"foo") + + async def test_gridfs_secondary_lazy(self): + # Should detect it's connected to secondary and not attempt to + # create index. + secondary_host, secondary_port = one(await self.client.secondaries) + client = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) + + # Still no connection. + fs = gridfs.AsyncGridFS(client.gfsreplica, "gfssecondarylazytest") + + # Connects, doesn't create index. + with self.assertRaises(NoFile): + await fs.get_last_version() + with self.assertRaises(NotPrimaryError): + await fs.put("data", encoding="utf-8") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_gridfs_bucket.py b/test/asynchronous/test_gridfs_bucket.py new file mode 100644 index 000000000..5d1cf5bef --- /dev/null +++ b/test/asynchronous/test_gridfs_bucket.py @@ -0,0 +1,574 @@ +# +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package.""" +from __future__ import annotations + +import asyncio +import datetime +import itertools +import sys +import threading +import time +from io import BytesIO +from test.asynchronous.helpers import ConcurrentRunner +from unittest.mock import patch + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import async_joinall, joinall, one + +import gridfs +from bson.binary import Binary +from bson.int64 import Int64 +from bson.objectid import ObjectId +from bson.son import SON +from gridfs.errors import CorruptGridFile, NoFile +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, + WriteConcernError, +) +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False + + +class JustWrite(ConcurrentRunner): + def __init__(self, gfs, num): + super().__init__() + self.gfs = gfs + self.num = num + self.daemon = True + + async def run(self): + for _ in range(self.num): + file = self.gfs.open_upload_stream("test") + await file.write(b"hello") + await file.close() + + +class JustRead(ConcurrentRunner): + def __init__(self, gfs, num, results): + super().__init__() + self.gfs = gfs + self.num = num + self.results = results + self.daemon = True + + async def run(self): + for _ in range(self.num): + file = await self.gfs.open_download_stream_by_name("test") + data = await file.read() + self.results.append(data) + assert data == b"hello" + + +class TestGridfs(AsyncIntegrationTest): + fs: gridfs.AsyncGridFSBucket + alt: gridfs.AsyncGridFSBucket + + async def asyncSetUp(self): + await super().asyncSetUp() + self.fs = gridfs.AsyncGridFSBucket(self.db) + self.alt = gridfs.AsyncGridFSBucket(self.db, bucket_name="alt") + await self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) + + async def test_basic(self): + oid = await self.fs.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + await self.fs.delete(oid) + with self.assertRaises(NoFile): + await self.fs.open_download_stream(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_multi_chunk_delete(self): + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + gfs = gridfs.AsyncGridFSBucket(self.db) + oid = await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(5, await self.db.fs.chunks.count_documents({})) + await gfs.delete(oid) + self.assertEqual(0, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + async def test_empty_file(self): + oid = await self.fs.upload_from_stream("test_filename", b"") + self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(0, await self.db.fs.chunks.count_documents({})) + + raw = await self.db.fs.files.find_one() + assert raw is not None + self.assertEqual(0, raw["length"]) + self.assertEqual(oid, raw["_id"]) + self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) + self.assertEqual(255 * 1024, raw["chunkSize"]) + self.assertNotIn("md5", raw) + + async def test_corrupt_chunk(self): + files_id = await self.fs.upload_from_stream("test_filename", b"foobar") + await self.db.fs.chunks.update_one( + {"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}} + ) + try: + out = await self.fs.open_download_stream(files_id) + with self.assertRaises(CorruptGridFile): + await out.read() + + out = await self.fs.open_download_stream(files_id) + with self.assertRaises(CorruptGridFile): + await out.readline() + finally: + await self.fs.delete(files_id) + + async def test_upload_ensures_index(self): + chunks = self.db.fs.chunks + files = self.db.fs.files + # Ensure the collections are removed. + await chunks.drop() + await files.drop() + await self.fs.upload_from_stream("filename", b"junk") + + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in (await chunks.index_information()).values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + + async def test_ensure_index_shell_compat(self): + files = self.db.fs.files + for i, j in itertools.combinations_with_replacement([1, 1.0, Int64(1)], 2): + # Create the index with different numeric types (as might be done + # from the mongo shell). + shell_index = [("filename", i), ("uploadDate", j)] + await self.db.command( + "createIndexes", + files.name, + indexes=[{"key": SON(shell_index), "name": "filename_1.0_uploadDate_1.0"}], + ) + + # No error. + await self.fs.upload_from_stream("filename", b"data") + + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in (await files.index_information()).values() + ) + ) + await files.drop() + + async def test_alt_collection(self): + oid = await self.alt.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", await (await self.alt.open_download_stream(oid)).read()) + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + await self.alt.delete(oid) + with self.assertRaises(NoFile): + await self.alt.open_download_stream(oid) + self.assertEqual(0, await self.db.alt.files.count_documents({})) + self.assertEqual(0, await self.db.alt.chunks.count_documents({})) + + with self.assertRaises(NoFile): + await self.alt.open_download_stream("foo") + await self.alt.upload_from_stream("foo", b"hello world") + self.assertEqual( + b"hello world", await (await self.alt.open_download_stream_by_name("foo")).read() + ) + + await self.alt.upload_from_stream("mike", b"") + await self.alt.upload_from_stream("test", b"foo") + await self.alt.upload_from_stream("hello world", b"") + + self.assertEqual( + {"mike", "test", "hello world", "foo"}, + {k["filename"] for k in await self.db.alt.files.find().to_list()}, + ) + + async def test_threaded_reads(self): + await self.fs.upload_from_stream("test", b"hello") + + threads = [] + results: list = [] + for i in range(10): + threads.append(JustRead(self.fs, 10, results)) + await threads[i].start() + + await async_joinall(threads) + + self.assertEqual(100 * [b"hello"], results) + + async def test_threaded_writes(self): + threads = [] + for i in range(10): + threads.append(JustWrite(self.fs, 10)) + await threads[i].start() + + await async_joinall(threads) + + fstr = await self.fs.open_download_stream_by_name("test") + self.assertEqual(await fstr.read(), b"hello") + + # Should have created 100 versions of 'test' file + self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"})) + + async def test_get_last_version(self): + one = await self.fs.upload_from_stream("test", b"foo") + await asyncio.sleep(0.01) + two = self.fs.open_upload_stream("test") + await two.write(b"bar") + await two.close() + await asyncio.sleep(0.01) + two = two._id + three = await self.fs.upload_from_stream("test", b"baz") + + self.assertEqual(b"baz", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(three) + self.assertEqual(b"bar", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(two) + self.assertEqual(b"foo", await (await self.fs.open_download_stream_by_name("test")).read()) + await self.fs.delete(one) + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test") + + async def test_get_version(self): + await self.fs.upload_from_stream("test", b"foo") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("test", b"bar") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("test", b"baz") + await asyncio.sleep(0.01) + + self.assertEqual( + b"foo", await (await self.fs.open_download_stream_by_name("test", revision=0)).read() + ) + self.assertEqual( + b"bar", await (await self.fs.open_download_stream_by_name("test", revision=1)).read() + ) + self.assertEqual( + b"baz", await (await self.fs.open_download_stream_by_name("test", revision=2)).read() + ) + + self.assertEqual( + b"baz", await (await self.fs.open_download_stream_by_name("test", revision=-1)).read() + ) + self.assertEqual( + b"bar", await (await self.fs.open_download_stream_by_name("test", revision=-2)).read() + ) + self.assertEqual( + b"foo", await (await self.fs.open_download_stream_by_name("test", revision=-3)).read() + ) + + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test", revision=3) + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("test", revision=-4) + + async def test_upload_from_stream(self): + oid = await self.fs.upload_from_stream( + "test_file", BytesIO(b"hello world"), chunk_size_bytes=1 + ) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read()) + + async def test_upload_from_stream_with_id(self): + oid = ObjectId() + await self.fs.upload_from_stream_with_id( + oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1 + ) + self.assertEqual(b"custom id", await (await self.fs.open_download_stream(oid)).read()) + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3) + @async_client_context.require_failCommand_fail_point + async def test_upload_bulk_write_error(self): + # Test BulkWriteError from insert_many is converted to an insert_one style error. + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + gin = self.fs.open_upload_stream("test_file", chunk_size_bytes=1) + async with self.fail_point(cause_wce): + # Assert we raise WriteConcernError, not BulkWriteError. + with self.assertRaises(WriteConcernError): + await gin.write(b"hello world") + # 3 chunks were uploaded. + self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.abort() + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 10) + async def test_upload_batching(self): + async with self.fs.open_upload_stream("test_file", chunk_size_bytes=1) as gin: + await gin.write(b"s" * (10 - 1)) + # No chunks were uploaded yet. + self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.write(b"s") + # All chunks were uploaded since we hit the _UPLOAD_BUFFER_CHUNKS limit. + self.assertEqual(10, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + + async def test_open_upload_stream(self): + gin = self.fs.open_upload_stream("from_stream") + await gin.write(b"from stream") + await gin.close() + self.assertEqual(b"from stream", await (await self.fs.open_download_stream(gin._id)).read()) + + async def test_open_upload_stream_with_id(self): + oid = ObjectId() + gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") + await gin.write(b"from stream with custom id") + await gin.close() + self.assertEqual( + b"from stream with custom id", await (await self.fs.open_download_stream(oid)).read() + ) + + async def test_missing_length_iter(self): + # Test fix that guards against PHP-237 + await self.fs.upload_from_stream("empty", b"") + doc = await self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None + doc.pop("length") + await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) + fstr = await self.fs.open_download_stream_by_name("empty") + + async def iterate_file(grid_file): + async for _ in grid_file: + pass + return True + + self.assertTrue(await iterate_file(fstr)) + + async def test_gridfs_lazy_connect(self): + client = await self.async_single_client( + "badhost", connect=False, serverSelectionTimeoutMS=0 + ) + cdb = client.db + gfs = gridfs.AsyncGridFSBucket(cdb) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.delete(0) + + gfs = gridfs.AsyncGridFSBucket(cdb) + with self.assertRaises(ServerSelectionTimeoutError): + await gfs.upload_from_stream("test", b"") # Still no connection. + + async def test_gridfs_find(self): + await self.fs.upload_from_stream("two", b"test2") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("two", b"test2+") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("one", b"test1") + await asyncio.sleep(0.01) + await self.fs.upload_from_stream("two", b"test2++") + files = self.db.fs.files + self.assertEqual(3, await files.count_documents({"filename": "two"})) + self.assertEqual(4, await files.count_documents({})) + cursor = self.fs.find( + {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2 + ) + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + await cursor.rewind() + gout = await cursor.next() + self.assertEqual(b"test1", await gout.read()) + gout = await cursor.next() + self.assertEqual(b"test2+", await gout.read()) + with self.assertRaises(StopAsyncIteration): + await cursor.next() + await cursor.close() + self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) + + async def test_grid_in_non_int_chunksize(self): + # Lua, and perhaps other buggy AsyncGridFS clients, store size as a float. + data = b"data" + await self.fs.upload_from_stream("f", data) + await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) + + self.assertEqual(data, await (await self.fs.open_download_stream_by_name("f")).read()) + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + gridfs.AsyncGridFSBucket((await self.async_rs_or_single_client(w=0)).pymongo_test) + + async def test_rename(self): + _id = await self.fs.upload_from_stream("first_name", b"testing") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("first_name")).read() + ) + + await self.fs.rename(_id, "second_name") + with self.assertRaises(NoFile): + await self.fs.open_download_stream_by_name("first_name") + self.assertEqual( + b"testing", await (await self.fs.open_download_stream_by_name("second_name")).read() + ) + + @patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", 5) + async def test_abort(self): + gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5) + await gin.write(b"test1") + await gin.write(b"test2") + await gin.write(b"test3") + self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + await gin.abort() + self.assertTrue(gin.closed) + with self.assertRaises(ValueError): + await gin.write(b"test4") + self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id})) + + async def test_download_to_stream(self): + file1 = BytesIO(b"hello world") + # Test with one chunk. + oid = await self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + file1.seek(0) + oid = await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + async def test_download_to_stream_by_name(self): + file1 = BytesIO(b"hello world") + # Test with one chunk. + _ = await self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + file2 = BytesIO() + await self.fs.download_to_stream_by_name("one_chunk", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + await self.db.drop_collection("fs.files") + await self.db.drop_collection("fs.chunks") + file1.seek(0) + await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) + self.assertEqual(11, await self.db.fs.chunks.count_documents({})) + + file2 = BytesIO() + await self.fs.download_to_stream_by_name("many_chunks", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + async def test_md5(self): + gin = self.fs.open_upload_stream("no md5") + await gin.write(b"no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.open_download_stream(gin._id) + self.assertIsNone(gout.md5) + + gin = self.fs.open_upload_stream_with_id(ObjectId(), "also no md5") + await gin.write(b"also no md5 sum") + await gin.close() + self.assertIsNone(gin.md5) + + gout = await self.fs.open_download_stream(gin._id) + self.assertIsNone(gout.md5) + + +class TestGridfsBucketReplicaSet(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + + @classmethod + @async_client_context.require_connection + async def asyncTearDownClass(cls): + await async_client_context.client.drop_database("gfsbucketreplica") + + async def test_gridfs_replica_set(self): + rsc = await self.async_rs_client( + w=async_client_context.w, read_preference=ReadPreference.SECONDARY + ) + + gfs = gridfs.AsyncGridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") + oid = await gfs.upload_from_stream("test_filename", b"foo") + content = await (await gfs.open_download_stream(oid)).read() + self.assertEqual(b"foo", content) + + async def test_gridfs_secondary(self): + secondary_host, secondary_port = one(await self.client.secondaries) + secondary_connection = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) + + # Should detect it's connected to secondary and not attempt to + # create index + gfs = gridfs.AsyncGridFSBucket( + secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest" + ) + + # This won't detect secondary, raises error + with self.assertRaises(NotPrimaryError): + await gfs.upload_from_stream("test_filename", b"foo") + + async def test_gridfs_secondary_lazy(self): + # Should detect it's connected to secondary and not attempt to + # create index. + secondary_host, secondary_port = one(await self.client.secondaries) + client = await self.async_single_client( + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) + + # Still no connection. + gfs = gridfs.AsyncGridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest") + + # Connects, doesn't create index. + with self.assertRaises(NoFile): + await gfs.open_download_stream_by_name("test_filename") + with self.assertRaises(NotPrimaryError): + await gfs.upload_from_stream("test_filename", b"data") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py new file mode 100644 index 000000000..f0451841c --- /dev/null +++ b/test/asynchronous/test_server_selection.py @@ -0,0 +1,211 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import os +import sys +from pathlib import Path + +from pymongo import AsyncMongoClient, ReadPreference +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.errors import ServerSelectionTimeoutError +from pymongo.hello import HelloCompat +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector +from pymongo.typings import strip_optional + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils_selection_tests import ( + create_selection_tests, + get_addresses, + get_topology_settings_dict, + make_server_description, +) +from test.utils import ( + EventListener, + FunctionCallRecorder, + OvertCommandListener, + async_wait_until, +) + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent, "server_selection", "server_selection" + ) +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "server_selection" + ) + + +class SelectionStoreSelector: + """No-op selector that keeps track of what was passed to it.""" + + def __init__(self): + self.selection = None + + def __call__(self, selection): + self.selection = selection + return selection + + +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore + pass + + +class TestCustomServerSelectorFunction(AsyncIntegrationTest): + @async_client_context.require_replica_set + async def test_functional_select_max_port_number_host(self): + # Selector that returns server with highest port number. + def custom_selector(servers): + ports = [s.address[1] for s in servers] + idx = ports.index(max(ports)) + return [servers[idx]] + + # Initialize client with appropriate listeners. + listener = OvertCommandListener() + client = await self.async_rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener] + ) + coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll + self.addAsyncCleanup(client.drop_database, "testdb") + + # Wait the node list to be fully populated. + async def all_hosts_started(): + return len((await client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len( + client._topology._description.readable_servers + ) + + await async_wait_until(all_hosts_started, "receive heartbeat from all hosts") + + expected_port = max( + [strip_optional(n.address[1]) for n in client._topology._description.readable_servers] + ) + + # Insert 1 record and access it 10 times. + await coll.insert_one({"name": "John Doe"}) + for _ in range(10): + await coll.find_one({"name": "John Doe"}) + + # Confirm all find commands are run against appropriate host. + for command in listener.started_events: + if command.command_name == "find": + self.assertEqual(command.connection_id[1], expected_port) + + async def test_invalid_server_selector(self): + # Client initialization must fail if server_selector is not callable. + for selector_candidate in [[], 10, "string", {}]: + with self.assertRaisesRegex(ValueError, "must be a callable"): + AsyncMongoClient(connect=False, server_selector=selector_candidate) + + # None value for server_selector is OK. + AsyncMongoClient(connect=False, server_selector=None) + + @async_client_context.require_replica_set + async def test_selector_called(self): + selector = FunctionCallRecorder(lambda x: x) + + # Client setup. + mongo_client = await self.async_rs_or_single_client(server_selector=selector) + test_collection = mongo_client.testdb.test_collection + self.addAsyncCleanup(mongo_client.drop_database, "testdb") + + # Do N operations and test selector is called at least N times. + await test_collection.insert_one({"age": 20, "name": "John"}) + await test_collection.insert_one({"age": 31, "name": "Jane"}) + await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) + await test_collection.find_one({"name": "Roe"}) + self.assertGreaterEqual(selector.call_count, 4) + + @async_client_context.require_replica_set + async def test_latency_threshold_application(self): + selector = SelectionStoreSelector() + + scenario_def: dict = { + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}}, + ], + } + } + + # Create & populate Topology such that all but one server is too slow. + rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]] + min_rtt_idx = rtt_times.index(min(rtt_times)) + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) + topology = Topology(TopologySettings(**settings)) + await topology.open() + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Invoke server selection and assert no filtering based on latency + # prior to custom server selection logic kicking in. + server = await topology.select_server(ReadPreference.NEAREST, _Op.TEST) + assert selector.selection is not None + self.assertEqual(len(selector.selection), len(topology.description.server_descriptions())) + + # Ensure proper filtering based on latency after custom selection. + self.assertEqual(server.description.address, seeds[min_rtt_idx]) + + @async_client_context.require_replica_set + async def test_server_selector_bypassed(self): + selector = FunctionCallRecorder(lambda x: x) + + scenario_def = { + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}}, + ], + } + } + + # Create & populate Topology such that no server is writeable. + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) + topology = Topology(TopologySettings(**settings)) + await topology.open() + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Invoke server selection and assert no calls to our custom selector. + with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"): + await topology.select_server( + writable_server_selector, _Op.TEST, server_selection_timeout=0.1 + ) + self.assertEqual(selector.call_count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py new file mode 100644 index 000000000..e2ae92a27 --- /dev/null +++ b/test/asynchronous/test_server_selection_in_window.py @@ -0,0 +1,179 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import asyncio +import os +import threading +from pathlib import Path +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner +from test.asynchronous.utils_selection_tests import create_topology +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator +from test.utils import ( + CMAPListener, + OvertCommandListener, + async_get_pool, + async_wait_until, +) + +from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + async def run_scenario(self, scenario_def): + topology = await create_topology(scenario_def) + + # Update mock operation_count state: + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) + server = topology.get_server_by_address(address) + server.pool.operation_count = mock["operation_count"] + + pref = ReadPreference.NEAREST + counts = {address: 0 for address in topology._description.server_descriptions()} + + # Number of times to repeat server selection + iterations = scenario_def["iterations"] + for _ in range(iterations): + server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + counts[server.description.address] += 1 + + # Verify expected_frequencies + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] + for host_str, freq in expected_frequencies.items(): + address = clean_node(host_str) + actual_freq = float(counts[address]) / iterations + if freq == 0: + # Should be exactly 0. + self.assertEqual(actual_freq, 0) + else: + # Should be within 'tolerance'. + self.assertAlmostEqual(actual_freq, freq, delta=tolerance) + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def) + + return run_scenario + + +class CustomSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + Server selection in_window tests do not have a 'tests' field. + The whole file represents a single test case. + """ + return [scenario_def] + + +CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() + + +class FinderTask(ConcurrentRunner): + def __init__(self, collection, iterations): + super().__init__() + self.daemon = True + self.collection = collection + self.iterations = iterations + self.passed = False + + async def run(self): + for _ in range(self.iterations): + await self.collection.find_one({}) + self.passed = True + + +class TestProse(AsyncIntegrationTest): + async def frequencies(self, client, listener, n_finds=10): + coll = client.test.test + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + await task.start() + for task in tasks: + await task.join() + for task in tasks: + self.assertTrue(task.passed) + + events = listener.started_events + self.assertEqual(len(events), n_finds * N_TASKS) + nodes = client.nodes + self.assertEqual(len(nodes), 2) + freqs = {address: 0.0 for address in nodes} + for event in events: + freqs[event.connection_id] += 1 + for address in freqs: + freqs[address] = freqs[address] / float(len(events)) + return freqs + + @async_client_context.require_failCommand_appName + @async_client_context.require_multiple_mongoses + async def test_load_balancing(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = await self.async_rs_client( + async_client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", + }, + } + async with self.fail_point(delay_finds): + nodes = async_client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = next(iter(nodes)) + freqs = await self.frequencies(client, listener) + self.assertLessEqual(freqs[delayed_server], 0.25) + listener.reset() + freqs = await self.frequencies(client, listener, n_finds=150) + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py new file mode 100644 index 000000000..71e287569 --- /dev/null +++ b/test/asynchronous/utils_selection_tests.py @@ -0,0 +1,203 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys +from test.asynchronous import AsyncPyMongoTestCase + +sys.path[0:0] = [""] + +from test import unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import AsyncMockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) + +from bson import json_util +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.common import HEARTBEAT_FREQUENCY +from pymongo.errors import AutoReconnect, ConfigurationError +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector + +_IS_SYNC = False + + +def get_topology_settings_dict(**kwargs): + settings = { + "monitor_class": DummyMonitor, + "heartbeat_frequency": HEARTBEAT_FREQUENCY, + "pool_class": AsyncMockPool, + } + settings.update(kwargs) + return settings + + +async def create_topology(scenario_def, **kwargs): + # Initialize topologies. + if "heartbeatFrequencyMS" in scenario_def: + frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0 + else: + frequency = HEARTBEAT_FREQUENCY + + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + + topology_type = get_topology_type_name(scenario_def) + if topology_type == "LoadBalanced": + kwargs.setdefault("load_balanced", True) + # Force topology description to ReplicaSet + elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]: + kwargs.setdefault("replica_set_name", "rs") + settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs) + + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + topology = Topology(TopologySettings(**settings)) + await topology.open() + + # Update topologies with server descriptions. + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Assert that descriptions match + assert ( + scenario_def["topology_description"]["type"] == topology.description.topology_type_name + ), topology.description.topology_type_name + + return topology + + +def create_test(scenario_def): + async def run_scenario(self): + _, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + top_latency = await create_topology(scenario_def) + + # "In latency window" is defined in the server selection + # spec as the subset of suitable_servers that falls within the + # allowable latency window. + top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000) + + # Create server selector. + if scenario_def.get("operation") == "write": + pref = writable_server_selector + else: + # Make first letter lowercase to match read_pref's modes. + pref_def = scenario_def["read_preference"] + if scenario_def.get("error"): + with self.assertRaises((ConfigurationError, ValueError)): + # Error can be raised when making Read Pref or selecting. + pref = parse_read_preference(pref_def) + await top_latency.select_server(pref, _Op.TEST) + return + + pref = parse_read_preference(pref_def) + + # Select servers. + if not scenario_def.get("suitable_servers"): + with self.assertRaises(AutoReconnect): + await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + if not scenario_def["in_latency_window"]: + with self.assertRaises(AutoReconnect): + await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + actual_suitable_s = await top_suitable.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + actual_latency_s = await top_latency.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + + expected_suitable_servers = {} + for server in scenario_def["suitable_servers"]: + server_description = make_server_description(server, hosts) + expected_suitable_servers[server["address"]] = server_description + + actual_suitable_servers = {} + for s in actual_suitable_s: + actual_suitable_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) + for k, actual in actual_suitable_servers.items(): + expected = expected_suitable_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + expected_latency_servers = {} + for server in scenario_def["in_latency_window"]: + server_description = make_server_description(server, hosts) + expected_latency_servers[server["address"]] = server_description + + actual_latency_servers = {} + for s in actual_latency_s: + actual_latency_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) + for k, actual in actual_latency_servers.items(): + expected = expected_latency_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + return run_scenario + + +def create_selection_tests(test_dir): + class TestAllScenarios(AsyncPyMongoTestCase): + pass + + for dirpath, _, filenames in os.walk(test_dir): + dirname = os.path.split(dirpath) + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + return TestAllScenarios diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index d433f1a7e..11d88850f 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -229,7 +229,7 @@ class AsyncSpecTestCreator: str(test_def["description"].replace(" ", "_").replace(".", "_")), ) - new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) diff --git a/test/test_bulk.py b/test/test_bulk.py index 6d29ff510..6a72bddfc 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -959,7 +959,6 @@ class TestBulkWriteConcern(BulkTestBase): @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_ordered(self): - self.skipTest("Skipping until PYTHON-4865 is resolved.") details = None # Ensure we don't raise on wnote. diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 05411d17b..810d44093 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -15,9 +15,11 @@ """Execute Transactions Spec tests.""" from __future__ import annotations +import asyncio import os import sys import time +from pathlib import Path sys.path[0:0] = [""] @@ -60,6 +62,8 @@ from pymongo.read_preferences import ReadPreference from pymongo.synchronous.pool import PoolState, _PoolClosedError from pymongo.topology_description import updated_topology_description +_IS_SYNC = True + OBJECT_TYPES = { # Event types. "ConnectionCheckedIn": ConnectionCheckedInEvent, @@ -81,7 +85,10 @@ OBJECT_TYPES = { class TestCMAP(IntegrationTest): # Location of JSON test specifications. - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") # Test operations: @@ -258,7 +265,6 @@ class TestCMAP(IntegrationTest): client._topology.open() else: client._get_topology() - self.addCleanup(client.close) self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -315,13 +321,11 @@ class TestCMAP(IntegrationTest): # def test_1_client_connection_pool_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. if client_context.has_secondaries: @@ -337,14 +341,12 @@ class TestCMAP(IntegrationTest): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" client = self.rs_or_single_client(uri) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. @@ -368,7 +370,6 @@ class TestCMAP(IntegrationTest): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) pool = get_pool(client) def mock_connect(*args, **kwargs): @@ -397,7 +398,6 @@ class TestCMAP(IntegrationTest): client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) - self.addCleanup(client.close) # Attempt to create a new connection. with self.assertRaisesRegex(OperationFailure, "failed"): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index ab8950250..47e38141b 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -16,11 +16,13 @@ """Tests for the gridfs package.""" from __future__ import annotations +import asyncio import datetime import sys import threading import time from io import BytesIO +from test.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] @@ -41,10 +43,12 @@ from pymongo.read_preferences import ReadPreference from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +_IS_SYNC = True -class JustWrite(threading.Thread): + +class JustWrite(ConcurrentRunner): def __init__(self, fs, n): - threading.Thread.__init__(self) + super().__init__() self.fs = fs self.n = n self.daemon = True @@ -56,9 +60,9 @@ class JustWrite(threading.Thread): file.close() -class JustRead(threading.Thread): +class JustRead(ConcurrentRunner): def __init__(self, fs, n, results): - threading.Thread.__init__(self) + super().__init__() self.fs = fs self.n = n self.results = results @@ -98,19 +102,21 @@ class TestGridfs(IntegrationTest): def test_basic(self): oid = self.fs.put(b"hello world") - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) - self.assertRaises(NoFile, self.fs.get, oid) + with self.assertRaises(NoFile): + self.fs.get(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) - self.assertRaises(NoFile, self.fs.get, "foo") + with self.assertRaises(NoFile): + self.fs.get("foo") oid = self.fs.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.fs.get("foo").read()) + self.assertEqual(b"hello world", (self.fs.get("foo")).read()) def test_multi_chunk_delete(self): self.db.fs.drop() @@ -142,7 +148,7 @@ class TestGridfs(IntegrationTest): def test_empty_file(self): oid = self.fs.put(b"") - self.assertEqual(b"", self.fs.get(oid).read()) + self.assertEqual(b"", (self.fs.get(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -159,10 +165,12 @@ class TestGridfs(IntegrationTest): self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.read) + with self.assertRaises(CorruptGridFile): + out.read() out = self.fs.get(files_id) - self.assertRaises(CorruptGridFile, out.readline) + with self.assertRaises(CorruptGridFile): + out.readline() finally: self.fs.delete(files_id) @@ -177,31 +185,33 @@ class TestGridfs(IntegrationTest): self.assertTrue( any( info.get("key") == [("files_id", 1), ("n", 1)] - for info in chunks.index_information().values() + for info in (chunks.index_information()).values() ) ) self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) def test_alt_collection(self): oid = self.alt.put(b"hello world") - self.assertEqual(b"hello world", self.alt.get(oid).read()) + self.assertEqual(b"hello world", (self.alt.get(oid)).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) - self.assertRaises(NoFile, self.alt.get, oid) + with self.assertRaises(NoFile): + self.alt.get(oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) - self.assertRaises(NoFile, self.alt.get, "foo") + with self.assertRaises(NoFile): + self.alt.get("foo") oid = self.alt.put(b"hello world", _id="foo") self.assertEqual("foo", oid) - self.assertEqual(b"hello world", self.alt.get("foo").read()) + self.assertEqual(b"hello world", (self.alt.get("foo")).read()) self.alt.put(b"", filename="mike") self.alt.put(b"foo", filename="test") @@ -212,23 +222,23 @@ class TestGridfs(IntegrationTest): def test_threaded_reads(self): self.fs.put(b"hello", _id="test") - threads = [] + tasks = [] results: list = [] for i in range(10): - threads.append(JustRead(self.fs, 10, results)) - threads[i].start() + tasks.append(JustRead(self.fs, 10, results)) + tasks[i].start() - joinall(threads) + joinall(tasks) self.assertEqual(100 * [b"hello"], results) def test_threaded_writes(self): - threads = [] + tasks = [] for i in range(10): - threads.append(JustWrite(self.fs, 10)) - threads[i].start() + tasks.append(JustWrite(self.fs, 10)) + tasks[i].start() - joinall(threads) + joinall(tasks) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") @@ -246,34 +256,37 @@ class TestGridfs(IntegrationTest): two = two._id three = self.fs.put(b"baz", filename="test") - self.assertEqual(b"baz", self.fs.get_last_version("test").read()) + self.assertEqual(b"baz", (self.fs.get_last_version("test")).read()) self.fs.delete(three) - self.assertEqual(b"bar", self.fs.get_last_version("test").read()) + self.assertEqual(b"bar", (self.fs.get_last_version("test")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version("test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version("test")).read()) self.fs.delete(one) - self.assertRaises(NoFile, self.fs.get_last_version, "test") + with self.assertRaises(NoFile): + self.fs.get_last_version("test") def test_get_last_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author") - self.assertEqual(b"bar", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.get_last_version(author="author").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author")).read()) self.fs.delete(one) one = self.fs.put(b"foo", filename="test", author="author1") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author2") - self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read()) - self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read()) - self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read()) + self.assertEqual(b"foo", (self.fs.get_last_version(author="author1")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(author="author2")).read()) + self.assertEqual(b"bar", (self.fs.get_last_version(filename="test")).read()) - self.assertRaises(NoFile, self.fs.get_last_version, author="author3") - self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1") + with self.assertRaises(NoFile): + self.fs.get_last_version(author="author3") + with self.assertRaises(NoFile): + self.fs.get_last_version(filename="nottest", author="author1") self.fs.delete(one) self.fs.delete(two) @@ -286,16 +299,18 @@ class TestGridfs(IntegrationTest): self.fs.put(b"baz", filename="test") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.get_version("test", 0).read()) - self.assertEqual(b"bar", self.fs.get_version("test", 1).read()) - self.assertEqual(b"baz", self.fs.get_version("test", 2).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", 0)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", 1)).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", 2)).read()) - self.assertEqual(b"baz", self.fs.get_version("test", -1).read()) - self.assertEqual(b"bar", self.fs.get_version("test", -2).read()) - self.assertEqual(b"foo", self.fs.get_version("test", -3).read()) + self.assertEqual(b"baz", (self.fs.get_version("test", -1)).read()) + self.assertEqual(b"bar", (self.fs.get_version("test", -2)).read()) + self.assertEqual(b"foo", (self.fs.get_version("test", -3)).read()) - self.assertRaises(NoFile, self.fs.get_version, "test", 3) - self.assertRaises(NoFile, self.fs.get_version, "test", -4) + with self.assertRaises(NoFile): + self.fs.get_version("test", 3) + with self.assertRaises(NoFile): + self.fs.get_version("test", -4) def test_get_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author1") @@ -305,25 +320,32 @@ class TestGridfs(IntegrationTest): three = self.fs.put(b"baz", filename="test", author="author2") self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=-2)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=-1)).read(), ) self.assertEqual( - b"foo", self.fs.get_version(filename="test", author="author1", version=0).read() + b"foo", + (self.fs.get_version(filename="test", author="author1", version=0)).read(), ) self.assertEqual( - b"bar", self.fs.get_version(filename="test", author="author1", version=1).read() + b"bar", + (self.fs.get_version(filename="test", author="author1", version=1)).read(), ) self.assertEqual( - b"baz", self.fs.get_version(filename="test", author="author2", version=0).read() + b"baz", + (self.fs.get_version(filename="test", author="author2", version=0)).read(), ) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read()) - self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=-1)).read()) + self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=2)).read()) - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3") - self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2) + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author3") + with self.assertRaises(NoFile): + self.fs.get_version(filename="test", author="author1", version=2) self.fs.delete(one) self.fs.delete(two) @@ -332,11 +354,12 @@ class TestGridfs(IntegrationTest): def test_put_filelike(self): oid = self.fs.put(BytesIO(b"hello world"), chunk_size=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", self.fs.get(oid).read()) + self.assertEqual(b"hello world", (self.fs.get(oid)).read()) def test_file_exists(self): oid = self.fs.put(b"hello") - self.assertRaises(FileExists, self.fs.put, b"world", _id=oid) + with self.assertRaises(FileExists): + self.fs.put(b"world", _id=oid) one = self.fs.new_file(_id=123) one.write(b"some content") @@ -345,15 +368,17 @@ class TestGridfs(IntegrationTest): # Attempt to upload a file with more chunks to the same _id. with patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE): two = self.fs.new_file(_id=123) - self.assertRaises(FileExists, two.write, b"x" * DEFAULT_CHUNK_SIZE * 3) + with self.assertRaises(FileExists): + two.write(b"x" * DEFAULT_CHUNK_SIZE * 3) # Original file is still readable (no extra chunks were uploaded). - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") two = self.fs.new_file(_id=123) two.write(b"some content") - self.assertRaises(FileExists, two.close) + with self.assertRaises(FileExists): + two.close() # Original file is still readable. - self.assertEqual(self.fs.get(123).read(), b"some content") + self.assertEqual((self.fs.get(123)).read(), b"some content") def test_exists(self): oid = self.fs.put(b"hello") @@ -381,15 +406,16 @@ class TestGridfs(IntegrationTest): self.assertFalse(self.fs.exists({"foo": {"$gt": 12}})) def test_put_unicode(self): - self.assertRaises(TypeError, self.fs.put, "hello") + with self.assertRaises(TypeError): + self.fs.put("hello") oid = self.fs.put("hello", encoding="utf-8") - self.assertEqual(b"hello", self.fs.get(oid).read()) - self.assertEqual("utf-8", self.fs.get(oid).encoding) + self.assertEqual(b"hello", (self.fs.get(oid)).read()) + self.assertEqual("utf-8", (self.fs.get(oid)).encoding) oid = self.fs.put("aé", encoding="iso-8859-1") - self.assertEqual("aé".encode("iso-8859-1"), self.fs.get(oid).read()) - self.assertEqual("iso-8859-1", self.fs.get(oid).encoding) + self.assertEqual("aé".encode("iso-8859-1"), (self.fs.get(oid)).read()) + self.assertEqual("iso-8859-1", (self.fs.get(oid)).encoding) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -411,11 +437,13 @@ class TestGridfs(IntegrationTest): client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) - self.assertRaises(ServerSelectionTimeoutError, gfs.list) + with self.assertRaises(ServerSelectionTimeoutError): + gfs.list() fs = gridfs.GridFS(db) f = fs.new_file() - self.assertRaises(ServerSelectionTimeoutError, f.close) + with self.assertRaises(ServerSelectionTimeoutError): + f.close() def test_gridfs_find(self): self.fs.put(b"test2", filename="two") @@ -429,14 +457,15 @@ class TestGridfs(IntegrationTest): self.assertEqual(3, files.count_documents({"filename": "two"})) self.assertEqual(4, files.count_documents({})) cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) cursor.rewind() - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test2+", gout.read()) - self.assertRaises(StopIteration, cursor.__next__) + with self.assertRaises(StopIteration): + cursor.__next__() cursor.rewind() items = cursor.to_list() self.assertEqual(len(items), 2) @@ -484,12 +513,12 @@ class TestGridfs(IntegrationTest): self.fs.put(data, filename="f") self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.get_version("f").read()) + self.assertEqual(data, (self.fs.get_version("f")).read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS((self.rs_or_single_client(w=0)).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -524,7 +553,7 @@ class TestGridfsReplicaSet(IntegrationTest): self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) oid = fs.put(b"foo") - content = fs.get(oid).read() + content = (fs.get(oid)).read() self.assertEqual(b"foo", content) def test_gridfs_secondary(self): @@ -538,7 +567,8 @@ class TestGridfsReplicaSet(IntegrationTest): fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, fs.put, b"foo") + with self.assertRaises(NotPrimaryError): + fs.put(b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to @@ -552,8 +582,10 @@ class TestGridfsReplicaSet(IntegrationTest): fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, fs.get_last_version) - self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8") + with self.assertRaises(NoFile): + fs.get_last_version() + with self.assertRaises(NotPrimaryError): + fs.put("data", encoding="utf-8") if __name__ == "__main__": diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 0af4dce81..e7486cb23 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -16,12 +16,14 @@ """Tests for the gridfs package.""" from __future__ import annotations +import asyncio import datetime import itertools import sys import threading import time from io import BytesIO +from test.helpers import ConcurrentRunner from unittest.mock import patch sys.path[0:0] = [""] @@ -44,10 +46,12 @@ from pymongo.errors import ( from pymongo.read_preferences import ReadPreference from pymongo.synchronous.mongo_client import MongoClient +_IS_SYNC = True -class JustWrite(threading.Thread): + +class JustWrite(ConcurrentRunner): def __init__(self, gfs, num): - threading.Thread.__init__(self) + super().__init__() self.gfs = gfs self.num = num self.daemon = True @@ -59,9 +63,9 @@ class JustWrite(threading.Thread): file.close() -class JustRead(threading.Thread): +class JustRead(ConcurrentRunner): def __init__(self, gfs, num, results): - threading.Thread.__init__(self) + super().__init__() self.gfs = gfs self.num = num self.results = results @@ -89,12 +93,13 @@ class TestGridfs(IntegrationTest): def test_basic(self): oid = self.fs.upload_from_stream("test_filename", b"hello world") - self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) - self.assertRaises(NoFile, self.fs.open_download_stream, oid) + with self.assertRaises(NoFile): + self.fs.open_download_stream(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -111,7 +116,7 @@ class TestGridfs(IntegrationTest): def test_empty_file(self): oid = self.fs.upload_from_stream("test_filename", b"") - self.assertEqual(b"", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"", (self.fs.open_download_stream(oid)).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -128,10 +133,12 @@ class TestGridfs(IntegrationTest): self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.open_download_stream(files_id) - self.assertRaises(CorruptGridFile, out.read) + with self.assertRaises(CorruptGridFile): + out.read() out = self.fs.open_download_stream(files_id) - self.assertRaises(CorruptGridFile, out.readline) + with self.assertRaises(CorruptGridFile): + out.readline() finally: self.fs.delete(files_id) @@ -146,13 +153,13 @@ class TestGridfs(IntegrationTest): self.assertTrue( any( info.get("key") == [("files_id", 1), ("n", 1)] - for info in chunks.index_information().values() + for info in (chunks.index_information()).values() ) ) self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) @@ -174,25 +181,27 @@ class TestGridfs(IntegrationTest): self.assertTrue( any( info.get("key") == [("filename", 1), ("uploadDate", 1)] - for info in files.index_information().values() + for info in (files.index_information()).values() ) ) files.drop() def test_alt_collection(self): oid = self.alt.upload_from_stream("test_filename", b"hello world") - self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.alt.open_download_stream(oid)).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) - self.assertRaises(NoFile, self.alt.open_download_stream, oid) + with self.assertRaises(NoFile): + self.alt.open_download_stream(oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) - self.assertRaises(NoFile, self.alt.open_download_stream, "foo") + with self.assertRaises(NoFile): + self.alt.open_download_stream("foo") self.alt.upload_from_stream("foo", b"hello world") - self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read()) + self.assertEqual(b"hello world", (self.alt.open_download_stream_by_name("foo")).read()) self.alt.upload_from_stream("mike", b"") self.alt.upload_from_stream("test", b"foo") @@ -200,7 +209,7 @@ class TestGridfs(IntegrationTest): self.assertEqual( {"mike", "test", "hello world", "foo"}, - {k["filename"] for k in list(self.db.alt.files.find())}, + {k["filename"] for k in self.db.alt.files.find().to_list()}, ) def test_threaded_reads(self): @@ -240,13 +249,14 @@ class TestGridfs(IntegrationTest): two = two._id three = self.fs.upload_from_stream("test", b"baz") - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(three) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(two) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test")).read()) self.fs.delete(one) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test") + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test") def test_get_version(self): self.fs.upload_from_stream("test", b"foo") @@ -256,28 +266,30 @@ class TestGridfs(IntegrationTest): self.fs.upload_from_stream("test", b"baz") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=0).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=1).read()) - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=2).read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=0)).read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=1)).read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=2)).read()) - self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=-1).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=-2).read()) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=-3).read()) + self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=-1)).read()) + self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=-2)).read()) + self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=-3)).read()) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test", revision=3) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("test", revision=-4) def test_upload_from_stream(self): oid = self.fs.upload_from_stream("test_file", BytesIO(b"hello world"), chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read()) def test_upload_from_stream_with_id(self): oid = ObjectId() self.fs.upload_from_stream_with_id( oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1 ) - self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"custom id", (self.fs.open_download_stream(oid)).read()) @patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3) @client_context.require_failCommand_fail_point @@ -316,14 +328,14 @@ class TestGridfs(IntegrationTest): gin = self.fs.open_upload_stream("from_stream") gin.write(b"from stream") gin.close() - self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read()) + self.assertEqual(b"from stream", (self.fs.open_download_stream(gin._id)).read()) def test_open_upload_stream_with_id(self): oid = ObjectId() gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") gin.write(b"from stream with custom id") gin.close() - self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read()) + self.assertEqual(b"from stream with custom id", (self.fs.open_download_stream(oid)).read()) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -345,12 +357,12 @@ class TestGridfs(IntegrationTest): client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) - self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) + with self.assertRaises(ServerSelectionTimeoutError): + gfs.delete(0) gfs = gridfs.GridFSBucket(cdb) - self.assertRaises( - ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b"" - ) # Still no connection. + with self.assertRaises(ServerSelectionTimeoutError): + gfs.upload_from_stream("test", b"") # Still no connection. def test_gridfs_find(self): self.fs.upload_from_stream("two", b"test2") @@ -366,14 +378,15 @@ class TestGridfs(IntegrationTest): cursor = self.fs.find( {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2 ) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) cursor.rewind() - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test1", gout.read()) - gout = next(cursor) + gout = cursor.next() self.assertEqual(b"test2+", gout.read()) - self.assertRaises(StopIteration, cursor.__next__) + with self.assertRaises(StopIteration): + cursor.next() cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) @@ -383,20 +396,21 @@ class TestGridfs(IntegrationTest): self.fs.upload_from_stream("f", data) self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.open_download_stream_by_name("f").read()) + self.assertEqual(data, (self.fs.open_download_stream_by_name("f")).read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test) + gridfs.GridFSBucket((self.rs_or_single_client(w=0)).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b"testing") - self.assertEqual(b"testing", self.fs.open_download_stream_by_name("first_name").read()) + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("first_name")).read()) self.fs.rename(_id, "second_name") - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name") - self.assertEqual(b"testing", self.fs.open_download_stream_by_name("second_name").read()) + with self.assertRaises(NoFile): + self.fs.open_download_stream_by_name("first_name") + self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("second_name")).read()) @patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", 5) def test_abort(self): @@ -407,7 +421,8 @@ class TestGridfs(IntegrationTest): self.assertEqual(3, self.db.fs.chunks.count_documents({"files_id": gin._id})) gin.abort() self.assertTrue(gin.closed) - self.assertRaises(ValueError, gin.write, b"test4") + with self.assertRaises(ValueError): + gin.write(b"test4") self.assertEqual(0, self.db.fs.chunks.count_documents({"files_id": gin._id})) def test_download_to_stream(self): @@ -490,7 +505,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest): gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") oid = gfs.upload_from_stream("test_filename", b"foo") - content = gfs.open_download_stream(oid).read() + content = (gfs.open_download_stream(oid)).read() self.assertEqual(b"foo", content) def test_gridfs_secondary(self): @@ -504,7 +519,8 @@ class TestGridfsBucketReplicaSet(IntegrationTest): gfs = gridfs.GridFSBucket(secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"foo") + with self.assertRaises(NotPrimaryError): + gfs.upload_from_stream("test_filename", b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to @@ -518,8 +534,10 @@ class TestGridfsBucketReplicaSet(IntegrationTest): gfs = gridfs.GridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename") - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data") + with self.assertRaises(NoFile): + gfs.open_download_stream_by_name("test_filename") + with self.assertRaises(NotPrimaryError): + gfs.upload_from_stream("test_filename", b"data") if __name__ == "__main__": diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 984b967f5..3e7f9a867 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -17,6 +17,7 @@ from __future__ import annotations import os import sys +from pathlib import Path from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError @@ -43,11 +44,17 @@ from test.utils_selection_tests import ( make_server_description, ) +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.path.join("server_selection", "server_selection"), -) +if _IS_SYNC: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent, "server_selection", "server_selection" + ) +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "server_selection" + ) class SelectionStoreSelector: @@ -61,7 +68,7 @@ class SelectionStoreSelector: return selection -class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore pass @@ -79,13 +86,12 @@ class TestCustomServerSelectorFunction(IntegrationTest): client = self.rs_or_single_client( server_selector=custom_selector, event_listeners=[listener] ) - self.addCleanup(client.close) coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, "testdb") # Wait the node list to be fully populated. def all_hosts_started(): - return len(client.admin.command(HelloCompat.LEGACY_CMD)["hosts"]) == len( + return len((client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len( client._topology._description.readable_servers ) @@ -121,7 +127,6 @@ class TestCustomServerSelectorFunction(IntegrationTest): # Client setup. mongo_client = self.rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection - self.addCleanup(mongo_client.close) self.addCleanup(mongo_client.drop_database, "testdb") # Do N operations and test selector is called at least N times. diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 05772fa38..7ccd4b529 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -15,9 +15,12 @@ """Test the topology module's Server Selection Spec implementation.""" from __future__ import annotations +import asyncio import os import threading +from pathlib import Path from test import IntegrationTest, client_context, unittest +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, OvertCommandListener, @@ -32,10 +35,14 @@ from pymongo.monitoring import ConnectionReadyEvent from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference +_IS_SYNC = True # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) class TestAllScenarios(unittest.TestCase): @@ -92,7 +99,7 @@ class CustomSpecTestCreator(SpecTestCreator): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() -class FinderThread(threading.Thread): +class FinderTask(ConcurrentRunner): def __init__(self, collection, iterations): super().__init__() self.daemon = True @@ -109,17 +116,17 @@ class FinderThread(threading.Thread): class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_THREADS = 10 - threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - for thread in threads: - self.assertTrue(thread.passed) + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_THREADS) + self.assertEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} diff --git a/test/utils.py b/test/utils.py index 5c1e0bfb7..e089b3fc2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -666,6 +666,11 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. @@ -827,7 +832,7 @@ async def async_get_pools(client): """Get all pools.""" return [ server.pool - async for server in await (await client._get_topology()).select_servers( + for server in await (await client._get_topology()).select_servers( any_server_selector, _Op.TEST ) ] diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 2d21888e2..9667ea701 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -18,96 +18,28 @@ from __future__ import annotations import datetime import os import sys +from test import PyMongoTestCase sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor from test.utils import MockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) from bson import json_util -from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.common import HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.hello import Hello, HelloCompat from pymongo.operations import _Op -from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology - -def get_addresses(server_list): - seeds = [] - hosts = [] - for server in server_list: - seeds.append(clean_node(server["address"])) - hosts.append(server["address"]) - return seeds, hosts - - -def make_last_write_date(server): - epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - millis = server.get("lastWrite", {}).get("lastWriteDate") - if millis: - diff = ((millis % 1000) + 1000) % 1000 - seconds = (millis - diff) / 1000 - micros = diff * 1000 - return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) - else: - # "Unknown" server. - return epoch - - -def make_server_description(server, hosts): - """Make a ServerDescription from server info in a JSON test.""" - server_type = server["type"] - if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server["address"]), Hello({})) - - hello_response = {"ok": True, "hosts": hosts} - if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response["setName"] = "rs" - - if server_type == "RSPrimary": - hello_response[HelloCompat.LEGACY_CMD] = True - elif server_type == "RSSecondary": - hello_response["secondary"] = True - elif server_type == "Mongos": - hello_response["msg"] = "isdbgrid" - elif server_type == "RSGhost": - hello_response["isreplicaset"] = True - elif server_type == "RSArbiter": - hello_response["arbiterOnly"] = True - - hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - - for field in "maxWireVersion", "tags", "idleWritePeriodMillis": - if field in server: - hello_response[field] = server[field] - - hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) - - # Sets _last_update_time to now. - sd = ServerDescription( - clean_node(server["address"]), - Hello(hello_response), - round_trip_time=server["avg_rtt_ms"] / 1000.0, - ) - - if "lastUpdateTime" in server: - sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. - - return sd - - -def get_topology_type_name(scenario_def): - td = scenario_def["topology_description"] - name = td["type"] - if name == "Unknown": - # PyMongo never starts a topology in type Unknown. - return "Sharded" if len(td["servers"]) > 1 else "Single" - else: - return name +_IS_SYNC = True def get_topology_settings_dict(**kwargs): @@ -244,7 +176,7 @@ def create_test(scenario_def): def create_selection_tests(test_dir): - class TestAllScenarios(unittest.TestCase): + class TestAllScenarios(PyMongoTestCase): pass for dirpath, _, filenames in os.walk(test_dir): diff --git a/test/utils_selection_tests_shared.py b/test/utils_selection_tests_shared.py new file mode 100644 index 000000000..dbaed1034 --- /dev/null +++ b/test/utils_selection_tests_shared.py @@ -0,0 +1,100 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys + +sys.path[0:0] = [""] + +from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription + + +def get_addresses(server_list): + seeds = [] + hosts = [] + for server in server_list: + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) + return seeds, hosts + + +def make_last_write_date(server): + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) + millis = server.get("lastWrite", {}).get("lastWriteDate") + if millis: + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) / 1000 + micros = diff * 1000 + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) + else: + # "Unknown" server. + return epoch + + +def make_server_description(server, hosts): + """Make a ServerDescription from server info in a JSON test.""" + server_type = server["type"] + if server_type in ("Unknown", "PossiblePrimary"): + return ServerDescription(clean_node(server["address"]), Hello({})) + + hello_response = {"ok": True, "hosts": hosts} + if server_type not in ("Standalone", "Mongos", "RSGhost"): + hello_response["setName"] = "rs" + + if server_type == "RSPrimary": + hello_response[HelloCompat.LEGACY_CMD] = True + elif server_type == "RSSecondary": + hello_response["secondary"] = True + elif server_type == "Mongos": + hello_response["msg"] = "isdbgrid" + elif server_type == "RSGhost": + hello_response["isreplicaset"] = True + elif server_type == "RSArbiter": + hello_response["arbiterOnly"] = True + + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} + + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": + if field in server: + hello_response[field] = server[field] + + hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) + + # Sets _last_update_time to now. + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) + + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. + + return sd + + +def get_topology_type_name(scenario_def): + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": + # PyMongo never starts a topology in type Unknown. + return "Sharded" if len(td["servers"]) > 1 else "Single" + else: + return name diff --git a/tools/synchro.py b/tools/synchro.py index 7e7aeec3a..69a2f07ba 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -122,7 +122,9 @@ replacements = { "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", + "StopAsyncIteration": "StopIteration", "create_async_event": "create_event", + "async_joinall": "joinall", } docstring_replacements: dict[tuple[str, str], str] = { @@ -169,7 +171,7 @@ gridfs_files = [ def async_only_test(f: str) -> bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py"] + return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"] test_files = [ @@ -202,6 +204,7 @@ converted_tests = [ "test_comment.py", "test_common.py", "test_connection_logging.py", + "test_connection_monitoring.py", "test_connections_survive_primary_stepdown_spec.py", "test_create_entities.py", "test_crud_unified.py", @@ -212,12 +215,14 @@ converted_tests = [ "test_dns.py", "test_encryption.py", "test_examples.py", + "test_grid_file.py", + "test_gridfs.py", + "test_gridfs_bucket.py", + "test_gridfs_spec.py", "test_heartbeat_monitoring.py", "test_index_management.py", - "test_grid_file.py", - "test_load_balancer.py", "test_json_util_integration.py", - "test_gridfs_spec.py", + "test_load_balancer.py", "test_logger.py", "test_max_staleness.py", "test_monitoring.py", @@ -233,9 +238,11 @@ converted_tests = [ "test_retryable_writes_unified.py", "test_run_command.py", "test_sdam_monitoring_spec.py", + "test_server_selection.py", + "test_server_selection_in_window.py", "test_server_selection_logging.py", - "test_session.py", "test_server_selection_rtt.py", + "test_session.py", "test_sessions_unified.py", "test_srv_polling.py", "test_ssl.py", @@ -245,6 +252,7 @@ converted_tests = [ "test_unified_format.py", "test_versioned_api_integration.py", "unified_format.py", + "utils_selection_tests.py", ] diff --git a/uv.lock b/uv.lock index e7f09f66f..a2e951e76 100644 --- a/uv.lock +++ b/uv.lock @@ -997,7 +997,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5 [[package]] name = "pymongo" -version = "4.11.0.dev0" +version = "4.12.0.dev0" source = { editable = "." } dependencies = [ { name = "dnspython" },