Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-09-04 15:29:57 -05:00
commit de3fed95ee
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
29 changed files with 2597 additions and 158 deletions

View File

@ -209,4 +209,4 @@ jobs:
ls
which python
pip install -e ".[test]"
PYMONGO_MUST_CONNECT=1 pytest -v test/test_client_context.py
PYMONGO_MUST_CONNECT=1 pytest -v -k client_context

View File

@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo.
## Making a Release
Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process).
## Converting a test to async
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.
Use this generated file as a starting point for the completed conversion.
The script is used like so: `python tools/convert_test_to_async.py [test_file.py]`

View File

@ -1176,24 +1176,6 @@ class AsyncGridIn:
raise AttributeError("GridIn object has no attribute '%s'" % name)
def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__. Use AsyncGridIn.set() instead"
)
async def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
@ -1204,9 +1186,17 @@ class AsyncGridIn:
# them now.
self._file[name] = value
if self._closed:
await self._coll.files.update_one(
{"_id": self._file["_id"]}, {"$set": {name: value}}
)
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
)
async def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
async def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
@ -1400,7 +1390,11 @@ class AsyncGridIn:
return False
class AsyncGridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any
class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
"""Class to read data out of GridFS."""
def __init__(
@ -1460,6 +1454,8 @@ class AsyncGridOut(io.IOBase):
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False
_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _a_grid_out_property("filename", "Name of this file.")
@ -1486,16 +1482,43 @@ class AsyncGridOut(io.IOBase):
_file: Any
_chunk_iter: Any
async def __anext__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool
def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError(
"AsyncGridOut does not support synchronous iteration. Use `async for` instead"
)
async def __anext__(self) -> bytes:
line = await self.readline()
if line:
return line
raise StopAsyncIteration()
async def to_list(self) -> list[bytes]:
return [x async for x in self] # noqa: C416, RUF100
async def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)
async def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
await self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = await self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break
return lines
async def open(self) -> None:
if not self._file:
@ -1616,18 +1639,11 @@ class AsyncGridOut(io.IOBase):
"""
return await self._read_size_or_line(size=size)
async def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)
def tell(self) -> int:
"""Return the current position of this file."""
return self._position
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.
:param pos: the position (or offset if using relative
@ -1690,12 +1706,15 @@ class AsyncGridOut(io.IOBase):
"""
return self
async def close(self) -> None: # type: ignore[override]
async def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
await self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True
def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")

View File

@ -38,7 +38,15 @@ def _a_grid_in_property(
) -> Any:
"""Create a GridIn property."""
warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)
def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" % field_name)
# Protect against PHP-237
@ -46,6 +54,15 @@ def _a_grid_in_property(
return self._file.get(field_name, 0)
return self._file.get(field_name, None)
def setter(self: Any, value: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if self._closed:
raise InvalidOperation(
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
)
self._file[field_name] = value
if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
@ -56,6 +73,8 @@ def _a_grid_in_property(
"has been called.",
)
if not read_only and not closed_only:
return property(getter, setter, doc=docstring)
return property(getter, doc=docstring)

View File

@ -1166,24 +1166,6 @@ class GridIn:
raise AttributeError("GridIn object has no attribute '%s'" % name)
def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__. Use GridIn.set() instead"
)
def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
@ -1194,7 +1176,17 @@ class GridIn:
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use GridIn.set() instead"
)
def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
@ -1388,7 +1380,11 @@ class GridIn:
return False
class GridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any
class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
"""Class to read data out of GridFS."""
def __init__(
@ -1448,6 +1444,8 @@ class GridOut(io.IOBase):
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False
_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _grid_out_property("filename", "Name of this file.")
@ -1474,14 +1472,43 @@ class GridOut(io.IOBase):
_file: Any
_chunk_iter: Any
def __next__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool
def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError("GridOut does not support synchronous iteration. Use `for` instead")
def __next__(self) -> bytes:
line = self.readline()
if line:
return line
raise StopIteration()
def to_list(self) -> list[bytes]:
return [x for x in self] # noqa: C416, RUF100
def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)
def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break
return lines
def open(self) -> None:
if not self._file:
@ -1602,18 +1629,11 @@ class GridOut(io.IOBase):
"""
return self._read_size_or_line(size=size)
def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)
def tell(self) -> int:
"""Return the current position of this file."""
return self._position
def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.
:param pos: the position (or offset if using relative
@ -1676,12 +1696,15 @@ class GridOut(io.IOBase):
"""
return self
def close(self) -> None: # type: ignore[override]
def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True
def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")

View File

@ -228,6 +228,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.asynchronous.database import AsyncDatabase
if not isinstance(database, AsyncDatabase):
raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given")
if not name or ".." in name:
raise InvalidName("collection names cannot be empty")

View File

@ -119,9 +119,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
read_concern or client.read_concern,
)
from pymongo.asynchronous.mongo_client import AsyncMongoClient
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
if not isinstance(client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
if name != "$external":
_check_name(name)

View File

@ -194,9 +194,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
async def collection_info(
self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
The returned collection info is passed to libmongocrypt which reads
@ -598,6 +596,9 @@ class AsyncClientEncryption(Generic[_DocumentType]):
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
if not isinstance(key_vault_client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}")
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
@ -683,6 +684,11 @@ class AsyncClientEncryption(Generic[_DocumentType]):
https://mongodb.com/docs/manual/reference/command/create
"""
if not isinstance(database, AsyncDatabase):
raise TypeError(
f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given"
)
encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:

View File

@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F:
if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
else:
async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return await cls.__anext__()
def aiter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return cls.__aiter__()

View File

@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler:
def __init__(
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
):
if not isinstance(client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
self.client = client
self.server_address = server.description.address
self.session = session

View File

@ -475,7 +475,7 @@ class Topology:
if server:
await server.pool.ready()
suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description
suppress_event = sd_old == server_description
if self._publish_server and not suppress_event:
assert self._events is not None
self._events.put(
@ -497,7 +497,7 @@ class Topology:
(td_old, self._description, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event:
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
@ -521,7 +521,7 @@ class Topology:
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)
# Wake waiters in select_servers().
# Wake anything waiting in select_servers().
self._condition.notify_all()
async def on_change(

View File

@ -231,6 +231,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.synchronous.database import Database
if not isinstance(database, Database):
raise TypeError(f"Collection requires a Database but {type(database)} given")
if not name or ".." in name:
raise InvalidName("collection names cannot be empty")

View File

@ -119,9 +119,14 @@ class Database(common.BaseObject, Generic[_DocumentType]):
read_concern or client.read_concern,
)
from pymongo.synchronous.mongo_client import MongoClient
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but given {type(client)}")
if name != "$external":
_check_name(name)

View File

@ -194,9 +194,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
def collection_info(
self, database: Database[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
The returned collection info is passed to libmongocrypt which reads
@ -596,6 +594,9 @@ class ClientEncryption(Generic[_DocumentType]):
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
if not isinstance(key_vault_client, MongoClient):
raise TypeError(f"MongoClient required but given {type(key_vault_client)}")
self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
@ -681,6 +682,11 @@ class ClientEncryption(Generic[_DocumentType]):
https://mongodb.com/docs/manual/reference/command/create
"""
if not isinstance(database, Database):
raise TypeError(
f"create_encrypted_collection() requires a Database but {type(database)} given"
)
encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:

View File

@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F:
if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
else:
def next(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__next__()
def iter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__iter__()

View File

@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler:
)
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but given {type(client)}")
self.client = client
self.server_address = server.description.address
self.session = session

View File

@ -475,7 +475,7 @@ class Topology:
if server:
server.pool.ready()
suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description
suppress_event = sd_old == server_description
if self._publish_server and not suppress_event:
assert self._events is not None
self._events.put(
@ -497,7 +497,7 @@ class Topology:
(td_old, self._description, self._topology_id),
)
)
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event:
_debug_log(
_SDAM_LOGGER,
topologyId=self._topology_id,
@ -521,7 +521,7 @@ class Topology:
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
# Wake waiters in select_servers().
# Wake anything waiting in select_servers().
self._condition.notify_all()
def on_change(

View File

@ -569,7 +569,10 @@ class ClientContext:
def sec_count():
return 0 if not self.client else len(self.client.secondaries)
return self._require(lambda: sec_count() >= count, "Not enough secondaries available")
def check():
return sec_count() >= count
return self._require(check, "Not enough secondaries available")
@property
def supports_secondary_read_pref(self):
@ -947,11 +950,11 @@ class UnitTest(PyMongoTestCase):
@classmethod
def _setup_class(cls):
cls._setup_class()
pass
@classmethod
def _tearDown_class(cls):
cls._tearDown_class()
pass
class IntegrationTest(PyMongoTestCase):

View File

@ -571,7 +571,10 @@ class AsyncClientContext:
async def sec_count():
return 0 if not self.client else len(await self.client.secondaries)
return self._require(lambda: sec_count() >= count, "Not enough secondaries available")
async def check():
return await sec_count() >= count
return self._require(check, "Not enough secondaries available")
@property
async def supports_secondary_read_pref(self):
@ -949,11 +952,11 @@ class AsyncUnitTest(AsyncPyMongoTestCase):
@classmethod
async def _setup_class(cls):
await cls._setup_class()
pass
@classmethod
async def _tearDown_class(cls):
await cls._tearDown_class()
pass
class AsyncIntegrationTest(AsyncPyMongoTestCase):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
# Copyright 2018-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import sys
sys.path[0:0] = [""]
from test.asynchronous import AsyncUnitTest, SkipTest, async_client_context, unittest
_IS_SYNC = False
class TestAsyncClientContext(AsyncUnitTest):
def test_must_connect(self):
if "PYMONGO_MUST_CONNECT" not in os.environ:
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
self.assertTrue(
async_client_context.connected,
"client context must be connected when "
"PYMONGO_MUST_CONNECT is set. Failed attempts:\n{}".format(
async_client_context.connection_attempt_info()
),
)
def test_serverless(self):
if "TEST_SERVERLESS" not in os.environ:
raise SkipTest("TEST_SERVERLESS is not set")
self.assertTrue(
async_client_context.connected and async_client_context.serverless,
"client context must be connected to serverless when "
f"TEST_SERVERLESS is set. Failed attempts:\n{async_client_context.connection_attempt_info()}",
)
def test_enableTestCommands_is_disabled(self):
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
self.assertFalse(
async_client_context.test_commands_enabled,
"enableTestCommands must be disabled when PYMONGO_DISABLE_TEST_COMMANDS is set.",
)
def test_setdefaultencoding_worked(self):
if "SETDEFAULTENCODING" not in os.environ:
raise SkipTest("SETDEFAULTENCODING is not set")
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,871 @@
#
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the grid_file module."""
from __future__ import annotations
import datetime
import io
import sys
import zipfile
from io import BytesIO
from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context
from pymongo.asynchronous.database import AsyncDatabase
sys.path[0:0] = [""]
from test import IntegrationTest, qcheck, unittest
from test.utils import EventListener, async_rs_or_single_client, rs_or_single_client
from bson.objectid import ObjectId
from gridfs import GridFS
from gridfs.asynchronous.grid_file import (
_SEEK_CUR,
_SEEK_END,
DEFAULT_CHUNK_SIZE,
AsyncGridFS,
AsyncGridIn,
AsyncGridOut,
AsyncGridOutCursor,
)
from gridfs.errors import NoFile
from pymongo import AsyncMongoClient
from pymongo.asynchronous.helpers import aiter, anext
from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError
from pymongo.message import _CursorAddress
_IS_SYNC = False
class AsyncTestGridFileNoConnect(AsyncUnitTest):
"""Test GridFile features on a client that does not connect."""
db: AsyncDatabase
@classmethod
def setUpClass(cls):
cls.db = AsyncMongoClient(connect=False).pymongo_test
def test_grid_in_custom_opts(self):
self.assertRaises(TypeError, AsyncGridIn, "foo")
a = AsyncGridIn(
self.db.fs,
_id=5,
filename="my_file",
contentType="text/html",
chunkSize=1000,
aliases=["foo"],
metadata={"foo": 1, "bar": 2},
bar=3,
baz="hello",
)
self.assertEqual(5, a._id)
self.assertEqual("my_file", a.filename)
self.assertEqual("my_file", a.name)
self.assertEqual("text/html", a.content_type)
self.assertEqual(1000, a.chunk_size)
self.assertEqual(["foo"], a.aliases)
self.assertEqual({"foo": 1, "bar": 2}, a.metadata)
self.assertEqual(3, a.bar)
self.assertEqual("hello", a.baz)
self.assertRaises(AttributeError, getattr, a, "mike")
b = AsyncGridIn(self.db.fs, content_type="text/html", chunk_size=1000, baz=100)
self.assertEqual("text/html", b.content_type)
self.assertEqual(1000, b.chunk_size)
self.assertEqual(100, b.baz)
class AsyncTestGridFile(AsyncIntegrationTest):
async def asyncSetUp(self):
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
async def test_basic(self):
f = AsyncGridIn(self.db.fs, filename="test")
await f.write(b"hello world")
await f.close()
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"hello world", await g.read())
# make sure it's still there...
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"hello world", await g.read())
f = AsyncGridIn(self.db.fs, filename="test")
await f.close()
self.assertEqual(2, await self.db.fs.files.count_documents({}))
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"", await g.read())
# test that reading 0 returns proper type
self.assertEqual(b"", await g.read(0))
async def test_md5(self):
f = AsyncGridIn(self.db.fs)
await f.write(b"hello world\n")
await f.close()
self.assertEqual(None, f.md5)
async def test_alternate_collection(self):
await self.db.alt.files.delete_many({})
await self.db.alt.chunks.delete_many({})
f = AsyncGridIn(self.db.alt)
await f.write(b"hello world")
await f.close()
self.assertEqual(1, await self.db.alt.files.count_documents({}))
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
g = AsyncGridOut(self.db.alt, f._id)
self.assertEqual(b"hello world", await g.read())
async def test_grid_in_default_opts(self):
self.assertRaises(TypeError, AsyncGridIn, "foo")
a = AsyncGridIn(self.db.fs)
self.assertTrue(isinstance(a._id, ObjectId))
self.assertRaises(AttributeError, setattr, a, "_id", 5)
self.assertEqual(None, a.filename)
self.assertEqual(None, a.name)
a.filename = "my_file"
self.assertEqual("my_file", a.filename)
self.assertEqual("my_file", a.name)
self.assertEqual(None, a.content_type)
a.content_type = "text/html"
self.assertEqual("text/html", a.content_type)
self.assertRaises(AttributeError, getattr, a, "length")
self.assertRaises(AttributeError, setattr, a, "length", 5)
self.assertEqual(255 * 1024, a.chunk_size)
self.assertRaises(AttributeError, setattr, a, "chunk_size", 5)
self.assertRaises(AttributeError, getattr, a, "upload_date")
self.assertRaises(AttributeError, setattr, a, "upload_date", 5)
self.assertRaises(AttributeError, getattr, a, "aliases")
a.aliases = ["foo"]
self.assertEqual(["foo"], a.aliases)
self.assertRaises(AttributeError, getattr, a, "metadata")
a.metadata = {"foo": 1}
self.assertEqual({"foo": 1}, a.metadata)
self.assertRaises(AttributeError, setattr, a, "md5", 5)
await a.close()
if _IS_SYNC:
a.forty_two = 42
else:
self.assertRaises(AttributeError, setattr, a, "forty_two", 42)
await a.set("forty_two", 42)
self.assertEqual(42, a.forty_two)
self.assertTrue(isinstance(a._id, ObjectId))
self.assertRaises(AttributeError, setattr, a, "_id", 5)
self.assertEqual("my_file", a.filename)
self.assertEqual("my_file", a.name)
self.assertEqual("text/html", a.content_type)
self.assertEqual(0, a.length)
self.assertRaises(AttributeError, setattr, a, "length", 5)
self.assertEqual(255 * 1024, a.chunk_size)
self.assertRaises(AttributeError, setattr, a, "chunk_size", 5)
self.assertTrue(isinstance(a.upload_date, datetime.datetime))
self.assertRaises(AttributeError, setattr, a, "upload_date", 5)
self.assertEqual(["foo"], a.aliases)
self.assertEqual({"foo": 1}, a.metadata)
self.assertEqual(None, a.md5)
self.assertRaises(AttributeError, setattr, a, "md5", 5)
# Make sure custom attributes that were set both before and after
# a.close() are reflected in b. PYTHON-411.
b = await AsyncGridFS(self.db).get_last_version(filename=a.filename)
self.assertEqual(a.metadata, b.metadata)
self.assertEqual(a.aliases, b.aliases)
self.assertEqual(a.forty_two, b.forty_two)
async def test_grid_out_default_opts(self):
self.assertRaises(TypeError, AsyncGridOut, "foo")
gout = AsyncGridOut(self.db.fs, 5)
with self.assertRaises(NoFile):
if not _IS_SYNC:
await gout.open()
gout.name
a = AsyncGridIn(self.db.fs)
await a.close()
b = AsyncGridOut(self.db.fs, a._id)
if not _IS_SYNC:
await b.open()
self.assertEqual(a._id, b._id)
self.assertEqual(0, b.length)
self.assertEqual(None, b.content_type)
self.assertEqual(None, b.name)
self.assertEqual(None, b.filename)
self.assertEqual(255 * 1024, b.chunk_size)
self.assertTrue(isinstance(b.upload_date, datetime.datetime))
self.assertEqual(None, b.aliases)
self.assertEqual(None, b.metadata)
self.assertEqual(None, b.md5)
for attr in [
"_id",
"name",
"content_type",
"length",
"chunk_size",
"upload_date",
"aliases",
"metadata",
"md5",
]:
self.assertRaises(AttributeError, setattr, b, attr, 5)
async def test_grid_out_cursor_options(self):
self.assertRaises(
TypeError, AsyncGridOutCursor.__init__, self.db.fs, {}, projection={"filename": 1}
)
cursor = AsyncGridOutCursor(self.db.fs, {})
cursor_clone = cursor.clone()
cursor_dict = cursor.__dict__.copy()
cursor_dict.pop("_session")
cursor_clone_dict = cursor_clone.__dict__.copy()
cursor_clone_dict.pop("_session")
self.assertDictEqual(cursor_dict, cursor_clone_dict)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
async def test_grid_out_custom_opts(self):
one = AsyncGridIn(
self.db.fs,
_id=5,
filename="my_file",
contentType="text/html",
chunkSize=1000,
aliases=["foo"],
metadata={"foo": 1, "bar": 2},
bar=3,
baz="hello",
)
await one.write(b"hello world")
await one.close()
two = AsyncGridOut(self.db.fs, 5)
if not _IS_SYNC:
await two.open()
self.assertEqual("my_file", two.name)
self.assertEqual("my_file", two.filename)
self.assertEqual(5, two._id)
self.assertEqual(11, two.length)
self.assertEqual("text/html", two.content_type)
self.assertEqual(1000, two.chunk_size)
self.assertTrue(isinstance(two.upload_date, datetime.datetime))
self.assertEqual(["foo"], two.aliases)
self.assertEqual({"foo": 1, "bar": 2}, two.metadata)
self.assertEqual(3, two.bar)
self.assertEqual(None, two.md5)
for attr in [
"_id",
"name",
"content_type",
"length",
"chunk_size",
"upload_date",
"aliases",
"metadata",
"md5",
]:
self.assertRaises(AttributeError, setattr, two, attr, 5)
async def test_grid_out_file_document(self):
one = AsyncGridIn(self.db.fs)
await one.write(b"foo bar")
await one.close()
two = AsyncGridOut(self.db.fs, file_document=await self.db.fs.files.find_one())
self.assertEqual(b"foo bar", await two.read())
three = AsyncGridOut(self.db.fs, 5, file_document=await self.db.fs.files.find_one())
self.assertEqual(b"foo bar", await three.read())
four = AsyncGridOut(self.db.fs, file_document={})
with self.assertRaises(NoFile):
if not _IS_SYNC:
await four.open()
four.name
async def test_write_file_like(self):
one = AsyncGridIn(self.db.fs)
await one.write(b"hello world")
await one.close()
two = AsyncGridOut(self.db.fs, one._id)
three = AsyncGridIn(self.db.fs)
await three.write(two)
await three.close()
four = AsyncGridOut(self.db.fs, three._id)
self.assertEqual(b"hello world", await four.read())
five = AsyncGridIn(self.db.fs, chunk_size=2)
await five.write(b"hello")
buffer = BytesIO(b" world")
await five.write(buffer)
await five.write(b" and mongodb")
await five.close()
self.assertEqual(
b"hello world and mongodb", await AsyncGridOut(self.db.fs, five._id).read()
)
async def test_write_lines(self):
a = AsyncGridIn(self.db.fs)
await a.writelines([b"hello ", b"world"])
await a.close()
self.assertEqual(b"hello world", await AsyncGridOut(self.db.fs, a._id).read())
async def test_close(self):
f = AsyncGridIn(self.db.fs)
await f.close()
with self.assertRaises(ValueError):
await f.write("test")
await f.close()
async def test_closed(self):
f = AsyncGridIn(self.db.fs, chunkSize=5)
await f.write(b"Hello world.\nHow are you?")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
if not _IS_SYNC:
await g.open()
self.assertFalse(g.closed)
await g.read(1)
self.assertFalse(g.closed)
await g.read(100)
self.assertFalse(g.closed)
await g.close()
self.assertTrue(g.closed)
async def test_multi_chunk_file(self):
random_string = b"a" * (DEFAULT_CHUNK_SIZE + 1000)
f = AsyncGridIn(self.db.fs)
await f.write(random_string)
await f.close()
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(2, await self.db.fs.chunks.count_documents({}))
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(random_string, await g.read())
# TODO: https://jira.mongodb.org/browse/PYTHON-4708
@async_client_context.require_sync
async def test_small_chunks(self):
self.files = 0
self.chunks = 0
async def helper(data):
f = AsyncGridIn(self.db.fs, chunkSize=1)
await f.write(data)
await f.close()
self.files += 1
self.chunks += len(data)
self.assertEqual(self.files, await self.db.fs.files.count_documents({}))
self.assertEqual(self.chunks, await self.db.fs.chunks.count_documents({}))
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(data, await g.read())
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(data, await g.read(10) + await g.read(10))
return True
qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20)))
async def test_seek(self):
f = AsyncGridIn(self.db.fs, chunkSize=3)
await f.write(b"hello world")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"hello world", await g.read())
await g.seek(0)
self.assertEqual(b"hello world", await g.read())
await g.seek(1)
self.assertEqual(b"ello world", await g.read())
with self.assertRaises(IOError):
await g.seek(-1)
await g.seek(-3, _SEEK_END)
self.assertEqual(b"rld", await g.read())
await g.seek(0, _SEEK_END)
self.assertEqual(b"", await g.read())
with self.assertRaises(IOError):
await g.seek(-100, _SEEK_END)
await g.seek(3)
await g.seek(3, _SEEK_CUR)
self.assertEqual(b"world", await g.read())
with self.assertRaises(IOError):
await g.seek(-100, _SEEK_CUR)
async def test_tell(self):
f = AsyncGridIn(self.db.fs, chunkSize=3)
await f.write(b"hello world")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(0, g.tell())
await g.read(0)
self.assertEqual(0, g.tell())
await g.read(1)
self.assertEqual(1, g.tell())
await g.read(2)
self.assertEqual(3, g.tell())
await g.read()
self.assertEqual(g.length, g.tell())
async def test_multiple_reads(self):
f = AsyncGridIn(self.db.fs, chunkSize=3)
await f.write(b"hello world")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"he", await g.read(2))
self.assertEqual(b"ll", await g.read(2))
self.assertEqual(b"o ", await g.read(2))
self.assertEqual(b"wo", await g.read(2))
self.assertEqual(b"rl", await g.read(2))
self.assertEqual(b"d", await g.read(2))
self.assertEqual(b"", await g.read(2))
async def test_readline(self):
f = AsyncGridIn(self.db.fs, chunkSize=5)
await f.write(
b"""Hello world,
How are you?
Hope all is well.
Bye"""
)
await f.close()
# Try read(), then readline().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"H", await g.read(1))
self.assertEqual(b"ello world,\n", await g.readline())
self.assertEqual(b"How a", await g.readline(5))
self.assertEqual(b"", await g.readline(0))
self.assertEqual(b"re you?\n", await g.readline())
self.assertEqual(b"Hope all is well.\n", await g.readline(1000))
self.assertEqual(b"Bye", await g.readline())
self.assertEqual(b"", await g.readline())
# Try readline() first, then read().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"He", await g.readline(2))
self.assertEqual(b"l", await g.read(1))
self.assertEqual(b"lo", await g.readline(2))
self.assertEqual(b" world,\n", await g.readline())
# Only readline().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"H", await g.readline(1))
self.assertEqual(b"e", await g.readline(1))
self.assertEqual(b"llo world,\n", await g.readline())
async def test_readlines(self):
f = AsyncGridIn(self.db.fs, chunkSize=5)
await f.write(
b"""Hello world,
How are you?
Hope all is well.
Bye"""
)
await f.close()
# Try read(), then readlines().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"He", await g.read(2))
self.assertEqual([b"llo world,\n", b"How are you?\n"], await g.readlines(11))
self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines())
self.assertEqual([], await g.readlines())
# Try readline(), then readlines().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"Hello world,\n", await g.readline())
self.assertEqual([b"How are you?\n", b"Hope all is well.\n"], await g.readlines(13))
self.assertEqual(b"Bye", await g.readline())
self.assertEqual([], await g.readlines())
# Only readlines().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
await g.readlines(),
)
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
await g.readlines(0),
)
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
self.assertEqual([b"How are you?\n"], await g.readlines(12))
self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines(18))
# Try readlines() first, then read().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
self.assertEqual(b"H", await g.read(1))
self.assertEqual([b"ow are you?\n", b"Hope all is well.\n"], await g.readlines(29))
self.assertEqual([b"Bye"], await g.readlines(1))
# Try readlines() first, then readline().
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
self.assertEqual(b"How are you?\n", await g.readline())
self.assertEqual([b"Hope all is well.\n"], await g.readlines(17))
self.assertEqual(b"Bye", await g.readline())
async def test_iterator(self):
f = AsyncGridIn(self.db.fs)
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
if _IS_SYNC:
self.assertEqual([], list(g))
else:
self.assertEqual([], await g.to_list())
f = AsyncGridIn(self.db.fs)
await f.write(b"hello world\nhere are\nsome lines.")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
if _IS_SYNC:
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
else:
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], await g.to_list())
self.assertEqual(b"", await g.read(5))
if _IS_SYNC:
self.assertEqual([], list(g))
else:
self.assertEqual([], await g.to_list())
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"hello world\n", await anext(aiter(g)))
self.assertEqual(b"here", await g.read(4))
self.assertEqual(b" are\n", await anext(aiter(g)))
self.assertEqual(b"some lines", await g.read(10))
self.assertEqual(b".", await anext(aiter(g)))
with self.assertRaises(StopAsyncIteration):
await aiter(g).__anext__()
f = AsyncGridIn(self.db.fs, chunk_size=2)
await f.write(b"hello world")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
if _IS_SYNC:
self.assertEqual([b"hello world"], list(g))
else:
self.assertEqual([b"hello world"], await g.to_list())
async def test_read_unaligned_buffer_size(self):
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
f = AsyncGridIn(self.db.fs, chunkSize=16)
await f.write(in_data)
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
out_data = b""
while 1:
s = await g.read(13)
if not s:
break
out_data += s
self.assertEqual(in_data, out_data)
async def test_readchunk(self):
in_data = b"a" * 10
f = AsyncGridIn(self.db.fs, chunkSize=3)
await f.write(in_data)
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(3, len(await g.readchunk()))
self.assertEqual(2, len(await g.read(2)))
self.assertEqual(1, len(await g.readchunk()))
self.assertEqual(3, len(await g.read(3)))
self.assertEqual(1, len(await g.readchunk()))
self.assertEqual(0, len(await g.readchunk()))
async def test_write_unicode(self):
f = AsyncGridIn(self.db.fs)
with self.assertRaises(TypeError):
await f.write("foo")
f = AsyncGridIn(self.db.fs, encoding="utf-8")
await f.write("foo")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual(b"foo", await g.read())
f = AsyncGridIn(self.db.fs, encoding="iso-8859-1")
await f.write("")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertEqual("".encode("iso-8859-1"), await g.read())
async def test_set_after_close(self):
f = AsyncGridIn(self.db.fs, _id="foo", bar="baz")
self.assertEqual("foo", f._id)
self.assertEqual("baz", f.bar)
self.assertRaises(AttributeError, getattr, f, "baz")
self.assertRaises(AttributeError, getattr, f, "uploadDate")
self.assertRaises(AttributeError, setattr, f, "_id", 5)
if _IS_SYNC:
f.bar = "foo"
f.baz = 5
else:
await f.set("bar", "foo")
await f.set("baz", 5)
self.assertEqual("foo", f._id)
self.assertEqual("foo", f.bar)
self.assertEqual(5, f.baz)
self.assertRaises(AttributeError, getattr, f, "uploadDate")
await f.close()
self.assertEqual("foo", f._id)
self.assertEqual("foo", f.bar)
self.assertEqual(5, f.baz)
self.assertTrue(f.uploadDate)
self.assertRaises(AttributeError, setattr, f, "_id", 5)
if _IS_SYNC:
f.bar = "a"
f.baz = "b"
else:
await f.set("bar", "a")
await f.set("baz", "b")
self.assertRaises(AttributeError, setattr, f, "upload_date", 5)
g = AsyncGridOut(self.db.fs, f._id)
if not _IS_SYNC:
await g.open()
self.assertEqual("a", g.bar)
self.assertEqual("b", g.baz)
# Versions 2.0.1 and older saved a _closed field for some reason.
self.assertRaises(AttributeError, getattr, g, "_closed")
async def test_context_manager(self):
contents = b"Imagine this is some important data..."
async with AsyncGridIn(self.db.fs, filename="important") as infile:
await infile.write(contents)
async with AsyncGridOut(self.db.fs, infile._id) as outfile:
self.assertEqual(contents, await outfile.read())
async def test_exception_file_non_existence(self):
contents = b"Imagine this is some important data..."
with self.assertRaises(ConnectionError):
async with AsyncGridIn(self.db.fs, filename="important") as infile:
await infile.write(contents)
raise ConnectionError("Test exception")
# Expectation: File chunks are written, entry in files doesn't appear.
self.assertEqual(
await self.db.fs.chunks.count_documents({"files_id": infile._id}), infile._chunk_number
)
self.assertIsNone(await self.db.fs.files.find_one({"_id": infile._id}))
self.assertTrue(infile.closed)
async def test_prechunked_string(self):
async def write_me(s, chunk_size):
buf = BytesIO(s)
infile = AsyncGridIn(self.db.fs)
while True:
to_write = buf.read(chunk_size)
if to_write == b"":
break
await infile.write(to_write)
await infile.close()
buf.close()
outfile = AsyncGridOut(self.db.fs, infile._id)
data = await outfile.read()
self.assertEqual(s, data)
s = b"x" * DEFAULT_CHUNK_SIZE * 4
# Test with default chunk size
await write_me(s, DEFAULT_CHUNK_SIZE)
# Multiple
await write_me(s, DEFAULT_CHUNK_SIZE * 3)
# Custom
await write_me(s, 262300)
async def test_grid_out_lazy_connect(self):
fs = self.db.fs
outfile = AsyncGridOut(fs, file_id=-1)
with self.assertRaises(NoFile):
await outfile.read()
with self.assertRaises(NoFile):
if not _IS_SYNC:
await outfile.open()
outfile.filename
infile = AsyncGridIn(fs, filename=1)
await infile.close()
outfile = AsyncGridOut(fs, infile._id)
await outfile.read()
outfile.filename
outfile = AsyncGridOut(fs, infile._id)
await outfile.readchunk()
async def test_grid_in_lazy_connect(self):
client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
fs = client.db.fs
infile = AsyncGridIn(fs, file_id=-1, chunk_size=1)
with self.assertRaises(ServerSelectionTimeoutError):
await infile.write(b"data")
with self.assertRaises(ServerSelectionTimeoutError):
await infile.close()
async def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs)
async def test_survive_cursor_not_found(self):
# By default the find command returns 101 documents in the first batch.
# Use 102 batches to cause a single getMore.
chunk_size = 1024
data = b"d" * (102 * chunk_size)
listener = EventListener()
client = await async_rs_or_single_client(event_listeners=[listener])
db = client.pymongo_test
async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile:
await infile.write(data)
async with AsyncGridOut(db.fs, infile._id) as outfile:
self.assertEqual(len(await outfile.readchunk()), chunk_size)
# Kill the cursor to simulate the cursor timing out on the server
# when an application spends a long time between two calls to
# readchunk().
assert await client.address is not None
await client._close_cursor_now(
outfile._chunk_iter._cursor.cursor_id,
_CursorAddress(await client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
)
# Read the rest of the file without error.
self.assertEqual(len(await outfile.read()), len(data) - chunk_size)
# Paranoid, ensure that a getMore was actually sent.
self.assertIn("getMore", listener.started_command_names())
@async_client_context.require_sync
async def test_zip(self):
zf = BytesIO()
z = zipfile.ZipFile(zf, "w")
z.writestr("test.txt", b"hello world")
z.close()
zf.seek(0)
f = AsyncGridIn(self.db.fs, filename="test.zip")
await f.write(zf)
await f.close()
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
g = AsyncGridOut(self.db.fs, f._id)
z = zipfile.ZipFile(g)
self.assertSequenceEqual(z.namelist(), ["test.txt"])
self.assertEqual(z.read("test.txt"), b"hello world")
async def test_grid_out_unsupported_operations(self):
f = AsyncGridIn(self.db.fs, chunkSize=3)
await f.write(b"hello world")
await f.close()
g = AsyncGridOut(self.db.fs, f._id)
self.assertRaises(io.UnsupportedOperation, g.writelines, [b"some", b"lines"])
self.assertRaises(io.UnsupportedOperation, g.write, b"some text")
self.assertRaises(io.UnsupportedOperation, g.fileno)
self.assertRaises(io.UnsupportedOperation, g.truncate)
self.assertFalse(g.writable())
self.assertFalse(g.isatty())
if __name__ == "__main__":
unittest.main()

View File

@ -35,23 +35,13 @@ try:
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser import parse_uri
if HAVE_SSL:

View File

@ -44,14 +44,16 @@ from pymongo.operations import *
from pymongo.synchronous.collection import Collection
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class BulkTestBase(IntegrationTest):
coll: Collection
coll_w0: Collection
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.coll = cls.db.test
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
@ -135,7 +137,8 @@ class BulkTestBase(IntegrationTest):
class TestBulk(BulkTestBase):
def test_empty(self):
self.assertRaises(InvalidOperation, self.coll.bulk_write, [])
with self.assertRaises(InvalidOperation):
self.coll.bulk_write([])
def test_insert(self):
expected = {
@ -180,15 +183,19 @@ class TestBulk(BulkTestBase):
self._test_update_many([{"$set": {"foo": "bar"}}])
def test_array_filters_validation(self):
self.assertRaises(TypeError, UpdateMany, {}, {}, array_filters={})
self.assertRaises(TypeError, UpdateOne, {}, {}, array_filters={})
with self.assertRaises(TypeError):
UpdateMany({}, {}, array_filters={}) # type: ignore[arg-type]
with self.assertRaises(TypeError):
UpdateOne({}, {}, array_filters={}) # type: ignore[arg-type]
def test_array_filters_unacknowledged(self):
coll = self.coll_w0
update_one = UpdateOne({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}])
update_many = UpdateMany({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}])
self.assertRaises(ConfigurationError, coll.bulk_write, [update_one])
self.assertRaises(ConfigurationError, coll.bulk_write, [update_many])
with self.assertRaises(ConfigurationError):
coll.bulk_write([update_one])
with self.assertRaises(ConfigurationError):
coll.bulk_write([update_many])
def _test_update_one(self, update):
expected = {
@ -790,8 +797,8 @@ class BulkAuthorizationTestBase(BulkTestBase):
@classmethod
@client_context.require_auth
@client_context.require_no_api_version
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
def setUp(self):
super().setUp()
@ -828,8 +835,16 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests)
self.assertFalse(result.acknowledged)
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
def predicate():
return self.coll.count_documents({}) == 2
wait_until(predicate, "insert 2 documents")
def predicate():
return self.coll.find_one({"_id": 1}) is None
wait_until(predicate, 'removed {"_id": 1}')
def test_no_results_ordered_failure(self):
requests: list = [
@ -843,7 +858,11 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests)
self.assertFalse(result.acknowledged)
wait_until(lambda: self.coll.count_documents({}) == 3, "insert 3 documents")
def predicate():
return self.coll.count_documents({}) == 3
wait_until(predicate, "insert 3 documents")
self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1}))
def test_no_results_unordered_success(self):
@ -855,8 +874,16 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests, ordered=False)
self.assertFalse(result.acknowledged)
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
def predicate():
return self.coll.count_documents({}) == 2
wait_until(predicate, "insert 2 documents")
def predicate():
return self.coll.find_one({"_id": 1}) is None
wait_until(predicate, 'removed {"_id": 1}')
def test_no_results_unordered_failure(self):
requests: list = [
@ -870,8 +897,16 @@ class TestBulkUnacknowledged(BulkTestBase):
]
result = self.coll_w0.bulk_write(requests, ordered=False)
self.assertFalse(result.acknowledged)
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
def predicate():
return self.coll.count_documents({}) == 2
wait_until(predicate, "insert 2 documents")
def predicate():
return self.coll.find_one({"_id": 1}) is None
wait_until(predicate, 'removed {"_id": 1}')
class TestBulkAuthorization(BulkAuthorizationTestBase):
@ -883,7 +918,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
)
coll = cli.pymongo_test.test
coll.find_one()
self.assertRaises(OperationFailure, coll.bulk_write, [InsertOne({"x": 1})])
with self.assertRaises(OperationFailure):
coll.bulk_write([InsertOne({"x": 1})])
def test_no_remove(self):
# We test that an authorization failure aborts the batch and is raised
@ -899,7 +935,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
DeleteMany({}), # Prohibited.
InsertOne({"x": 3}), # Never attempted.
]
self.assertRaises(OperationFailure, coll.bulk_write, requests)
with self.assertRaises(OperationFailure):
coll.bulk_write(requests) # type: ignore[arg-type]
self.assertEqual({1, 2}, set(self.coll.distinct("x")))
@ -908,18 +945,18 @@ class TestBulkWriteConcern(BulkTestBase):
secondary: MongoClient
@classmethod
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.w = client_context.w
cls.secondary = None
if cls.w is not None and cls.w > 1:
for member in client_context.hello["hosts"]:
if member != client_context.hello["primary"]:
for member in (client_context.hello)["hosts"]:
if member != (client_context.hello)["primary"]:
cls.secondary = single_client(*partition_node(member))
break
@classmethod
def tearDownClass(cls):
def async_tearDownClass(cls):
if cls.secondary:
cls.secondary.close()

View File

@ -18,10 +18,12 @@ import sys
sys.path[0:0] = [""]
from test import SkipTest, client_context, unittest
from test import SkipTest, UnitTest, client_context, unittest
_IS_SYNC = True
class TestClientContext(unittest.TestCase):
class TestClientContext(UnitTest):
def test_must_connect(self):
if "PYMONGO_MUST_CONNECT" not in os.environ:
raise SkipTest("PYMONGO_MUST_CONNECT is not set")

View File

@ -21,6 +21,7 @@ import io
import sys
import zipfile
from io import BytesIO
from test import IntegrationTest, UnitTest, client_context
from pymongo.synchronous.database import Database
@ -36,16 +37,20 @@ from gridfs.synchronous.grid_file import (
_SEEK_CUR,
_SEEK_END,
DEFAULT_CHUNK_SIZE,
GridFS,
GridIn,
GridOut,
GridOutCursor,
)
from pymongo import MongoClient
from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError
from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError
from pymongo.message import _CursorAddress
from pymongo.synchronous.helpers import iter, next
_IS_SYNC = True
class TestGridFileNoConnect(unittest.TestCase):
class TestGridFileNoConnect(UnitTest):
"""Test GridFile features on a client that does not connect."""
db: Database
@ -151,6 +156,7 @@ class TestGridFile(IntegrationTest):
self.assertEqual(None, a.content_type)
a.content_type = "text/html"
self.assertEqual("text/html", a.content_type)
self.assertRaises(AttributeError, getattr, a, "length")
@ -164,17 +170,24 @@ class TestGridFile(IntegrationTest):
self.assertRaises(AttributeError, getattr, a, "aliases")
a.aliases = ["foo"]
self.assertEqual(["foo"], a.aliases)
self.assertRaises(AttributeError, getattr, a, "metadata")
a.metadata = {"foo": 1}
self.assertEqual({"foo": 1}, a.metadata)
self.assertRaises(AttributeError, setattr, a, "md5", 5)
a.close()
a.forty_two = 42
if _IS_SYNC:
a.forty_two = 42
else:
self.assertRaises(AttributeError, setattr, a, "forty_two", 42)
a.set("forty_two", 42)
self.assertEqual(42, a.forty_two)
self.assertTrue(isinstance(a._id, ObjectId))
@ -213,12 +226,16 @@ class TestGridFile(IntegrationTest):
gout = GridOut(self.db.fs, 5)
with self.assertRaises(NoFile):
if not _IS_SYNC:
gout.open()
gout.name
a = GridIn(self.db.fs)
a.close()
b = GridOut(self.db.fs, a._id)
if not _IS_SYNC:
b.open()
self.assertEqual(a._id, b._id)
self.assertEqual(0, b.length)
@ -278,6 +295,9 @@ class TestGridFile(IntegrationTest):
two = GridOut(self.db.fs, 5)
if not _IS_SYNC:
two.open()
self.assertEqual("my_file", two.name)
self.assertEqual("my_file", two.filename)
self.assertEqual(5, two._id)
@ -316,6 +336,8 @@ class TestGridFile(IntegrationTest):
four = GridOut(self.db.fs, file_document={})
with self.assertRaises(NoFile):
if not _IS_SYNC:
four.open()
four.name
def test_write_file_like(self):
@ -350,7 +372,8 @@ class TestGridFile(IntegrationTest):
def test_close(self):
f = GridIn(self.db.fs)
f.close()
self.assertRaises(ValueError, f.write, "test")
with self.assertRaises(ValueError):
f.write("test")
f.close()
def test_closed(self):
@ -359,6 +382,8 @@ class TestGridFile(IntegrationTest):
f.close()
g = GridOut(self.db.fs, f._id)
if not _IS_SYNC:
g.open()
self.assertFalse(g.closed)
g.read(1)
self.assertFalse(g.closed)
@ -380,6 +405,8 @@ class TestGridFile(IntegrationTest):
g = GridOut(self.db.fs, f._id)
self.assertEqual(random_string, g.read())
# TODO: https://jira.mongodb.org/browse/PYTHON-4708
@client_context.require_sync
def test_small_chunks(self):
self.files = 0
self.chunks = 0
@ -415,18 +442,21 @@ class TestGridFile(IntegrationTest):
self.assertEqual(b"hello world", g.read())
g.seek(1)
self.assertEqual(b"ello world", g.read())
self.assertRaises(IOError, g.seek, -1)
with self.assertRaises(IOError):
g.seek(-1)
g.seek(-3, _SEEK_END)
self.assertEqual(b"rld", g.read())
g.seek(0, _SEEK_END)
self.assertEqual(b"", g.read())
self.assertRaises(IOError, g.seek, -100, _SEEK_END)
with self.assertRaises(IOError):
g.seek(-100, _SEEK_END)
g.seek(3)
g.seek(3, _SEEK_CUR)
self.assertEqual(b"world", g.read())
self.assertRaises(IOError, g.seek, -100, _SEEK_CUR)
with self.assertRaises(IOError):
g.seek(-100, _SEEK_CUR)
def test_tell(self):
f = GridIn(self.db.fs, chunkSize=3)
@ -519,12 +549,14 @@ Bye"""
# Only readlines().
g = GridOut(self.db.fs, f._id)
self.assertEqual(
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines()
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
g.readlines(),
)
g = GridOut(self.db.fs, f._id)
self.assertEqual(
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines(0)
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
g.readlines(0),
)
g = GridOut(self.db.fs, f._id)
@ -550,15 +582,25 @@ Bye"""
f = GridIn(self.db.fs)
f.close()
g = GridOut(self.db.fs, f._id)
self.assertEqual([], list(g))
if _IS_SYNC:
self.assertEqual([], list(g))
else:
self.assertEqual([], g.to_list())
f = GridIn(self.db.fs)
f.write(b"hello world\nhere are\nsome lines.")
f.close()
g = GridOut(self.db.fs, f._id)
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
if _IS_SYNC:
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
else:
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], g.to_list())
self.assertEqual(b"", g.read(5))
self.assertEqual([], list(g))
if _IS_SYNC:
self.assertEqual([], list(g))
else:
self.assertEqual([], g.to_list())
g = GridOut(self.db.fs, f._id)
self.assertEqual(b"hello world\n", next(iter(g)))
@ -566,13 +608,17 @@ Bye"""
self.assertEqual(b" are\n", next(iter(g)))
self.assertEqual(b"some lines", g.read(10))
self.assertEqual(b".", next(iter(g)))
self.assertRaises(StopIteration, iter(g).__next__)
with self.assertRaises(StopIteration):
iter(g).__next__()
f = GridIn(self.db.fs, chunk_size=2)
f.write(b"hello world")
f.close()
g = GridOut(self.db.fs, f._id)
self.assertEqual([b"hello world"], list(g))
if _IS_SYNC:
self.assertEqual([b"hello world"], list(g))
else:
self.assertEqual([b"hello world"], g.to_list())
def test_read_unaligned_buffer_size(self):
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
@ -610,7 +656,8 @@ Bye"""
def test_write_unicode(self):
f = GridIn(self.db.fs)
self.assertRaises(TypeError, f.write, "foo")
with self.assertRaises(TypeError):
f.write("foo")
f = GridIn(self.db.fs, encoding="utf-8")
f.write("foo")
@ -635,8 +682,12 @@ Bye"""
self.assertRaises(AttributeError, getattr, f, "uploadDate")
self.assertRaises(AttributeError, setattr, f, "_id", 5)
f.bar = "foo"
f.baz = 5
if _IS_SYNC:
f.bar = "foo"
f.baz = 5
else:
f.set("bar", "foo")
f.set("baz", 5)
self.assertEqual("foo", f._id)
self.assertEqual("foo", f.bar)
@ -651,11 +702,17 @@ Bye"""
self.assertTrue(f.uploadDate)
self.assertRaises(AttributeError, setattr, f, "_id", 5)
f.bar = "a"
f.baz = "b"
if _IS_SYNC:
f.bar = "a"
f.baz = "b"
else:
f.set("bar", "a")
f.set("baz", "b")
self.assertRaises(AttributeError, setattr, f, "upload_date", 5)
g = GridOut(self.db.fs, f._id)
if not _IS_SYNC:
g.open()
self.assertEqual("a", g.bar)
self.assertEqual("b", g.baz)
# Versions 2.0.1 and older saved a _closed field for some reason.
@ -713,8 +770,12 @@ Bye"""
def test_grid_out_lazy_connect(self):
fs = self.db.fs
outfile = GridOut(fs, file_id=-1)
self.assertRaises(NoFile, outfile.read)
self.assertRaises(NoFile, getattr, outfile, "filename")
with self.assertRaises(NoFile):
outfile.read()
with self.assertRaises(NoFile):
if not _IS_SYNC:
outfile.open()
outfile.filename
infile = GridIn(fs, filename=1)
infile.close()
@ -730,13 +791,15 @@ Bye"""
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
fs = client.db.fs
infile = GridIn(fs, file_id=-1, chunk_size=1)
self.assertRaises(ServerSelectionTimeoutError, infile.write, b"data")
self.assertRaises(ServerSelectionTimeoutError, infile.close)
with self.assertRaises(ServerSelectionTimeoutError):
infile.write(b"data")
with self.assertRaises(ServerSelectionTimeoutError):
infile.close()
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):
GridIn(rs_or_single_client(w=0).pymongo_test.fs)
GridIn((rs_or_single_client(w=0)).pymongo_test.fs)
def test_survive_cursor_not_found(self):
# By default the find command returns 101 documents in the first batch.
@ -758,7 +821,7 @@ Bye"""
assert client.address is not None
client._close_cursor_now(
outfile._chunk_iter._cursor.cursor_id,
_CursorAddress(client.address, db.fs.chunks.full_name),
_CursorAddress(client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
)
# Read the rest of the file without error.
@ -767,6 +830,7 @@ Bye"""
# Paranoid, ensure that a getMore was actually sent.
self.assertIn("getMore", listener.started_command_names())
@client_context.require_sync
def test_zip(self):
zf = BytesIO()
z = zipfile.ZipFile(zf, "w")

View File

@ -1954,7 +1954,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
if client.get("ignoreExtraMessages", False):
actual_logs = actual_logs[: len(client["messages"])]
self.assertEqual(len(client["messages"]), len(actual_logs))
self.assertEqual(
len(client["messages"]),
len(actual_logs),
f"expected {client['messages']} but got {actual_logs}",
)
for expected_msg, actual_msg in zip(client["messages"], actual_logs):
expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data")

View File

@ -0,0 +1,141 @@
from __future__ import annotations
import asyncio
import sys
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
replacements = {
"Collection": "AsyncCollection",
"Database": "AsyncDatabase",
"Cursor": "AsyncCursor",
"MongoClient": "AsyncMongoClient",
"CommandCursor": "AsyncCommandCursor",
"RawBatchCursor": "AsyncRawBatchCursor",
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
"ClientSession": "AsyncClientSession",
"ChangeStream": "AsyncChangeStream",
"CollectionChangeStream": "AsyncCollectionChangeStream",
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
"ClusterChangeStream": "AsyncClusterChangeStream",
"_Bulk": "_AsyncBulk",
"_ClientBulk": "_AsyncClientBulk",
"Connection": "AsyncConnection",
"synchronous": "asynchronous",
"Synchronous": "Asynchronous",
"next": "await anext",
"_Lock": "_ALock",
"_Condition": "_ACondition",
"GridFS": "AsyncGridFS",
"GridFSBucket": "AsyncGridFSBucket",
"GridIn": "AsyncGridIn",
"GridOut": "AsyncGridOut",
"GridOutCursor": "AsyncGridOutCursor",
"GridOutIterator": "AsyncGridOutIterator",
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
"_grid_in_property": "_a_grid_in_property",
"_grid_out_property": "_a_grid_out_property",
"ClientEncryption": "AsyncClientEncryption",
"MongoCryptCallback": "AsyncMongoCryptCallback",
"ExplicitEncrypter": "AsyncExplicitEncrypter",
"AutoEncrypter": "AsyncAutoEncrypter",
"ContextManager": "AsyncContextManager",
"ClientContext": "AsyncClientContext",
"TestCollection": "AsyncTestCollection",
"IntegrationTest": "AsyncIntegrationTest",
"PyMongoTestCase": "AsyncPyMongoTestCase",
"MockClientTest": "AsyncMockClientTest",
"client_context": "async_client_context",
"setUp": "asyncSetUp",
"tearDown": "asyncTearDown",
"wait_until": "await async_wait_until",
"addCleanup": "addAsyncCleanup",
"TestCase": "IsolatedAsyncioTestCase",
"UnitTest": "AsyncUnitTest",
"MockClient": "AsyncMockClient",
"SpecRunner": "AsyncSpecRunner",
"TransactionsBase": "AsyncTransactionsBase",
"get_pool": "await async_get_pool",
"is_mongos": "await async_is_mongos",
"rs_or_single_client": "await async_rs_or_single_client",
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
"rs_client": "await async_rs_client",
"single_client": "await async_single_client",
"from_client": "await async_from_client",
"closing": "aclosing",
"assertRaisesExactly": "asyncAssertRaisesExactly",
"get_mock_client": "await get_async_mock_client",
"close": "await aclose",
}
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
def get_async_methods() -> set[str]:
result: set[str] = set()
for x in async_classes:
methods = {
k
for k, v in vars(x).items()
if callable(v)
and not isinstance(v, classmethod)
and asyncio.iscoroutinefunction(v)
and v.__name__[0] != "_"
}
result = result | methods
return result
async_methods = get_async_methods()
def apply_replacements(lines: list[str]) -> list[str]:
for i in range(len(lines)):
if "_IS_SYNC = True" in lines[i]:
lines[i] = "_IS_SYNC = False"
if "def test" in lines[i]:
lines[i] = lines[i].replace("def test", "async def test")
for k in replacements:
if k in lines[i]:
lines[i] = lines[i].replace(k, replacements[k])
for k in async_methods:
if k + "(" in lines[i]:
tokens = lines[i].split(" ")
for j in range(len(tokens)):
if k + "(" in tokens[j]:
if j < 2:
tokens.insert(0, "await")
else:
tokens.insert(j, "await")
break
new_line = " ".join(tokens)
lines[i] = new_line
return lines
def process_file(input_file: str, output_file: str) -> None:
with open(input_file, "r+") as f:
lines = f.readlines()
lines = apply_replacements(lines)
with open(output_file, "w+") as f2:
f2.seek(0)
f2.writelines(lines)
f2.truncate()
def main() -> None:
args = sys.argv[1:]
sync_file = "./test/" + args[0]
async_file = "./" + args[0]
process_file(sync_file, async_file)
main()

View File

@ -46,7 +46,10 @@ replacements = {
"async_sendall": "sendall",
"asynchronous": "synchronous",
"Asynchronous": "Synchronous",
"AsyncBulkTestBase": "BulkTestBase",
"AsyncBulkAuthorizationTestBase": "BulkAuthorizationTestBase",
"anext": "next",
"aiter": "iter",
"_ALock": "_Lock",
"_ACondition": "_Condition",
"AsyncGridFS": "GridFS",
@ -98,6 +101,8 @@ replacements = {
"default_async": "default",
"aclose": "close",
"PyMongo|async": "PyMongo",
"AsyncTestGridFile": "TestGridFile",
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
}
docstring_replacements: dict[tuple[str, str], str] = {
@ -154,15 +159,18 @@ converted_tests = [
"conftest.py",
"pymongo_mocks.py",
"utils_spec_runner.py",
"test_bulk.py",
"test_client.py",
"test_client_bulk_write.py",
"test_collection.py",
"test_cursor.py",
"test_database.py",
"test_encryption.py",
"test_grid_file.py",
"test_logger.py",
"test_session.py",
"test_transactions.py",
"test_client_context.py",
]
sync_test_files = [
@ -294,6 +302,8 @@ def translate_docstrings(lines: list[str]) -> list[str]:
lines[i] = lines[i].replace(k, replacements[k])
if "Sync" in lines[i] and "Synchronous" not in lines[i] and replacements[k] in lines[i]:
lines[i] = lines[i].replace("Sync", "")
if "rsApplyStop" in lines[i]:
lines[i] = lines[i].replace("rsApplyStop", "rsSyncApplyStop")
if "async for" in lines[i] or "async with" in lines[i] or "async def" in lines[i]:
lines[i] = lines[i].replace("async ", "")
if "await " in lines[i] and "tailable" not in lines[i]: