Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-02-12 10:40:39 -06:00
commit e296cf9f1b
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
43 changed files with 3399 additions and 333 deletions

View File

@ -227,6 +227,186 @@ tasks:
- noauth
- nossl
- sync_async
- name: test-4.2-standalone-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- standalone
- auth
- ssl
- sync
- name: test-4.2-standalone-auth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- standalone
- auth
- ssl
- async
- name: test-4.2-standalone-auth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- standalone
- auth
- ssl
- sync_async
- name: test-4.2-standalone-noauth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- standalone
- noauth
- ssl
- sync
- name: test-4.2-standalone-noauth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- standalone
- noauth
- ssl
- async
- name: test-4.2-standalone-noauth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- standalone
- noauth
- ssl
- sync_async
- name: test-4.2-standalone-noauth-nossl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- standalone
- noauth
- nossl
- sync
- name: test-4.2-standalone-noauth-nossl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- standalone
- noauth
- nossl
- async
- name: test-4.2-standalone-noauth-nossl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: server
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- standalone
- noauth
- nossl
- sync_async
- name: test-4.4-standalone-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
@ -1667,6 +1847,186 @@ tasks:
- noauth
- nossl
- sync_async
- name: test-4.2-replica_set-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- replica_set
- auth
- ssl
- sync
- name: test-4.2-replica_set-auth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- replica_set
- auth
- ssl
- async
- name: test-4.2-replica_set-auth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- replica_set
- auth
- ssl
- sync_async
- name: test-4.2-replica_set-noauth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- replica_set
- noauth
- ssl
- sync
- name: test-4.2-replica_set-noauth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- replica_set
- noauth
- ssl
- async
- name: test-4.2-replica_set-noauth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- replica_set
- noauth
- ssl
- sync_async
- name: test-4.2-replica_set-noauth-nossl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- replica_set
- noauth
- nossl
- sync
- name: test-4.2-replica_set-noauth-nossl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- replica_set
- noauth
- nossl
- async
- name: test-4.2-replica_set-noauth-nossl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: replica_set
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- replica_set
- noauth
- nossl
- sync_async
- name: test-4.4-replica_set-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
@ -3107,6 +3467,186 @@ tasks:
- noauth
- nossl
- sync_async
- name: test-4.2-sharded_cluster-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- sharded_cluster
- auth
- ssl
- sync
- name: test-4.2-sharded_cluster-auth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- sharded_cluster
- auth
- ssl
- async
- name: test-4.2-sharded_cluster-auth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: auth
SSL: ssl
- func: run tests
vars:
AUTH: auth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- sharded_cluster
- auth
- ssl
- sync_async
- name: test-4.2-sharded_cluster-noauth-ssl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- sharded_cluster
- noauth
- ssl
- sync
- name: test-4.2-sharded_cluster-noauth-ssl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- sharded_cluster
- noauth
- ssl
- async
- name: test-4.2-sharded_cluster-noauth-ssl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: ssl
- func: run tests
vars:
AUTH: noauth
SSL: ssl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- sharded_cluster
- noauth
- ssl
- sync_async
- name: test-4.2-sharded_cluster-noauth-nossl-sync
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync
TEST_SUITES: default
tags:
- "4.2"
- sharded_cluster
- noauth
- nossl
- sync
- name: test-4.2-sharded_cluster-noauth-nossl-async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: async
TEST_SUITES: default_async
tags:
- "4.2"
- sharded_cluster
- noauth
- nossl
- async
- name: test-4.2-sharded_cluster-noauth-nossl-sync_async
commands:
- func: bootstrap mongo-orchestration
vars:
VERSION: "4.2"
TOPOLOGY: sharded_cluster
AUTH: noauth
SSL: nossl
- func: run tests
vars:
AUTH: noauth
SSL: nossl
SYNC: sync_async
TEST_SUITES: ""
tags:
- "4.2"
- sharded_cluster
- noauth
- nossl
- sync_async
- name: test-4.4-sharded_cluster-auth-ssl-sync
commands:
- func: bootstrap mongo-orchestration

View File

@ -817,10 +817,23 @@ buildvariants:
PYTHON_BINARY: /opt/python/3.13/bin/python3
# Ocsp tests
- name: ocsp-rhel8-v4.4-python3.9
- name: ocsp-rhel8-v4.2-python3.9
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v4.4 Python3.9
display_name: OCSP RHEL8 v4.2 Python3.9
run_on:
- rhel87-small
batchtime: 20160
expansions:
AUTH: noauth
SSL: ssl
TOPOLOGY: server
VERSION: "4.2"
PYTHON_BINARY: /opt/python/3.9/bin/python3
- name: ocsp-rhel8-v4.4-python3.10
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v4.4 Python3.10
run_on:
- rhel87-small
batchtime: 20160
@ -829,11 +842,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: "4.4"
PYTHON_BINARY: /opt/python/3.9/bin/python3
- name: ocsp-rhel8-v5.0-python3.10
PYTHON_BINARY: /opt/python/3.10/bin/python3
- name: ocsp-rhel8-v5.0-python3.11
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v5.0 Python3.10
display_name: OCSP RHEL8 v5.0 Python3.11
run_on:
- rhel87-small
batchtime: 20160
@ -842,11 +855,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: "5.0"
PYTHON_BINARY: /opt/python/3.10/bin/python3
- name: ocsp-rhel8-v6.0-python3.11
PYTHON_BINARY: /opt/python/3.11/bin/python3
- name: ocsp-rhel8-v6.0-python3.12
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v6.0 Python3.11
display_name: OCSP RHEL8 v6.0 Python3.12
run_on:
- rhel87-small
batchtime: 20160
@ -855,11 +868,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: "6.0"
PYTHON_BINARY: /opt/python/3.11/bin/python3
- name: ocsp-rhel8-v7.0-python3.12
PYTHON_BINARY: /opt/python/3.12/bin/python3
- name: ocsp-rhel8-v7.0-python3.13
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v7.0 Python3.12
display_name: OCSP RHEL8 v7.0 Python3.13
run_on:
- rhel87-small
batchtime: 20160
@ -868,11 +881,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: "7.0"
PYTHON_BINARY: /opt/python/3.12/bin/python3
- name: ocsp-rhel8-v8.0-python3.13
PYTHON_BINARY: /opt/python/3.13/bin/python3
- name: ocsp-rhel8-v8.0-pypy3.10
tasks:
- name: .ocsp
display_name: OCSP RHEL8 v8.0 Python3.13
display_name: OCSP RHEL8 v8.0 PyPy3.10
run_on:
- rhel87-small
batchtime: 20160
@ -881,11 +894,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: "8.0"
PYTHON_BINARY: /opt/python/3.13/bin/python3
- name: ocsp-rhel8-rapid-pypy3.10
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
- name: ocsp-rhel8-rapid-python3.9
tasks:
- name: .ocsp
display_name: OCSP RHEL8 rapid PyPy3.10
display_name: OCSP RHEL8 rapid Python3.9
run_on:
- rhel87-small
batchtime: 20160
@ -894,11 +907,11 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: rapid
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
- name: ocsp-rhel8-latest-python3.9
PYTHON_BINARY: /opt/python/3.9/bin/python3
- name: ocsp-rhel8-latest-python3.10
tasks:
- name: .ocsp
display_name: OCSP RHEL8 latest Python3.9
display_name: OCSP RHEL8 latest Python3.10
run_on:
- rhel87-small
batchtime: 20160
@ -907,7 +920,7 @@ buildvariants:
SSL: ssl
TOPOLOGY: server
VERSION: latest
PYTHON_BINARY: /opt/python/3.9/bin/python3
PYTHON_BINARY: /opt/python/3.10/bin/python3
- name: ocsp-win64-v4.4-python3.9
tasks:
- name: .ocsp-rsa !.ocsp-staple
@ -1066,6 +1079,19 @@ buildvariants:
PYTHON_BINARY: /opt/python/3.9/bin/python3
# Server tests
- name: test-rhel8-python3.9-cov-no-c
tasks:
- name: .standalone .sync_async
- name: .replica_set .sync_async
- name: .sharded_cluster .sync_async
display_name: "* Test RHEL8 Python3.9 cov No C"
run_on:
- rhel87-small
expansions:
COVERAGE: coverage
NO_EXT: "1"
PYTHON_BINARY: /opt/python/3.9/bin/python3
tags: [coverage_tag]
- name: test-rhel8-python3.9-cov
tasks:
- name: .standalone .sync_async
@ -1078,6 +1104,19 @@ buildvariants:
COVERAGE: coverage
PYTHON_BINARY: /opt/python/3.9/bin/python3
tags: [coverage_tag]
- name: test-rhel8-python3.13-cov-no-c
tasks:
- name: .standalone .sync_async
- name: .replica_set .sync_async
- name: .sharded_cluster .sync_async
display_name: "* Test RHEL8 Python3.13 cov No C"
run_on:
- rhel87-small
expansions:
COVERAGE: coverage
NO_EXT: "1"
PYTHON_BINARY: /opt/python/3.13/bin/python3
tags: [coverage_tag]
- name: test-rhel8-python3.13-cov
tasks:
- name: .standalone .sync_async
@ -1090,6 +1129,19 @@ buildvariants:
COVERAGE: coverage
PYTHON_BINARY: /opt/python/3.13/bin/python3
tags: [coverage_tag]
- name: test-rhel8-pypy3.10-cov-no-c
tasks:
- name: .standalone .sync_async
- name: .replica_set .sync_async
- name: .sharded_cluster .sync_async
display_name: "* Test RHEL8 PyPy3.10 cov No C"
run_on:
- rhel87-small
expansions:
COVERAGE: coverage
NO_EXT: "1"
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
tags: [coverage_tag]
- name: test-rhel8-pypy3.10-cov
tasks:
- name: .standalone .sync_async
@ -1338,6 +1390,7 @@ buildvariants:
- name: storage-inmemory-rhel8-python3.9
tasks:
- name: .standalone .noauth .nossl .4.0 .sync_async
- name: .standalone .noauth .nossl .4.2 .sync_async
- name: .standalone .noauth .nossl .4.4 .sync_async
- name: .standalone .noauth .nossl .5.0 .sync_async
- name: .standalone .noauth .nossl .6.0 .sync_async

View File

@ -26,7 +26,7 @@ from shrub.v3.shrub_service import ShrubService
# Globals
##############
ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"]
ALL_VERSIONS = ["4.0", "4.2", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"]
CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"]
PYPYS = ["pypy3.10"]
ALL_PYTHONS = CPYTHONS + PYPYS
@ -279,8 +279,9 @@ def create_server_variants() -> list[BuildVariant]:
host = DEFAULT_HOST
# Prefix the display name with an asterisk so it is sorted first.
base_display_name = "* Test"
for python in [*MIN_MAX_PYTHON, PYPYS[-1]]:
for python, c_ext in product([*MIN_MAX_PYTHON, PYPYS[-1]], C_EXTS):
expansions = dict(COVERAGE="coverage")
handle_c_ext(c_ext, expansions)
display_name = get_display_name(base_display_name, host, python=python, **expansions)
variant = create_variant(
[f".{t} .sync_async" for t in TOPOLOGIES],

View File

@ -178,7 +178,7 @@ documentation including narrative docs, and the [Sphinx docstring format](https:
You can build the documentation locally by running:
```bash
just docs-build
just docs
```
When updating docs, it can be helpful to run the live docs server as:
@ -261,6 +261,11 @@ To prevent the `synchro` hook from accidentally overwriting code, it first check
of a file is changing and not its async counterpart, and will fail.
In the unlikely scenario that you want to override this behavior, first export `OVERRIDE_SYNCHRO_CHECK=1`.
Sometimes, the `synchro` hook will fail and introduce changes many previously unmodified files. This is due to static
Python errors, such as missing imports, incorrect syntax, or other fatal typos. To resolve these issues,
run `pre-commit run --all-files --hook-stage manual ruff` and fix all reported errors before running the `synchro`
hook again.
## Converting a test to async
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.

View File

@ -420,3 +420,10 @@ the collection:
DuplicateKeyError: E11000 duplicate key error index: test_database.profiles.$user_id_1 dup key: { : 212 }
.. seealso:: The MongoDB documentation on `indexes <https://www.mongodb.com/docs/manual/indexes/>`_
Task Cancellation
-----------------
`Cancelling <https://docs.python.org/3/library/asyncio-task.html#task-cancellation>`_ an asyncio Task
that is running a PyMongo operation is treated as a fatal interrupt. Any connections, cursors, and transactions
involved in a cancelled Task will be safely closed and cleaned up as part of the cancellation. If those resources are
also used elsewhere, attempting to utilize them after the cancellation will result in an error.

View File

@ -231,7 +231,7 @@ class AsyncGridFS:
try:
doc = await anext(cursor)
return AsyncGridOut(self._collection, file_document=doc, session=session)
except StopIteration:
except StopAsyncIteration:
raise NoFile("no version %d for filename %r" % (version, filename)) from None
async def get_last_version(

View File

@ -391,7 +391,8 @@ class AsyncChangeStream(Generic[_DocumentType]):
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
await self.close()
raise

View File

@ -21,7 +21,7 @@ Causally Consistent Reads
.. code-block:: python
with client.start_session(causal_consistency=True) as session:
async with client.start_session(causal_consistency=True) as session:
collection = client.db.collection
await collection.update_one({"_id": 1}, {"$set": {"x": 10}}, session=session)
secondary_c = collection.with_options(read_preference=ReadPreference.SECONDARY)
@ -53,8 +53,8 @@ operation:
orders = client.db.orders
inventory = client.db.inventory
with client.start_session() as session:
async with session.start_transaction():
async with client.start_session() as session:
async with await session.start_transaction():
await orders.insert_one({"sku": "abc123", "qty": 100}, session=session)
await inventory.update_one(
{"sku": "abc123", "qty": {"$gte": 100}},
@ -62,7 +62,7 @@ operation:
session=session,
)
Upon normal completion of ``async with session.start_transaction()`` block, the
Upon normal completion of ``async with await session.start_transaction()`` block, the
transaction automatically calls :meth:`AsyncClientSession.commit_transaction`.
If the block exits with an exception, the transaction automatically calls
:meth:`AsyncClientSession.abort_transaction`.
@ -113,7 +113,7 @@ replica set secondaries.
.. code-block:: python
# Each read using this session reads data from the same point in time.
with client.start_session(snapshot=True) as session:
async with client.start_session(snapshot=True) as session:
order = await orders.find_one({"sku": "abc123"}, session=session)
inventory = await inventory.find_one({"sku": "abc123"}, session=session)
@ -619,7 +619,7 @@ class AsyncClientSession:
await inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}},
{"$inc": {"qty": -100}}, session=session)
with client.start_session() as session:
async with client.start_session() as session:
await session.with_transaction(callback)
To pass arbitrary arguments to the ``callback``, wrap your callable
@ -628,7 +628,7 @@ class AsyncClientSession:
async def callback(session, custom_arg, custom_kwarg=None):
# Transaction operations...
with client.start_session() as session:
async with client.start_session() as session:
await session.with_transaction(
lambda s: callback(s, "custom_arg", custom_kwarg=1))
@ -697,7 +697,8 @@ class AsyncClientSession:
)
try:
ret = await callback(self)
except Exception as exc:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
if self.in_transaction:
await self.abort_transaction()
if (

View File

@ -1126,7 +1126,8 @@ class AsyncCursor(Generic[_DocumentType]):
self._killed = True
await self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
await self.close()
raise

View File

@ -127,8 +127,6 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -766,8 +764,6 @@ class AsyncClientEncryption(Generic[_DocumentType]):
await database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -276,7 +276,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
:param type_registry: instance of
:class:`~bson.codec_options.TypeRegistry` to enable encoding
and decoding of custom types.
:param datetime_conversion: Specifies how UTC datetimes should be decoded
:param kwargs: **Additional optional parameters available as keyword arguments:**
- `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded
within BSON. Valid options include 'datetime_ms' to return as a
DatetimeMS, 'datetime' to return as a datetime.datetime and
raising a ValueError for out-of-range values, 'datetime_auto' to
@ -284,9 +286,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
out-of-range and 'datetime_clamp' to clamp to the minimum and
maximum possible datetimes. Defaults to 'datetime'. See
:ref:`handling-out-of-range-datetimes` for details.
| **Other optional parameters can be passed as keyword arguments:**
- `directConnection` (optional): if ``True``, forces this client to
connect directly to the specified MongoDB host as a standalone.
If ``false``, the client connects to the entire replica set of
@ -2044,8 +2043,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2060,8 +2057,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
await self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2076,8 +2071,6 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
try:
await self._process_kill_cursors()
await self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -262,8 +262,6 @@ class Monitor(MonitorBase):
details = cast(Mapping[str, Any], exc.details)
await self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
@ -429,8 +427,6 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -494,8 +490,6 @@ class _RttMonitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except asyncio.CancelledError:
raise
except Exception:
await self._pool.reset()

View File

@ -559,7 +559,7 @@ class AsyncConnection:
)
except (OperationFailure, NotPrimaryError):
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)
@ -576,6 +576,7 @@ class AsyncConnection:
try:
await async_sendall(self.conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -586,6 +587,7 @@ class AsyncConnection:
"""
try:
return await receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -704,8 +706,6 @@ class AsyncConnection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -1269,6 +1269,7 @@ class Pool:
try:
sock = await _configured_socket(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
self.active_contexts.discard(tmp_context)
@ -1308,6 +1309,7 @@ class Pool:
handler.contribute_socket(conn, completed_handshake=False)
await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
@ -1369,6 +1371,7 @@ class Pool:
async with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
@ -1515,6 +1518,7 @@ class Pool:
async with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
if conn:
# We checked out a socket but authentication failed.

View File

@ -100,6 +100,7 @@ class AsyncPeriodicExecutor:
if not await self._target():
self._stopped = True
break
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self._stopped = True
raise
@ -232,6 +233,7 @@ class PeriodicExecutor:
if not self._target():
self._stopped = True
break
# Catch KeyboardInterrupt, etc. and cleanup.
except BaseException:
with self._lock:
self._stopped = True

View File

@ -389,7 +389,8 @@ class ChangeStream(Generic[_DocumentType]):
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.close()
raise

View File

@ -694,7 +694,8 @@ class ClientSession:
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
except Exception as exc:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
if self.in_transaction:
self.abort_transaction()
if (

View File

@ -1124,7 +1124,8 @@ class Cursor(Generic[_DocumentType]):
self._killed = True
self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.close()
raise

View File

@ -15,7 +15,6 @@
"""Support for explicit client-side field level encryption."""
from __future__ import annotations
import asyncio
import contextlib
import enum
import socket
@ -127,8 +126,6 @@ def _wrap_encryption_errors() -> Iterator[None]:
# BSON encoding/decoding errors are unrelated to encryption so
# we should propagate them unchanged.
raise
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptionError(exc) from exc
@ -760,8 +757,6 @@ class ClientEncryption(Generic[_DocumentType]):
database.create_collection(name=name, **kwargs),
encrypted_fields,
)
except asyncio.CancelledError:
raise
except Exception as exc:
raise EncryptedCollectionError(exc, encrypted_fields) from exc

View File

@ -274,7 +274,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
:param type_registry: instance of
:class:`~bson.codec_options.TypeRegistry` to enable encoding
and decoding of custom types.
:param datetime_conversion: Specifies how UTC datetimes should be decoded
:param kwargs: **Additional optional parameters available as keyword arguments:**
- `datetime_conversion` (optional): Specifies how UTC datetimes should be decoded
within BSON. Valid options include 'datetime_ms' to return as a
DatetimeMS, 'datetime' to return as a datetime.datetime and
raising a ValueError for out-of-range values, 'datetime_auto' to
@ -282,9 +284,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
out-of-range and 'datetime_clamp' to clamp to the minimum and
maximum possible datetimes. Defaults to 'datetime'. See
:ref:`handling-out-of-range-datetimes` for details.
| **Other optional parameters can be passed as keyword arguments:**
- `directConnection` (optional): if ``True``, forces this client to
connect directly to the specified MongoDB host as a standalone.
If ``false``, the client connects to the entire replica set of
@ -2038,8 +2037,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
@ -2054,8 +2051,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
for address, cursor_ids in address_to_cursor_ids.items():
try:
self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
@ -2070,8 +2065,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
try:
self._process_kill_cursors()
self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return

View File

@ -260,8 +260,6 @@ class Monitor(MonitorBase):
details = cast(Mapping[str, Any], exc.details)
self._topology.receive_cluster_time(details.get("$clusterTime"))
raise
except asyncio.CancelledError:
raise
except ReferenceError:
raise
except Exception as error:
@ -427,8 +425,6 @@ class SrvMonitor(MonitorBase):
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
@ -492,8 +488,6 @@ class _RttMonitor(MonitorBase):
except ReferenceError:
# Topology was garbage-collected.
self.close()
except asyncio.CancelledError:
raise
except Exception:
self._pool.reset()

View File

@ -559,7 +559,7 @@ class Connection:
)
except (OperationFailure, NotPrimaryError):
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)
@ -576,6 +576,7 @@ class Connection:
try:
sendall(self.conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -586,6 +587,7 @@ class Connection:
"""
try:
return receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -702,8 +704,6 @@ class Connection:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass
@ -1263,6 +1263,7 @@ class Pool:
try:
sock = _configured_socket(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
with self.lock:
self.active_contexts.discard(tmp_context)
@ -1302,6 +1303,7 @@ class Pool:
handler.contribute_socket(conn, completed_handshake=False)
conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
@ -1363,6 +1365,7 @@ class Pool:
with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
@ -1509,6 +1512,7 @@ class Pool:
with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
if conn:
# We checked out a socket but authentication failed.

View File

@ -66,7 +66,7 @@ class DummyMonitor:
def cancel_check(self):
pass
def join(self):
async def join(self):
pass
def open(self):
@ -75,7 +75,7 @@ class DummyMonitor:
def request_check(self):
pass
def close(self):
async def close(self):
self.opened = False

View File

@ -0,0 +1,126 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test that async cancellation performed by users clean up resources correctly."""
from __future__ import annotations
import asyncio
import sys
from test.utils import async_get_pool, delay, one
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected
class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
pool = await async_get_pool(self.client)
await self.client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(self.client.db.test.delete_many, {})
conn = one(pool.conns)
async def task():
await self.client.db.test.find_one({"$where": delay(0.2)})
task = asyncio.create_task(task())
await asyncio.sleep(0.1)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(conn.closed)
@async_client_context.require_transactions
async def test_async_cancellation_aborts_transaction(self):
await self.client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(self.client.db.test.delete_many, {})
session = self.client.start_session()
async def callback(session):
await self.client.db.test.find_one({"$where": delay(0.2)}, session=session)
async def task():
await session.with_transaction(callback)
task = asyncio.create_task(task())
await asyncio.sleep(0.1)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertFalse(session.in_transaction)
@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_cursor(self):
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
self.addAsyncCleanup(self.client.db.test.delete_many, {})
cursor = self.client.db.test.find({}, batch_size=1)
await cursor.next()
# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}
async def task():
async with self.fail_point(fail_command):
await cursor.next()
task = asyncio.create_task(task())
await asyncio.sleep(0.1)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(cursor._killed)
@async_client_context.require_change_streams
@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_change_stream(self):
self.addAsyncCleanup(self.client.db.test.delete_many, {})
change_stream = await self.client.db.test.watch(batch_size=2)
# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}
async def task():
async with self.fail_point(fail_command):
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
await change_stream.next()
task = asyncio.create_task(task())
await asyncio.sleep(0.1)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(change_stream._closed)

View File

@ -961,7 +961,6 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
@async_client_context.require_replica_set
@async_client_context.require_secondaries_count(1)
async def test_write_concern_failure_ordered(self):
self.skipTest("Skipping until PYTHON-4865 is resolved.")
details = None
# Ensure we don't raise on wnote.

View File

@ -0,0 +1,479 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Execute Transactions Spec tests."""
from __future__ import annotations
import asyncio
import os
import sys
import time
from pathlib import Path
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest
from test.asynchronous.pymongo_mocks import DummyMonitor
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask
from test.utils import (
CMAPListener,
async_client_context,
async_get_pool,
async_get_pools,
async_wait_until,
camel_to_snake,
)
from bson.objectid import ObjectId
from bson.son import SON
from pymongo.asynchronous.pool import PoolState, _PoolClosedError
from pymongo.errors import (
ConnectionFailure,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
)
from pymongo.monitoring import (
ConnectionCheckedInEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutFailedReason,
ConnectionCheckOutStartedEvent,
ConnectionClosedEvent,
ConnectionClosedReason,
ConnectionCreatedEvent,
ConnectionReadyEvent,
PoolClearedEvent,
PoolClosedEvent,
PoolCreatedEvent,
PoolReadyEvent,
)
from pymongo.read_preferences import ReadPreference
from pymongo.topology_description import updated_topology_description
_IS_SYNC = False
OBJECT_TYPES = {
# Event types.
"ConnectionCheckedIn": ConnectionCheckedInEvent,
"ConnectionCheckedOut": ConnectionCheckedOutEvent,
"ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent,
"ConnectionClosed": ConnectionClosedEvent,
"ConnectionCreated": ConnectionCreatedEvent,
"ConnectionReady": ConnectionReadyEvent,
"ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent,
"ConnectionPoolCreated": PoolCreatedEvent,
"ConnectionPoolReady": PoolReadyEvent,
"ConnectionPoolCleared": PoolClearedEvent,
"ConnectionPoolClosed": PoolClosedEvent,
# Error types.
"PoolClosedError": _PoolClosedError,
"WaitQueueTimeoutError": WaitQueueTimeoutError,
}
class AsyncTestCMAP(AsyncIntegrationTest):
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring")
# Test operations:
async def start(self, op):
"""Run the 'start' thread operation."""
target = op["target"]
thread = SpecRunnerTask(target)
await thread.start()
self.targets[target] = thread
async def wait(self, op):
"""Run the 'wait' operation."""
await asyncio.sleep(op["ms"] / 1000.0)
async def wait_for_thread(self, op):
"""Run the 'waitForThread' operation."""
target = op["target"]
thread = self.targets[target]
await thread.stop()
await thread.join()
if thread.exc:
raise thread.exc
self.assertFalse(thread.ops)
async def wait_for_event(self, op):
"""Run the 'waitForEvent' operation."""
event = OBJECT_TYPES[op["event"]]
count = op["count"]
timeout = op.get("timeout", 10000) / 1000.0
await async_wait_until(
lambda: self.listener.event_count(event) >= count,
f"find {count} {event} event(s)",
timeout=timeout,
)
async def check_out(self, op):
"""Run the 'checkOut' operation."""
label = op["label"]
async with self.pool.checkout() as conn:
# Call 'pin_cursor' so we can hold the socket.
conn.pin_cursor()
if label:
self.labels[label] = conn
else:
self.addAsyncCleanup(conn.close_conn, None)
async def check_in(self, op):
"""Run the 'checkIn' operation."""
label = op["connection"]
conn = self.labels[label]
await self.pool.checkin(conn)
async def ready(self, op):
"""Run the 'ready' operation."""
await self.pool.ready()
async def clear(self, op):
"""Run the 'clear' operation."""
if "interruptInUseConnections" in op:
await self.pool.reset(interrupt_connections=op["interruptInUseConnections"])
else:
await self.pool.reset()
async def close(self, op):
"""Run the 'close' operation."""
await self.pool.close()
async def run_operation(self, op):
"""Run a single operation in a test."""
op_name = camel_to_snake(op["name"])
thread = op["thread"]
meth = getattr(self, op_name)
if thread:
await self.targets[thread].schedule(lambda: meth(op))
else:
await meth(op)
async def run_operations(self, ops):
"""Run a test's operations."""
for op in ops:
self._ops.append(op)
await self.run_operation(op)
def check_object(self, actual, expected):
"""Assert that the actual object matches the expected object."""
self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]])
for attr, expected_val in expected.items():
if attr == "type":
continue
c2s = camel_to_snake(attr)
if c2s == "interrupt_in_use_connections":
c2s = "interrupt_connections"
actual_val = getattr(actual, c2s)
if expected_val == 42:
self.assertIsNotNone(actual_val)
else:
self.assertEqual(actual_val, expected_val)
def check_event(self, actual, expected):
"""Assert that the actual event matches the expected event."""
self.check_object(actual, expected)
def actual_events(self, ignore):
"""Return all the non-ignored events."""
ignore = tuple(OBJECT_TYPES[name] for name in ignore)
return [event for event in self.listener.events if not isinstance(event, ignore)]
def check_events(self, events, ignore):
"""Check the events of a test."""
actual_events = self.actual_events(ignore)
for actual, expected in zip(actual_events, events):
self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}")
self.check_event(actual, expected)
if len(events) > len(actual_events):
self.fail(f"missing events: {events[len(actual_events) :]!r}")
def check_error(self, actual, expected):
message = expected.pop("message")
self.check_object(actual, expected)
self.assertIn(message, str(actual))
async def _set_fail_point(self, client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
await client.admin.command(cmd)
async def set_fail_point(self, command_args):
if not async_client_context.supports_failCommand_fail_point:
self.skipTest("failCommand fail point must be supported")
await self._set_fail_point(self.client, command_args)
async def run_scenario(self, scenario_def, test):
"""Run a CMAP spec test."""
self.logs: list = []
self.assertEqual(scenario_def["version"], 1)
self.assertIn(scenario_def["style"], ["unit", "integration"])
self.listener = CMAPListener()
self._ops: list = []
# Configure the fail point before creating the client.
if "failPoint" in test:
fp = test["failPoint"]
await self.set_fail_point(fp)
self.addAsyncCleanup(
self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"}
)
opts = test["poolOptions"].copy()
opts["event_listeners"] = [self.listener]
opts["_monitor_class"] = DummyMonitor
opts["connect"] = False
# Support backgroundThreadIntervalMS, default to 50ms.
interval = opts.pop("backgroundThreadIntervalMS", 50)
if interval < 0:
kill_cursor_frequency = 99999999
else:
kill_cursor_frequency = interval / 1000.0
with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05):
client = await self.async_single_client(**opts)
# Update the SD to a known type because the DummyMonitor will not.
# Note we cannot simply call topology.on_change because that would
# internally call pool.ready() which introduces unexpected
# PoolReadyEvents. Instead, update the initial state before
# opening the Topology.
td = async_client_context.client._topology.description
sd = td.server_descriptions()[
(await async_client_context.host, await async_client_context.port)
]
client._topology._description = updated_topology_description(
client._topology._description, sd
)
# When backgroundThreadIntervalMS is negative we do not start the
# background thread to ensure it never runs.
if interval < 0:
await client._topology.open()
else:
await client._get_topology()
self.pool = list(client._topology._servers.values())[0].pool
# Map of target names to Thread objects.
self.targets: dict = {}
# Map of label names to AsyncConnection objects
self.labels: dict = {}
async def cleanup():
for t in self.targets.values():
await t.stop()
for t in self.targets.values():
await t.join(5)
for conn in self.labels.values():
conn.close_conn(None)
self.addAsyncCleanup(cleanup)
try:
if test["error"]:
with self.assertRaises(PyMongoError) as ctx:
await self.run_operations(test["operations"])
self.check_error(ctx.exception, test["error"])
else:
await self.run_operations(test["operations"])
self.check_events(test["events"], test["ignore"])
except Exception:
# Print the events after a test failure.
print("\nFailed test: {!r}".format(test["description"]))
print("Operations:")
for op in self._ops:
print(op)
print("Threads:")
print(self.targets)
print("AsyncConnections:")
print(self.labels)
print("Events:")
for event in self.listener.events:
print(event)
print("Log:")
for log in self.logs:
print(log)
raise
POOL_OPTIONS = {
"maxPoolSize": 50,
"minPoolSize": 1,
"maxIdleTimeMS": 10000,
"waitQueueTimeoutMS": 10000,
}
#
# Prose tests. Numbers correspond to the prose test number in the spec.
#
async def test_1_client_connection_pool_options(self):
client = await self.async_rs_or_single_client(**self.POOL_OPTIONS)
pool_opts = (await async_get_pool(client)).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
async def test_2_all_client_pools_have_same_options(self):
client = await self.async_rs_or_single_client(**self.POOL_OPTIONS)
await client.admin.command("ping")
# Discover at least one secondary.
if await async_client_context.has_secondaries:
await client.admin.command("ping", read_preference=ReadPreference.SECONDARY)
pools = await async_get_pools(client)
pool_opts = pools[0].opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
for pool in pools[1:]:
self.assertEqual(pool.opts, pool_opts)
async def test_3_uri_connection_pool_options(self):
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
uri = f"mongodb://{await async_client_context.pair}/?{opts}"
client = await self.async_rs_or_single_client(uri)
pool_opts = (await async_get_pool(client)).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
async def test_4_subscribe_to_events(self):
listener = CMAPListener()
client = await self.async_single_client(event_listeners=[listener])
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
# Creates a new connection.
await client.admin.command("ping")
self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1)
self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1)
self.assertEqual(listener.event_count(ConnectionReadyEvent), 1)
self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1)
self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1)
# Uses the existing connection.
await client.admin.command("ping")
self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2)
self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2)
self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2)
await client.close()
self.assertEqual(listener.event_count(PoolClosedEvent), 1)
self.assertEqual(listener.event_count(ConnectionClosedEvent), 1)
async def test_5_check_out_fails_connection_error(self):
listener = CMAPListener()
client = await self.async_single_client(event_listeners=[listener])
pool = await async_get_pool(client)
def mock_connect(*args, **kwargs):
raise ConnectionFailure("connect failed")
pool.connect = mock_connect
# Un-patch Pool.connect to break the cyclic reference.
self.addCleanup(delattr, pool, "connect")
# Attempt to create a new connection.
with self.assertRaisesRegex(ConnectionFailure, "connect failed"):
await client.admin.command("ping")
self.assertIsInstance(listener.events[0], PoolCreatedEvent)
self.assertIsInstance(listener.events[1], PoolReadyEvent)
self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent)
self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent)
self.assertIsInstance(listener.events[4], PoolClearedEvent)
failed_event = listener.events[3]
self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR)
@async_client_context.require_no_fips
async def test_5_check_out_fails_auth_error(self):
listener = CMAPListener()
client = await self.async_single_client_noauth(
username="notauser", password="fail", event_listeners=[listener]
)
# Attempt to create a new connection.
with self.assertRaisesRegex(OperationFailure, "failed"):
await client.admin.command("ping")
self.assertIsInstance(listener.events[0], PoolCreatedEvent)
self.assertIsInstance(listener.events[1], PoolReadyEvent)
self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent)
self.assertIsInstance(listener.events[3], ConnectionCreatedEvent)
# Error happens here.
self.assertIsInstance(listener.events[4], ConnectionClosedEvent)
self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent)
self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR)
#
# Extra non-spec tests
#
def assertRepr(self, obj):
new_obj = eval(repr(obj))
self.assertEqual(type(new_obj), type(obj))
self.assertEqual(repr(new_obj), repr(obj))
async def test_events_repr(self):
host = ("localhost", 27017)
self.assertRepr(ConnectionCheckedInEvent(host, 1))
self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic()))
self.assertRepr(
ConnectionCheckOutFailedEvent(
host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic()
)
)
self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED))
self.assertRepr(ConnectionCreatedEvent(host, 1))
self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic()))
self.assertRepr(ConnectionCheckOutStartedEvent(host))
self.assertRepr(PoolCreatedEvent(host, {}))
self.assertRepr(PoolClearedEvent(host))
self.assertRepr(PoolClearedEvent(host, service_id=ObjectId()))
self.assertRepr(PoolClosedEvent(host))
async def test_close_leaves_pool_unpaused(self):
listener = CMAPListener()
client = await self.async_single_client(event_listeners=[listener])
await client.admin.command("ping")
pool = await async_get_pool(client)
await client.close()
self.assertEqual(1, listener.event_count(PoolClosedEvent))
self.assertEqual(PoolState.CLOSED, pool.state)
# Checking out a connection should fail
with self.assertRaises(_PoolClosedError):
async with pool.checkout():
pass
def create_test(scenario_def, test, name):
async def run_scenario(self):
await self.run_scenario(scenario_def, test)
return run_scenario
class CMAPSpecTestCreator(AsyncSpecTestCreator):
def tests(self, scenario_def):
"""Extract the tests from a spec file.
CMAP tests do not have a 'tests' field. The whole file represents
a single test case.
"""
return [scenario_def]
test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH)
test_creator.create_tests()
if __name__ == "__main__":
unittest.main()

View File

@ -739,7 +739,7 @@ class AsyncTestSpec(AsyncSpecRunner):
return errors
async def create_test(scenario_def, test, name):
def create_test(scenario_def, test, name):
@async_client_context.require_test_commands
async def run_scenario(self):
await self.run_scenario(scenario_def, test)

View File

@ -0,0 +1,602 @@
#
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the gridfs package."""
from __future__ import annotations
import asyncio
import datetime
import sys
import threading
import time
from io import BytesIO
from test.asynchronous.helpers import ConcurrentRunner
from unittest.mock import patch
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import async_joinall, one
import gridfs
from bson.binary import Binary
from gridfs.asynchronous.grid_file import DEFAULT_CHUNK_SIZE, AsyncGridOutCursor
from gridfs.errors import CorruptGridFile, FileExists, NoFile
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import (
ConfigurationError,
NotPrimaryError,
ServerSelectionTimeoutError,
)
from pymongo.read_preferences import ReadPreference
_IS_SYNC = False
class JustWrite(ConcurrentRunner):
def __init__(self, fs, n):
super().__init__()
self.fs = fs
self.n = n
self.daemon = True
async def run(self):
for _ in range(self.n):
file = self.fs.new_file(filename="test")
await file.write(b"hello")
await file.close()
class JustRead(ConcurrentRunner):
def __init__(self, fs, n, results):
super().__init__()
self.fs = fs
self.n = n
self.results = results
self.daemon = True
async def run(self):
for _ in range(self.n):
file = await self.fs.get("test")
data = await file.read()
self.results.append(data)
assert data == b"hello"
class TestGridfsNoConnect(unittest.IsolatedAsyncioTestCase):
db: AsyncDatabase
async def asyncSetUp(self):
await super().asyncSetUp()
self.db = AsyncMongoClient(connect=False).pymongo_test
async def test_gridfs(self):
self.assertRaises(TypeError, gridfs.AsyncGridFS, "foo")
self.assertRaises(TypeError, gridfs.AsyncGridFS, self.db, 5)
class TestGridfs(AsyncIntegrationTest):
fs: gridfs.AsyncGridFS
alt: gridfs.AsyncGridFS
async def asyncSetUp(self):
await super().asyncSetUp()
self.fs = gridfs.AsyncGridFS(self.db)
self.alt = gridfs.AsyncGridFS(self.db, "alt")
await self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
)
async def test_basic(self):
oid = await self.fs.put(b"hello world")
self.assertEqual(b"hello world", await (await self.fs.get(oid)).read())
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
await self.fs.delete(oid)
with self.assertRaises(NoFile):
await self.fs.get(oid)
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
with self.assertRaises(NoFile):
await self.fs.get("foo")
oid = await self.fs.put(b"hello world", _id="foo")
self.assertEqual("foo", oid)
self.assertEqual(b"hello world", await (await self.fs.get("foo")).read())
async def test_multi_chunk_delete(self):
await self.db.fs.drop()
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
gfs = gridfs.AsyncGridFS(self.db)
oid = await gfs.put(b"hello", chunkSize=1)
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
await gfs.delete(oid)
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
async def test_list(self):
self.assertEqual([], await self.fs.list())
await self.fs.put(b"hello world")
self.assertEqual([], await self.fs.list())
# PYTHON-598: in server versions before 2.5.x, creating an index on
# filename, uploadDate causes list() to include None.
await self.fs.get_last_version()
self.assertEqual([], await self.fs.list())
await self.fs.put(b"", filename="mike")
await self.fs.put(b"foo", filename="test")
await self.fs.put(b"", filename="hello world")
self.assertEqual({"mike", "test", "hello world"}, set(await self.fs.list()))
async def test_empty_file(self):
oid = await self.fs.put(b"")
self.assertEqual(b"", await (await self.fs.get(oid)).read())
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
raw = await self.db.fs.files.find_one()
assert raw is not None
self.assertEqual(0, raw["length"])
self.assertEqual(oid, raw["_id"])
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
self.assertEqual(255 * 1024, raw["chunkSize"])
self.assertNotIn("md5", raw)
async def test_corrupt_chunk(self):
files_id = await self.fs.put(b"foobar")
await self.db.fs.chunks.update_one(
{"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}
)
try:
out = await self.fs.get(files_id)
with self.assertRaises(CorruptGridFile):
await out.read()
out = await self.fs.get(files_id)
with self.assertRaises(CorruptGridFile):
await out.readline()
finally:
await self.fs.delete(files_id)
async def test_put_ensures_index(self):
chunks = self.db.fs.chunks
files = self.db.fs.files
# Ensure the collections are removed.
await chunks.drop()
await files.drop()
await self.fs.put(b"junk")
self.assertTrue(
any(
info.get("key") == [("files_id", 1), ("n", 1)]
for info in (await chunks.index_information()).values()
)
)
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in (await files.index_information()).values()
)
)
async def test_alt_collection(self):
oid = await self.alt.put(b"hello world")
self.assertEqual(b"hello world", await (await self.alt.get(oid)).read())
self.assertEqual(1, await self.db.alt.files.count_documents({}))
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
await self.alt.delete(oid)
with self.assertRaises(NoFile):
await self.alt.get(oid)
self.assertEqual(0, await self.db.alt.files.count_documents({}))
self.assertEqual(0, await self.db.alt.chunks.count_documents({}))
with self.assertRaises(NoFile):
await self.alt.get("foo")
oid = await self.alt.put(b"hello world", _id="foo")
self.assertEqual("foo", oid)
self.assertEqual(b"hello world", await (await self.alt.get("foo")).read())
await self.alt.put(b"", filename="mike")
await self.alt.put(b"foo", filename="test")
await self.alt.put(b"", filename="hello world")
self.assertEqual({"mike", "test", "hello world"}, set(await self.alt.list()))
async def test_threaded_reads(self):
await self.fs.put(b"hello", _id="test")
tasks = []
results: list = []
for i in range(10):
tasks.append(JustRead(self.fs, 10, results))
await tasks[i].start()
await async_joinall(tasks)
self.assertEqual(100 * [b"hello"], results)
async def test_threaded_writes(self):
tasks = []
for i in range(10):
tasks.append(JustWrite(self.fs, 10))
await tasks[i].start()
await async_joinall(tasks)
f = await self.fs.get_last_version("test")
self.assertEqual(await f.read(), b"hello")
# Should have created 100 versions of 'test' file
self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"}))
async def test_get_last_version(self):
one = await self.fs.put(b"foo", filename="test")
await asyncio.sleep(0.01)
two = self.fs.new_file(filename="test")
await two.write(b"bar")
await two.close()
await asyncio.sleep(0.01)
two = two._id
three = await self.fs.put(b"baz", filename="test")
self.assertEqual(b"baz", await (await self.fs.get_last_version("test")).read())
await self.fs.delete(three)
self.assertEqual(b"bar", await (await self.fs.get_last_version("test")).read())
await self.fs.delete(two)
self.assertEqual(b"foo", await (await self.fs.get_last_version("test")).read())
await self.fs.delete(one)
with self.assertRaises(NoFile):
await self.fs.get_last_version("test")
async def test_get_last_version_with_metadata(self):
one = await self.fs.put(b"foo", filename="test", author="author")
await asyncio.sleep(0.01)
two = await self.fs.put(b"bar", filename="test", author="author")
self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author")).read())
await self.fs.delete(two)
self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author")).read())
await self.fs.delete(one)
one = await self.fs.put(b"foo", filename="test", author="author1")
await asyncio.sleep(0.01)
two = await self.fs.put(b"bar", filename="test", author="author2")
self.assertEqual(b"foo", await (await self.fs.get_last_version(author="author1")).read())
self.assertEqual(b"bar", await (await self.fs.get_last_version(author="author2")).read())
self.assertEqual(b"bar", await (await self.fs.get_last_version(filename="test")).read())
with self.assertRaises(NoFile):
await self.fs.get_last_version(author="author3")
with self.assertRaises(NoFile):
await self.fs.get_last_version(filename="nottest", author="author1")
await self.fs.delete(one)
await self.fs.delete(two)
async def test_get_version(self):
await self.fs.put(b"foo", filename="test")
await asyncio.sleep(0.01)
await self.fs.put(b"bar", filename="test")
await asyncio.sleep(0.01)
await self.fs.put(b"baz", filename="test")
await asyncio.sleep(0.01)
self.assertEqual(b"foo", await (await self.fs.get_version("test", 0)).read())
self.assertEqual(b"bar", await (await self.fs.get_version("test", 1)).read())
self.assertEqual(b"baz", await (await self.fs.get_version("test", 2)).read())
self.assertEqual(b"baz", await (await self.fs.get_version("test", -1)).read())
self.assertEqual(b"bar", await (await self.fs.get_version("test", -2)).read())
self.assertEqual(b"foo", await (await self.fs.get_version("test", -3)).read())
with self.assertRaises(NoFile):
await self.fs.get_version("test", 3)
with self.assertRaises(NoFile):
await self.fs.get_version("test", -4)
async def test_get_version_with_metadata(self):
one = await self.fs.put(b"foo", filename="test", author="author1")
await asyncio.sleep(0.01)
two = await self.fs.put(b"bar", filename="test", author="author1")
await asyncio.sleep(0.01)
three = await self.fs.put(b"baz", filename="test", author="author2")
self.assertEqual(
b"foo",
await (await self.fs.get_version(filename="test", author="author1", version=-2)).read(),
)
self.assertEqual(
b"bar",
await (await self.fs.get_version(filename="test", author="author1", version=-1)).read(),
)
self.assertEqual(
b"foo",
await (await self.fs.get_version(filename="test", author="author1", version=0)).read(),
)
self.assertEqual(
b"bar",
await (await self.fs.get_version(filename="test", author="author1", version=1)).read(),
)
self.assertEqual(
b"baz",
await (await self.fs.get_version(filename="test", author="author2", version=0)).read(),
)
self.assertEqual(
b"baz", await (await self.fs.get_version(filename="test", version=-1)).read()
)
self.assertEqual(
b"baz", await (await self.fs.get_version(filename="test", version=2)).read()
)
with self.assertRaises(NoFile):
await self.fs.get_version(filename="test", author="author3")
with self.assertRaises(NoFile):
await self.fs.get_version(filename="test", author="author1", version=2)
await self.fs.delete(one)
await self.fs.delete(two)
await self.fs.delete(three)
async def test_put_filelike(self):
oid = await self.fs.put(BytesIO(b"hello world"), chunk_size=1)
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
self.assertEqual(b"hello world", await (await self.fs.get(oid)).read())
async def test_file_exists(self):
oid = await self.fs.put(b"hello")
with self.assertRaises(FileExists):
await self.fs.put(b"world", _id=oid)
one = self.fs.new_file(_id=123)
await one.write(b"some content")
await one.close()
# Attempt to upload a file with more chunks to the same _id.
with patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE):
two = self.fs.new_file(_id=123)
with self.assertRaises(FileExists):
await two.write(b"x" * DEFAULT_CHUNK_SIZE * 3)
# Original file is still readable (no extra chunks were uploaded).
self.assertEqual(await (await self.fs.get(123)).read(), b"some content")
two = self.fs.new_file(_id=123)
await two.write(b"some content")
with self.assertRaises(FileExists):
await two.close()
# Original file is still readable.
self.assertEqual(await (await self.fs.get(123)).read(), b"some content")
async def test_exists(self):
oid = await self.fs.put(b"hello")
self.assertTrue(await self.fs.exists(oid))
self.assertTrue(await self.fs.exists({"_id": oid}))
self.assertTrue(await self.fs.exists(_id=oid))
self.assertFalse(await self.fs.exists(filename="mike"))
self.assertFalse(await self.fs.exists("mike"))
oid = await self.fs.put(b"hello", filename="mike", foo=12)
self.assertTrue(await self.fs.exists(oid))
self.assertTrue(await self.fs.exists({"_id": oid}))
self.assertTrue(await self.fs.exists(_id=oid))
self.assertTrue(await self.fs.exists(filename="mike"))
self.assertTrue(await self.fs.exists({"filename": "mike"}))
self.assertTrue(await self.fs.exists(foo=12))
self.assertTrue(await self.fs.exists({"foo": 12}))
self.assertTrue(await self.fs.exists(foo={"$gt": 11}))
self.assertTrue(await self.fs.exists({"foo": {"$gt": 11}}))
self.assertFalse(await self.fs.exists(foo=13))
self.assertFalse(await self.fs.exists({"foo": 13}))
self.assertFalse(await self.fs.exists(foo={"$gt": 12}))
self.assertFalse(await self.fs.exists({"foo": {"$gt": 12}}))
async def test_put_unicode(self):
with self.assertRaises(TypeError):
await self.fs.put("hello")
oid = await self.fs.put("hello", encoding="utf-8")
self.assertEqual(b"hello", await (await self.fs.get(oid)).read())
self.assertEqual("utf-8", (await self.fs.get(oid)).encoding)
oid = await self.fs.put("", encoding="iso-8859-1")
self.assertEqual("".encode("iso-8859-1"), await (await self.fs.get(oid)).read())
self.assertEqual("iso-8859-1", (await self.fs.get(oid)).encoding)
async def test_missing_length_iter(self):
# Test fix that guards against PHP-237
await self.fs.put(b"", filename="empty")
doc = await self.db.fs.files.find_one({"filename": "empty"})
assert doc is not None
doc.pop("length")
await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
f = await self.fs.get_last_version(filename="empty")
async def iterate_file(grid_file):
async for _chunk in grid_file:
pass
return True
self.assertTrue(await iterate_file(f))
async def test_gridfs_lazy_connect(self):
client = await self.async_single_client(
"badhost", connect=False, serverSelectionTimeoutMS=10
)
db = client.db
gfs = gridfs.AsyncGridFS(db)
with self.assertRaises(ServerSelectionTimeoutError):
await gfs.list()
fs = gridfs.AsyncGridFS(db)
f = fs.new_file()
with self.assertRaises(ServerSelectionTimeoutError):
await f.close()
async def test_gridfs_find(self):
await self.fs.put(b"test2", filename="two")
await asyncio.sleep(0.01)
await self.fs.put(b"test2+", filename="two")
await asyncio.sleep(0.01)
await self.fs.put(b"test1", filename="one")
await asyncio.sleep(0.01)
await self.fs.put(b"test2++", filename="two")
files = self.db.fs.files
self.assertEqual(3, await files.count_documents({"filename": "two"}))
self.assertEqual(4, await files.count_documents({}))
cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2)
gout = await cursor.next()
self.assertEqual(b"test1", await gout.read())
await cursor.rewind()
gout = await cursor.next()
self.assertEqual(b"test1", await gout.read())
gout = await cursor.next()
self.assertEqual(b"test2+", await gout.read())
with self.assertRaises(StopAsyncIteration):
await cursor.__anext__()
await cursor.rewind()
items = await cursor.to_list()
self.assertEqual(len(items), 2)
await cursor.rewind()
items = await cursor.to_list(1)
self.assertEqual(len(items), 1)
await cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
async def test_delete_not_initialized(self):
# Creating a cursor with invalid arguments will not run __init__
# but will still call __del__.
cursor = AsyncGridOutCursor.__new__(AsyncGridOutCursor) # Skip calling __init__
with self.assertRaises(TypeError):
cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore
cursor.__del__() # no error
async def test_gridfs_find_one(self):
self.assertEqual(None, await self.fs.find_one())
id1 = await self.fs.put(b"test1", filename="file1")
res = await self.fs.find_one()
assert res is not None
self.assertEqual(b"test1", await res.read())
id2 = await self.fs.put(b"test2", filename="file2", meta="data")
res1 = await self.fs.find_one(id1)
assert res1 is not None
self.assertEqual(b"test1", await res1.read())
res2 = await self.fs.find_one(id2)
assert res2 is not None
self.assertEqual(b"test2", await res2.read())
res3 = await self.fs.find_one({"filename": "file1"})
assert res3 is not None
self.assertEqual(b"test1", await res3.read())
res4 = await self.fs.find_one(id2)
assert res4 is not None
self.assertEqual("data", res4.meta)
async def test_grid_in_non_int_chunksize(self):
# Lua, and perhaps other buggy AsyncGridFS clients, store size as a float.
data = b"data"
await self.fs.put(data, filename="f")
await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
self.assertEqual(data, await (await self.fs.get_version("f")).read())
async def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.AsyncGridFS((await self.async_rs_or_single_client(w=0)).pymongo_test)
async def test_md5(self):
gin = self.fs.new_file()
await gin.write(b"no md5 sum")
await gin.close()
self.assertIsNone(gin.md5)
gout = await self.fs.get(gin._id)
self.assertIsNone(gout.md5)
_id = await self.fs.put(b"still no md5 sum")
gout = await self.fs.get(_id)
self.assertIsNone(gout.md5)
class TestGridfsReplicaSet(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
async def asyncSetUp(self):
await super().asyncSetUp()
@classmethod
@async_client_context.require_connection
async def asyncTearDownClass(cls):
await async_client_context.client.drop_database("gfsreplica")
async def test_gridfs_replica_set(self):
rsc = await self.async_rs_client(
w=async_client_context.w, read_preference=ReadPreference.SECONDARY
)
fs = gridfs.AsyncGridFS(rsc.gfsreplica, "gfsreplicatest")
gin = fs.new_file()
self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY)
oid = await fs.put(b"foo")
content = await (await fs.get(oid)).read()
self.assertEqual(b"foo", content)
async def test_gridfs_secondary(self):
secondary_host, secondary_port = one(await self.client.secondaries)
secondary_connection = await self.async_single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
)
# Should detect it's connected to secondary and not attempt to
# create index
fs = gridfs.AsyncGridFS(secondary_connection.gfsreplica, "gfssecondarytest")
# This won't detect secondary, raises error
with self.assertRaises(NotPrimaryError):
await fs.put(b"foo")
async def test_gridfs_secondary_lazy(self):
# Should detect it's connected to secondary and not attempt to
# create index.
secondary_host, secondary_port = one(await self.client.secondaries)
client = await self.async_single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
)
# Still no connection.
fs = gridfs.AsyncGridFS(client.gfsreplica, "gfssecondarylazytest")
# Connects, doesn't create index.
with self.assertRaises(NoFile):
await fs.get_last_version()
with self.assertRaises(NotPrimaryError):
await fs.put("data", encoding="utf-8")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,574 @@
#
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the gridfs package."""
from __future__ import annotations
import asyncio
import datetime
import itertools
import sys
import threading
import time
from io import BytesIO
from test.asynchronous.helpers import ConcurrentRunner
from unittest.mock import patch
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import async_joinall, joinall, one
import gridfs
from bson.binary import Binary
from bson.int64 import Int64
from bson.objectid import ObjectId
from bson.son import SON
from gridfs.errors import CorruptGridFile, NoFile
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.errors import (
ConfigurationError,
NotPrimaryError,
ServerSelectionTimeoutError,
WriteConcernError,
)
from pymongo.read_preferences import ReadPreference
_IS_SYNC = False
class JustWrite(ConcurrentRunner):
def __init__(self, gfs, num):
super().__init__()
self.gfs = gfs
self.num = num
self.daemon = True
async def run(self):
for _ in range(self.num):
file = self.gfs.open_upload_stream("test")
await file.write(b"hello")
await file.close()
class JustRead(ConcurrentRunner):
def __init__(self, gfs, num, results):
super().__init__()
self.gfs = gfs
self.num = num
self.results = results
self.daemon = True
async def run(self):
for _ in range(self.num):
file = await self.gfs.open_download_stream_by_name("test")
data = await file.read()
self.results.append(data)
assert data == b"hello"
class TestGridfs(AsyncIntegrationTest):
fs: gridfs.AsyncGridFSBucket
alt: gridfs.AsyncGridFSBucket
async def asyncSetUp(self):
await super().asyncSetUp()
self.fs = gridfs.AsyncGridFSBucket(self.db)
self.alt = gridfs.AsyncGridFSBucket(self.db, bucket_name="alt")
await self.cleanup_colls(
self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks
)
async def test_basic(self):
oid = await self.fs.upload_from_stream("test_filename", b"hello world")
self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read())
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
await self.fs.delete(oid)
with self.assertRaises(NoFile):
await self.fs.open_download_stream(oid)
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
async def test_multi_chunk_delete(self):
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
gfs = gridfs.AsyncGridFSBucket(self.db)
oid = await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
await gfs.delete(oid)
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
async def test_empty_file(self):
oid = await self.fs.upload_from_stream("test_filename", b"")
self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read())
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
raw = await self.db.fs.files.find_one()
assert raw is not None
self.assertEqual(0, raw["length"])
self.assertEqual(oid, raw["_id"])
self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime))
self.assertEqual(255 * 1024, raw["chunkSize"])
self.assertNotIn("md5", raw)
async def test_corrupt_chunk(self):
files_id = await self.fs.upload_from_stream("test_filename", b"foobar")
await self.db.fs.chunks.update_one(
{"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}
)
try:
out = await self.fs.open_download_stream(files_id)
with self.assertRaises(CorruptGridFile):
await out.read()
out = await self.fs.open_download_stream(files_id)
with self.assertRaises(CorruptGridFile):
await out.readline()
finally:
await self.fs.delete(files_id)
async def test_upload_ensures_index(self):
chunks = self.db.fs.chunks
files = self.db.fs.files
# Ensure the collections are removed.
await chunks.drop()
await files.drop()
await self.fs.upload_from_stream("filename", b"junk")
self.assertTrue(
any(
info.get("key") == [("files_id", 1), ("n", 1)]
for info in (await chunks.index_information()).values()
)
)
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in (await files.index_information()).values()
)
)
async def test_ensure_index_shell_compat(self):
files = self.db.fs.files
for i, j in itertools.combinations_with_replacement([1, 1.0, Int64(1)], 2):
# Create the index with different numeric types (as might be done
# from the mongo shell).
shell_index = [("filename", i), ("uploadDate", j)]
await self.db.command(
"createIndexes",
files.name,
indexes=[{"key": SON(shell_index), "name": "filename_1.0_uploadDate_1.0"}],
)
# No error.
await self.fs.upload_from_stream("filename", b"data")
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in (await files.index_information()).values()
)
)
await files.drop()
async def test_alt_collection(self):
oid = await self.alt.upload_from_stream("test_filename", b"hello world")
self.assertEqual(b"hello world", await (await self.alt.open_download_stream(oid)).read())
self.assertEqual(1, await self.db.alt.files.count_documents({}))
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
await self.alt.delete(oid)
with self.assertRaises(NoFile):
await self.alt.open_download_stream(oid)
self.assertEqual(0, await self.db.alt.files.count_documents({}))
self.assertEqual(0, await self.db.alt.chunks.count_documents({}))
with self.assertRaises(NoFile):
await self.alt.open_download_stream("foo")
await self.alt.upload_from_stream("foo", b"hello world")
self.assertEqual(
b"hello world", await (await self.alt.open_download_stream_by_name("foo")).read()
)
await self.alt.upload_from_stream("mike", b"")
await self.alt.upload_from_stream("test", b"foo")
await self.alt.upload_from_stream("hello world", b"")
self.assertEqual(
{"mike", "test", "hello world", "foo"},
{k["filename"] for k in await self.db.alt.files.find().to_list()},
)
async def test_threaded_reads(self):
await self.fs.upload_from_stream("test", b"hello")
threads = []
results: list = []
for i in range(10):
threads.append(JustRead(self.fs, 10, results))
await threads[i].start()
await async_joinall(threads)
self.assertEqual(100 * [b"hello"], results)
async def test_threaded_writes(self):
threads = []
for i in range(10):
threads.append(JustWrite(self.fs, 10))
await threads[i].start()
await async_joinall(threads)
fstr = await self.fs.open_download_stream_by_name("test")
self.assertEqual(await fstr.read(), b"hello")
# Should have created 100 versions of 'test' file
self.assertEqual(100, await self.db.fs.files.count_documents({"filename": "test"}))
async def test_get_last_version(self):
one = await self.fs.upload_from_stream("test", b"foo")
await asyncio.sleep(0.01)
two = self.fs.open_upload_stream("test")
await two.write(b"bar")
await two.close()
await asyncio.sleep(0.01)
two = two._id
three = await self.fs.upload_from_stream("test", b"baz")
self.assertEqual(b"baz", await (await self.fs.open_download_stream_by_name("test")).read())
await self.fs.delete(three)
self.assertEqual(b"bar", await (await self.fs.open_download_stream_by_name("test")).read())
await self.fs.delete(two)
self.assertEqual(b"foo", await (await self.fs.open_download_stream_by_name("test")).read())
await self.fs.delete(one)
with self.assertRaises(NoFile):
await self.fs.open_download_stream_by_name("test")
async def test_get_version(self):
await self.fs.upload_from_stream("test", b"foo")
await asyncio.sleep(0.01)
await self.fs.upload_from_stream("test", b"bar")
await asyncio.sleep(0.01)
await self.fs.upload_from_stream("test", b"baz")
await asyncio.sleep(0.01)
self.assertEqual(
b"foo", await (await self.fs.open_download_stream_by_name("test", revision=0)).read()
)
self.assertEqual(
b"bar", await (await self.fs.open_download_stream_by_name("test", revision=1)).read()
)
self.assertEqual(
b"baz", await (await self.fs.open_download_stream_by_name("test", revision=2)).read()
)
self.assertEqual(
b"baz", await (await self.fs.open_download_stream_by_name("test", revision=-1)).read()
)
self.assertEqual(
b"bar", await (await self.fs.open_download_stream_by_name("test", revision=-2)).read()
)
self.assertEqual(
b"foo", await (await self.fs.open_download_stream_by_name("test", revision=-3)).read()
)
with self.assertRaises(NoFile):
await self.fs.open_download_stream_by_name("test", revision=3)
with self.assertRaises(NoFile):
await self.fs.open_download_stream_by_name("test", revision=-4)
async def test_upload_from_stream(self):
oid = await self.fs.upload_from_stream(
"test_file", BytesIO(b"hello world"), chunk_size_bytes=1
)
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
self.assertEqual(b"hello world", await (await self.fs.open_download_stream(oid)).read())
async def test_upload_from_stream_with_id(self):
oid = ObjectId()
await self.fs.upload_from_stream_with_id(
oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1
)
self.assertEqual(b"custom id", await (await self.fs.open_download_stream(oid)).read())
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3)
@async_client_context.require_failCommand_fail_point
async def test_upload_bulk_write_error(self):
# Test BulkWriteError from insert_many is converted to an insert_one style error.
expected_wce = {
"code": 100,
"codeName": "UnsatisfiableWriteConcern",
"errmsg": "Not enough data-bearing nodes",
}
cause_wce = {
"configureFailPoint": "failCommand",
"mode": {"times": 2},
"data": {"failCommands": ["insert"], "writeConcernError": expected_wce},
}
gin = self.fs.open_upload_stream("test_file", chunk_size_bytes=1)
async with self.fail_point(cause_wce):
# Assert we raise WriteConcernError, not BulkWriteError.
with self.assertRaises(WriteConcernError):
await gin.write(b"hello world")
# 3 chunks were uploaded.
self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
await gin.abort()
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 10)
async def test_upload_batching(self):
async with self.fs.open_upload_stream("test_file", chunk_size_bytes=1) as gin:
await gin.write(b"s" * (10 - 1))
# No chunks were uploaded yet.
self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
await gin.write(b"s")
# All chunks were uploaded since we hit the _UPLOAD_BUFFER_CHUNKS limit.
self.assertEqual(10, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
async def test_open_upload_stream(self):
gin = self.fs.open_upload_stream("from_stream")
await gin.write(b"from stream")
await gin.close()
self.assertEqual(b"from stream", await (await self.fs.open_download_stream(gin._id)).read())
async def test_open_upload_stream_with_id(self):
oid = ObjectId()
gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id")
await gin.write(b"from stream with custom id")
await gin.close()
self.assertEqual(
b"from stream with custom id", await (await self.fs.open_download_stream(oid)).read()
)
async def test_missing_length_iter(self):
# Test fix that guards against PHP-237
await self.fs.upload_from_stream("empty", b"")
doc = await self.db.fs.files.find_one({"filename": "empty"})
assert doc is not None
doc.pop("length")
await self.db.fs.files.replace_one({"_id": doc["_id"]}, doc)
fstr = await self.fs.open_download_stream_by_name("empty")
async def iterate_file(grid_file):
async for _ in grid_file:
pass
return True
self.assertTrue(await iterate_file(fstr))
async def test_gridfs_lazy_connect(self):
client = await self.async_single_client(
"badhost", connect=False, serverSelectionTimeoutMS=0
)
cdb = client.db
gfs = gridfs.AsyncGridFSBucket(cdb)
with self.assertRaises(ServerSelectionTimeoutError):
await gfs.delete(0)
gfs = gridfs.AsyncGridFSBucket(cdb)
with self.assertRaises(ServerSelectionTimeoutError):
await gfs.upload_from_stream("test", b"") # Still no connection.
async def test_gridfs_find(self):
await self.fs.upload_from_stream("two", b"test2")
await asyncio.sleep(0.01)
await self.fs.upload_from_stream("two", b"test2+")
await asyncio.sleep(0.01)
await self.fs.upload_from_stream("one", b"test1")
await asyncio.sleep(0.01)
await self.fs.upload_from_stream("two", b"test2++")
files = self.db.fs.files
self.assertEqual(3, await files.count_documents({"filename": "two"}))
self.assertEqual(4, await files.count_documents({}))
cursor = self.fs.find(
{}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2
)
gout = await cursor.next()
self.assertEqual(b"test1", await gout.read())
await cursor.rewind()
gout = await cursor.next()
self.assertEqual(b"test1", await gout.read())
gout = await cursor.next()
self.assertEqual(b"test2+", await gout.read())
with self.assertRaises(StopAsyncIteration):
await cursor.next()
await cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
async def test_grid_in_non_int_chunksize(self):
# Lua, and perhaps other buggy AsyncGridFS clients, store size as a float.
data = b"data"
await self.fs.upload_from_stream("f", data)
await self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
self.assertEqual(data, await (await self.fs.open_download_stream_by_name("f")).read())
async def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.AsyncGridFSBucket((await self.async_rs_or_single_client(w=0)).pymongo_test)
async def test_rename(self):
_id = await self.fs.upload_from_stream("first_name", b"testing")
self.assertEqual(
b"testing", await (await self.fs.open_download_stream_by_name("first_name")).read()
)
await self.fs.rename(_id, "second_name")
with self.assertRaises(NoFile):
await self.fs.open_download_stream_by_name("first_name")
self.assertEqual(
b"testing", await (await self.fs.open_download_stream_by_name("second_name")).read()
)
@patch("gridfs.asynchronous.grid_file._UPLOAD_BUFFER_SIZE", 5)
async def test_abort(self):
gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5)
await gin.write(b"test1")
await gin.write(b"test2")
await gin.write(b"test3")
self.assertEqual(3, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
await gin.abort()
self.assertTrue(gin.closed)
with self.assertRaises(ValueError):
await gin.write(b"test4")
self.assertEqual(0, await self.db.fs.chunks.count_documents({"files_id": gin._id}))
async def test_download_to_stream(self):
file1 = BytesIO(b"hello world")
# Test with one chunk.
oid = await self.fs.upload_from_stream("one_chunk", file1)
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
file2 = BytesIO()
await self.fs.download_to_stream(oid, file2)
file1.seek(0)
file2.seek(0)
self.assertEqual(file1.read(), file2.read())
# Test with many chunks.
await self.db.drop_collection("fs.files")
await self.db.drop_collection("fs.chunks")
file1.seek(0)
oid = await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1)
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
file2 = BytesIO()
await self.fs.download_to_stream(oid, file2)
file1.seek(0)
file2.seek(0)
self.assertEqual(file1.read(), file2.read())
async def test_download_to_stream_by_name(self):
file1 = BytesIO(b"hello world")
# Test with one chunk.
_ = await self.fs.upload_from_stream("one_chunk", file1)
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
file2 = BytesIO()
await self.fs.download_to_stream_by_name("one_chunk", file2)
file1.seek(0)
file2.seek(0)
self.assertEqual(file1.read(), file2.read())
# Test with many chunks.
await self.db.drop_collection("fs.files")
await self.db.drop_collection("fs.chunks")
file1.seek(0)
await self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1)
self.assertEqual(11, await self.db.fs.chunks.count_documents({}))
file2 = BytesIO()
await self.fs.download_to_stream_by_name("many_chunks", file2)
file1.seek(0)
file2.seek(0)
self.assertEqual(file1.read(), file2.read())
async def test_md5(self):
gin = self.fs.open_upload_stream("no md5")
await gin.write(b"no md5 sum")
await gin.close()
self.assertIsNone(gin.md5)
gout = await self.fs.open_download_stream(gin._id)
self.assertIsNone(gout.md5)
gin = self.fs.open_upload_stream_with_id(ObjectId(), "also no md5")
await gin.write(b"also no md5 sum")
await gin.close()
self.assertIsNone(gin.md5)
gout = await self.fs.open_download_stream(gin._id)
self.assertIsNone(gout.md5)
class TestGridfsBucketReplicaSet(AsyncIntegrationTest):
@async_client_context.require_secondaries_count(1)
async def asyncSetUp(self):
await super().asyncSetUp()
@classmethod
@async_client_context.require_connection
async def asyncTearDownClass(cls):
await async_client_context.client.drop_database("gfsbucketreplica")
async def test_gridfs_replica_set(self):
rsc = await self.async_rs_client(
w=async_client_context.w, read_preference=ReadPreference.SECONDARY
)
gfs = gridfs.AsyncGridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
oid = await gfs.upload_from_stream("test_filename", b"foo")
content = await (await gfs.open_download_stream(oid)).read()
self.assertEqual(b"foo", content)
async def test_gridfs_secondary(self):
secondary_host, secondary_port = one(await self.client.secondaries)
secondary_connection = await self.async_single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY
)
# Should detect it's connected to secondary and not attempt to
# create index
gfs = gridfs.AsyncGridFSBucket(
secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest"
)
# This won't detect secondary, raises error
with self.assertRaises(NotPrimaryError):
await gfs.upload_from_stream("test_filename", b"foo")
async def test_gridfs_secondary_lazy(self):
# Should detect it's connected to secondary and not attempt to
# create index.
secondary_host, secondary_port = one(await self.client.secondaries)
client = await self.async_single_client(
secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False
)
# Still no connection.
gfs = gridfs.AsyncGridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest")
# Connects, doesn't create index.
with self.assertRaises(NoFile):
await gfs.open_download_stream_by_name("test_filename")
with self.assertRaises(NotPrimaryError):
await gfs.upload_from_stream("test_filename", b"data")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,211 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from pymongo import AsyncMongoClient, ReadPreference
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.hello import HelloCompat
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
from pymongo.typings import strip_optional
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.utils_selection_tests import (
create_selection_tests,
get_addresses,
get_topology_settings_dict,
make_server_description,
)
from test.utils import (
EventListener,
FunctionCallRecorder,
OvertCommandListener,
async_wait_until,
)
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent, "server_selection", "server_selection"
)
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "server_selection"
)
class SelectionStoreSelector:
"""No-op selector that keeps track of what was passed to it."""
def __init__(self):
self.selection = None
def __call__(self, selection):
self.selection = selection
return selection
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
pass
class TestCustomServerSelectorFunction(AsyncIntegrationTest):
@async_client_context.require_replica_set
async def test_functional_select_max_port_number_host(self):
# Selector that returns server with highest port number.
def custom_selector(servers):
ports = [s.address[1] for s in servers]
idx = ports.index(max(ports))
return [servers[idx]]
# Initialize client with appropriate listeners.
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
server_selector=custom_selector, event_listeners=[listener]
)
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
self.addAsyncCleanup(client.drop_database, "testdb")
# Wait the node list to be fully populated.
async def all_hosts_started():
return len((await client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len(
client._topology._description.readable_servers
)
await async_wait_until(all_hosts_started, "receive heartbeat from all hosts")
expected_port = max(
[strip_optional(n.address[1]) for n in client._topology._description.readable_servers]
)
# Insert 1 record and access it 10 times.
await coll.insert_one({"name": "John Doe"})
for _ in range(10):
await coll.find_one({"name": "John Doe"})
# Confirm all find commands are run against appropriate host.
for command in listener.started_events:
if command.command_name == "find":
self.assertEqual(command.connection_id[1], expected_port)
async def test_invalid_server_selector(self):
# Client initialization must fail if server_selector is not callable.
for selector_candidate in [[], 10, "string", {}]:
with self.assertRaisesRegex(ValueError, "must be a callable"):
AsyncMongoClient(connect=False, server_selector=selector_candidate)
# None value for server_selector is OK.
AsyncMongoClient(connect=False, server_selector=None)
@async_client_context.require_replica_set
async def test_selector_called(self):
selector = FunctionCallRecorder(lambda x: x)
# Client setup.
mongo_client = await self.async_rs_or_single_client(server_selector=selector)
test_collection = mongo_client.testdb.test_collection
self.addAsyncCleanup(mongo_client.drop_database, "testdb")
# Do N operations and test selector is called at least N times.
await test_collection.insert_one({"age": 20, "name": "John"})
await test_collection.insert_one({"age": 31, "name": "Jane"})
await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}})
await test_collection.find_one({"name": "Roe"})
self.assertGreaterEqual(selector.call_count, 4)
@async_client_context.require_replica_set
async def test_latency_threshold_application(self):
selector = SelectionStoreSelector()
scenario_def: dict = {
"topology_description": {
"type": "ReplicaSetWithPrimary",
"servers": [
{"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}},
{"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}},
{"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}},
],
}
}
# Create & populate Topology such that all but one server is too slow.
rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]]
min_rtt_idx = rtt_times.index(min(rtt_times))
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
settings = get_topology_settings_dict(
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector
)
topology = Topology(TopologySettings(**settings))
await topology.open()
for server in scenario_def["topology_description"]["servers"]:
server_description = make_server_description(server, hosts)
await topology.on_change(server_description)
# Invoke server selection and assert no filtering based on latency
# prior to custom server selection logic kicking in.
server = await topology.select_server(ReadPreference.NEAREST, _Op.TEST)
assert selector.selection is not None
self.assertEqual(len(selector.selection), len(topology.description.server_descriptions()))
# Ensure proper filtering based on latency after custom selection.
self.assertEqual(server.description.address, seeds[min_rtt_idx])
@async_client_context.require_replica_set
async def test_server_selector_bypassed(self):
selector = FunctionCallRecorder(lambda x: x)
scenario_def = {
"topology_description": {
"type": "ReplicaSetNoPrimary",
"servers": [
{"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}},
{"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}},
{"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}},
],
}
}
# Create & populate Topology such that no server is writeable.
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
settings = get_topology_settings_dict(
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector
)
topology = Topology(TopologySettings(**settings))
await topology.open()
for server in scenario_def["topology_description"]["servers"]:
server_description = make_server_description(server, hosts)
await topology.on_change(server_description)
# Invoke server selection and assert no calls to our custom selector.
with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"):
await topology.select_server(
writable_server_selector, _Op.TEST, server_selection_timeout=0.1
)
self.assertEqual(selector.call_count, 0)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,179 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations
import asyncio
import os
import threading
from pathlib import Path
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.helpers import ConcurrentRunner
from test.asynchronous.utils_selection_tests import create_topology
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator
from test.utils import (
CMAPListener,
OvertCommandListener,
async_get_pool,
async_wait_until,
)
from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
)
class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
async def run_scenario(self, scenario_def):
topology = await create_topology(scenario_def)
# Update mock operation_count state:
for mock in scenario_def["mocked_topology_state"]:
address = clean_node(mock["address"])
server = topology.get_server_by_address(address)
server.pool.operation_count = mock["operation_count"]
pref = ReadPreference.NEAREST
counts = {address: 0 for address in topology._description.server_descriptions()}
# Number of times to repeat server selection
iterations = scenario_def["iterations"]
for _ in range(iterations):
server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
counts[server.description.address] += 1
# Verify expected_frequencies
outcome = scenario_def["outcome"]
tolerance = outcome["tolerance"]
expected_frequencies = outcome["expected_frequencies"]
for host_str, freq in expected_frequencies.items():
address = clean_node(host_str)
actual_freq = float(counts[address]) / iterations
if freq == 0:
# Should be exactly 0.
self.assertEqual(actual_freq, 0)
else:
# Should be within 'tolerance'.
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)
def create_test(scenario_def, test, name):
async def run_scenario(self):
await self.run_scenario(scenario_def)
return run_scenario
class CustomSpecTestCreator(AsyncSpecTestCreator):
def tests(self, scenario_def):
"""Extract the tests from a spec file.
Server selection in_window tests do not have a 'tests' field.
The whole file represents a single test case.
"""
return [scenario_def]
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
class FinderTask(ConcurrentRunner):
def __init__(self, collection, iterations):
super().__init__()
self.daemon = True
self.collection = collection
self.iterations = iterations
self.passed = False
async def run(self):
for _ in range(self.iterations):
await self.collection.find_one({})
self.passed = True
class TestProse(AsyncIntegrationTest):
async def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_TASKS = 10
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
await task.start()
for task in tasks:
await task.join()
for task in tasks:
self.assertTrue(task.passed)
events = listener.started_events
self.assertEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}
for event in events:
freqs[event.connection_id] += 1
for address in freqs:
freqs[address] = freqs[address] / float(len(events))
return freqs
@async_client_context.require_failCommand_appName
@async_client_context.require_multiple_mongoses
async def test_load_balancing(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = await self.async_rs_client(
async_client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
# Wait for both pools to be populated.
await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20)
# Delay find commands on only one mongos.
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"times": 10000},
"data": {
"failCommands": ["find"],
"blockConnection": True,
"blockTimeMS": 500,
"appName": "loadBalancingTest",
},
}
async with self.fail_point(delay_finds):
nodes = async_client_context.client.nodes
self.assertEqual(len(nodes), 1)
delayed_server = next(iter(nodes))
freqs = await self.frequencies(client, listener)
self.assertLessEqual(freqs[delayed_server], 0.25)
listener.reset()
freqs = await self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,203 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for testing Server Selection and Max Staleness."""
from __future__ import annotations
import datetime
import os
import sys
from test.asynchronous import AsyncPyMongoTestCase
sys.path[0:0] = [""]
from test import unittest
from test.pymongo_mocks import DummyMonitor
from test.utils import AsyncMockPool, parse_read_preference
from test.utils_selection_tests_shared import (
get_addresses,
get_topology_type_name,
make_server_description,
)
from bson import json_util
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
from pymongo.common import HEARTBEAT_FREQUENCY
from pymongo.errors import AutoReconnect, ConfigurationError
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
_IS_SYNC = False
def get_topology_settings_dict(**kwargs):
settings = {
"monitor_class": DummyMonitor,
"heartbeat_frequency": HEARTBEAT_FREQUENCY,
"pool_class": AsyncMockPool,
}
settings.update(kwargs)
return settings
async def create_topology(scenario_def, **kwargs):
# Initialize topologies.
if "heartbeatFrequencyMS" in scenario_def:
frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0
else:
frequency = HEARTBEAT_FREQUENCY
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
topology_type = get_topology_type_name(scenario_def)
if topology_type == "LoadBalanced":
kwargs.setdefault("load_balanced", True)
# Force topology description to ReplicaSet
elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]:
kwargs.setdefault("replica_set_name", "rs")
settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs)
# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.
topology = Topology(TopologySettings(**settings))
await topology.open()
# Update topologies with server descriptions.
for server in scenario_def["topology_description"]["servers"]:
server_description = make_server_description(server, hosts)
await topology.on_change(server_description)
# Assert that descriptions match
assert (
scenario_def["topology_description"]["type"] == topology.description.topology_type_name
), topology.description.topology_type_name
return topology
def create_test(scenario_def):
async def run_scenario(self):
_, hosts = get_addresses(scenario_def["topology_description"]["servers"])
# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.
top_latency = await create_topology(scenario_def)
# "In latency window" is defined in the server selection
# spec as the subset of suitable_servers that falls within the
# allowable latency window.
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
# Create server selector.
if scenario_def.get("operation") == "write":
pref = writable_server_selector
else:
# Make first letter lowercase to match read_pref's modes.
pref_def = scenario_def["read_preference"]
if scenario_def.get("error"):
with self.assertRaises((ConfigurationError, ValueError)):
# Error can be raised when making Read Pref or selecting.
pref = parse_read_preference(pref_def)
await top_latency.select_server(pref, _Op.TEST)
return
pref = parse_read_preference(pref_def)
# Select servers.
if not scenario_def.get("suitable_servers"):
with self.assertRaises(AutoReconnect):
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
return
if not scenario_def["in_latency_window"]:
with self.assertRaises(AutoReconnect):
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
return
actual_suitable_s = await top_suitable.select_servers(
pref, _Op.TEST, server_selection_timeout=0
)
actual_latency_s = await top_latency.select_servers(
pref, _Op.TEST, server_selection_timeout=0
)
expected_suitable_servers = {}
for server in scenario_def["suitable_servers"]:
server_description = make_server_description(server, hosts)
expected_suitable_servers[server["address"]] = server_description
actual_suitable_servers = {}
for s in actual_suitable_s:
actual_suitable_servers[
"%s:%d" % (s.description.address[0], s.description.address[1])
] = s.description
self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers))
for k, actual in actual_suitable_servers.items():
expected = expected_suitable_servers[k]
self.assertEqual(expected.address, actual.address)
self.assertEqual(expected.server_type, actual.server_type)
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
self.assertEqual(expected.tags, actual.tags)
self.assertEqual(expected.all_hosts, actual.all_hosts)
expected_latency_servers = {}
for server in scenario_def["in_latency_window"]:
server_description = make_server_description(server, hosts)
expected_latency_servers[server["address"]] = server_description
actual_latency_servers = {}
for s in actual_latency_s:
actual_latency_servers[
"%s:%d" % (s.description.address[0], s.description.address[1])
] = s.description
self.assertEqual(len(actual_latency_servers), len(expected_latency_servers))
for k, actual in actual_latency_servers.items():
expected = expected_latency_servers[k]
self.assertEqual(expected.address, actual.address)
self.assertEqual(expected.server_type, actual.server_type)
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
self.assertEqual(expected.tags, actual.tags)
self.assertEqual(expected.all_hosts, actual.all_hosts)
return run_scenario
def create_selection_tests(test_dir):
class TestAllScenarios(AsyncPyMongoTestCase):
pass
for dirpath, _, filenames in os.walk(test_dir):
dirname = os.path.split(dirpath)
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]
for filename in filenames:
if os.path.splitext(filename)[1] != ".json":
continue
with open(os.path.join(dirpath, filename)) as scenario_stream:
scenario_def = json_util.loads(scenario_stream.read())
# Construct test from scenario.
new_test = create_test(scenario_def)
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
new_test.__name__ = test_name
setattr(TestAllScenarios, new_test.__name__, new_test)
return TestAllScenarios

View File

@ -229,7 +229,7 @@ class AsyncSpecTestCreator:
str(test_def["description"].replace(" ", "_").replace(".", "_")),
)
new_test = await self._create_test(scenario_def, test_def, test_name)
new_test = self._create_test(scenario_def, test_def, test_name)
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
new_test = self.ensure_run_on(scenario_def, new_test)

View File

@ -959,7 +959,6 @@ class TestBulkWriteConcern(BulkTestBase):
@client_context.require_replica_set
@client_context.require_secondaries_count(1)
def test_write_concern_failure_ordered(self):
self.skipTest("Skipping until PYTHON-4865 is resolved.")
details = None
# Ensure we don't raise on wnote.

View File

@ -15,9 +15,11 @@
"""Execute Transactions Spec tests."""
from __future__ import annotations
import asyncio
import os
import sys
import time
from pathlib import Path
sys.path[0:0] = [""]
@ -60,6 +62,8 @@ from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.pool import PoolState, _PoolClosedError
from pymongo.topology_description import updated_topology_description
_IS_SYNC = True
OBJECT_TYPES = {
# Event types.
"ConnectionCheckedIn": ConnectionCheckedInEvent,
@ -81,7 +85,10 @@ OBJECT_TYPES = {
class TestCMAP(IntegrationTest):
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring")
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring")
else:
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring")
# Test operations:
@ -258,7 +265,6 @@ class TestCMAP(IntegrationTest):
client._topology.open()
else:
client._get_topology()
self.addCleanup(client.close)
self.pool = list(client._topology._servers.values())[0].pool
# Map of target names to Thread objects.
@ -315,13 +321,11 @@ class TestCMAP(IntegrationTest):
#
def test_1_client_connection_pool_options(self):
client = self.rs_or_single_client(**self.POOL_OPTIONS)
self.addCleanup(client.close)
pool_opts = get_pool(client).opts
pool_opts = (get_pool(client)).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
def test_2_all_client_pools_have_same_options(self):
client = self.rs_or_single_client(**self.POOL_OPTIONS)
self.addCleanup(client.close)
client.admin.command("ping")
# Discover at least one secondary.
if client_context.has_secondaries:
@ -337,14 +341,12 @@ class TestCMAP(IntegrationTest):
opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()])
uri = f"mongodb://{client_context.pair}/?{opts}"
client = self.rs_or_single_client(uri)
self.addCleanup(client.close)
pool_opts = get_pool(client).opts
pool_opts = (get_pool(client)).opts
self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS)
def test_4_subscribe_to_events(self):
listener = CMAPListener()
client = self.single_client(event_listeners=[listener])
self.addCleanup(client.close)
self.assertEqual(listener.event_count(PoolCreatedEvent), 1)
# Creates a new connection.
@ -368,7 +370,6 @@ class TestCMAP(IntegrationTest):
def test_5_check_out_fails_connection_error(self):
listener = CMAPListener()
client = self.single_client(event_listeners=[listener])
self.addCleanup(client.close)
pool = get_pool(client)
def mock_connect(*args, **kwargs):
@ -397,7 +398,6 @@ class TestCMAP(IntegrationTest):
client = self.single_client_noauth(
username="notauser", password="fail", event_listeners=[listener]
)
self.addCleanup(client.close)
# Attempt to create a new connection.
with self.assertRaisesRegex(OperationFailure, "failed"):

View File

@ -16,11 +16,13 @@
"""Tests for the gridfs package."""
from __future__ import annotations
import asyncio
import datetime
import sys
import threading
import time
from io import BytesIO
from test.helpers import ConcurrentRunner
from unittest.mock import patch
sys.path[0:0] = [""]
@ -41,10 +43,12 @@ from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
class JustWrite(threading.Thread):
class JustWrite(ConcurrentRunner):
def __init__(self, fs, n):
threading.Thread.__init__(self)
super().__init__()
self.fs = fs
self.n = n
self.daemon = True
@ -56,9 +60,9 @@ class JustWrite(threading.Thread):
file.close()
class JustRead(threading.Thread):
class JustRead(ConcurrentRunner):
def __init__(self, fs, n, results):
threading.Thread.__init__(self)
super().__init__()
self.fs = fs
self.n = n
self.results = results
@ -98,19 +102,21 @@ class TestGridfs(IntegrationTest):
def test_basic(self):
oid = self.fs.put(b"hello world")
self.assertEqual(b"hello world", self.fs.get(oid).read())
self.assertEqual(b"hello world", (self.fs.get(oid)).read())
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
self.fs.delete(oid)
self.assertRaises(NoFile, self.fs.get, oid)
with self.assertRaises(NoFile):
self.fs.get(oid)
self.assertEqual(0, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
self.assertRaises(NoFile, self.fs.get, "foo")
with self.assertRaises(NoFile):
self.fs.get("foo")
oid = self.fs.put(b"hello world", _id="foo")
self.assertEqual("foo", oid)
self.assertEqual(b"hello world", self.fs.get("foo").read())
self.assertEqual(b"hello world", (self.fs.get("foo")).read())
def test_multi_chunk_delete(self):
self.db.fs.drop()
@ -142,7 +148,7 @@ class TestGridfs(IntegrationTest):
def test_empty_file(self):
oid = self.fs.put(b"")
self.assertEqual(b"", self.fs.get(oid).read())
self.assertEqual(b"", (self.fs.get(oid)).read())
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
@ -159,10 +165,12 @@ class TestGridfs(IntegrationTest):
self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}})
try:
out = self.fs.get(files_id)
self.assertRaises(CorruptGridFile, out.read)
with self.assertRaises(CorruptGridFile):
out.read()
out = self.fs.get(files_id)
self.assertRaises(CorruptGridFile, out.readline)
with self.assertRaises(CorruptGridFile):
out.readline()
finally:
self.fs.delete(files_id)
@ -177,31 +185,33 @@ class TestGridfs(IntegrationTest):
self.assertTrue(
any(
info.get("key") == [("files_id", 1), ("n", 1)]
for info in chunks.index_information().values()
for info in (chunks.index_information()).values()
)
)
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in files.index_information().values()
for info in (files.index_information()).values()
)
)
def test_alt_collection(self):
oid = self.alt.put(b"hello world")
self.assertEqual(b"hello world", self.alt.get(oid).read())
self.assertEqual(b"hello world", (self.alt.get(oid)).read())
self.assertEqual(1, self.db.alt.files.count_documents({}))
self.assertEqual(1, self.db.alt.chunks.count_documents({}))
self.alt.delete(oid)
self.assertRaises(NoFile, self.alt.get, oid)
with self.assertRaises(NoFile):
self.alt.get(oid)
self.assertEqual(0, self.db.alt.files.count_documents({}))
self.assertEqual(0, self.db.alt.chunks.count_documents({}))
self.assertRaises(NoFile, self.alt.get, "foo")
with self.assertRaises(NoFile):
self.alt.get("foo")
oid = self.alt.put(b"hello world", _id="foo")
self.assertEqual("foo", oid)
self.assertEqual(b"hello world", self.alt.get("foo").read())
self.assertEqual(b"hello world", (self.alt.get("foo")).read())
self.alt.put(b"", filename="mike")
self.alt.put(b"foo", filename="test")
@ -212,23 +222,23 @@ class TestGridfs(IntegrationTest):
def test_threaded_reads(self):
self.fs.put(b"hello", _id="test")
threads = []
tasks = []
results: list = []
for i in range(10):
threads.append(JustRead(self.fs, 10, results))
threads[i].start()
tasks.append(JustRead(self.fs, 10, results))
tasks[i].start()
joinall(threads)
joinall(tasks)
self.assertEqual(100 * [b"hello"], results)
def test_threaded_writes(self):
threads = []
tasks = []
for i in range(10):
threads.append(JustWrite(self.fs, 10))
threads[i].start()
tasks.append(JustWrite(self.fs, 10))
tasks[i].start()
joinall(threads)
joinall(tasks)
f = self.fs.get_last_version("test")
self.assertEqual(f.read(), b"hello")
@ -246,34 +256,37 @@ class TestGridfs(IntegrationTest):
two = two._id
three = self.fs.put(b"baz", filename="test")
self.assertEqual(b"baz", self.fs.get_last_version("test").read())
self.assertEqual(b"baz", (self.fs.get_last_version("test")).read())
self.fs.delete(three)
self.assertEqual(b"bar", self.fs.get_last_version("test").read())
self.assertEqual(b"bar", (self.fs.get_last_version("test")).read())
self.fs.delete(two)
self.assertEqual(b"foo", self.fs.get_last_version("test").read())
self.assertEqual(b"foo", (self.fs.get_last_version("test")).read())
self.fs.delete(one)
self.assertRaises(NoFile, self.fs.get_last_version, "test")
with self.assertRaises(NoFile):
self.fs.get_last_version("test")
def test_get_last_version_with_metadata(self):
one = self.fs.put(b"foo", filename="test", author="author")
time.sleep(0.01)
two = self.fs.put(b"bar", filename="test", author="author")
self.assertEqual(b"bar", self.fs.get_last_version(author="author").read())
self.assertEqual(b"bar", (self.fs.get_last_version(author="author")).read())
self.fs.delete(two)
self.assertEqual(b"foo", self.fs.get_last_version(author="author").read())
self.assertEqual(b"foo", (self.fs.get_last_version(author="author")).read())
self.fs.delete(one)
one = self.fs.put(b"foo", filename="test", author="author1")
time.sleep(0.01)
two = self.fs.put(b"bar", filename="test", author="author2")
self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read())
self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read())
self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read())
self.assertEqual(b"foo", (self.fs.get_last_version(author="author1")).read())
self.assertEqual(b"bar", (self.fs.get_last_version(author="author2")).read())
self.assertEqual(b"bar", (self.fs.get_last_version(filename="test")).read())
self.assertRaises(NoFile, self.fs.get_last_version, author="author3")
self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1")
with self.assertRaises(NoFile):
self.fs.get_last_version(author="author3")
with self.assertRaises(NoFile):
self.fs.get_last_version(filename="nottest", author="author1")
self.fs.delete(one)
self.fs.delete(two)
@ -286,16 +299,18 @@ class TestGridfs(IntegrationTest):
self.fs.put(b"baz", filename="test")
time.sleep(0.01)
self.assertEqual(b"foo", self.fs.get_version("test", 0).read())
self.assertEqual(b"bar", self.fs.get_version("test", 1).read())
self.assertEqual(b"baz", self.fs.get_version("test", 2).read())
self.assertEqual(b"foo", (self.fs.get_version("test", 0)).read())
self.assertEqual(b"bar", (self.fs.get_version("test", 1)).read())
self.assertEqual(b"baz", (self.fs.get_version("test", 2)).read())
self.assertEqual(b"baz", self.fs.get_version("test", -1).read())
self.assertEqual(b"bar", self.fs.get_version("test", -2).read())
self.assertEqual(b"foo", self.fs.get_version("test", -3).read())
self.assertEqual(b"baz", (self.fs.get_version("test", -1)).read())
self.assertEqual(b"bar", (self.fs.get_version("test", -2)).read())
self.assertEqual(b"foo", (self.fs.get_version("test", -3)).read())
self.assertRaises(NoFile, self.fs.get_version, "test", 3)
self.assertRaises(NoFile, self.fs.get_version, "test", -4)
with self.assertRaises(NoFile):
self.fs.get_version("test", 3)
with self.assertRaises(NoFile):
self.fs.get_version("test", -4)
def test_get_version_with_metadata(self):
one = self.fs.put(b"foo", filename="test", author="author1")
@ -305,25 +320,32 @@ class TestGridfs(IntegrationTest):
three = self.fs.put(b"baz", filename="test", author="author2")
self.assertEqual(
b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read()
b"foo",
(self.fs.get_version(filename="test", author="author1", version=-2)).read(),
)
self.assertEqual(
b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read()
b"bar",
(self.fs.get_version(filename="test", author="author1", version=-1)).read(),
)
self.assertEqual(
b"foo", self.fs.get_version(filename="test", author="author1", version=0).read()
b"foo",
(self.fs.get_version(filename="test", author="author1", version=0)).read(),
)
self.assertEqual(
b"bar", self.fs.get_version(filename="test", author="author1", version=1).read()
b"bar",
(self.fs.get_version(filename="test", author="author1", version=1)).read(),
)
self.assertEqual(
b"baz", self.fs.get_version(filename="test", author="author2", version=0).read()
b"baz",
(self.fs.get_version(filename="test", author="author2", version=0)).read(),
)
self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read())
self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read())
self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=-1)).read())
self.assertEqual(b"baz", (self.fs.get_version(filename="test", version=2)).read())
self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3")
self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2)
with self.assertRaises(NoFile):
self.fs.get_version(filename="test", author="author3")
with self.assertRaises(NoFile):
self.fs.get_version(filename="test", author="author1", version=2)
self.fs.delete(one)
self.fs.delete(two)
@ -332,11 +354,12 @@ class TestGridfs(IntegrationTest):
def test_put_filelike(self):
oid = self.fs.put(BytesIO(b"hello world"), chunk_size=1)
self.assertEqual(11, self.db.fs.chunks.count_documents({}))
self.assertEqual(b"hello world", self.fs.get(oid).read())
self.assertEqual(b"hello world", (self.fs.get(oid)).read())
def test_file_exists(self):
oid = self.fs.put(b"hello")
self.assertRaises(FileExists, self.fs.put, b"world", _id=oid)
with self.assertRaises(FileExists):
self.fs.put(b"world", _id=oid)
one = self.fs.new_file(_id=123)
one.write(b"some content")
@ -345,15 +368,17 @@ class TestGridfs(IntegrationTest):
# Attempt to upload a file with more chunks to the same _id.
with patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", DEFAULT_CHUNK_SIZE):
two = self.fs.new_file(_id=123)
self.assertRaises(FileExists, two.write, b"x" * DEFAULT_CHUNK_SIZE * 3)
with self.assertRaises(FileExists):
two.write(b"x" * DEFAULT_CHUNK_SIZE * 3)
# Original file is still readable (no extra chunks were uploaded).
self.assertEqual(self.fs.get(123).read(), b"some content")
self.assertEqual((self.fs.get(123)).read(), b"some content")
two = self.fs.new_file(_id=123)
two.write(b"some content")
self.assertRaises(FileExists, two.close)
with self.assertRaises(FileExists):
two.close()
# Original file is still readable.
self.assertEqual(self.fs.get(123).read(), b"some content")
self.assertEqual((self.fs.get(123)).read(), b"some content")
def test_exists(self):
oid = self.fs.put(b"hello")
@ -381,15 +406,16 @@ class TestGridfs(IntegrationTest):
self.assertFalse(self.fs.exists({"foo": {"$gt": 12}}))
def test_put_unicode(self):
self.assertRaises(TypeError, self.fs.put, "hello")
with self.assertRaises(TypeError):
self.fs.put("hello")
oid = self.fs.put("hello", encoding="utf-8")
self.assertEqual(b"hello", self.fs.get(oid).read())
self.assertEqual("utf-8", self.fs.get(oid).encoding)
self.assertEqual(b"hello", (self.fs.get(oid)).read())
self.assertEqual("utf-8", (self.fs.get(oid)).encoding)
oid = self.fs.put("", encoding="iso-8859-1")
self.assertEqual("".encode("iso-8859-1"), self.fs.get(oid).read())
self.assertEqual("iso-8859-1", self.fs.get(oid).encoding)
self.assertEqual("".encode("iso-8859-1"), (self.fs.get(oid)).read())
self.assertEqual("iso-8859-1", (self.fs.get(oid)).encoding)
def test_missing_length_iter(self):
# Test fix that guards against PHP-237
@ -411,11 +437,13 @@ class TestGridfs(IntegrationTest):
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10)
db = client.db
gfs = gridfs.GridFS(db)
self.assertRaises(ServerSelectionTimeoutError, gfs.list)
with self.assertRaises(ServerSelectionTimeoutError):
gfs.list()
fs = gridfs.GridFS(db)
f = fs.new_file()
self.assertRaises(ServerSelectionTimeoutError, f.close)
with self.assertRaises(ServerSelectionTimeoutError):
f.close()
def test_gridfs_find(self):
self.fs.put(b"test2", filename="two")
@ -429,14 +457,15 @@ class TestGridfs(IntegrationTest):
self.assertEqual(3, files.count_documents({"filename": "two"}))
self.assertEqual(4, files.count_documents({}))
cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2)
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test1", gout.read())
cursor.rewind()
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test1", gout.read())
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test2+", gout.read())
self.assertRaises(StopIteration, cursor.__next__)
with self.assertRaises(StopIteration):
cursor.__next__()
cursor.rewind()
items = cursor.to_list()
self.assertEqual(len(items), 2)
@ -484,12 +513,12 @@ class TestGridfs(IntegrationTest):
self.fs.put(data, filename="f")
self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
self.assertEqual(data, self.fs.get_version("f").read())
self.assertEqual(data, (self.fs.get_version("f")).read())
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test)
gridfs.GridFS((self.rs_or_single_client(w=0)).pymongo_test)
def test_md5(self):
gin = self.fs.new_file()
@ -524,7 +553,7 @@ class TestGridfsReplicaSet(IntegrationTest):
self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY)
oid = fs.put(b"foo")
content = fs.get(oid).read()
content = (fs.get(oid)).read()
self.assertEqual(b"foo", content)
def test_gridfs_secondary(self):
@ -538,7 +567,8 @@ class TestGridfsReplicaSet(IntegrationTest):
fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest")
# This won't detect secondary, raises error
self.assertRaises(NotPrimaryError, fs.put, b"foo")
with self.assertRaises(NotPrimaryError):
fs.put(b"foo")
def test_gridfs_secondary_lazy(self):
# Should detect it's connected to secondary and not attempt to
@ -552,8 +582,10 @@ class TestGridfsReplicaSet(IntegrationTest):
fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest")
# Connects, doesn't create index.
self.assertRaises(NoFile, fs.get_last_version)
self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8")
with self.assertRaises(NoFile):
fs.get_last_version()
with self.assertRaises(NotPrimaryError):
fs.put("data", encoding="utf-8")
if __name__ == "__main__":

View File

@ -16,12 +16,14 @@
"""Tests for the gridfs package."""
from __future__ import annotations
import asyncio
import datetime
import itertools
import sys
import threading
import time
from io import BytesIO
from test.helpers import ConcurrentRunner
from unittest.mock import patch
sys.path[0:0] = [""]
@ -44,10 +46,12 @@ from pymongo.errors import (
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
class JustWrite(threading.Thread):
class JustWrite(ConcurrentRunner):
def __init__(self, gfs, num):
threading.Thread.__init__(self)
super().__init__()
self.gfs = gfs
self.num = num
self.daemon = True
@ -59,9 +63,9 @@ class JustWrite(threading.Thread):
file.close()
class JustRead(threading.Thread):
class JustRead(ConcurrentRunner):
def __init__(self, gfs, num, results):
threading.Thread.__init__(self)
super().__init__()
self.gfs = gfs
self.num = num
self.results = results
@ -89,12 +93,13 @@ class TestGridfs(IntegrationTest):
def test_basic(self):
oid = self.fs.upload_from_stream("test_filename", b"hello world")
self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read())
self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read())
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(1, self.db.fs.chunks.count_documents({}))
self.fs.delete(oid)
self.assertRaises(NoFile, self.fs.open_download_stream, oid)
with self.assertRaises(NoFile):
self.fs.open_download_stream(oid)
self.assertEqual(0, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
@ -111,7 +116,7 @@ class TestGridfs(IntegrationTest):
def test_empty_file(self):
oid = self.fs.upload_from_stream("test_filename", b"")
self.assertEqual(b"", self.fs.open_download_stream(oid).read())
self.assertEqual(b"", (self.fs.open_download_stream(oid)).read())
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
@ -128,10 +133,12 @@ class TestGridfs(IntegrationTest):
self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}})
try:
out = self.fs.open_download_stream(files_id)
self.assertRaises(CorruptGridFile, out.read)
with self.assertRaises(CorruptGridFile):
out.read()
out = self.fs.open_download_stream(files_id)
self.assertRaises(CorruptGridFile, out.readline)
with self.assertRaises(CorruptGridFile):
out.readline()
finally:
self.fs.delete(files_id)
@ -146,13 +153,13 @@ class TestGridfs(IntegrationTest):
self.assertTrue(
any(
info.get("key") == [("files_id", 1), ("n", 1)]
for info in chunks.index_information().values()
for info in (chunks.index_information()).values()
)
)
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in files.index_information().values()
for info in (files.index_information()).values()
)
)
@ -174,25 +181,27 @@ class TestGridfs(IntegrationTest):
self.assertTrue(
any(
info.get("key") == [("filename", 1), ("uploadDate", 1)]
for info in files.index_information().values()
for info in (files.index_information()).values()
)
)
files.drop()
def test_alt_collection(self):
oid = self.alt.upload_from_stream("test_filename", b"hello world")
self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read())
self.assertEqual(b"hello world", (self.alt.open_download_stream(oid)).read())
self.assertEqual(1, self.db.alt.files.count_documents({}))
self.assertEqual(1, self.db.alt.chunks.count_documents({}))
self.alt.delete(oid)
self.assertRaises(NoFile, self.alt.open_download_stream, oid)
with self.assertRaises(NoFile):
self.alt.open_download_stream(oid)
self.assertEqual(0, self.db.alt.files.count_documents({}))
self.assertEqual(0, self.db.alt.chunks.count_documents({}))
self.assertRaises(NoFile, self.alt.open_download_stream, "foo")
with self.assertRaises(NoFile):
self.alt.open_download_stream("foo")
self.alt.upload_from_stream("foo", b"hello world")
self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read())
self.assertEqual(b"hello world", (self.alt.open_download_stream_by_name("foo")).read())
self.alt.upload_from_stream("mike", b"")
self.alt.upload_from_stream("test", b"foo")
@ -200,7 +209,7 @@ class TestGridfs(IntegrationTest):
self.assertEqual(
{"mike", "test", "hello world", "foo"},
{k["filename"] for k in list(self.db.alt.files.find())},
{k["filename"] for k in self.db.alt.files.find().to_list()},
)
def test_threaded_reads(self):
@ -240,13 +249,14 @@ class TestGridfs(IntegrationTest):
two = two._id
three = self.fs.upload_from_stream("test", b"baz")
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read())
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test")).read())
self.fs.delete(three)
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read())
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test")).read())
self.fs.delete(two)
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read())
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test")).read())
self.fs.delete(one)
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test")
with self.assertRaises(NoFile):
self.fs.open_download_stream_by_name("test")
def test_get_version(self):
self.fs.upload_from_stream("test", b"foo")
@ -256,28 +266,30 @@ class TestGridfs(IntegrationTest):
self.fs.upload_from_stream("test", b"baz")
time.sleep(0.01)
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=0).read())
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=1).read())
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=2).read())
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=0)).read())
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=1)).read())
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=2)).read())
self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=-1).read())
self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=-2).read())
self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=-3).read())
self.assertEqual(b"baz", (self.fs.open_download_stream_by_name("test", revision=-1)).read())
self.assertEqual(b"bar", (self.fs.open_download_stream_by_name("test", revision=-2)).read())
self.assertEqual(b"foo", (self.fs.open_download_stream_by_name("test", revision=-3)).read())
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3)
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4)
with self.assertRaises(NoFile):
self.fs.open_download_stream_by_name("test", revision=3)
with self.assertRaises(NoFile):
self.fs.open_download_stream_by_name("test", revision=-4)
def test_upload_from_stream(self):
oid = self.fs.upload_from_stream("test_file", BytesIO(b"hello world"), chunk_size_bytes=1)
self.assertEqual(11, self.db.fs.chunks.count_documents({}))
self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read())
self.assertEqual(b"hello world", (self.fs.open_download_stream(oid)).read())
def test_upload_from_stream_with_id(self):
oid = ObjectId()
self.fs.upload_from_stream_with_id(
oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1
)
self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read())
self.assertEqual(b"custom id", (self.fs.open_download_stream(oid)).read())
@patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_CHUNKS", 3)
@client_context.require_failCommand_fail_point
@ -316,14 +328,14 @@ class TestGridfs(IntegrationTest):
gin = self.fs.open_upload_stream("from_stream")
gin.write(b"from stream")
gin.close()
self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read())
self.assertEqual(b"from stream", (self.fs.open_download_stream(gin._id)).read())
def test_open_upload_stream_with_id(self):
oid = ObjectId()
gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id")
gin.write(b"from stream with custom id")
gin.close()
self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read())
self.assertEqual(b"from stream with custom id", (self.fs.open_download_stream(oid)).read())
def test_missing_length_iter(self):
# Test fix that guards against PHP-237
@ -345,12 +357,12 @@ class TestGridfs(IntegrationTest):
client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0)
cdb = client.db
gfs = gridfs.GridFSBucket(cdb)
self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0)
with self.assertRaises(ServerSelectionTimeoutError):
gfs.delete(0)
gfs = gridfs.GridFSBucket(cdb)
self.assertRaises(
ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b""
) # Still no connection.
with self.assertRaises(ServerSelectionTimeoutError):
gfs.upload_from_stream("test", b"") # Still no connection.
def test_gridfs_find(self):
self.fs.upload_from_stream("two", b"test2")
@ -366,14 +378,15 @@ class TestGridfs(IntegrationTest):
cursor = self.fs.find(
{}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2
)
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test1", gout.read())
cursor.rewind()
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test1", gout.read())
gout = next(cursor)
gout = cursor.next()
self.assertEqual(b"test2+", gout.read())
self.assertRaises(StopIteration, cursor.__next__)
with self.assertRaises(StopIteration):
cursor.next()
cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
@ -383,20 +396,21 @@ class TestGridfs(IntegrationTest):
self.fs.upload_from_stream("f", data)
self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}})
self.assertEqual(data, self.fs.open_download_stream_by_name("f").read())
self.assertEqual(data, (self.fs.open_download_stream_by_name("f")).read())
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test)
gridfs.GridFSBucket((self.rs_or_single_client(w=0)).pymongo_test)
def test_rename(self):
_id = self.fs.upload_from_stream("first_name", b"testing")
self.assertEqual(b"testing", self.fs.open_download_stream_by_name("first_name").read())
self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("first_name")).read())
self.fs.rename(_id, "second_name")
self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name")
self.assertEqual(b"testing", self.fs.open_download_stream_by_name("second_name").read())
with self.assertRaises(NoFile):
self.fs.open_download_stream_by_name("first_name")
self.assertEqual(b"testing", (self.fs.open_download_stream_by_name("second_name")).read())
@patch("gridfs.synchronous.grid_file._UPLOAD_BUFFER_SIZE", 5)
def test_abort(self):
@ -407,7 +421,8 @@ class TestGridfs(IntegrationTest):
self.assertEqual(3, self.db.fs.chunks.count_documents({"files_id": gin._id}))
gin.abort()
self.assertTrue(gin.closed)
self.assertRaises(ValueError, gin.write, b"test4")
with self.assertRaises(ValueError):
gin.write(b"test4")
self.assertEqual(0, self.db.fs.chunks.count_documents({"files_id": gin._id}))
def test_download_to_stream(self):
@ -490,7 +505,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest")
oid = gfs.upload_from_stream("test_filename", b"foo")
content = gfs.open_download_stream(oid).read()
content = (gfs.open_download_stream(oid)).read()
self.assertEqual(b"foo", content)
def test_gridfs_secondary(self):
@ -504,7 +519,8 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
gfs = gridfs.GridFSBucket(secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest")
# This won't detect secondary, raises error
self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"foo")
with self.assertRaises(NotPrimaryError):
gfs.upload_from_stream("test_filename", b"foo")
def test_gridfs_secondary_lazy(self):
# Should detect it's connected to secondary and not attempt to
@ -518,8 +534,10 @@ class TestGridfsBucketReplicaSet(IntegrationTest):
gfs = gridfs.GridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest")
# Connects, doesn't create index.
self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename")
self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data")
with self.assertRaises(NoFile):
gfs.open_download_stream_by_name("test_filename")
with self.assertRaises(NotPrimaryError):
gfs.upload_from_stream("test_filename", b"data")
if __name__ == "__main__":

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
from pymongo import MongoClient, ReadPreference
from pymongo.errors import ServerSelectionTimeoutError
@ -43,11 +44,17 @@ from test.utils_selection_tests import (
make_server_description,
)
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.path.join("server_selection", "server_selection"),
)
if _IS_SYNC:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent, "server_selection", "server_selection"
)
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "server_selection"
)
class SelectionStoreSelector:
@ -61,7 +68,7 @@ class SelectionStoreSelector:
return selection
class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore
class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore
pass
@ -79,13 +86,12 @@ class TestCustomServerSelectorFunction(IntegrationTest):
client = self.rs_or_single_client(
server_selector=custom_selector, event_listeners=[listener]
)
self.addCleanup(client.close)
coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll
self.addCleanup(client.drop_database, "testdb")
# Wait the node list to be fully populated.
def all_hosts_started():
return len(client.admin.command(HelloCompat.LEGACY_CMD)["hosts"]) == len(
return len((client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len(
client._topology._description.readable_servers
)
@ -121,7 +127,6 @@ class TestCustomServerSelectorFunction(IntegrationTest):
# Client setup.
mongo_client = self.rs_or_single_client(server_selector=selector)
test_collection = mongo_client.testdb.test_collection
self.addCleanup(mongo_client.close)
self.addCleanup(mongo_client.drop_database, "testdb")
# Do N operations and test selector is called at least N times.

View File

@ -15,9 +15,12 @@
"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations
import asyncio
import os
import threading
from pathlib import Path
from test import IntegrationTest, client_context, unittest
from test.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
OvertCommandListener,
@ -32,10 +35,14 @@ from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window")
)
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
)
class TestAllScenarios(unittest.TestCase):
@ -92,7 +99,7 @@ class CustomSpecTestCreator(SpecTestCreator):
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
class FinderThread(threading.Thread):
class FinderTask(ConcurrentRunner):
def __init__(self, collection, iterations):
super().__init__()
self.daemon = True
@ -109,17 +116,17 @@ class FinderThread(threading.Thread):
class TestProse(IntegrationTest):
def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_THREADS = 10
threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
for thread in threads:
self.assertTrue(thread.passed)
N_TASKS = 10
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
task.start()
for task in tasks:
task.join()
for task in tasks:
self.assertTrue(task.passed)
events = listener.started_events
self.assertEqual(len(events), n_finds * N_THREADS)
self.assertEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}

View File

@ -666,6 +666,11 @@ def joinall(threads):
assert not t.is_alive(), "Thread %s hung" % t
async def async_joinall(tasks):
"""Join threads with a 5-minute timeout, assert joins succeeded"""
await asyncio.wait([t.task for t in tasks if t is not None], timeout=300)
def wait_until(predicate, success_description, timeout=10):
"""Wait up to 10 seconds (by default) for predicate to be true.
@ -827,7 +832,7 @@ async def async_get_pools(client):
"""Get all pools."""
return [
server.pool
async for server in await (await client._get_topology()).select_servers(
for server in await (await client._get_topology()).select_servers(
any_server_selector, _Op.TEST
)
]

View File

@ -18,96 +18,28 @@ from __future__ import annotations
import datetime
import os
import sys
from test import PyMongoTestCase
sys.path[0:0] = [""]
from test import unittest
from test.pymongo_mocks import DummyMonitor
from test.utils import MockPool, parse_read_preference
from test.utils_selection_tests_shared import (
get_addresses,
get_topology_type_name,
make_server_description,
)
from bson import json_util
from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node
from pymongo.common import HEARTBEAT_FREQUENCY
from pymongo.errors import AutoReconnect, ConfigurationError
from pymongo.hello import Hello, HelloCompat
from pymongo.operations import _Op
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology
def get_addresses(server_list):
seeds = []
hosts = []
for server in server_list:
seeds.append(clean_node(server["address"]))
hosts.append(server["address"])
return seeds, hosts
def make_last_write_date(server):
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
millis = server.get("lastWrite", {}).get("lastWriteDate")
if millis:
diff = ((millis % 1000) + 1000) % 1000
seconds = (millis - diff) / 1000
micros = diff * 1000
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
else:
# "Unknown" server.
return epoch
def make_server_description(server, hosts):
"""Make a ServerDescription from server info in a JSON test."""
server_type = server["type"]
if server_type in ("Unknown", "PossiblePrimary"):
return ServerDescription(clean_node(server["address"]), Hello({}))
hello_response = {"ok": True, "hosts": hosts}
if server_type not in ("Standalone", "Mongos", "RSGhost"):
hello_response["setName"] = "rs"
if server_type == "RSPrimary":
hello_response[HelloCompat.LEGACY_CMD] = True
elif server_type == "RSSecondary":
hello_response["secondary"] = True
elif server_type == "Mongos":
hello_response["msg"] = "isdbgrid"
elif server_type == "RSGhost":
hello_response["isreplicaset"] = True
elif server_type == "RSArbiter":
hello_response["arbiterOnly"] = True
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
if field in server:
hello_response[field] = server[field]
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
# Sets _last_update_time to now.
sd = ServerDescription(
clean_node(server["address"]),
Hello(hello_response),
round_trip_time=server["avg_rtt_ms"] / 1000.0,
)
if "lastUpdateTime" in server:
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
return sd
def get_topology_type_name(scenario_def):
td = scenario_def["topology_description"]
name = td["type"]
if name == "Unknown":
# PyMongo never starts a topology in type Unknown.
return "Sharded" if len(td["servers"]) > 1 else "Single"
else:
return name
_IS_SYNC = True
def get_topology_settings_dict(**kwargs):
@ -244,7 +176,7 @@ def create_test(scenario_def):
def create_selection_tests(test_dir):
class TestAllScenarios(unittest.TestCase):
class TestAllScenarios(PyMongoTestCase):
pass
for dirpath, _, filenames in os.walk(test_dir):

View File

@ -0,0 +1,100 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for testing Server Selection and Max Staleness."""
from __future__ import annotations
import datetime
import os
import sys
sys.path[0:0] = [""]
from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node
from pymongo.hello import Hello, HelloCompat
from pymongo.server_description import ServerDescription
def get_addresses(server_list):
seeds = []
hosts = []
for server in server_list:
seeds.append(clean_node(server["address"]))
hosts.append(server["address"])
return seeds, hosts
def make_last_write_date(server):
epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
millis = server.get("lastWrite", {}).get("lastWriteDate")
if millis:
diff = ((millis % 1000) + 1000) % 1000
seconds = (millis - diff) / 1000
micros = diff * 1000
return epoch + datetime.timedelta(seconds=seconds, microseconds=micros)
else:
# "Unknown" server.
return epoch
def make_server_description(server, hosts):
"""Make a ServerDescription from server info in a JSON test."""
server_type = server["type"]
if server_type in ("Unknown", "PossiblePrimary"):
return ServerDescription(clean_node(server["address"]), Hello({}))
hello_response = {"ok": True, "hosts": hosts}
if server_type not in ("Standalone", "Mongos", "RSGhost"):
hello_response["setName"] = "rs"
if server_type == "RSPrimary":
hello_response[HelloCompat.LEGACY_CMD] = True
elif server_type == "RSSecondary":
hello_response["secondary"] = True
elif server_type == "Mongos":
hello_response["msg"] = "isdbgrid"
elif server_type == "RSGhost":
hello_response["isreplicaset"] = True
elif server_type == "RSArbiter":
hello_response["arbiterOnly"] = True
hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)}
for field in "maxWireVersion", "tags", "idleWritePeriodMillis":
if field in server:
hello_response[field] = server[field]
hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION)
# Sets _last_update_time to now.
sd = ServerDescription(
clean_node(server["address"]),
Hello(hello_response),
round_trip_time=server["avg_rtt_ms"] / 1000.0,
)
if "lastUpdateTime" in server:
sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec.
return sd
def get_topology_type_name(scenario_def):
td = scenario_def["topology_description"]
name = td["type"]
if name == "Unknown":
# PyMongo never starts a topology in type Unknown.
return "Sharded" if len(td["servers"]) > 1 else "Single"
else:
return name

View File

@ -122,7 +122,9 @@ replacements = {
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
"StopAsyncIteration": "StopIteration",
"create_async_event": "create_event",
"async_joinall": "joinall",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -169,7 +171,7 @@ gridfs_files = [
def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py", "test_concurrency.py"]
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]
test_files = [
@ -202,6 +204,7 @@ converted_tests = [
"test_comment.py",
"test_common.py",
"test_connection_logging.py",
"test_connection_monitoring.py",
"test_connections_survive_primary_stepdown_spec.py",
"test_create_entities.py",
"test_crud_unified.py",
@ -212,12 +215,14 @@ converted_tests = [
"test_dns.py",
"test_encryption.py",
"test_examples.py",
"test_grid_file.py",
"test_gridfs.py",
"test_gridfs_bucket.py",
"test_gridfs_spec.py",
"test_heartbeat_monitoring.py",
"test_index_management.py",
"test_grid_file.py",
"test_load_balancer.py",
"test_json_util_integration.py",
"test_gridfs_spec.py",
"test_load_balancer.py",
"test_logger.py",
"test_max_staleness.py",
"test_monitoring.py",
@ -233,9 +238,11 @@ converted_tests = [
"test_retryable_writes_unified.py",
"test_run_command.py",
"test_sdam_monitoring_spec.py",
"test_server_selection.py",
"test_server_selection_in_window.py",
"test_server_selection_logging.py",
"test_session.py",
"test_server_selection_rtt.py",
"test_session.py",
"test_sessions_unified.py",
"test_srv_polling.py",
"test_ssl.py",
@ -245,6 +252,7 @@ converted_tests = [
"test_unified_format.py",
"test_versioned_api_integration.py",
"unified_format.py",
"utils_selection_tests.py",
]

2
uv.lock generated
View File

@ -997,7 +997,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5
[[package]]
name = "pymongo"
version = "4.11.0.dev0"
version = "4.12.0.dev0"
source = { editable = "." }
dependencies = [
{ name = "dnspython" },