From 3840d9dd0fe292e08e510310a03f3be4a8b5bbee Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 3 Sep 2024 13:26:11 -0400 Subject: [PATCH 1/6] Add script to help convert sync tests to async tests (#1825) --- CONTRIBUTING.md | 7 ++ tools/convert_test_to_async.py | 141 +++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tools/convert_test_to_async.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c84447013..42cc8dc1b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo. ## Making a Release Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process). + +## Converting a test to async +The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a +partially-converted asynchronous version of the same name to the `test/asynchronous` directory. +Use this generated file as a starting point for the completed conversion. + +The script is used like so: `python tools/convert_test_to_async.py [test_file.py]` diff --git a/tools/convert_test_to_async.py b/tools/convert_test_to_async.py new file mode 100644 index 000000000..dbdb217c8 --- /dev/null +++ b/tools/convert_test_to_async.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import asyncio +import sys + +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.asynchronous.database import AsyncDatabase + +replacements = { + "Collection": "AsyncCollection", + "Database": "AsyncDatabase", + "Cursor": "AsyncCursor", + "MongoClient": "AsyncMongoClient", + "CommandCursor": "AsyncCommandCursor", + "RawBatchCursor": "AsyncRawBatchCursor", + "RawBatchCommandCursor": "AsyncRawBatchCommandCursor", + "ClientSession": "AsyncClientSession", + "ChangeStream": "AsyncChangeStream", + "CollectionChangeStream": "AsyncCollectionChangeStream", + "DatabaseChangeStream": "AsyncDatabaseChangeStream", + "ClusterChangeStream": "AsyncClusterChangeStream", + "_Bulk": "_AsyncBulk", + "_ClientBulk": "_AsyncClientBulk", + "Connection": "AsyncConnection", + "synchronous": "asynchronous", + "Synchronous": "Asynchronous", + "next": "await anext", + "_Lock": "_ALock", + "_Condition": "_ACondition", + "GridFS": "AsyncGridFS", + "GridFSBucket": "AsyncGridFSBucket", + "GridIn": "AsyncGridIn", + "GridOut": "AsyncGridOut", + "GridOutCursor": "AsyncGridOutCursor", + "GridOutIterator": "AsyncGridOutIterator", + "GridOutChunkIterator": "_AsyncGridOutChunkIterator", + "_grid_in_property": "_a_grid_in_property", + "_grid_out_property": "_a_grid_out_property", + "ClientEncryption": "AsyncClientEncryption", + "MongoCryptCallback": "AsyncMongoCryptCallback", + "ExplicitEncrypter": "AsyncExplicitEncrypter", + "AutoEncrypter": "AsyncAutoEncrypter", + "ContextManager": "AsyncContextManager", + "ClientContext": "AsyncClientContext", + "TestCollection": "AsyncTestCollection", + "IntegrationTest": "AsyncIntegrationTest", + "PyMongoTestCase": "AsyncPyMongoTestCase", + "MockClientTest": "AsyncMockClientTest", + "client_context": "async_client_context", + "setUp": "asyncSetUp", + "tearDown": "asyncTearDown", + "wait_until": "await async_wait_until", + "addCleanup": "addAsyncCleanup", + "TestCase": "IsolatedAsyncioTestCase", + "UnitTest": "AsyncUnitTest", + "MockClient": "AsyncMockClient", + "SpecRunner": "AsyncSpecRunner", + "TransactionsBase": "AsyncTransactionsBase", + "get_pool": "await async_get_pool", + "is_mongos": "await async_is_mongos", + "rs_or_single_client": "await async_rs_or_single_client", + "rs_or_single_client_noauth": "await async_rs_or_single_client_noauth", + "rs_client": "await async_rs_client", + "single_client": "await async_single_client", + "from_client": "await async_from_client", + "closing": "aclosing", + "assertRaisesExactly": "asyncAssertRaisesExactly", + "get_mock_client": "await get_async_mock_client", + "close": "await aclose", +} + +async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor] + + +def get_async_methods() -> set[str]: + result: set[str] = set() + for x in async_classes: + methods = { + k + for k, v in vars(x).items() + if callable(v) + and not isinstance(v, classmethod) + and asyncio.iscoroutinefunction(v) + and v.__name__[0] != "_" + } + result = result | methods + return result + + +async_methods = get_async_methods() + + +def apply_replacements(lines: list[str]) -> list[str]: + for i in range(len(lines)): + if "_IS_SYNC = True" in lines[i]: + lines[i] = "_IS_SYNC = False" + if "def test" in lines[i]: + lines[i] = lines[i].replace("def test", "async def test") + for k in replacements: + if k in lines[i]: + lines[i] = lines[i].replace(k, replacements[k]) + for k in async_methods: + if k + "(" in lines[i]: + tokens = lines[i].split(" ") + for j in range(len(tokens)): + if k + "(" in tokens[j]: + if j < 2: + tokens.insert(0, "await") + else: + tokens.insert(j, "await") + break + new_line = " ".join(tokens) + + lines[i] = new_line + + return lines + + +def process_file(input_file: str, output_file: str) -> None: + with open(input_file, "r+") as f: + lines = f.readlines() + lines = apply_replacements(lines) + + with open(output_file, "w+") as f2: + f2.seek(0) + f2.writelines(lines) + f2.truncate() + + +def main() -> None: + args = sys.argv[1:] + sync_file = "./test/" + args[0] + async_file = "./" + args[0] + + process_file(sync_file, async_file) + + +main() From ba8a139e7220a342cddbd4efcb7e937254345f5a Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:18:58 -0700 Subject: [PATCH 2/6] PYTHON-4651: Migrate test_client_context.py to async (#1819) --- .github/workflows/test-python.yml | 2 +- test/__init__.py | 4 +- test/asynchronous/__init__.py | 4 +- test/asynchronous/test_client_context.py | 66 ++++++++++++++++++++++++ test/test_client_context.py | 6 ++- tools/synchro.py | 1 + 6 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 test/asynchronous/test_client_context.py diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 036b2c4b7..921168c13 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -209,4 +209,4 @@ jobs: ls which python pip install -e ".[test]" - PYMONGO_MUST_CONNECT=1 pytest -v test/test_client_context.py + PYMONGO_MUST_CONNECT=1 pytest -v -k client_context diff --git a/test/__init__.py b/test/__init__.py index 2a23ae0fd..d978d7da3 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -947,11 +947,11 @@ class UnitTest(PyMongoTestCase): @classmethod def _setup_class(cls): - cls._setup_class() + pass @classmethod def _tearDown_class(cls): - cls._tearDown_class() + pass class IntegrationTest(PyMongoTestCase): diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 3d22b5ff7..def4bc1b8 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -949,11 +949,11 @@ class AsyncUnitTest(AsyncPyMongoTestCase): @classmethod async def _setup_class(cls): - await cls._setup_class() + pass @classmethod async def _tearDown_class(cls): - await cls._tearDown_class() + pass class AsyncIntegrationTest(AsyncPyMongoTestCase): diff --git a/test/asynchronous/test_client_context.py b/test/asynchronous/test_client_context.py new file mode 100644 index 000000000..a0cb53a14 --- /dev/null +++ b/test/asynchronous/test_client_context.py @@ -0,0 +1,66 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncUnitTest, SkipTest, async_client_context, unittest + +_IS_SYNC = False + + +class TestAsyncClientContext(AsyncUnitTest): + def test_must_connect(self): + if "PYMONGO_MUST_CONNECT" not in os.environ: + raise SkipTest("PYMONGO_MUST_CONNECT is not set") + + self.assertTrue( + async_client_context.connected, + "client context must be connected when " + "PYMONGO_MUST_CONNECT is set. Failed attempts:\n{}".format( + async_client_context.connection_attempt_info() + ), + ) + + def test_serverless(self): + if "TEST_SERVERLESS" not in os.environ: + raise SkipTest("TEST_SERVERLESS is not set") + + self.assertTrue( + async_client_context.connected and async_client_context.serverless, + "client context must be connected to serverless when " + f"TEST_SERVERLESS is set. Failed attempts:\n{async_client_context.connection_attempt_info()}", + ) + + def test_enableTestCommands_is_disabled(self): + if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ: + raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set") + + self.assertFalse( + async_client_context.test_commands_enabled, + "enableTestCommands must be disabled when PYMONGO_DISABLE_TEST_COMMANDS is set.", + ) + + def test_setdefaultencoding_worked(self): + if "SETDEFAULTENCODING" not in os.environ: + raise SkipTest("SETDEFAULTENCODING is not set") + + self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client_context.py b/test/test_client_context.py index 196647cb0..be8a56214 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -18,10 +18,12 @@ import sys sys.path[0:0] = [""] -from test import SkipTest, client_context, unittest +from test import SkipTest, UnitTest, client_context, unittest + +_IS_SYNC = True -class TestClientContext(unittest.TestCase): +class TestClientContext(UnitTest): def test_must_connect(self): if "PYMONGO_MUST_CONNECT" not in os.environ: raise SkipTest("PYMONGO_MUST_CONNECT is not set") diff --git a/tools/synchro.py b/tools/synchro.py index 6fb711674..adc0de297 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -163,6 +163,7 @@ converted_tests = [ "test_logger.py", "test_session.py", "test_transactions.py", + "test_client_context.py", ] sync_test_files = [ From 5a70039ad20d7f10fb324aec6cd1661cf62f720c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 3 Sep 2024 16:57:41 -0400 Subject: [PATCH 3/6] PYTHON-4701 - Topology logging should use suppress_event (#1826) --- pymongo/asynchronous/topology.py | 4 ++-- pymongo/synchronous/topology.py | 4 ++-- test/unified_format.py | 6 +++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 2df30d244..8a46e7fec 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -475,7 +475,7 @@ class Topology: if server: await server.pool.ready() - suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description + suppress_event = sd_old == server_description if self._publish_server and not suppress_event: assert self._events is not None self._events.put( @@ -497,7 +497,7 @@ class Topology: (td_old, self._description, self._topology_id), ) ) - if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): + if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event: _debug_log( _SDAM_LOGGER, topologyId=self._topology_id, diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 54a9d8a69..9932d2cbd 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -475,7 +475,7 @@ class Topology: if server: server.pool.ready() - suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description + suppress_event = sd_old == server_description if self._publish_server and not suppress_event: assert self._events is not None self._events.put( @@ -497,7 +497,7 @@ class Topology: (td_old, self._description, self._topology_id), ) ) - if _SDAM_LOGGER.isEnabledFor(logging.DEBUG): + if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event: _debug_log( _SDAM_LOGGER, topologyId=self._topology_id, diff --git a/test/unified_format.py b/test/unified_format.py index 99fe0b169..e4ebf677e 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1954,7 +1954,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest): if client.get("ignoreExtraMessages", False): actual_logs = actual_logs[: len(client["messages"])] - self.assertEqual(len(client["messages"]), len(actual_logs)) + self.assertEqual( + len(client["messages"]), + len(actual_logs), + f"expected {client['messages']} but got {actual_logs}", + ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") From 5a49ccc759665825c32f1cba9f780f195daf890f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 4 Sep 2024 08:57:59 -0400 Subject: [PATCH 4/6] PYTHON-4590 - Add type guards to async API methods (#1820) --- pymongo/asynchronous/collection.py | 4 ++++ pymongo/asynchronous/database.py | 5 +++++ pymongo/asynchronous/encryption.py | 12 +++++++++--- pymongo/asynchronous/mongo_client.py | 3 +++ pymongo/synchronous/collection.py | 4 ++++ pymongo/synchronous/database.py | 5 +++++ pymongo/synchronous/encryption.py | 12 +++++++++--- pymongo/synchronous/mongo_client.py | 3 +++ test/helpers.py | 10 ---------- 9 files changed, 42 insertions(+), 16 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index e5a54c090..6d8dfaf89 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -228,6 +228,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): ) if not isinstance(name, str): raise TypeError("name must be an instance of str") + from pymongo.asynchronous.database import AsyncDatabase + + if not isinstance(database, AsyncDatabase): + raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index b61d58183..d5eec0134 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -119,9 +119,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]): read_concern or client.read_concern, ) + from pymongo.asynchronous.mongo_client import AsyncMongoClient + if not isinstance(name, str): raise TypeError("name must be an instance of str") + if not isinstance(client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 3fb00c6ca..c4cb886df 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -194,9 +194,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) - async def collection_info( - self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes - ) -> Optional[bytes]: + async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -598,6 +596,9 @@ class AsyncClientEncryption(Generic[_DocumentType]): if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + if not isinstance(key_vault_client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}") + self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client @@ -683,6 +684,11 @@ class AsyncClientEncryption(Generic[_DocumentType]): https://mongodb.com/docs/manual/reference/command/create """ + if not isinstance(database, AsyncDatabase): + raise TypeError( + f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given" + ) + encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): if isinstance(field, dict) and field.get("keyId") is None: diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 05e4e80f1..2af773c44 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler: def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): + if not isinstance(client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(client)}") + self.client = client self.server_address = server.description.address self.session = session diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 54db3a56b..93e24432e 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -231,6 +231,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]): ) if not isinstance(name, str): raise TypeError("name must be an instance of str") + from pymongo.synchronous.database import Database + + if not isinstance(database, Database): + raise TypeError(f"Collection requires a Database but {type(database)} given") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 93a998528..1cd8ee643 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -119,9 +119,14 @@ class Database(common.BaseObject, Generic[_DocumentType]): read_concern or client.read_concern, ) + from pymongo.synchronous.mongo_client import MongoClient + if not isinstance(name, str): raise TypeError("name must be an instance of str") + if not isinstance(client, MongoClient): + raise TypeError(f"MongoClient required but given {type(client)}") + if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index e06ddad93..2efa99597 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -194,9 +194,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc] # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) - def collection_info( - self, database: Database[Mapping[str, Any]], filter: bytes - ) -> Optional[bytes]: + def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -596,6 +594,9 @@ class ClientEncryption(Generic[_DocumentType]): if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + if not isinstance(key_vault_client, MongoClient): + raise TypeError(f"MongoClient required but given {type(key_vault_client)}") + self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client @@ -681,6 +682,11 @@ class ClientEncryption(Generic[_DocumentType]): https://mongodb.com/docs/manual/reference/command/create """ + if not isinstance(database, Database): + raise TypeError( + f"create_encrypted_collection() requires a Database but {type(database)} given" + ) + encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): if isinstance(field, dict) and field.get("keyId") is None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 77e029a7c..6c5f68b7e 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler: ) def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): + if not isinstance(client, MongoClient): + raise TypeError(f"MongoClient required but given {type(client)}") + self.client = client self.server_address = server.description.address self.session = session diff --git a/test/helpers.py b/test/helpers.py index d136e5b8d..b38b2e298 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -35,23 +35,13 @@ try: HAVE_IPADDRESS = True except ImportError: HAVE_IPADDRESS = False -from contextlib import contextmanager from functools import wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, no_type_check from unittest import SkipTest -from urllib.parse import quote_plus -import pymongo -import pymongo.errors from bson.son import SON from pymongo import common, message -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.database import Database -from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri if HAVE_SSL: From 4e74c8274e7c4cb7658d445d526dcc33ced1750b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 4 Sep 2024 08:58:14 -0400 Subject: [PATCH 5/6] PYTHON-4669 - Update Async GridFS APIs for Motor Compatibility (#1821) --- gridfs/asynchronous/grid_file.py | 101 ++-- gridfs/grid_file_shared.py | 19 + gridfs/synchronous/grid_file.py | 97 ++-- pymongo/asynchronous/helpers.py | 5 + pymongo/asynchronous/topology.py | 2 +- pymongo/synchronous/helpers.py | 5 + pymongo/synchronous/topology.py | 2 +- test/asynchronous/test_grid_file.py | 871 ++++++++++++++++++++++++++++ test/test_grid_file.py | 114 +++- tools/synchro.py | 4 + 10 files changed, 1115 insertions(+), 105 deletions(-) create mode 100644 test/asynchronous/test_grid_file.py diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index afc1a0f75..a49d51d30 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -1176,24 +1176,6 @@ class AsyncGridIn: raise AttributeError("GridIn object has no attribute '%s'" % name) def __setattr__(self, name: str, value: Any) -> None: - # For properties of this instance like _buffer, or descriptors set on - # the class like filename, use regular __setattr__ - if name in self.__dict__ or name in self.__class__.__dict__: - object.__setattr__(self, name, value) - else: - if _IS_SYNC: - # All other attributes are part of the document in db.fs.files. - # Store them to be sent to server on close() or if closed, send - # them now. - self._file[name] = value - if self._closed: - self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) - else: - raise AttributeError( - "AsyncGridIn does not support __setattr__. Use AsyncGridIn.set() instead" - ) - - async def set(self, name: str, value: Any) -> None: # For properties of this instance like _buffer, or descriptors set on # the class like filename, use regular __setattr__ if name in self.__dict__ or name in self.__class__.__dict__: @@ -1204,9 +1186,17 @@ class AsyncGridIn: # them now. self._file[name] = value if self._closed: - await self._coll.files.update_one( - {"_id": self._file["_id"]}, {"$set": {name: value}} - ) + if _IS_SYNC: + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) + else: + raise AttributeError( + "AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead" + ) + + async def set(self, name: str, value: Any) -> None: + self._file[name] = value + if self._closed: + await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) async def _flush_data(self, data: Any, force: bool = False) -> None: """Flush `data` to a chunk.""" @@ -1400,7 +1390,11 @@ class AsyncGridIn: return False -class AsyncGridOut(io.IOBase): +GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any + + +class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore + """Class to read data out of GridFS.""" def __init__( @@ -1460,6 +1454,8 @@ class AsyncGridOut(io.IOBase): self._position = 0 self._file = file_document self._session = session + if not _IS_SYNC: + self.closed = False _id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.") filename: str = _a_grid_out_property("filename", "Name of this file.") @@ -1486,16 +1482,43 @@ class AsyncGridOut(io.IOBase): _file: Any _chunk_iter: Any - async def __anext__(self) -> bytes: - return super().__next__() + if not _IS_SYNC: + closed: bool - def __next__(self) -> bytes: # noqa: F811, RUF100 - if _IS_SYNC: - return super().__next__() - else: - raise TypeError( - "AsyncGridOut does not support synchronous iteration. Use `async for` instead" - ) + async def __anext__(self) -> bytes: + line = await self.readline() + if line: + return line + raise StopAsyncIteration() + + async def to_list(self) -> list[bytes]: + return [x async for x in self] # noqa: C416, RUF100 + + async def readline(self, size: int = -1) -> bytes: + """Read one line or up to `size` bytes from the file. + + :param size: the maximum number of bytes to read + """ + return await self._read_size_or_line(size=size, line=True) + + async def readlines(self, size: int = -1) -> list[bytes]: + """Read one line or up to `size` bytes from the file. + + :param size: the maximum number of bytes to read + """ + await self.open() + lines = [] + remainder = int(self.length) - self._position + bytes_read = 0 + while remainder > 0: + line = await self._read_size_or_line(line=True) + bytes_read += len(line) + lines.append(line) + remainder = int(self.length) - self._position + if 0 < size < bytes_read: + break + + return lines async def open(self) -> None: if not self._file: @@ -1616,18 +1639,11 @@ class AsyncGridOut(io.IOBase): """ return await self._read_size_or_line(size=size) - async def readline(self, size: int = -1) -> bytes: # type: ignore[override] - """Read one line or up to `size` bytes from the file. - - :param size: the maximum number of bytes to read - """ - return await self._read_size_or_line(size=size, line=True) - def tell(self) -> int: """Return the current position of this file.""" return self._position - async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override] + async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: """Set the current position of this file. :param pos: the position (or offset if using relative @@ -1690,12 +1706,15 @@ class AsyncGridOut(io.IOBase): """ return self - async def close(self) -> None: # type: ignore[override] + async def close(self) -> None: """Make GridOut more generically file-like.""" if self._chunk_iter: await self._chunk_iter.close() self._chunk_iter = None - super().close() + if _IS_SYNC: + super().close() + else: + self.closed = True def write(self, value: Any) -> NoReturn: raise io.UnsupportedOperation("write") diff --git a/gridfs/grid_file_shared.py b/gridfs/grid_file_shared.py index b6f02a53d..79a0ad7f8 100644 --- a/gridfs/grid_file_shared.py +++ b/gridfs/grid_file_shared.py @@ -38,7 +38,15 @@ def _a_grid_in_property( ) -> Any: """Create a GridIn property.""" + warn_str = "" + if docstring.startswith("DEPRECATED,"): + warn_str = ( + f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0" + ) + def getter(self: Any) -> Any: + if warn_str: + warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning) if closed_only and not self._closed: raise AttributeError("can only get %r on a closed file" % field_name) # Protect against PHP-237 @@ -46,6 +54,15 @@ def _a_grid_in_property( return self._file.get(field_name, 0) return self._file.get(field_name, None) + def setter(self: Any, value: Any) -> Any: + if warn_str: + warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning) + if self._closed: + raise InvalidOperation( + "AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead" + ) + self._file[field_name] = value + if read_only: docstring += "\n\nThis attribute is read-only." elif closed_only: @@ -56,6 +73,8 @@ def _a_grid_in_property( "has been called.", ) + if not read_only and not closed_only: + return property(getter, setter, doc=docstring) return property(getter, doc=docstring) diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 80015f96e..655f05f57 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -1166,24 +1166,6 @@ class GridIn: raise AttributeError("GridIn object has no attribute '%s'" % name) def __setattr__(self, name: str, value: Any) -> None: - # For properties of this instance like _buffer, or descriptors set on - # the class like filename, use regular __setattr__ - if name in self.__dict__ or name in self.__class__.__dict__: - object.__setattr__(self, name, value) - else: - if _IS_SYNC: - # All other attributes are part of the document in db.fs.files. - # Store them to be sent to server on close() or if closed, send - # them now. - self._file[name] = value - if self._closed: - self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) - else: - raise AttributeError( - "GridIn does not support __setattr__. Use GridIn.set() instead" - ) - - def set(self, name: str, value: Any) -> None: # For properties of this instance like _buffer, or descriptors set on # the class like filename, use regular __setattr__ if name in self.__dict__ or name in self.__class__.__dict__: @@ -1194,7 +1176,17 @@ class GridIn: # them now. self._file[name] = value if self._closed: - self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) + if _IS_SYNC: + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) + else: + raise AttributeError( + "GridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use GridIn.set() instead" + ) + + def set(self, name: str, value: Any) -> None: + self._file[name] = value + if self._closed: + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) def _flush_data(self, data: Any, force: bool = False) -> None: """Flush `data` to a chunk.""" @@ -1388,7 +1380,11 @@ class GridIn: return False -class GridOut(io.IOBase): +GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any + + +class GridOut(GRIDOUT_BASE_CLASS): # type: ignore + """Class to read data out of GridFS.""" def __init__( @@ -1448,6 +1444,8 @@ class GridOut(io.IOBase): self._position = 0 self._file = file_document self._session = session + if not _IS_SYNC: + self.closed = False _id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.") filename: str = _grid_out_property("filename", "Name of this file.") @@ -1474,14 +1472,43 @@ class GridOut(io.IOBase): _file: Any _chunk_iter: Any - def __next__(self) -> bytes: - return super().__next__() + if not _IS_SYNC: + closed: bool - def __next__(self) -> bytes: # noqa: F811, RUF100 - if _IS_SYNC: - return super().__next__() - else: - raise TypeError("GridOut does not support synchronous iteration. Use `for` instead") + def __next__(self) -> bytes: + line = self.readline() + if line: + return line + raise StopIteration() + + def to_list(self) -> list[bytes]: + return [x for x in self] # noqa: C416, RUF100 + + def readline(self, size: int = -1) -> bytes: + """Read one line or up to `size` bytes from the file. + + :param size: the maximum number of bytes to read + """ + return self._read_size_or_line(size=size, line=True) + + def readlines(self, size: int = -1) -> list[bytes]: + """Read one line or up to `size` bytes from the file. + + :param size: the maximum number of bytes to read + """ + self.open() + lines = [] + remainder = int(self.length) - self._position + bytes_read = 0 + while remainder > 0: + line = self._read_size_or_line(line=True) + bytes_read += len(line) + lines.append(line) + remainder = int(self.length) - self._position + if 0 < size < bytes_read: + break + + return lines def open(self) -> None: if not self._file: @@ -1602,18 +1629,11 @@ class GridOut(io.IOBase): """ return self._read_size_or_line(size=size) - def readline(self, size: int = -1) -> bytes: # type: ignore[override] - """Read one line or up to `size` bytes from the file. - - :param size: the maximum number of bytes to read - """ - return self._read_size_or_line(size=size, line=True) - def tell(self) -> int: """Return the current position of this file.""" return self._position - def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override] + def seek(self, pos: int, whence: int = _SEEK_SET) -> int: """Set the current position of this file. :param pos: the position (or offset if using relative @@ -1676,12 +1696,15 @@ class GridOut(io.IOBase): """ return self - def close(self) -> None: # type: ignore[override] + def close(self) -> None: """Make GridOut more generically file-like.""" if self._chunk_iter: self._chunk_iter.close() self._chunk_iter = None - super().close() + if _IS_SYNC: + super().close() + else: + self.closed = True def write(self, value: Any) -> NoReturn: raise io.UnsupportedOperation("write") diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 8a85135c1..1ac8b6630 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F: if sys.version_info >= (3, 10): anext = builtins.anext + aiter = builtins.aiter else: async def anext(cls: Any) -> Any: """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext.""" return await cls.__anext__() + + def aiter(cls: Any) -> Any: + """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext.""" + return cls.__aiter__() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 8a46e7fec..9dd1a1c76 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -521,7 +521,7 @@ class Topology: if server: await server.pool.reset(interrupt_connections=interrupt_connections) - # Wake waiters in select_servers(). + # Wake anything waiting in select_servers(). self._condition.notify_all() async def on_change( diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index e6bbf5d51..064583dad 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F: if sys.version_info >= (3, 10): next = builtins.next + iter = builtins.iter else: def next(cls: Any) -> Any: """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next.""" return cls.__next__() + + def iter(cls: Any) -> Any: + """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next.""" + return cls.__iter__() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 9932d2cbd..414865154 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -521,7 +521,7 @@ class Topology: if server: server.pool.reset(interrupt_connections=interrupt_connections) - # Wake waiters in select_servers(). + # Wake anything waiting in select_servers(). self._condition.notify_all() def on_change( diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py new file mode 100644 index 000000000..7071fc76f --- /dev/null +++ b/test/asynchronous/test_grid_file.py @@ -0,0 +1,871 @@ +# +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the grid_file module.""" +from __future__ import annotations + +import datetime +import io +import sys +import zipfile +from io import BytesIO +from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context + +from pymongo.asynchronous.database import AsyncDatabase + +sys.path[0:0] = [""] + +from test import IntegrationTest, qcheck, unittest +from test.utils import EventListener, async_rs_or_single_client, rs_or_single_client + +from bson.objectid import ObjectId +from gridfs import GridFS +from gridfs.asynchronous.grid_file import ( + _SEEK_CUR, + _SEEK_END, + DEFAULT_CHUNK_SIZE, + AsyncGridFS, + AsyncGridIn, + AsyncGridOut, + AsyncGridOutCursor, +) +from gridfs.errors import NoFile +from pymongo import AsyncMongoClient +from pymongo.asynchronous.helpers import aiter, anext +from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError +from pymongo.message import _CursorAddress + +_IS_SYNC = False + + +class AsyncTestGridFileNoConnect(AsyncUnitTest): + """Test GridFile features on a client that does not connect.""" + + db: AsyncDatabase + + @classmethod + def setUpClass(cls): + cls.db = AsyncMongoClient(connect=False).pymongo_test + + def test_grid_in_custom_opts(self): + self.assertRaises(TypeError, AsyncGridIn, "foo") + + a = AsyncGridIn( + self.db.fs, + _id=5, + filename="my_file", + contentType="text/html", + chunkSize=1000, + aliases=["foo"], + metadata={"foo": 1, "bar": 2}, + bar=3, + baz="hello", + ) + + self.assertEqual(5, a._id) + self.assertEqual("my_file", a.filename) + self.assertEqual("my_file", a.name) + self.assertEqual("text/html", a.content_type) + self.assertEqual(1000, a.chunk_size) + self.assertEqual(["foo"], a.aliases) + self.assertEqual({"foo": 1, "bar": 2}, a.metadata) + self.assertEqual(3, a.bar) + self.assertEqual("hello", a.baz) + self.assertRaises(AttributeError, getattr, a, "mike") + + b = AsyncGridIn(self.db.fs, content_type="text/html", chunk_size=1000, baz=100) + self.assertEqual("text/html", b.content_type) + self.assertEqual(1000, b.chunk_size) + self.assertEqual(100, b.baz) + + +class AsyncTestGridFile(AsyncIntegrationTest): + async def asyncSetUp(self): + await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) + + async def test_basic(self): + f = AsyncGridIn(self.db.fs, filename="test") + await f.write(b"hello world") + await f.close() + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"hello world", await g.read()) + + # make sure it's still there... + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"hello world", await g.read()) + + f = AsyncGridIn(self.db.fs, filename="test") + await f.close() + self.assertEqual(2, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"", await g.read()) + + # test that reading 0 returns proper type + self.assertEqual(b"", await g.read(0)) + + async def test_md5(self): + f = AsyncGridIn(self.db.fs) + await f.write(b"hello world\n") + await f.close() + self.assertEqual(None, f.md5) + + async def test_alternate_collection(self): + await self.db.alt.files.delete_many({}) + await self.db.alt.chunks.delete_many({}) + + f = AsyncGridIn(self.db.alt) + await f.write(b"hello world") + await f.close() + + self.assertEqual(1, await self.db.alt.files.count_documents({})) + self.assertEqual(1, await self.db.alt.chunks.count_documents({})) + + g = AsyncGridOut(self.db.alt, f._id) + self.assertEqual(b"hello world", await g.read()) + + async def test_grid_in_default_opts(self): + self.assertRaises(TypeError, AsyncGridIn, "foo") + + a = AsyncGridIn(self.db.fs) + + self.assertTrue(isinstance(a._id, ObjectId)) + self.assertRaises(AttributeError, setattr, a, "_id", 5) + + self.assertEqual(None, a.filename) + self.assertEqual(None, a.name) + a.filename = "my_file" + self.assertEqual("my_file", a.filename) + self.assertEqual("my_file", a.name) + + self.assertEqual(None, a.content_type) + a.content_type = "text/html" + + self.assertEqual("text/html", a.content_type) + + self.assertRaises(AttributeError, getattr, a, "length") + self.assertRaises(AttributeError, setattr, a, "length", 5) + + self.assertEqual(255 * 1024, a.chunk_size) + self.assertRaises(AttributeError, setattr, a, "chunk_size", 5) + + self.assertRaises(AttributeError, getattr, a, "upload_date") + self.assertRaises(AttributeError, setattr, a, "upload_date", 5) + + self.assertRaises(AttributeError, getattr, a, "aliases") + a.aliases = ["foo"] + + self.assertEqual(["foo"], a.aliases) + + self.assertRaises(AttributeError, getattr, a, "metadata") + a.metadata = {"foo": 1} + + self.assertEqual({"foo": 1}, a.metadata) + + self.assertRaises(AttributeError, setattr, a, "md5", 5) + + await a.close() + + if _IS_SYNC: + a.forty_two = 42 + else: + self.assertRaises(AttributeError, setattr, a, "forty_two", 42) + await a.set("forty_two", 42) + + self.assertEqual(42, a.forty_two) + + self.assertTrue(isinstance(a._id, ObjectId)) + self.assertRaises(AttributeError, setattr, a, "_id", 5) + + self.assertEqual("my_file", a.filename) + self.assertEqual("my_file", a.name) + + self.assertEqual("text/html", a.content_type) + + self.assertEqual(0, a.length) + self.assertRaises(AttributeError, setattr, a, "length", 5) + + self.assertEqual(255 * 1024, a.chunk_size) + self.assertRaises(AttributeError, setattr, a, "chunk_size", 5) + + self.assertTrue(isinstance(a.upload_date, datetime.datetime)) + self.assertRaises(AttributeError, setattr, a, "upload_date", 5) + + self.assertEqual(["foo"], a.aliases) + + self.assertEqual({"foo": 1}, a.metadata) + + self.assertEqual(None, a.md5) + self.assertRaises(AttributeError, setattr, a, "md5", 5) + + # Make sure custom attributes that were set both before and after + # a.close() are reflected in b. PYTHON-411. + b = await AsyncGridFS(self.db).get_last_version(filename=a.filename) + self.assertEqual(a.metadata, b.metadata) + self.assertEqual(a.aliases, b.aliases) + self.assertEqual(a.forty_two, b.forty_two) + + async def test_grid_out_default_opts(self): + self.assertRaises(TypeError, AsyncGridOut, "foo") + + gout = AsyncGridOut(self.db.fs, 5) + with self.assertRaises(NoFile): + if not _IS_SYNC: + await gout.open() + gout.name + + a = AsyncGridIn(self.db.fs) + await a.close() + + b = AsyncGridOut(self.db.fs, a._id) + if not _IS_SYNC: + await b.open() + + self.assertEqual(a._id, b._id) + self.assertEqual(0, b.length) + self.assertEqual(None, b.content_type) + self.assertEqual(None, b.name) + self.assertEqual(None, b.filename) + self.assertEqual(255 * 1024, b.chunk_size) + self.assertTrue(isinstance(b.upload_date, datetime.datetime)) + self.assertEqual(None, b.aliases) + self.assertEqual(None, b.metadata) + self.assertEqual(None, b.md5) + + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: + self.assertRaises(AttributeError, setattr, b, attr, 5) + + async def test_grid_out_cursor_options(self): + self.assertRaises( + TypeError, AsyncGridOutCursor.__init__, self.db.fs, {}, projection={"filename": 1} + ) + + cursor = AsyncGridOutCursor(self.db.fs, {}) + cursor_clone = cursor.clone() + + cursor_dict = cursor.__dict__.copy() + cursor_dict.pop("_session") + cursor_clone_dict = cursor_clone.__dict__.copy() + cursor_clone_dict.pop("_session") + self.assertDictEqual(cursor_dict, cursor_clone_dict) + + self.assertRaises(NotImplementedError, cursor.add_option, 0) + self.assertRaises(NotImplementedError, cursor.remove_option, 0) + + async def test_grid_out_custom_opts(self): + one = AsyncGridIn( + self.db.fs, + _id=5, + filename="my_file", + contentType="text/html", + chunkSize=1000, + aliases=["foo"], + metadata={"foo": 1, "bar": 2}, + bar=3, + baz="hello", + ) + await one.write(b"hello world") + await one.close() + + two = AsyncGridOut(self.db.fs, 5) + + if not _IS_SYNC: + await two.open() + + self.assertEqual("my_file", two.name) + self.assertEqual("my_file", two.filename) + self.assertEqual(5, two._id) + self.assertEqual(11, two.length) + self.assertEqual("text/html", two.content_type) + self.assertEqual(1000, two.chunk_size) + self.assertTrue(isinstance(two.upload_date, datetime.datetime)) + self.assertEqual(["foo"], two.aliases) + self.assertEqual({"foo": 1, "bar": 2}, two.metadata) + self.assertEqual(3, two.bar) + self.assertEqual(None, two.md5) + + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: + self.assertRaises(AttributeError, setattr, two, attr, 5) + + async def test_grid_out_file_document(self): + one = AsyncGridIn(self.db.fs) + await one.write(b"foo bar") + await one.close() + + two = AsyncGridOut(self.db.fs, file_document=await self.db.fs.files.find_one()) + self.assertEqual(b"foo bar", await two.read()) + + three = AsyncGridOut(self.db.fs, 5, file_document=await self.db.fs.files.find_one()) + self.assertEqual(b"foo bar", await three.read()) + + four = AsyncGridOut(self.db.fs, file_document={}) + with self.assertRaises(NoFile): + if not _IS_SYNC: + await four.open() + four.name + + async def test_write_file_like(self): + one = AsyncGridIn(self.db.fs) + await one.write(b"hello world") + await one.close() + + two = AsyncGridOut(self.db.fs, one._id) + + three = AsyncGridIn(self.db.fs) + await three.write(two) + await three.close() + + four = AsyncGridOut(self.db.fs, three._id) + self.assertEqual(b"hello world", await four.read()) + + five = AsyncGridIn(self.db.fs, chunk_size=2) + await five.write(b"hello") + buffer = BytesIO(b" world") + await five.write(buffer) + await five.write(b" and mongodb") + await five.close() + self.assertEqual( + b"hello world and mongodb", await AsyncGridOut(self.db.fs, five._id).read() + ) + + async def test_write_lines(self): + a = AsyncGridIn(self.db.fs) + await a.writelines([b"hello ", b"world"]) + await a.close() + + self.assertEqual(b"hello world", await AsyncGridOut(self.db.fs, a._id).read()) + + async def test_close(self): + f = AsyncGridIn(self.db.fs) + await f.close() + with self.assertRaises(ValueError): + await f.write("test") + await f.close() + + async def test_closed(self): + f = AsyncGridIn(self.db.fs, chunkSize=5) + await f.write(b"Hello world.\nHow are you?") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + if not _IS_SYNC: + await g.open() + self.assertFalse(g.closed) + await g.read(1) + self.assertFalse(g.closed) + await g.read(100) + self.assertFalse(g.closed) + await g.close() + self.assertTrue(g.closed) + + async def test_multi_chunk_file(self): + random_string = b"a" * (DEFAULT_CHUNK_SIZE + 1000) + + f = AsyncGridIn(self.db.fs) + await f.write(random_string) + await f.close() + + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(2, await self.db.fs.chunks.count_documents({})) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(random_string, await g.read()) + + # TODO: https://jira.mongodb.org/browse/PYTHON-4708 + @async_client_context.require_sync + async def test_small_chunks(self): + self.files = 0 + self.chunks = 0 + + async def helper(data): + f = AsyncGridIn(self.db.fs, chunkSize=1) + await f.write(data) + await f.close() + + self.files += 1 + self.chunks += len(data) + + self.assertEqual(self.files, await self.db.fs.files.count_documents({})) + self.assertEqual(self.chunks, await self.db.fs.chunks.count_documents({})) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(data, await g.read()) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(data, await g.read(10) + await g.read(10)) + return True + + qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20))) + + async def test_seek(self): + f = AsyncGridIn(self.db.fs, chunkSize=3) + await f.write(b"hello world") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"hello world", await g.read()) + await g.seek(0) + self.assertEqual(b"hello world", await g.read()) + await g.seek(1) + self.assertEqual(b"ello world", await g.read()) + with self.assertRaises(IOError): + await g.seek(-1) + + await g.seek(-3, _SEEK_END) + self.assertEqual(b"rld", await g.read()) + await g.seek(0, _SEEK_END) + self.assertEqual(b"", await g.read()) + with self.assertRaises(IOError): + await g.seek(-100, _SEEK_END) + + await g.seek(3) + await g.seek(3, _SEEK_CUR) + self.assertEqual(b"world", await g.read()) + with self.assertRaises(IOError): + await g.seek(-100, _SEEK_CUR) + + async def test_tell(self): + f = AsyncGridIn(self.db.fs, chunkSize=3) + await f.write(b"hello world") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(0, g.tell()) + await g.read(0) + self.assertEqual(0, g.tell()) + await g.read(1) + self.assertEqual(1, g.tell()) + await g.read(2) + self.assertEqual(3, g.tell()) + await g.read() + self.assertEqual(g.length, g.tell()) + + async def test_multiple_reads(self): + f = AsyncGridIn(self.db.fs, chunkSize=3) + await f.write(b"hello world") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"he", await g.read(2)) + self.assertEqual(b"ll", await g.read(2)) + self.assertEqual(b"o ", await g.read(2)) + self.assertEqual(b"wo", await g.read(2)) + self.assertEqual(b"rl", await g.read(2)) + self.assertEqual(b"d", await g.read(2)) + self.assertEqual(b"", await g.read(2)) + + async def test_readline(self): + f = AsyncGridIn(self.db.fs, chunkSize=5) + await f.write( + b"""Hello world, +How are you? +Hope all is well. +Bye""" + ) + await f.close() + + # Try read(), then readline(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"H", await g.read(1)) + self.assertEqual(b"ello world,\n", await g.readline()) + self.assertEqual(b"How a", await g.readline(5)) + self.assertEqual(b"", await g.readline(0)) + self.assertEqual(b"re you?\n", await g.readline()) + self.assertEqual(b"Hope all is well.\n", await g.readline(1000)) + self.assertEqual(b"Bye", await g.readline()) + self.assertEqual(b"", await g.readline()) + + # Try readline() first, then read(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"He", await g.readline(2)) + self.assertEqual(b"l", await g.read(1)) + self.assertEqual(b"lo", await g.readline(2)) + self.assertEqual(b" world,\n", await g.readline()) + + # Only readline(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"H", await g.readline(1)) + self.assertEqual(b"e", await g.readline(1)) + self.assertEqual(b"llo world,\n", await g.readline()) + + async def test_readlines(self): + f = AsyncGridIn(self.db.fs, chunkSize=5) + await f.write( + b"""Hello world, +How are you? +Hope all is well. +Bye""" + ) + await f.close() + + # Try read(), then readlines(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"He", await g.read(2)) + self.assertEqual([b"llo world,\n", b"How are you?\n"], await g.readlines(11)) + self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines()) + self.assertEqual([], await g.readlines()) + + # Try readline(), then readlines(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"Hello world,\n", await g.readline()) + self.assertEqual([b"How are you?\n", b"Hope all is well.\n"], await g.readlines(13)) + self.assertEqual(b"Bye", await g.readline()) + self.assertEqual([], await g.readlines()) + + # Only readlines(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual( + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], + await g.readlines(), + ) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual( + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], + await g.readlines(0), + ) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual([b"Hello world,\n"], await g.readlines(1)) + self.assertEqual([b"How are you?\n"], await g.readlines(12)) + self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines(18)) + + # Try readlines() first, then read(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual([b"Hello world,\n"], await g.readlines(1)) + self.assertEqual(b"H", await g.read(1)) + self.assertEqual([b"ow are you?\n", b"Hope all is well.\n"], await g.readlines(29)) + self.assertEqual([b"Bye"], await g.readlines(1)) + + # Try readlines() first, then readline(). + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual([b"Hello world,\n"], await g.readlines(1)) + self.assertEqual(b"How are you?\n", await g.readline()) + self.assertEqual([b"Hope all is well.\n"], await g.readlines(17)) + self.assertEqual(b"Bye", await g.readline()) + + async def test_iterator(self): + f = AsyncGridIn(self.db.fs) + await f.close() + g = AsyncGridOut(self.db.fs, f._id) + if _IS_SYNC: + self.assertEqual([], list(g)) + else: + self.assertEqual([], await g.to_list()) + + f = AsyncGridIn(self.db.fs) + await f.write(b"hello world\nhere are\nsome lines.") + await f.close() + g = AsyncGridOut(self.db.fs, f._id) + if _IS_SYNC: + self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g)) + else: + self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], await g.to_list()) + + self.assertEqual(b"", await g.read(5)) + if _IS_SYNC: + self.assertEqual([], list(g)) + else: + self.assertEqual([], await g.to_list()) + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"hello world\n", await anext(aiter(g))) + self.assertEqual(b"here", await g.read(4)) + self.assertEqual(b" are\n", await anext(aiter(g))) + self.assertEqual(b"some lines", await g.read(10)) + self.assertEqual(b".", await anext(aiter(g))) + with self.assertRaises(StopAsyncIteration): + await aiter(g).__anext__() + + f = AsyncGridIn(self.db.fs, chunk_size=2) + await f.write(b"hello world") + await f.close() + g = AsyncGridOut(self.db.fs, f._id) + if _IS_SYNC: + self.assertEqual([b"hello world"], list(g)) + else: + self.assertEqual([b"hello world"], await g.to_list()) + + async def test_read_unaligned_buffer_size(self): + in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk." + f = AsyncGridIn(self.db.fs, chunkSize=16) + await f.write(in_data) + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + out_data = b"" + while 1: + s = await g.read(13) + if not s: + break + out_data += s + + self.assertEqual(in_data, out_data) + + async def test_readchunk(self): + in_data = b"a" * 10 + f = AsyncGridIn(self.db.fs, chunkSize=3) + await f.write(in_data) + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(3, len(await g.readchunk())) + + self.assertEqual(2, len(await g.read(2))) + self.assertEqual(1, len(await g.readchunk())) + + self.assertEqual(3, len(await g.read(3))) + + self.assertEqual(1, len(await g.readchunk())) + + self.assertEqual(0, len(await g.readchunk())) + + async def test_write_unicode(self): + f = AsyncGridIn(self.db.fs) + with self.assertRaises(TypeError): + await f.write("foo") + + f = AsyncGridIn(self.db.fs, encoding="utf-8") + await f.write("foo") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual(b"foo", await g.read()) + + f = AsyncGridIn(self.db.fs, encoding="iso-8859-1") + await f.write("aé") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + self.assertEqual("aé".encode("iso-8859-1"), await g.read()) + + async def test_set_after_close(self): + f = AsyncGridIn(self.db.fs, _id="foo", bar="baz") + + self.assertEqual("foo", f._id) + self.assertEqual("baz", f.bar) + self.assertRaises(AttributeError, getattr, f, "baz") + self.assertRaises(AttributeError, getattr, f, "uploadDate") + + self.assertRaises(AttributeError, setattr, f, "_id", 5) + if _IS_SYNC: + f.bar = "foo" + f.baz = 5 + else: + await f.set("bar", "foo") + await f.set("baz", 5) + + self.assertEqual("foo", f._id) + self.assertEqual("foo", f.bar) + self.assertEqual(5, f.baz) + self.assertRaises(AttributeError, getattr, f, "uploadDate") + + await f.close() + + self.assertEqual("foo", f._id) + self.assertEqual("foo", f.bar) + self.assertEqual(5, f.baz) + self.assertTrue(f.uploadDate) + + self.assertRaises(AttributeError, setattr, f, "_id", 5) + if _IS_SYNC: + f.bar = "a" + f.baz = "b" + else: + await f.set("bar", "a") + await f.set("baz", "b") + self.assertRaises(AttributeError, setattr, f, "upload_date", 5) + + g = AsyncGridOut(self.db.fs, f._id) + if not _IS_SYNC: + await g.open() + self.assertEqual("a", g.bar) + self.assertEqual("b", g.baz) + # Versions 2.0.1 and older saved a _closed field for some reason. + self.assertRaises(AttributeError, getattr, g, "_closed") + + async def test_context_manager(self): + contents = b"Imagine this is some important data..." + + async with AsyncGridIn(self.db.fs, filename="important") as infile: + await infile.write(contents) + + async with AsyncGridOut(self.db.fs, infile._id) as outfile: + self.assertEqual(contents, await outfile.read()) + + async def test_exception_file_non_existence(self): + contents = b"Imagine this is some important data..." + + with self.assertRaises(ConnectionError): + async with AsyncGridIn(self.db.fs, filename="important") as infile: + await infile.write(contents) + raise ConnectionError("Test exception") + + # Expectation: File chunks are written, entry in files doesn't appear. + self.assertEqual( + await self.db.fs.chunks.count_documents({"files_id": infile._id}), infile._chunk_number + ) + + self.assertIsNone(await self.db.fs.files.find_one({"_id": infile._id})) + self.assertTrue(infile.closed) + + async def test_prechunked_string(self): + async def write_me(s, chunk_size): + buf = BytesIO(s) + infile = AsyncGridIn(self.db.fs) + while True: + to_write = buf.read(chunk_size) + if to_write == b"": + break + await infile.write(to_write) + await infile.close() + buf.close() + + outfile = AsyncGridOut(self.db.fs, infile._id) + data = await outfile.read() + self.assertEqual(s, data) + + s = b"x" * DEFAULT_CHUNK_SIZE * 4 + # Test with default chunk size + await write_me(s, DEFAULT_CHUNK_SIZE) + # Multiple + await write_me(s, DEFAULT_CHUNK_SIZE * 3) + # Custom + await write_me(s, 262300) + + async def test_grid_out_lazy_connect(self): + fs = self.db.fs + outfile = AsyncGridOut(fs, file_id=-1) + with self.assertRaises(NoFile): + await outfile.read() + with self.assertRaises(NoFile): + if not _IS_SYNC: + await outfile.open() + outfile.filename + + infile = AsyncGridIn(fs, filename=1) + await infile.close() + + outfile = AsyncGridOut(fs, infile._id) + await outfile.read() + outfile.filename + + outfile = AsyncGridOut(fs, infile._id) + await outfile.readchunk() + + async def test_grid_in_lazy_connect(self): + client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + fs = client.db.fs + infile = AsyncGridIn(fs, file_id=-1, chunk_size=1) + with self.assertRaises(ServerSelectionTimeoutError): + await infile.write(b"data") + with self.assertRaises(ServerSelectionTimeoutError): + await infile.close() + + async def test_unacknowledged(self): + # w=0 is prohibited. + with self.assertRaises(ConfigurationError): + AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs) + + async def test_survive_cursor_not_found(self): + # By default the find command returns 101 documents in the first batch. + # Use 102 batches to cause a single getMore. + chunk_size = 1024 + data = b"d" * (102 * chunk_size) + listener = EventListener() + client = await async_rs_or_single_client(event_listeners=[listener]) + db = client.pymongo_test + async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile: + await infile.write(data) + + async with AsyncGridOut(db.fs, infile._id) as outfile: + self.assertEqual(len(await outfile.readchunk()), chunk_size) + + # Kill the cursor to simulate the cursor timing out on the server + # when an application spends a long time between two calls to + # readchunk(). + assert await client.address is not None + await client._close_cursor_now( + outfile._chunk_iter._cursor.cursor_id, + _CursorAddress(await client.address, db.fs.chunks.full_name), # type: ignore[arg-type] + ) + + # Read the rest of the file without error. + self.assertEqual(len(await outfile.read()), len(data) - chunk_size) + + # Paranoid, ensure that a getMore was actually sent. + self.assertIn("getMore", listener.started_command_names()) + + @async_client_context.require_sync + async def test_zip(self): + zf = BytesIO() + z = zipfile.ZipFile(zf, "w") + z.writestr("test.txt", b"hello world") + z.close() + zf.seek(0) + + f = AsyncGridIn(self.db.fs, filename="test.zip") + await f.write(zf) + await f.close() + self.assertEqual(1, await self.db.fs.files.count_documents({})) + self.assertEqual(1, await self.db.fs.chunks.count_documents({})) + + g = AsyncGridOut(self.db.fs, f._id) + z = zipfile.ZipFile(g) + self.assertSequenceEqual(z.namelist(), ["test.txt"]) + self.assertEqual(z.read("test.txt"), b"hello world") + + async def test_grid_out_unsupported_operations(self): + f = AsyncGridIn(self.db.fs, chunkSize=3) + await f.write(b"hello world") + await f.close() + + g = AsyncGridOut(self.db.fs, f._id) + + self.assertRaises(io.UnsupportedOperation, g.writelines, [b"some", b"lines"]) + self.assertRaises(io.UnsupportedOperation, g.write, b"some text") + self.assertRaises(io.UnsupportedOperation, g.fileno) + self.assertRaises(io.UnsupportedOperation, g.truncate) + + self.assertFalse(g.writable()) + self.assertFalse(g.isatty()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_grid_file.py b/test/test_grid_file.py index f663f1365..0e806eb5c 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -21,6 +21,7 @@ import io import sys import zipfile from io import BytesIO +from test import IntegrationTest, UnitTest, client_context from pymongo.synchronous.database import Database @@ -36,16 +37,20 @@ from gridfs.synchronous.grid_file import ( _SEEK_CUR, _SEEK_END, DEFAULT_CHUNK_SIZE, + GridFS, GridIn, GridOut, GridOutCursor, ) from pymongo import MongoClient -from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError +from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError from pymongo.message import _CursorAddress +from pymongo.synchronous.helpers import iter, next + +_IS_SYNC = True -class TestGridFileNoConnect(unittest.TestCase): +class TestGridFileNoConnect(UnitTest): """Test GridFile features on a client that does not connect.""" db: Database @@ -151,6 +156,7 @@ class TestGridFile(IntegrationTest): self.assertEqual(None, a.content_type) a.content_type = "text/html" + self.assertEqual("text/html", a.content_type) self.assertRaises(AttributeError, getattr, a, "length") @@ -164,17 +170,24 @@ class TestGridFile(IntegrationTest): self.assertRaises(AttributeError, getattr, a, "aliases") a.aliases = ["foo"] + self.assertEqual(["foo"], a.aliases) self.assertRaises(AttributeError, getattr, a, "metadata") a.metadata = {"foo": 1} + self.assertEqual({"foo": 1}, a.metadata) self.assertRaises(AttributeError, setattr, a, "md5", 5) a.close() - a.forty_two = 42 + if _IS_SYNC: + a.forty_two = 42 + else: + self.assertRaises(AttributeError, setattr, a, "forty_two", 42) + a.set("forty_two", 42) + self.assertEqual(42, a.forty_two) self.assertTrue(isinstance(a._id, ObjectId)) @@ -213,12 +226,16 @@ class TestGridFile(IntegrationTest): gout = GridOut(self.db.fs, 5) with self.assertRaises(NoFile): + if not _IS_SYNC: + gout.open() gout.name a = GridIn(self.db.fs) a.close() b = GridOut(self.db.fs, a._id) + if not _IS_SYNC: + b.open() self.assertEqual(a._id, b._id) self.assertEqual(0, b.length) @@ -278,6 +295,9 @@ class TestGridFile(IntegrationTest): two = GridOut(self.db.fs, 5) + if not _IS_SYNC: + two.open() + self.assertEqual("my_file", two.name) self.assertEqual("my_file", two.filename) self.assertEqual(5, two._id) @@ -316,6 +336,8 @@ class TestGridFile(IntegrationTest): four = GridOut(self.db.fs, file_document={}) with self.assertRaises(NoFile): + if not _IS_SYNC: + four.open() four.name def test_write_file_like(self): @@ -350,7 +372,8 @@ class TestGridFile(IntegrationTest): def test_close(self): f = GridIn(self.db.fs) f.close() - self.assertRaises(ValueError, f.write, "test") + with self.assertRaises(ValueError): + f.write("test") f.close() def test_closed(self): @@ -359,6 +382,8 @@ class TestGridFile(IntegrationTest): f.close() g = GridOut(self.db.fs, f._id) + if not _IS_SYNC: + g.open() self.assertFalse(g.closed) g.read(1) self.assertFalse(g.closed) @@ -380,6 +405,8 @@ class TestGridFile(IntegrationTest): g = GridOut(self.db.fs, f._id) self.assertEqual(random_string, g.read()) + # TODO: https://jira.mongodb.org/browse/PYTHON-4708 + @client_context.require_sync def test_small_chunks(self): self.files = 0 self.chunks = 0 @@ -415,18 +442,21 @@ class TestGridFile(IntegrationTest): self.assertEqual(b"hello world", g.read()) g.seek(1) self.assertEqual(b"ello world", g.read()) - self.assertRaises(IOError, g.seek, -1) + with self.assertRaises(IOError): + g.seek(-1) g.seek(-3, _SEEK_END) self.assertEqual(b"rld", g.read()) g.seek(0, _SEEK_END) self.assertEqual(b"", g.read()) - self.assertRaises(IOError, g.seek, -100, _SEEK_END) + with self.assertRaises(IOError): + g.seek(-100, _SEEK_END) g.seek(3) g.seek(3, _SEEK_CUR) self.assertEqual(b"world", g.read()) - self.assertRaises(IOError, g.seek, -100, _SEEK_CUR) + with self.assertRaises(IOError): + g.seek(-100, _SEEK_CUR) def test_tell(self): f = GridIn(self.db.fs, chunkSize=3) @@ -519,12 +549,14 @@ Bye""" # Only readlines(). g = GridOut(self.db.fs, f._id) self.assertEqual( - [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines() + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], + g.readlines(), ) g = GridOut(self.db.fs, f._id) self.assertEqual( - [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines(0) + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], + g.readlines(0), ) g = GridOut(self.db.fs, f._id) @@ -550,15 +582,25 @@ Bye""" f = GridIn(self.db.fs) f.close() g = GridOut(self.db.fs, f._id) - self.assertEqual([], list(g)) + if _IS_SYNC: + self.assertEqual([], list(g)) + else: + self.assertEqual([], g.to_list()) f = GridIn(self.db.fs) f.write(b"hello world\nhere are\nsome lines.") f.close() g = GridOut(self.db.fs, f._id) - self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g)) + if _IS_SYNC: + self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g)) + else: + self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], g.to_list()) + self.assertEqual(b"", g.read(5)) - self.assertEqual([], list(g)) + if _IS_SYNC: + self.assertEqual([], list(g)) + else: + self.assertEqual([], g.to_list()) g = GridOut(self.db.fs, f._id) self.assertEqual(b"hello world\n", next(iter(g))) @@ -566,13 +608,17 @@ Bye""" self.assertEqual(b" are\n", next(iter(g))) self.assertEqual(b"some lines", g.read(10)) self.assertEqual(b".", next(iter(g))) - self.assertRaises(StopIteration, iter(g).__next__) + with self.assertRaises(StopIteration): + iter(g).__next__() f = GridIn(self.db.fs, chunk_size=2) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) - self.assertEqual([b"hello world"], list(g)) + if _IS_SYNC: + self.assertEqual([b"hello world"], list(g)) + else: + self.assertEqual([b"hello world"], g.to_list()) def test_read_unaligned_buffer_size(self): in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk." @@ -610,7 +656,8 @@ Bye""" def test_write_unicode(self): f = GridIn(self.db.fs) - self.assertRaises(TypeError, f.write, "foo") + with self.assertRaises(TypeError): + f.write("foo") f = GridIn(self.db.fs, encoding="utf-8") f.write("foo") @@ -635,8 +682,12 @@ Bye""" self.assertRaises(AttributeError, getattr, f, "uploadDate") self.assertRaises(AttributeError, setattr, f, "_id", 5) - f.bar = "foo" - f.baz = 5 + if _IS_SYNC: + f.bar = "foo" + f.baz = 5 + else: + f.set("bar", "foo") + f.set("baz", 5) self.assertEqual("foo", f._id) self.assertEqual("foo", f.bar) @@ -651,11 +702,17 @@ Bye""" self.assertTrue(f.uploadDate) self.assertRaises(AttributeError, setattr, f, "_id", 5) - f.bar = "a" - f.baz = "b" + if _IS_SYNC: + f.bar = "a" + f.baz = "b" + else: + f.set("bar", "a") + f.set("baz", "b") self.assertRaises(AttributeError, setattr, f, "upload_date", 5) g = GridOut(self.db.fs, f._id) + if not _IS_SYNC: + g.open() self.assertEqual("a", g.bar) self.assertEqual("b", g.baz) # Versions 2.0.1 and older saved a _closed field for some reason. @@ -713,8 +770,12 @@ Bye""" def test_grid_out_lazy_connect(self): fs = self.db.fs outfile = GridOut(fs, file_id=-1) - self.assertRaises(NoFile, outfile.read) - self.assertRaises(NoFile, getattr, outfile, "filename") + with self.assertRaises(NoFile): + outfile.read() + with self.assertRaises(NoFile): + if not _IS_SYNC: + outfile.open() + outfile.filename infile = GridIn(fs, filename=1) infile.close() @@ -730,13 +791,15 @@ Bye""" client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) - self.assertRaises(ServerSelectionTimeoutError, infile.write, b"data") - self.assertRaises(ServerSelectionTimeoutError, infile.close) + with self.assertRaises(ServerSelectionTimeoutError): + infile.write(b"data") + with self.assertRaises(ServerSelectionTimeoutError): + infile.close() def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - GridIn(rs_or_single_client(w=0).pymongo_test.fs) + GridIn((rs_or_single_client(w=0)).pymongo_test.fs) def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -758,7 +821,7 @@ Bye""" assert client.address is not None client._close_cursor_now( outfile._chunk_iter._cursor.cursor_id, - _CursorAddress(client.address, db.fs.chunks.full_name), + _CursorAddress(client.address, db.fs.chunks.full_name), # type: ignore[arg-type] ) # Read the rest of the file without error. @@ -767,6 +830,7 @@ Bye""" # Paranoid, ensure that a getMore was actually sent. self.assertIn("getMore", listener.started_command_names()) + @client_context.require_sync def test_zip(self): zf = BytesIO() z = zipfile.ZipFile(zf, "w") diff --git a/tools/synchro.py b/tools/synchro.py index adc0de297..b8fc9f33c 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -47,6 +47,7 @@ replacements = { "asynchronous": "synchronous", "Asynchronous": "Synchronous", "anext": "next", + "aiter": "iter", "_ALock": "_Lock", "_ACondition": "_Condition", "AsyncGridFS": "GridFS", @@ -98,6 +99,8 @@ replacements = { "default_async": "default", "aclose": "close", "PyMongo|async": "PyMongo", + "AsyncTestGridFile": "TestGridFile", + "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", } docstring_replacements: dict[tuple[str, str], str] = { @@ -160,6 +163,7 @@ converted_tests = [ "test_cursor.py", "test_database.py", "test_encryption.py", + "test_grid_file.py", "test_logger.py", "test_session.py", "test_transactions.py", From b37fb918964222625428ea66ba8154ace65759c4 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:36:35 -0700 Subject: [PATCH 6/6] PYTHON-4704 Migrate test_bulk.py to async (#1827) --- test/__init__.py | 5 +- test/asynchronous/__init__.py | 5 +- test/asynchronous/test_bulk.py | 1134 ++++++++++++++++++++++++++++++++ test/test_bulk.py | 83 ++- tools/synchro.py | 5 + 5 files changed, 1207 insertions(+), 25 deletions(-) create mode 100644 test/asynchronous/test_bulk.py diff --git a/test/__init__.py b/test/__init__.py index d978d7da3..41af81f97 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -569,7 +569,10 @@ class ClientContext: def sec_count(): return 0 if not self.client else len(self.client.secondaries) - return self._require(lambda: sec_count() >= count, "Not enough secondaries available") + def check(): + return sec_count() >= count + + return self._require(check, "Not enough secondaries available") @property def supports_secondary_read_pref(self): diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index def4bc1b8..d1af89c18 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -571,7 +571,10 @@ class AsyncClientContext: async def sec_count(): return 0 if not self.client else len(await self.client.secondaries) - return self._require(lambda: sec_count() >= count, "Not enough secondaries available") + async def check(): + return await sec_count() >= count + + return self._require(check, "Not enough secondaries available") @property async def supports_secondary_read_pref(self): diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py new file mode 100644 index 000000000..24111ad7c --- /dev/null +++ b/test/asynchronous/test_bulk.py @@ -0,0 +1,1134 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the bulk API.""" +from __future__ import annotations + +import sys +import uuid +from typing import Any, Optional + +from pymongo.asynchronous.mongo_client import AsyncMongoClient + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest +from test.utils import ( + async_rs_or_single_client_noauth, + async_wait_until, + single_client, +) + +from bson.binary import Binary, UuidRepresentation +from bson.codec_options import CodecOptions +from bson.objectid import ObjectId +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.common import partition_node +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + InvalidOperation, + OperationFailure, +) +from pymongo.operations import * +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class AsyncBulkTestBase(AsyncIntegrationTest): + coll: AsyncCollection + coll_w0: AsyncCollection + + @classmethod + async def _setup_class(cls): + await super()._setup_class() + cls.coll = cls.db.test + cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) + + async def asyncSetUp(self): + super().setUp() + await self.coll.drop() + + def assertEqualResponse(self, expected, actual): + """Compare response from bulk.execute() to expected response.""" + for key, value in expected.items(): + if key == "nModified": + self.assertEqual(value, actual["nModified"]) + elif key == "upserted": + expected_upserts = value + actual_upserts = actual["upserted"] + self.assertEqual( + len(expected_upserts), + len(actual_upserts), + 'Expected %d elements in "upserted", got %d' + % (len(expected_upserts), len(actual_upserts)), + ) + + for e, a in zip(expected_upserts, actual_upserts): + self.assertEqualUpsert(e, a) + + elif key == "writeErrors": + expected_errors = value + actual_errors = actual["writeErrors"] + self.assertEqual( + len(expected_errors), + len(actual_errors), + 'Expected %d elements in "writeErrors", got %d' + % (len(expected_errors), len(actual_errors)), + ) + + for e, a in zip(expected_errors, actual_errors): + self.assertEqualWriteError(e, a) + + else: + self.assertEqual( + actual.get(key), + value, + f"{key!r} value of {actual.get(key)!r} does not match expected {value!r}", + ) + + def assertEqualUpsert(self, expected, actual): + """Compare bulk.execute()['upserts'] to expected value. + + Like: {'index': 0, '_id': ObjectId()} + """ + self.assertEqual(expected["index"], actual["index"]) + if expected["_id"] == "...": + # Unspecified value. + self.assertTrue("_id" in actual) + else: + self.assertEqual(expected["_id"], actual["_id"]) + + def assertEqualWriteError(self, expected, actual): + """Compare bulk.execute()['writeErrors'] to expected value. + + Like: {'index': 0, 'code': 123, 'errmsg': '...', 'op': { ... }} + """ + self.assertEqual(expected["index"], actual["index"]) + self.assertEqual(expected["code"], actual["code"]) + if expected["errmsg"] == "...": + # Unspecified value. + self.assertTrue("errmsg" in actual) + else: + self.assertEqual(expected["errmsg"], actual["errmsg"]) + + expected_op = expected["op"].copy() + actual_op = actual["op"].copy() + if expected_op.get("_id") == "...": + # Unspecified _id. + self.assertTrue("_id" in actual_op) + actual_op.pop("_id") + expected_op.pop("_id") + + self.assertEqual(expected_op, actual_op) + + +class AsyncTestBulk(AsyncBulkTestBase): + async def test_empty(self): + with self.assertRaises(InvalidOperation): + await self.coll.bulk_write([]) + + async def test_insert(self): + expected = { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + + result = await self.coll.bulk_write([InsertOne({})]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(1, result.inserted_count) + self.assertEqual(1, await self.coll.count_documents({})) + + async def _test_update_many(self, update): + expected = { + "nMatched": 2, + "nModified": 2, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + await self.coll.insert_many([{}, {}]) + + result = await self.coll.bulk_write([UpdateMany({}, update)]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(2, result.matched_count) + self.assertTrue(result.modified_count in (2, None)) + + async def test_update_many(self): + await self._test_update_many({"$set": {"foo": "bar"}}) + + @async_client_context.require_version_min(4, 1, 11) + async def test_update_many_pipeline(self): + await self._test_update_many([{"$set": {"foo": "bar"}}]) + + async def test_array_filters_validation(self): + with self.assertRaises(TypeError): + await UpdateMany({}, {}, array_filters={}) # type: ignore[arg-type] + with self.assertRaises(TypeError): + await UpdateOne({}, {}, array_filters={}) # type: ignore[arg-type] + + async def test_array_filters_unacknowledged(self): + coll = self.coll_w0 + update_one = UpdateOne({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) + update_many = UpdateMany({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) + with self.assertRaises(ConfigurationError): + await coll.bulk_write([update_one]) + with self.assertRaises(ConfigurationError): + await coll.bulk_write([update_many]) + + async def _test_update_one(self, update): + expected = { + "nMatched": 1, + "nModified": 1, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + + await self.coll.insert_many([{}, {}]) + + result = await self.coll.bulk_write([UpdateOne({}, update)]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (1, None)) + + async def test_update_one(self): + await self._test_update_one({"$set": {"foo": "bar"}}) + + @async_client_context.require_version_min(4, 1, 11) + async def test_update_one_pipeline(self): + await self._test_update_one([{"$set": {"foo": "bar"}}]) + + async def test_replace_one(self): + expected = { + "nMatched": 1, + "nModified": 1, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + + await self.coll.insert_many([{}, {}]) + + result = await self.coll.bulk_write([ReplaceOne({}, {"foo": "bar"})]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(1, result.matched_count) + self.assertTrue(result.modified_count in (1, None)) + + async def test_remove(self): + # Test removing all documents, ordered. + expected = { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 2, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + await self.coll.insert_many([{}, {}]) + + result = await self.coll.bulk_write([DeleteMany({})]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(2, result.deleted_count) + + async def test_remove_one(self): + # Test removing one document, empty selector. + await self.coll.insert_many([{}, {}]) + expected = { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 1, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], + } + + result = await self.coll.bulk_write([DeleteOne({})]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(1, result.deleted_count) + self.assertEqual(await self.coll.count_documents({}), 1) + + async def test_upsert(self): + expected = { + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 0, + "nRemoved": 0, + "upserted": [{"index": 0, "_id": "..."}], + } + + result = await self.coll.bulk_write([ReplaceOne({}, {"foo": "bar"}, upsert=True)]) + self.assertEqualResponse(expected, result.bulk_api_result) + self.assertEqual(1, result.upserted_count) + assert result.upserted_ids is not None + self.assertEqual(1, len(result.upserted_ids)) + self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId)) + + self.assertEqual(await self.coll.count_documents({"foo": "bar"}), 1) + + async def test_numerous_inserts(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = await async_client_context.max_write_batch_size + 100 + requests = [InsertOne[dict]({}) for _ in range(n_docs)] + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + # Same with ordered bulk. + await self.coll.drop() + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + async def test_bulk_max_message_size(self): + await self.coll.delete_many({}) + self.addCleanup(self.coll.delete_many, {}) + _16_MB = 16 * 1000 * 1000 + # Generate a list of documents such that the first batched OP_MSG is + # as close as possible to the 48MB limit. + docs = [ + {"_id": 1, "l": "s" * _16_MB}, + {"_id": 2, "l": "s" * _16_MB}, + {"_id": 3, "l": "s" * (_16_MB - 10000)}, + ] + # Fill in the remaining ~10000 bytes with small documents. + for i in range(4, 10000): + docs.append({"_id": i}) + result = await self.coll.insert_many(docs) + self.assertEqual(len(docs), len(result.inserted_ids)) + + async def test_generator_insert(self): + def gen(): + yield {"a": 1, "b": 1} + yield {"a": 1, "b": 2} + yield {"a": 2, "b": 3} + yield {"a": 3, "b": 5} + yield {"a": 5, "b": 8} + + result = await self.coll.insert_many(gen()) + self.assertEqual(5, len(result.inserted_ids)) + + async def test_bulk_write_no_results(self): + result = await self.coll_w0.bulk_write([InsertOne({})]) + self.assertFalse(result.acknowledged) + self.assertRaises(InvalidOperation, lambda: result.inserted_count) + self.assertRaises(InvalidOperation, lambda: result.matched_count) + self.assertRaises(InvalidOperation, lambda: result.modified_count) + self.assertRaises(InvalidOperation, lambda: result.deleted_count) + self.assertRaises(InvalidOperation, lambda: result.upserted_count) + self.assertRaises(InvalidOperation, lambda: result.upserted_ids) + + async def test_bulk_write_invalid_arguments(self): + # The requests argument must be a list. + generator = (InsertOne[dict]({}) for _ in range(10)) + with self.assertRaises(TypeError): + await self.coll.bulk_write(generator) # type: ignore[arg-type] + + # Document is not wrapped in a bulk write operation. + with self.assertRaises(TypeError): + await self.coll.bulk_write([{}]) # type: ignore[list-item] + + async def test_upsert_large(self): + big = "a" * (await async_client_context.max_bson_size - 37) + result = await self.coll.bulk_write( + [UpdateOne({"x": 1}, {"$set": {"s": big}}, upsert=True)] + ) + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 0, + "nRemoved": 0, + "upserted": [{"index": 0, "_id": "..."}], + }, + result.bulk_api_result, + ) + + self.assertEqual(1, await self.coll.count_documents({"x": 1})) + + async def test_client_generated_upsert_id(self): + result = await self.coll.bulk_write( + [ + UpdateOne({"_id": 0}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": 1}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": 2}, {"_id": 2}, upsert=True), + ] + ) + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": 0}, + {"index": 1, "_id": 1}, + {"index": 2, "_id": 2}, + ], + }, + result.bulk_api_result, + ) + + async def test_upsert_uuid_standard(self): + options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + coll = self.coll.with_options(codec_options=options) + uuids = [uuid.uuid4() for _ in range(3)] + result = await coll.bulk_write( + [ + UpdateOne({"_id": uuids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": uuids[2]}, {"_id": uuids[2]}, upsert=True), + ] + ) + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": uuids[0]}, + {"index": 1, "_id": uuids[1]}, + {"index": 2, "_id": uuids[2]}, + ], + }, + result.bulk_api_result, + ) + + async def test_upsert_uuid_unspecified(self): + options = CodecOptions(uuid_representation=UuidRepresentation.UNSPECIFIED) + coll = self.coll.with_options(codec_options=options) + uuids = [Binary.from_uuid(uuid.uuid4()) for _ in range(3)] + result = await coll.bulk_write( + [ + UpdateOne({"_id": uuids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": uuids[2]}, {"_id": uuids[2]}, upsert=True), + ] + ) + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": uuids[0]}, + {"index": 1, "_id": uuids[1]}, + {"index": 2, "_id": uuids[2]}, + ], + }, + result.bulk_api_result, + ) + + async def test_upsert_uuid_standard_subdocuments(self): + options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + coll = self.coll.with_options(codec_options=options) + ids: list = [{"f": Binary(bytes(i)), "f2": uuid.uuid4()} for i in range(3)] + + result = await coll.bulk_write( + [ + UpdateOne({"_id": ids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": ids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": ids[2]}, {"_id": ids[2]}, upsert=True), + ] + ) + + # The `Binary` values are returned as `bytes` objects. + for _id in ids: + _id["f"] = bytes(_id["f"]) + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": ids[0]}, + {"index": 1, "_id": ids[1]}, + {"index": 2, "_id": ids[2]}, + ], + }, + result.bulk_api_result, + ) + + async def test_single_ordered_batch(self): + result = await self.coll.bulk_write( + [ + InsertOne({"a": 1}), + UpdateOne({"a": 1}, {"$set": {"b": 1}}), + UpdateOne({"a": 2}, {"$set": {"b": 2}}, upsert=True), + InsertOne({"a": 3}), + DeleteOne({"a": 3}), + ] + ) + self.assertEqualResponse( + { + "nMatched": 1, + "nModified": 1, + "nUpserted": 1, + "nInserted": 2, + "nRemoved": 1, + "upserted": [{"index": 2, "_id": "..."}], + }, + result.bulk_api_result, + ) + + async def test_single_error_ordered_batch(self): + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + requests: list = [ + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 3, "a": 2}), + ] + try: + await self.coll.bulk_write(requests) + except BulkWriteError as exc: + result = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) + + async def test_multiple_error_ordered_batch(self): + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + requests: list = [ + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + UpdateOne({"b": 3}, {"$set": {"a": 2}}, upsert=True), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 4, "a": 3}), + InsertOne({"b": 5, "a": 1}), + ] + + try: + await self.coll.bulk_write(requests) + except BulkWriteError as exc: + result = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) + + async def test_single_unordered_batch(self): + requests: list = [ + InsertOne({"a": 1}), + UpdateOne({"a": 1}, {"$set": {"b": 1}}), + UpdateOne({"a": 2}, {"$set": {"b": 2}}, upsert=True), + InsertOne({"a": 3}), + DeleteOne({"a": 3}), + ] + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqualResponse( + { + "nMatched": 1, + "nModified": 1, + "nUpserted": 1, + "nInserted": 2, + "nRemoved": 1, + "upserted": [{"index": 2, "_id": "..."}], + "writeErrors": [], + "writeConcernErrors": [], + }, + result.bulk_api_result, + ) + + async def test_single_error_unordered_batch(self): + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + requests: list = [ + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 3, "a": 2}), + ] + + try: + await self.coll.bulk_write(requests, ordered=False) + except BulkWriteError as exc: + result = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 2, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) + + async def test_multiple_error_unordered_batch(self): + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + requests: list = [ + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True), + UpdateOne({"b": 3}, {"$set": {"a": 4}}, upsert=True), + UpdateOne({"b": 4}, {"$set": {"a": 3}}, upsert=True), + InsertOne({"b": 5, "a": 2}), + InsertOne({"b": 6, "a": 1}), + ] + + try: + await self.coll.bulk_write(requests, ordered=False) + except BulkWriteError as exc: + result = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + # Assume the update at index 1 runs before the update at index 3, + # although the spec does not require it. Same for inserts. + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 2, + "nInserted": 2, + "nRemoved": 0, + "upserted": [{"index": 1, "_id": "..."}, {"index": 2, "_id": "..."}], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 3, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 4}, + "u": {"$set": {"a": 3}}, + "multi": False, + "upsert": True, + }, + }, + { + "index": 5, + "code": 11000, + "errmsg": "...", + "op": {"_id": "...", "b": 6, "a": 1}, + }, + ], + }, + result, + ) + + async def test_large_inserts_ordered(self): + big = "x" * await async_client_context.max_bson_size + requests = [ + InsertOne({"b": 1, "a": 1}), + InsertOne({"big": big}), + InsertOne({"b": 2, "a": 2}), + ] + + try: + await self.coll.bulk_write(requests) + except BulkWriteError as exc: + result = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqual(1, result["nInserted"]) + + await self.coll.delete_many({}) + + big = "x" * (1024 * 1024 * 4) + write_result = await self.coll.bulk_write( + [ + InsertOne({"a": 1, "big": big}), + InsertOne({"a": 2, "big": big}), + InsertOne({"a": 3, "big": big}), + InsertOne({"a": 4, "big": big}), + InsertOne({"a": 5, "big": big}), + InsertOne({"a": 6, "big": big}), + ] + ) + + self.assertEqual(6, write_result.inserted_count) + self.assertEqual(6, await self.coll.count_documents({})) + + async def test_large_inserts_unordered(self): + big = "x" * await async_client_context.max_bson_size + requests = [ + InsertOne({"b": 1, "a": 1}), + InsertOne({"big": big}), + InsertOne({"b": 2, "a": 2}), + ] + + try: + await self.coll.bulk_write(requests, ordered=False) + except BulkWriteError as exc: + details = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqual(2, details["nInserted"]) + + await self.coll.delete_many({}) + + big = "x" * (1024 * 1024 * 4) + result = await self.coll.bulk_write( + [ + InsertOne({"a": 1, "big": big}), + InsertOne({"a": 2, "big": big}), + InsertOne({"a": 3, "big": big}), + InsertOne({"a": 4, "big": big}), + InsertOne({"a": 5, "big": big}), + InsertOne({"a": 6, "big": big}), + ], + ordered=False, + ) + + self.assertEqual(6, result.inserted_count) + self.assertEqual(6, await self.coll.count_documents({})) + + +class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): + @classmethod + @async_client_context.require_auth + @async_client_context.require_no_api_version + async def _setup_class(cls): + await super()._setup_class() + + async def asyncSetUp(self): + super().setUp() + await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) + await self.db.command( + "createRole", + "noremove", + privileges=[ + { + "actions": ["insert", "update", "find"], + "resource": {"db": "pymongo_test", "collection": "test"}, + } + ], + roles=[], + ) + + async_client_context.create_user(self.db.name, "noremove", "pw", ["noremove"]) + + async def asyncTearDown(self): + await self.db.command("dropRole", "noremove") + await remove_all_users(self.db) + + +class AsyncTestBulkUnacknowledged(AsyncBulkTestBase): + async def asyncTearDown(self): + await self.coll.delete_many({}) + + async def test_no_results_ordered_success(self): + requests: list = [ + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"a": 2}), + DeleteOne({"a": 1}), + ] + result = await self.coll_w0.bulk_write(requests) + self.assertFalse(result.acknowledged) + + async def predicate(): + return await self.coll.count_documents({}) == 2 + + await async_wait_until(predicate, "insert 2 documents") + + async def predicate(): + return await self.coll.find_one({"_id": 1}) is None + + await async_wait_until(predicate, 'removed {"_id": 1}') + + async def test_no_results_ordered_failure(self): + requests: list = [ + InsertOne({"_id": 1}), + UpdateOne({"_id": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"_id": 2}), + # Fails with duplicate key error. + InsertOne({"_id": 1}), + # Should not be executed since the batch is ordered. + DeleteOne({"_id": 1}), + ] + result = await self.coll_w0.bulk_write(requests) + self.assertFalse(result.acknowledged) + + async def predicate(): + return await self.coll.count_documents({}) == 3 + + await async_wait_until(predicate, "insert 3 documents") + self.assertEqual({"_id": 1}, await self.coll.find_one({"_id": 1})) + + async def test_no_results_unordered_success(self): + requests: list = [ + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"a": 2}), + DeleteOne({"a": 1}), + ] + result = await self.coll_w0.bulk_write(requests, ordered=False) + self.assertFalse(result.acknowledged) + + async def predicate(): + return await self.coll.count_documents({}) == 2 + + await async_wait_until(predicate, "insert 2 documents") + + async def predicate(): + return await self.coll.find_one({"_id": 1}) is None + + await async_wait_until(predicate, 'removed {"_id": 1}') + + async def test_no_results_unordered_failure(self): + requests: list = [ + InsertOne({"_id": 1}), + UpdateOne({"_id": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"_id": 2}), + # Fails with duplicate key error. + InsertOne({"_id": 1}), + # Should be executed since the batch is unordered. + DeleteOne({"_id": 1}), + ] + result = await self.coll_w0.bulk_write(requests, ordered=False) + self.assertFalse(result.acknowledged) + + async def predicate(): + return await self.coll.count_documents({}) == 2 + + await async_wait_until(predicate, "insert 2 documents") + + async def predicate(): + return await self.coll.find_one({"_id": 1}) is None + + await async_wait_until(predicate, 'removed {"_id": 1}') + + +class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase): + async def test_readonly(self): + # We test that an authorization failure aborts the batch and is raised + # as OperationFailure. + cli = await async_rs_or_single_client_noauth( + username="readonly", password="pw", authSource="pymongo_test" + ) + coll = cli.pymongo_test.test + coll.find_one() + with self.assertRaises(OperationFailure): + await coll.bulk_write([InsertOne({"x": 1})]) + + async def test_no_remove(self): + # We test that an authorization failure aborts the batch and is raised + # as OperationFailure. + cli = await async_rs_or_single_client_noauth( + username="noremove", password="pw", authSource="pymongo_test" + ) + coll = cli.pymongo_test.test + coll.find_one() + requests = [ + InsertOne({"x": 1}), + ReplaceOne({"x": 2}, {"x": 2}, upsert=True), + DeleteMany({}), # Prohibited. + InsertOne({"x": 3}), # Never attempted. + ] + with self.assertRaises(OperationFailure): + await coll.bulk_write(requests) # type: ignore[arg-type] + self.assertEqual({1, 2}, set(await self.coll.distinct("x"))) + + +class AsyncTestBulkWriteConcern(AsyncBulkTestBase): + w: Optional[int] + secondary: AsyncMongoClient + + @classmethod + async def _setup_class(cls): + await super()._setup_class() + cls.w = async_client_context.w + cls.secondary = None + if cls.w is not None and cls.w > 1: + for member in (await async_client_context.hello)["hosts"]: + if member != (await async_client_context.hello)["primary"]: + cls.secondary = single_client(*partition_node(member)) + break + + @classmethod + async def async_tearDownClass(cls): + if cls.secondary: + await cls.secondary.close() + + async def cause_wtimeout(self, requests, ordered): + if not async_client_context.test_commands_enabled: + self.skipTest("Test commands must be enabled.") + + # Use the rsSyncApplyStop failpoint to pause replication on a + # secondary which will cause a wtimeout error. + await self.secondary.admin.command("configureFailPoint", "rsSyncApplyStop", mode="alwaysOn") + + try: + coll = self.coll.with_options(write_concern=WriteConcern(w=self.w, wtimeout=1)) + return await coll.bulk_write(requests, ordered=ordered) + finally: + await self.secondary.admin.command("configureFailPoint", "rsSyncApplyStop", mode="off") + + @async_client_context.require_replica_set + @async_client_context.require_secondaries_count(1) + async def test_write_concern_failure_ordered(self): + # Ensure we don't raise on wnote. + coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) + result = await coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})]) + self.assertTrue(result.acknowledged) + + requests: list[Any] = [InsertOne({"a": 1}), InsertOne({"a": 2})] + # Replication wtimeout is a 'soft' error. + # It shouldn't stop batch processing. + try: + await self.cause_wtimeout(requests, ordered=True) + except BulkWriteError as exc: + details = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 2, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + }, + details, + ) + + # When talking to legacy servers there will be a + # write concern error for each operation. + self.assertTrue(len(details["writeConcernErrors"]) > 0) + + failed = details["writeConcernErrors"][0] + self.assertEqual(64, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) + + await self.coll.delete_many({}) + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + + # Fail due to write concern support as well + # as duplicate key error on ordered batch. + requests = [ + InsertOne({"a": 1}), + ReplaceOne({"a": 3}, {"b": 1}, upsert=True), + InsertOne({"a": 1}), + InsertOne({"a": 2}), + ] + try: + await self.cause_wtimeout(requests, ordered=True) + except BulkWriteError as exc: + details = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqualResponse( + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 1, + "nRemoved": 0, + "upserted": [{"index": 1, "_id": "..."}], + "writeErrors": [ + {"index": 2, "code": 11000, "errmsg": "...", "op": {"_id": "...", "a": 1}} + ], + }, + details, + ) + + self.assertTrue(len(details["writeConcernErrors"]) > 1) + failed = details["writeErrors"][0] + self.assertTrue("duplicate" in failed["errmsg"]) + + @async_client_context.require_replica_set + @async_client_context.require_secondaries_count(1) + async def test_write_concern_failure_unordered(self): + # Ensure we don't raise on wnote. + coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) + result = await coll_ww.bulk_write( + [DeleteOne({"something": "that does no exist"})], ordered=False + ) + self.assertTrue(result.acknowledged) + + requests = [ + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"a": 3, "b": 1}}, upsert=True), + InsertOne({"a": 2}), + ] + # Replication wtimeout is a 'soft' error. + # It shouldn't stop batch processing. + try: + await self.cause_wtimeout(requests, ordered=False) + except BulkWriteError as exc: + details = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqual(2, details["nInserted"]) + self.assertEqual(1, details["nUpserted"]) + self.assertEqual(0, len(details["writeErrors"])) + # When talking to legacy servers there will be a + # write concern error for each operation. + self.assertTrue(len(details["writeConcernErrors"]) > 1) + + await self.coll.delete_many({}) + await self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) + + # Fail due to write concern support as well + # as duplicate key error on unordered batch. + requests: list = [ + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"a": 3, "b": 1}}, upsert=True), + InsertOne({"a": 1}), + InsertOne({"a": 2}), + ] + try: + await self.cause_wtimeout(requests, ordered=False) + except BulkWriteError as exc: + details = exc.details + self.assertEqual(exc.code, 65) + else: + self.fail("Error not raised") + + self.assertEqual(2, details["nInserted"]) + self.assertEqual(1, details["nUpserted"]) + self.assertEqual(1, len(details["writeErrors"])) + # When talking to legacy servers there will be a + # write concern error for each operation. + self.assertTrue(len(details["writeConcernErrors"]) > 1) + + failed = details["writeErrors"][0] + self.assertEqual(2, failed["index"]) + self.assertEqual(11000, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) + self.assertEqual(1, failed["op"]["a"]) + + failed = details["writeConcernErrors"][0] + self.assertEqual(64, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) + + upserts = details["upserted"] + self.assertEqual(1, len(upserts)) + self.assertEqual(1, upserts[0]["index"]) + self.assertTrue(upserts[0].get("_id")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_bulk.py b/test/test_bulk.py index 663dfaf19..9069109cf 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -44,14 +44,16 @@ from pymongo.operations import * from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class BulkTestBase(IntegrationTest): coll: Collection coll_w0: Collection @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.coll = cls.db.test cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) @@ -135,7 +137,8 @@ class BulkTestBase(IntegrationTest): class TestBulk(BulkTestBase): def test_empty(self): - self.assertRaises(InvalidOperation, self.coll.bulk_write, []) + with self.assertRaises(InvalidOperation): + self.coll.bulk_write([]) def test_insert(self): expected = { @@ -180,15 +183,19 @@ class TestBulk(BulkTestBase): self._test_update_many([{"$set": {"foo": "bar"}}]) def test_array_filters_validation(self): - self.assertRaises(TypeError, UpdateMany, {}, {}, array_filters={}) - self.assertRaises(TypeError, UpdateOne, {}, {}, array_filters={}) + with self.assertRaises(TypeError): + UpdateMany({}, {}, array_filters={}) # type: ignore[arg-type] + with self.assertRaises(TypeError): + UpdateOne({}, {}, array_filters={}) # type: ignore[arg-type] def test_array_filters_unacknowledged(self): coll = self.coll_w0 update_one = UpdateOne({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) update_many = UpdateMany({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) - self.assertRaises(ConfigurationError, coll.bulk_write, [update_one]) - self.assertRaises(ConfigurationError, coll.bulk_write, [update_many]) + with self.assertRaises(ConfigurationError): + coll.bulk_write([update_one]) + with self.assertRaises(ConfigurationError): + coll.bulk_write([update_many]) def _test_update_one(self, update): expected = { @@ -790,8 +797,8 @@ class BulkAuthorizationTestBase(BulkTestBase): @classmethod @client_context.require_auth @client_context.require_no_api_version - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() def setUp(self): super().setUp() @@ -828,8 +835,16 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") - wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') + + def predicate(): + return self.coll.count_documents({}) == 2 + + wait_until(predicate, "insert 2 documents") + + def predicate(): + return self.coll.find_one({"_id": 1}) is None + + wait_until(predicate, 'removed {"_id": 1}') def test_no_results_ordered_failure(self): requests: list = [ @@ -843,7 +858,11 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: self.coll.count_documents({}) == 3, "insert 3 documents") + + def predicate(): + return self.coll.count_documents({}) == 3 + + wait_until(predicate, "insert 3 documents") self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1})) def test_no_results_unordered_success(self): @@ -855,8 +874,16 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") - wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') + + def predicate(): + return self.coll.count_documents({}) == 2 + + wait_until(predicate, "insert 2 documents") + + def predicate(): + return self.coll.find_one({"_id": 1}) is None + + wait_until(predicate, 'removed {"_id": 1}') def test_no_results_unordered_failure(self): requests: list = [ @@ -870,8 +897,16 @@ class TestBulkUnacknowledged(BulkTestBase): ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents") - wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') + + def predicate(): + return self.coll.count_documents({}) == 2 + + wait_until(predicate, "insert 2 documents") + + def predicate(): + return self.coll.find_one({"_id": 1}) is None + + wait_until(predicate, 'removed {"_id": 1}') class TestBulkAuthorization(BulkAuthorizationTestBase): @@ -883,7 +918,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): ) coll = cli.pymongo_test.test coll.find_one() - self.assertRaises(OperationFailure, coll.bulk_write, [InsertOne({"x": 1})]) + with self.assertRaises(OperationFailure): + coll.bulk_write([InsertOne({"x": 1})]) def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised @@ -899,7 +935,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): DeleteMany({}), # Prohibited. InsertOne({"x": 3}), # Never attempted. ] - self.assertRaises(OperationFailure, coll.bulk_write, requests) + with self.assertRaises(OperationFailure): + coll.bulk_write(requests) # type: ignore[arg-type] self.assertEqual({1, 2}, set(self.coll.distinct("x"))) @@ -908,18 +945,18 @@ class TestBulkWriteConcern(BulkTestBase): secondary: MongoClient @classmethod - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.w = client_context.w cls.secondary = None if cls.w is not None and cls.w > 1: - for member in client_context.hello["hosts"]: - if member != client_context.hello["primary"]: + for member in (client_context.hello)["hosts"]: + if member != (client_context.hello)["primary"]: cls.secondary = single_client(*partition_node(member)) break @classmethod - def tearDownClass(cls): + def async_tearDownClass(cls): if cls.secondary: cls.secondary.close() diff --git a/tools/synchro.py b/tools/synchro.py index b8fc9f33c..f4019f0bb 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -46,6 +46,8 @@ replacements = { "async_sendall": "sendall", "asynchronous": "synchronous", "Asynchronous": "Synchronous", + "AsyncBulkTestBase": "BulkTestBase", + "AsyncBulkAuthorizationTestBase": "BulkAuthorizationTestBase", "anext": "next", "aiter": "iter", "_ALock": "_Lock", @@ -157,6 +159,7 @@ converted_tests = [ "conftest.py", "pymongo_mocks.py", "utils_spec_runner.py", + "test_bulk.py", "test_client.py", "test_client_bulk_write.py", "test_collection.py", @@ -299,6 +302,8 @@ def translate_docstrings(lines: list[str]) -> list[str]: lines[i] = lines[i].replace(k, replacements[k]) if "Sync" in lines[i] and "Synchronous" not in lines[i] and replacements[k] in lines[i]: lines[i] = lines[i].replace("Sync", "") + if "rsApplyStop" in lines[i]: + lines[i] = lines[i].replace("rsApplyStop", "rsSyncApplyStop") if "async for" in lines[i] or "async with" in lines[i] or "async def" in lines[i]: lines[i] = lines[i].replace("async ", "") if "await " in lines[i] and "tailable" not in lines[i]: