Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
de3fed95ee
2
.github/workflows/test-python.yml
vendored
2
.github/workflows/test-python.yml
vendored
@ -209,4 +209,4 @@ jobs:
|
||||
ls
|
||||
which python
|
||||
pip install -e ".[test]"
|
||||
PYMONGO_MUST_CONNECT=1 pytest -v test/test_client_context.py
|
||||
PYMONGO_MUST_CONNECT=1 pytest -v -k client_context
|
||||
|
||||
@ -248,3 +248,10 @@ you are attempting to validate new spec tests in PyMongo.
|
||||
## Making a Release
|
||||
|
||||
Follow the [Python Driver Release Process Wiki](https://wiki.corp.mongodb.com/display/DRIVERS/Python+Driver+Release+Process).
|
||||
|
||||
## Converting a test to async
|
||||
The `tools/convert_test_to_async.py` script takes in an existing synchronous test file and outputs a
|
||||
partially-converted asynchronous version of the same name to the `test/asynchronous` directory.
|
||||
Use this generated file as a starting point for the completed conversion.
|
||||
|
||||
The script is used like so: `python tools/convert_test_to_async.py [test_file.py]`
|
||||
|
||||
@ -1176,24 +1176,6 @@ class AsyncGridIn:
|
||||
raise AttributeError("GridIn object has no attribute '%s'" % name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
# For properties of this instance like _buffer, or descriptors set on
|
||||
# the class like filename, use regular __setattr__
|
||||
if name in self.__dict__ or name in self.__class__.__dict__:
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if _IS_SYNC:
|
||||
# All other attributes are part of the document in db.fs.files.
|
||||
# Store them to be sent to server on close() or if closed, send
|
||||
# them now.
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
else:
|
||||
raise AttributeError(
|
||||
"AsyncGridIn does not support __setattr__. Use AsyncGridIn.set() instead"
|
||||
)
|
||||
|
||||
async def set(self, name: str, value: Any) -> None:
|
||||
# For properties of this instance like _buffer, or descriptors set on
|
||||
# the class like filename, use regular __setattr__
|
||||
if name in self.__dict__ or name in self.__class__.__dict__:
|
||||
@ -1204,9 +1186,17 @@ class AsyncGridIn:
|
||||
# them now.
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
await self._coll.files.update_one(
|
||||
{"_id": self._file["_id"]}, {"$set": {name: value}}
|
||||
)
|
||||
if _IS_SYNC:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
else:
|
||||
raise AttributeError(
|
||||
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
|
||||
)
|
||||
|
||||
async def set(self, name: str, value: Any) -> None:
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
|
||||
async def _flush_data(self, data: Any, force: bool = False) -> None:
|
||||
"""Flush `data` to a chunk."""
|
||||
@ -1400,7 +1390,11 @@ class AsyncGridIn:
|
||||
return False
|
||||
|
||||
|
||||
class AsyncGridOut(io.IOBase):
|
||||
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any
|
||||
|
||||
|
||||
class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
|
||||
"""Class to read data out of GridFS."""
|
||||
|
||||
def __init__(
|
||||
@ -1460,6 +1454,8 @@ class AsyncGridOut(io.IOBase):
|
||||
self._position = 0
|
||||
self._file = file_document
|
||||
self._session = session
|
||||
if not _IS_SYNC:
|
||||
self.closed = False
|
||||
|
||||
_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
|
||||
filename: str = _a_grid_out_property("filename", "Name of this file.")
|
||||
@ -1486,16 +1482,43 @@ class AsyncGridOut(io.IOBase):
|
||||
_file: Any
|
||||
_chunk_iter: Any
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
return super().__next__()
|
||||
if not _IS_SYNC:
|
||||
closed: bool
|
||||
|
||||
def __next__(self) -> bytes: # noqa: F811, RUF100
|
||||
if _IS_SYNC:
|
||||
return super().__next__()
|
||||
else:
|
||||
raise TypeError(
|
||||
"AsyncGridOut does not support synchronous iteration. Use `async for` instead"
|
||||
)
|
||||
async def __anext__(self) -> bytes:
|
||||
line = await self.readline()
|
||||
if line:
|
||||
return line
|
||||
raise StopAsyncIteration()
|
||||
|
||||
async def to_list(self) -> list[bytes]:
|
||||
return [x async for x in self] # noqa: C416, RUF100
|
||||
|
||||
async def readline(self, size: int = -1) -> bytes:
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
return await self._read_size_or_line(size=size, line=True)
|
||||
|
||||
async def readlines(self, size: int = -1) -> list[bytes]:
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
await self.open()
|
||||
lines = []
|
||||
remainder = int(self.length) - self._position
|
||||
bytes_read = 0
|
||||
while remainder > 0:
|
||||
line = await self._read_size_or_line(line=True)
|
||||
bytes_read += len(line)
|
||||
lines.append(line)
|
||||
remainder = int(self.length) - self._position
|
||||
if 0 < size < bytes_read:
|
||||
break
|
||||
|
||||
return lines
|
||||
|
||||
async def open(self) -> None:
|
||||
if not self._file:
|
||||
@ -1616,18 +1639,11 @@ class AsyncGridOut(io.IOBase):
|
||||
"""
|
||||
return await self._read_size_or_line(size=size)
|
||||
|
||||
async def readline(self, size: int = -1) -> bytes: # type: ignore[override]
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
return await self._read_size_or_line(size=size, line=True)
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position of this file."""
|
||||
return self._position
|
||||
|
||||
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
|
||||
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
|
||||
"""Set the current position of this file.
|
||||
|
||||
:param pos: the position (or offset if using relative
|
||||
@ -1690,12 +1706,15 @@ class AsyncGridOut(io.IOBase):
|
||||
"""
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
async def close(self) -> None:
|
||||
"""Make GridOut more generically file-like."""
|
||||
if self._chunk_iter:
|
||||
await self._chunk_iter.close()
|
||||
self._chunk_iter = None
|
||||
super().close()
|
||||
if _IS_SYNC:
|
||||
super().close()
|
||||
else:
|
||||
self.closed = True
|
||||
|
||||
def write(self, value: Any) -> NoReturn:
|
||||
raise io.UnsupportedOperation("write")
|
||||
|
||||
@ -38,7 +38,15 @@ def _a_grid_in_property(
|
||||
) -> Any:
|
||||
"""Create a GridIn property."""
|
||||
|
||||
warn_str = ""
|
||||
if docstring.startswith("DEPRECATED,"):
|
||||
warn_str = (
|
||||
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
|
||||
)
|
||||
|
||||
def getter(self: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if closed_only and not self._closed:
|
||||
raise AttributeError("can only get %r on a closed file" % field_name)
|
||||
# Protect against PHP-237
|
||||
@ -46,6 +54,15 @@ def _a_grid_in_property(
|
||||
return self._file.get(field_name, 0)
|
||||
return self._file.get(field_name, None)
|
||||
|
||||
def setter(self: Any, value: Any) -> Any:
|
||||
if warn_str:
|
||||
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
|
||||
if self._closed:
|
||||
raise InvalidOperation(
|
||||
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
|
||||
)
|
||||
self._file[field_name] = value
|
||||
|
||||
if read_only:
|
||||
docstring += "\n\nThis attribute is read-only."
|
||||
elif closed_only:
|
||||
@ -56,6 +73,8 @@ def _a_grid_in_property(
|
||||
"has been called.",
|
||||
)
|
||||
|
||||
if not read_only and not closed_only:
|
||||
return property(getter, setter, doc=docstring)
|
||||
return property(getter, doc=docstring)
|
||||
|
||||
|
||||
|
||||
@ -1166,24 +1166,6 @@ class GridIn:
|
||||
raise AttributeError("GridIn object has no attribute '%s'" % name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
# For properties of this instance like _buffer, or descriptors set on
|
||||
# the class like filename, use regular __setattr__
|
||||
if name in self.__dict__ or name in self.__class__.__dict__:
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if _IS_SYNC:
|
||||
# All other attributes are part of the document in db.fs.files.
|
||||
# Store them to be sent to server on close() or if closed, send
|
||||
# them now.
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
else:
|
||||
raise AttributeError(
|
||||
"GridIn does not support __setattr__. Use GridIn.set() instead"
|
||||
)
|
||||
|
||||
def set(self, name: str, value: Any) -> None:
|
||||
# For properties of this instance like _buffer, or descriptors set on
|
||||
# the class like filename, use regular __setattr__
|
||||
if name in self.__dict__ or name in self.__class__.__dict__:
|
||||
@ -1194,7 +1176,17 @@ class GridIn:
|
||||
# them now.
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
if _IS_SYNC:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
else:
|
||||
raise AttributeError(
|
||||
"GridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use GridIn.set() instead"
|
||||
)
|
||||
|
||||
def set(self, name: str, value: Any) -> None:
|
||||
self._file[name] = value
|
||||
if self._closed:
|
||||
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
|
||||
|
||||
def _flush_data(self, data: Any, force: bool = False) -> None:
|
||||
"""Flush `data` to a chunk."""
|
||||
@ -1388,7 +1380,11 @@ class GridIn:
|
||||
return False
|
||||
|
||||
|
||||
class GridOut(io.IOBase):
|
||||
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any
|
||||
|
||||
|
||||
class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
|
||||
"""Class to read data out of GridFS."""
|
||||
|
||||
def __init__(
|
||||
@ -1448,6 +1444,8 @@ class GridOut(io.IOBase):
|
||||
self._position = 0
|
||||
self._file = file_document
|
||||
self._session = session
|
||||
if not _IS_SYNC:
|
||||
self.closed = False
|
||||
|
||||
_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
|
||||
filename: str = _grid_out_property("filename", "Name of this file.")
|
||||
@ -1474,14 +1472,43 @@ class GridOut(io.IOBase):
|
||||
_file: Any
|
||||
_chunk_iter: Any
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
return super().__next__()
|
||||
if not _IS_SYNC:
|
||||
closed: bool
|
||||
|
||||
def __next__(self) -> bytes: # noqa: F811, RUF100
|
||||
if _IS_SYNC:
|
||||
return super().__next__()
|
||||
else:
|
||||
raise TypeError("GridOut does not support synchronous iteration. Use `for` instead")
|
||||
def __next__(self) -> bytes:
|
||||
line = self.readline()
|
||||
if line:
|
||||
return line
|
||||
raise StopIteration()
|
||||
|
||||
def to_list(self) -> list[bytes]:
|
||||
return [x for x in self] # noqa: C416, RUF100
|
||||
|
||||
def readline(self, size: int = -1) -> bytes:
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
return self._read_size_or_line(size=size, line=True)
|
||||
|
||||
def readlines(self, size: int = -1) -> list[bytes]:
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
self.open()
|
||||
lines = []
|
||||
remainder = int(self.length) - self._position
|
||||
bytes_read = 0
|
||||
while remainder > 0:
|
||||
line = self._read_size_or_line(line=True)
|
||||
bytes_read += len(line)
|
||||
lines.append(line)
|
||||
remainder = int(self.length) - self._position
|
||||
if 0 < size < bytes_read:
|
||||
break
|
||||
|
||||
return lines
|
||||
|
||||
def open(self) -> None:
|
||||
if not self._file:
|
||||
@ -1602,18 +1629,11 @@ class GridOut(io.IOBase):
|
||||
"""
|
||||
return self._read_size_or_line(size=size)
|
||||
|
||||
def readline(self, size: int = -1) -> bytes: # type: ignore[override]
|
||||
"""Read one line or up to `size` bytes from the file.
|
||||
|
||||
:param size: the maximum number of bytes to read
|
||||
"""
|
||||
return self._read_size_or_line(size=size, line=True)
|
||||
|
||||
def tell(self) -> int:
|
||||
"""Return the current position of this file."""
|
||||
return self._position
|
||||
|
||||
def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
|
||||
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
|
||||
"""Set the current position of this file.
|
||||
|
||||
:param pos: the position (or offset if using relative
|
||||
@ -1676,12 +1696,15 @@ class GridOut(io.IOBase):
|
||||
"""
|
||||
return self
|
||||
|
||||
def close(self) -> None: # type: ignore[override]
|
||||
def close(self) -> None:
|
||||
"""Make GridOut more generically file-like."""
|
||||
if self._chunk_iter:
|
||||
self._chunk_iter.close()
|
||||
self._chunk_iter = None
|
||||
super().close()
|
||||
if _IS_SYNC:
|
||||
super().close()
|
||||
else:
|
||||
self.closed = True
|
||||
|
||||
def write(self, value: Any) -> NoReturn:
|
||||
raise io.UnsupportedOperation("write")
|
||||
|
||||
@ -228,6 +228,10 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
if not isinstance(database, AsyncDatabase):
|
||||
raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given")
|
||||
|
||||
if not name or ".." in name:
|
||||
raise InvalidName("collection names cannot be empty")
|
||||
|
||||
@ -119,9 +119,14 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern or client.read_concern,
|
||||
)
|
||||
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
|
||||
if not isinstance(client, AsyncMongoClient):
|
||||
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
|
||||
|
||||
if name != "$external":
|
||||
_check_name(name)
|
||||
|
||||
|
||||
@ -194,9 +194,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
|
||||
async def collection_info(
|
||||
self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes
|
||||
) -> Optional[bytes]:
|
||||
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
|
||||
"""Get the collection info for a namespace.
|
||||
|
||||
The returned collection info is passed to libmongocrypt which reads
|
||||
@ -598,6 +596,9 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
|
||||
|
||||
if not isinstance(key_vault_client, AsyncMongoClient):
|
||||
raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}")
|
||||
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
@ -683,6 +684,11 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
https://mongodb.com/docs/manual/reference/command/create
|
||||
|
||||
"""
|
||||
if not isinstance(database, AsyncDatabase):
|
||||
raise TypeError(
|
||||
f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given"
|
||||
)
|
||||
|
||||
encrypted_fields = deepcopy(encrypted_fields)
|
||||
for i, field in enumerate(encrypted_fields["fields"]):
|
||||
if isinstance(field, dict) and field.get("keyId") is None:
|
||||
|
||||
@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F:
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
anext = builtins.anext
|
||||
aiter = builtins.aiter
|
||||
else:
|
||||
|
||||
async def anext(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
|
||||
return await cls.__anext__()
|
||||
|
||||
def aiter(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
|
||||
return cls.__aiter__()
|
||||
|
||||
@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler:
|
||||
def __init__(
|
||||
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
|
||||
):
|
||||
if not isinstance(client, AsyncMongoClient):
|
||||
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
|
||||
|
||||
self.client = client
|
||||
self.server_address = server.description.address
|
||||
self.session = session
|
||||
|
||||
@ -475,7 +475,7 @@ class Topology:
|
||||
if server:
|
||||
await server.pool.ready()
|
||||
|
||||
suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description
|
||||
suppress_event = sd_old == server_description
|
||||
if self._publish_server and not suppress_event:
|
||||
assert self._events is not None
|
||||
self._events.put(
|
||||
@ -497,7 +497,7 @@ class Topology:
|
||||
(td_old, self._description, self._topology_id),
|
||||
)
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event:
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology_id,
|
||||
@ -521,7 +521,7 @@ class Topology:
|
||||
if server:
|
||||
await server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
# Wake waiters in select_servers().
|
||||
# Wake anything waiting in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
async def on_change(
|
||||
|
||||
@ -231,6 +231,10 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
)
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
from pymongo.synchronous.database import Database
|
||||
|
||||
if not isinstance(database, Database):
|
||||
raise TypeError(f"Collection requires a Database but {type(database)} given")
|
||||
|
||||
if not name or ".." in name:
|
||||
raise InvalidName("collection names cannot be empty")
|
||||
|
||||
@ -119,9 +119,14 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
read_concern or client.read_concern,
|
||||
)
|
||||
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("name must be an instance of str")
|
||||
|
||||
if not isinstance(client, MongoClient):
|
||||
raise TypeError(f"MongoClient required but given {type(client)}")
|
||||
|
||||
if name != "$external":
|
||||
_check_name(name)
|
||||
|
||||
|
||||
@ -194,9 +194,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
# Wrap I/O errors in PyMongo exceptions.
|
||||
_raise_connection_failure((host, port), error)
|
||||
|
||||
def collection_info(
|
||||
self, database: Database[Mapping[str, Any]], filter: bytes
|
||||
) -> Optional[bytes]:
|
||||
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
|
||||
"""Get the collection info for a namespace.
|
||||
|
||||
The returned collection info is passed to libmongocrypt which reads
|
||||
@ -596,6 +594,9 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
if not isinstance(codec_options, CodecOptions):
|
||||
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
|
||||
|
||||
if not isinstance(key_vault_client, MongoClient):
|
||||
raise TypeError(f"MongoClient required but given {type(key_vault_client)}")
|
||||
|
||||
self._kms_providers = kms_providers
|
||||
self._key_vault_namespace = key_vault_namespace
|
||||
self._key_vault_client = key_vault_client
|
||||
@ -681,6 +682,11 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
https://mongodb.com/docs/manual/reference/command/create
|
||||
|
||||
"""
|
||||
if not isinstance(database, Database):
|
||||
raise TypeError(
|
||||
f"create_encrypted_collection() requires a Database but {type(database)} given"
|
||||
)
|
||||
|
||||
encrypted_fields = deepcopy(encrypted_fields)
|
||||
for i, field in enumerate(encrypted_fields["fields"]):
|
||||
if isinstance(field, dict) and field.get("keyId") is None:
|
||||
|
||||
@ -70,8 +70,13 @@ def _handle_reauth(func: F) -> F:
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
next = builtins.next
|
||||
iter = builtins.iter
|
||||
else:
|
||||
|
||||
def next(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
|
||||
return cls.__next__()
|
||||
|
||||
def iter(cls: Any) -> Any:
|
||||
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
|
||||
return cls.__iter__()
|
||||
|
||||
@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler:
|
||||
)
|
||||
|
||||
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
|
||||
if not isinstance(client, MongoClient):
|
||||
raise TypeError(f"MongoClient required but given {type(client)}")
|
||||
|
||||
self.client = client
|
||||
self.server_address = server.description.address
|
||||
self.session = session
|
||||
|
||||
@ -475,7 +475,7 @@ class Topology:
|
||||
if server:
|
||||
server.pool.ready()
|
||||
|
||||
suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description
|
||||
suppress_event = sd_old == server_description
|
||||
if self._publish_server and not suppress_event:
|
||||
assert self._events is not None
|
||||
self._events.put(
|
||||
@ -497,7 +497,7 @@ class Topology:
|
||||
(td_old, self._description, self._topology_id),
|
||||
)
|
||||
)
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
if _SDAM_LOGGER.isEnabledFor(logging.DEBUG) and not suppress_event:
|
||||
_debug_log(
|
||||
_SDAM_LOGGER,
|
||||
topologyId=self._topology_id,
|
||||
@ -521,7 +521,7 @@ class Topology:
|
||||
if server:
|
||||
server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
# Wake waiters in select_servers().
|
||||
# Wake anything waiting in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
def on_change(
|
||||
|
||||
@ -569,7 +569,10 @@ class ClientContext:
|
||||
def sec_count():
|
||||
return 0 if not self.client else len(self.client.secondaries)
|
||||
|
||||
return self._require(lambda: sec_count() >= count, "Not enough secondaries available")
|
||||
def check():
|
||||
return sec_count() >= count
|
||||
|
||||
return self._require(check, "Not enough secondaries available")
|
||||
|
||||
@property
|
||||
def supports_secondary_read_pref(self):
|
||||
@ -947,11 +950,11 @@ class UnitTest(PyMongoTestCase):
|
||||
|
||||
@classmethod
|
||||
def _setup_class(cls):
|
||||
cls._setup_class()
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _tearDown_class(cls):
|
||||
cls._tearDown_class()
|
||||
pass
|
||||
|
||||
|
||||
class IntegrationTest(PyMongoTestCase):
|
||||
|
||||
@ -571,7 +571,10 @@ class AsyncClientContext:
|
||||
async def sec_count():
|
||||
return 0 if not self.client else len(await self.client.secondaries)
|
||||
|
||||
return self._require(lambda: sec_count() >= count, "Not enough secondaries available")
|
||||
async def check():
|
||||
return await sec_count() >= count
|
||||
|
||||
return self._require(check, "Not enough secondaries available")
|
||||
|
||||
@property
|
||||
async def supports_secondary_read_pref(self):
|
||||
@ -949,11 +952,11 @@ class AsyncUnitTest(AsyncPyMongoTestCase):
|
||||
|
||||
@classmethod
|
||||
async def _setup_class(cls):
|
||||
await cls._setup_class()
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls._tearDown_class()
|
||||
pass
|
||||
|
||||
|
||||
class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
|
||||
1134
test/asynchronous/test_bulk.py
Normal file
1134
test/asynchronous/test_bulk.py
Normal file
File diff suppressed because it is too large
Load Diff
66
test/asynchronous/test_client_context.py
Normal file
66
test/asynchronous/test_client_context.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2018-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncUnitTest, SkipTest, async_client_context, unittest
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestAsyncClientContext(AsyncUnitTest):
|
||||
def test_must_connect(self):
|
||||
if "PYMONGO_MUST_CONNECT" not in os.environ:
|
||||
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
|
||||
|
||||
self.assertTrue(
|
||||
async_client_context.connected,
|
||||
"client context must be connected when "
|
||||
"PYMONGO_MUST_CONNECT is set. Failed attempts:\n{}".format(
|
||||
async_client_context.connection_attempt_info()
|
||||
),
|
||||
)
|
||||
|
||||
def test_serverless(self):
|
||||
if "TEST_SERVERLESS" not in os.environ:
|
||||
raise SkipTest("TEST_SERVERLESS is not set")
|
||||
|
||||
self.assertTrue(
|
||||
async_client_context.connected and async_client_context.serverless,
|
||||
"client context must be connected to serverless when "
|
||||
f"TEST_SERVERLESS is set. Failed attempts:\n{async_client_context.connection_attempt_info()}",
|
||||
)
|
||||
|
||||
def test_enableTestCommands_is_disabled(self):
|
||||
if "PYMONGO_DISABLE_TEST_COMMANDS" not in os.environ:
|
||||
raise SkipTest("PYMONGO_DISABLE_TEST_COMMANDS is not set")
|
||||
|
||||
self.assertFalse(
|
||||
async_client_context.test_commands_enabled,
|
||||
"enableTestCommands must be disabled when PYMONGO_DISABLE_TEST_COMMANDS is set.",
|
||||
)
|
||||
|
||||
def test_setdefaultencoding_worked(self):
|
||||
if "SETDEFAULTENCODING" not in os.environ:
|
||||
raise SkipTest("SETDEFAULTENCODING is not set")
|
||||
|
||||
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
871
test/asynchronous/test_grid_file.py
Normal file
871
test/asynchronous/test_grid_file.py
Normal file
@ -0,0 +1,871 @@
|
||||
#
|
||||
# Copyright 2009-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the grid_file module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import io
|
||||
import sys
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context
|
||||
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import IntegrationTest, qcheck, unittest
|
||||
from test.utils import EventListener, async_rs_or_single_client, rs_or_single_client
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs import GridFS
|
||||
from gridfs.asynchronous.grid_file import (
|
||||
_SEEK_CUR,
|
||||
_SEEK_END,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
AsyncGridFS,
|
||||
AsyncGridIn,
|
||||
AsyncGridOut,
|
||||
AsyncGridOutCursor,
|
||||
)
|
||||
from gridfs.errors import NoFile
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.helpers import aiter, anext
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError
|
||||
from pymongo.message import _CursorAddress
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncTestGridFileNoConnect(AsyncUnitTest):
|
||||
"""Test GridFile features on a client that does not connect."""
|
||||
|
||||
db: AsyncDatabase
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.db = AsyncMongoClient(connect=False).pymongo_test
|
||||
|
||||
def test_grid_in_custom_opts(self):
|
||||
self.assertRaises(TypeError, AsyncGridIn, "foo")
|
||||
|
||||
a = AsyncGridIn(
|
||||
self.db.fs,
|
||||
_id=5,
|
||||
filename="my_file",
|
||||
contentType="text/html",
|
||||
chunkSize=1000,
|
||||
aliases=["foo"],
|
||||
metadata={"foo": 1, "bar": 2},
|
||||
bar=3,
|
||||
baz="hello",
|
||||
)
|
||||
|
||||
self.assertEqual(5, a._id)
|
||||
self.assertEqual("my_file", a.filename)
|
||||
self.assertEqual("my_file", a.name)
|
||||
self.assertEqual("text/html", a.content_type)
|
||||
self.assertEqual(1000, a.chunk_size)
|
||||
self.assertEqual(["foo"], a.aliases)
|
||||
self.assertEqual({"foo": 1, "bar": 2}, a.metadata)
|
||||
self.assertEqual(3, a.bar)
|
||||
self.assertEqual("hello", a.baz)
|
||||
self.assertRaises(AttributeError, getattr, a, "mike")
|
||||
|
||||
b = AsyncGridIn(self.db.fs, content_type="text/html", chunk_size=1000, baz=100)
|
||||
self.assertEqual("text/html", b.content_type)
|
||||
self.assertEqual(1000, b.chunk_size)
|
||||
self.assertEqual(100, b.baz)
|
||||
|
||||
|
||||
class AsyncTestGridFile(AsyncIntegrationTest):
|
||||
async def asyncSetUp(self):
|
||||
await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks)
|
||||
|
||||
async def test_basic(self):
|
||||
f = AsyncGridIn(self.db.fs, filename="test")
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"hello world", await g.read())
|
||||
|
||||
# make sure it's still there...
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"hello world", await g.read())
|
||||
|
||||
f = AsyncGridIn(self.db.fs, filename="test")
|
||||
await f.close()
|
||||
self.assertEqual(2, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"", await g.read())
|
||||
|
||||
# test that reading 0 returns proper type
|
||||
self.assertEqual(b"", await g.read(0))
|
||||
|
||||
async def test_md5(self):
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
await f.write(b"hello world\n")
|
||||
await f.close()
|
||||
self.assertEqual(None, f.md5)
|
||||
|
||||
async def test_alternate_collection(self):
|
||||
await self.db.alt.files.delete_many({})
|
||||
await self.db.alt.chunks.delete_many({})
|
||||
|
||||
f = AsyncGridIn(self.db.alt)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
|
||||
self.assertEqual(1, await self.db.alt.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.alt.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.alt, f._id)
|
||||
self.assertEqual(b"hello world", await g.read())
|
||||
|
||||
async def test_grid_in_default_opts(self):
|
||||
self.assertRaises(TypeError, AsyncGridIn, "foo")
|
||||
|
||||
a = AsyncGridIn(self.db.fs)
|
||||
|
||||
self.assertTrue(isinstance(a._id, ObjectId))
|
||||
self.assertRaises(AttributeError, setattr, a, "_id", 5)
|
||||
|
||||
self.assertEqual(None, a.filename)
|
||||
self.assertEqual(None, a.name)
|
||||
a.filename = "my_file"
|
||||
self.assertEqual("my_file", a.filename)
|
||||
self.assertEqual("my_file", a.name)
|
||||
|
||||
self.assertEqual(None, a.content_type)
|
||||
a.content_type = "text/html"
|
||||
|
||||
self.assertEqual("text/html", a.content_type)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "length")
|
||||
self.assertRaises(AttributeError, setattr, a, "length", 5)
|
||||
|
||||
self.assertEqual(255 * 1024, a.chunk_size)
|
||||
self.assertRaises(AttributeError, setattr, a, "chunk_size", 5)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "upload_date")
|
||||
self.assertRaises(AttributeError, setattr, a, "upload_date", 5)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "aliases")
|
||||
a.aliases = ["foo"]
|
||||
|
||||
self.assertEqual(["foo"], a.aliases)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "metadata")
|
||||
a.metadata = {"foo": 1}
|
||||
|
||||
self.assertEqual({"foo": 1}, a.metadata)
|
||||
|
||||
self.assertRaises(AttributeError, setattr, a, "md5", 5)
|
||||
|
||||
await a.close()
|
||||
|
||||
if _IS_SYNC:
|
||||
a.forty_two = 42
|
||||
else:
|
||||
self.assertRaises(AttributeError, setattr, a, "forty_two", 42)
|
||||
await a.set("forty_two", 42)
|
||||
|
||||
self.assertEqual(42, a.forty_two)
|
||||
|
||||
self.assertTrue(isinstance(a._id, ObjectId))
|
||||
self.assertRaises(AttributeError, setattr, a, "_id", 5)
|
||||
|
||||
self.assertEqual("my_file", a.filename)
|
||||
self.assertEqual("my_file", a.name)
|
||||
|
||||
self.assertEqual("text/html", a.content_type)
|
||||
|
||||
self.assertEqual(0, a.length)
|
||||
self.assertRaises(AttributeError, setattr, a, "length", 5)
|
||||
|
||||
self.assertEqual(255 * 1024, a.chunk_size)
|
||||
self.assertRaises(AttributeError, setattr, a, "chunk_size", 5)
|
||||
|
||||
self.assertTrue(isinstance(a.upload_date, datetime.datetime))
|
||||
self.assertRaises(AttributeError, setattr, a, "upload_date", 5)
|
||||
|
||||
self.assertEqual(["foo"], a.aliases)
|
||||
|
||||
self.assertEqual({"foo": 1}, a.metadata)
|
||||
|
||||
self.assertEqual(None, a.md5)
|
||||
self.assertRaises(AttributeError, setattr, a, "md5", 5)
|
||||
|
||||
# Make sure custom attributes that were set both before and after
|
||||
# a.close() are reflected in b. PYTHON-411.
|
||||
b = await AsyncGridFS(self.db).get_last_version(filename=a.filename)
|
||||
self.assertEqual(a.metadata, b.metadata)
|
||||
self.assertEqual(a.aliases, b.aliases)
|
||||
self.assertEqual(a.forty_two, b.forty_two)
|
||||
|
||||
async def test_grid_out_default_opts(self):
|
||||
self.assertRaises(TypeError, AsyncGridOut, "foo")
|
||||
|
||||
gout = AsyncGridOut(self.db.fs, 5)
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
await gout.open()
|
||||
gout.name
|
||||
|
||||
a = AsyncGridIn(self.db.fs)
|
||||
await a.close()
|
||||
|
||||
b = AsyncGridOut(self.db.fs, a._id)
|
||||
if not _IS_SYNC:
|
||||
await b.open()
|
||||
|
||||
self.assertEqual(a._id, b._id)
|
||||
self.assertEqual(0, b.length)
|
||||
self.assertEqual(None, b.content_type)
|
||||
self.assertEqual(None, b.name)
|
||||
self.assertEqual(None, b.filename)
|
||||
self.assertEqual(255 * 1024, b.chunk_size)
|
||||
self.assertTrue(isinstance(b.upload_date, datetime.datetime))
|
||||
self.assertEqual(None, b.aliases)
|
||||
self.assertEqual(None, b.metadata)
|
||||
self.assertEqual(None, b.md5)
|
||||
|
||||
for attr in [
|
||||
"_id",
|
||||
"name",
|
||||
"content_type",
|
||||
"length",
|
||||
"chunk_size",
|
||||
"upload_date",
|
||||
"aliases",
|
||||
"metadata",
|
||||
"md5",
|
||||
]:
|
||||
self.assertRaises(AttributeError, setattr, b, attr, 5)
|
||||
|
||||
async def test_grid_out_cursor_options(self):
|
||||
self.assertRaises(
|
||||
TypeError, AsyncGridOutCursor.__init__, self.db.fs, {}, projection={"filename": 1}
|
||||
)
|
||||
|
||||
cursor = AsyncGridOutCursor(self.db.fs, {})
|
||||
cursor_clone = cursor.clone()
|
||||
|
||||
cursor_dict = cursor.__dict__.copy()
|
||||
cursor_dict.pop("_session")
|
||||
cursor_clone_dict = cursor_clone.__dict__.copy()
|
||||
cursor_clone_dict.pop("_session")
|
||||
self.assertDictEqual(cursor_dict, cursor_clone_dict)
|
||||
|
||||
self.assertRaises(NotImplementedError, cursor.add_option, 0)
|
||||
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
|
||||
|
||||
async def test_grid_out_custom_opts(self):
|
||||
one = AsyncGridIn(
|
||||
self.db.fs,
|
||||
_id=5,
|
||||
filename="my_file",
|
||||
contentType="text/html",
|
||||
chunkSize=1000,
|
||||
aliases=["foo"],
|
||||
metadata={"foo": 1, "bar": 2},
|
||||
bar=3,
|
||||
baz="hello",
|
||||
)
|
||||
await one.write(b"hello world")
|
||||
await one.close()
|
||||
|
||||
two = AsyncGridOut(self.db.fs, 5)
|
||||
|
||||
if not _IS_SYNC:
|
||||
await two.open()
|
||||
|
||||
self.assertEqual("my_file", two.name)
|
||||
self.assertEqual("my_file", two.filename)
|
||||
self.assertEqual(5, two._id)
|
||||
self.assertEqual(11, two.length)
|
||||
self.assertEqual("text/html", two.content_type)
|
||||
self.assertEqual(1000, two.chunk_size)
|
||||
self.assertTrue(isinstance(two.upload_date, datetime.datetime))
|
||||
self.assertEqual(["foo"], two.aliases)
|
||||
self.assertEqual({"foo": 1, "bar": 2}, two.metadata)
|
||||
self.assertEqual(3, two.bar)
|
||||
self.assertEqual(None, two.md5)
|
||||
|
||||
for attr in [
|
||||
"_id",
|
||||
"name",
|
||||
"content_type",
|
||||
"length",
|
||||
"chunk_size",
|
||||
"upload_date",
|
||||
"aliases",
|
||||
"metadata",
|
||||
"md5",
|
||||
]:
|
||||
self.assertRaises(AttributeError, setattr, two, attr, 5)
|
||||
|
||||
async def test_grid_out_file_document(self):
|
||||
one = AsyncGridIn(self.db.fs)
|
||||
await one.write(b"foo bar")
|
||||
await one.close()
|
||||
|
||||
two = AsyncGridOut(self.db.fs, file_document=await self.db.fs.files.find_one())
|
||||
self.assertEqual(b"foo bar", await two.read())
|
||||
|
||||
three = AsyncGridOut(self.db.fs, 5, file_document=await self.db.fs.files.find_one())
|
||||
self.assertEqual(b"foo bar", await three.read())
|
||||
|
||||
four = AsyncGridOut(self.db.fs, file_document={})
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
await four.open()
|
||||
four.name
|
||||
|
||||
async def test_write_file_like(self):
|
||||
one = AsyncGridIn(self.db.fs)
|
||||
await one.write(b"hello world")
|
||||
await one.close()
|
||||
|
||||
two = AsyncGridOut(self.db.fs, one._id)
|
||||
|
||||
three = AsyncGridIn(self.db.fs)
|
||||
await three.write(two)
|
||||
await three.close()
|
||||
|
||||
four = AsyncGridOut(self.db.fs, three._id)
|
||||
self.assertEqual(b"hello world", await four.read())
|
||||
|
||||
five = AsyncGridIn(self.db.fs, chunk_size=2)
|
||||
await five.write(b"hello")
|
||||
buffer = BytesIO(b" world")
|
||||
await five.write(buffer)
|
||||
await five.write(b" and mongodb")
|
||||
await five.close()
|
||||
self.assertEqual(
|
||||
b"hello world and mongodb", await AsyncGridOut(self.db.fs, five._id).read()
|
||||
)
|
||||
|
||||
async def test_write_lines(self):
|
||||
a = AsyncGridIn(self.db.fs)
|
||||
await a.writelines([b"hello ", b"world"])
|
||||
await a.close()
|
||||
|
||||
self.assertEqual(b"hello world", await AsyncGridOut(self.db.fs, a._id).read())
|
||||
|
||||
async def test_close(self):
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
await f.close()
|
||||
with self.assertRaises(ValueError):
|
||||
await f.write("test")
|
||||
await f.close()
|
||||
|
||||
async def test_closed(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=5)
|
||||
await f.write(b"Hello world.\nHow are you?")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
if not _IS_SYNC:
|
||||
await g.open()
|
||||
self.assertFalse(g.closed)
|
||||
await g.read(1)
|
||||
self.assertFalse(g.closed)
|
||||
await g.read(100)
|
||||
self.assertFalse(g.closed)
|
||||
await g.close()
|
||||
self.assertTrue(g.closed)
|
||||
|
||||
async def test_multi_chunk_file(self):
|
||||
random_string = b"a" * (DEFAULT_CHUNK_SIZE + 1000)
|
||||
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
await f.write(random_string)
|
||||
await f.close()
|
||||
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(2, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(random_string, await g.read())
|
||||
|
||||
# TODO: https://jira.mongodb.org/browse/PYTHON-4708
|
||||
@async_client_context.require_sync
|
||||
async def test_small_chunks(self):
|
||||
self.files = 0
|
||||
self.chunks = 0
|
||||
|
||||
async def helper(data):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=1)
|
||||
await f.write(data)
|
||||
await f.close()
|
||||
|
||||
self.files += 1
|
||||
self.chunks += len(data)
|
||||
|
||||
self.assertEqual(self.files, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(self.chunks, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(data, await g.read())
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(data, await g.read(10) + await g.read(10))
|
||||
return True
|
||||
|
||||
qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20)))
|
||||
|
||||
async def test_seek(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=3)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"hello world", await g.read())
|
||||
await g.seek(0)
|
||||
self.assertEqual(b"hello world", await g.read())
|
||||
await g.seek(1)
|
||||
self.assertEqual(b"ello world", await g.read())
|
||||
with self.assertRaises(IOError):
|
||||
await g.seek(-1)
|
||||
|
||||
await g.seek(-3, _SEEK_END)
|
||||
self.assertEqual(b"rld", await g.read())
|
||||
await g.seek(0, _SEEK_END)
|
||||
self.assertEqual(b"", await g.read())
|
||||
with self.assertRaises(IOError):
|
||||
await g.seek(-100, _SEEK_END)
|
||||
|
||||
await g.seek(3)
|
||||
await g.seek(3, _SEEK_CUR)
|
||||
self.assertEqual(b"world", await g.read())
|
||||
with self.assertRaises(IOError):
|
||||
await g.seek(-100, _SEEK_CUR)
|
||||
|
||||
async def test_tell(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=3)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(0, g.tell())
|
||||
await g.read(0)
|
||||
self.assertEqual(0, g.tell())
|
||||
await g.read(1)
|
||||
self.assertEqual(1, g.tell())
|
||||
await g.read(2)
|
||||
self.assertEqual(3, g.tell())
|
||||
await g.read()
|
||||
self.assertEqual(g.length, g.tell())
|
||||
|
||||
async def test_multiple_reads(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=3)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"he", await g.read(2))
|
||||
self.assertEqual(b"ll", await g.read(2))
|
||||
self.assertEqual(b"o ", await g.read(2))
|
||||
self.assertEqual(b"wo", await g.read(2))
|
||||
self.assertEqual(b"rl", await g.read(2))
|
||||
self.assertEqual(b"d", await g.read(2))
|
||||
self.assertEqual(b"", await g.read(2))
|
||||
|
||||
async def test_readline(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=5)
|
||||
await f.write(
|
||||
b"""Hello world,
|
||||
How are you?
|
||||
Hope all is well.
|
||||
Bye"""
|
||||
)
|
||||
await f.close()
|
||||
|
||||
# Try read(), then readline().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"H", await g.read(1))
|
||||
self.assertEqual(b"ello world,\n", await g.readline())
|
||||
self.assertEqual(b"How a", await g.readline(5))
|
||||
self.assertEqual(b"", await g.readline(0))
|
||||
self.assertEqual(b"re you?\n", await g.readline())
|
||||
self.assertEqual(b"Hope all is well.\n", await g.readline(1000))
|
||||
self.assertEqual(b"Bye", await g.readline())
|
||||
self.assertEqual(b"", await g.readline())
|
||||
|
||||
# Try readline() first, then read().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"He", await g.readline(2))
|
||||
self.assertEqual(b"l", await g.read(1))
|
||||
self.assertEqual(b"lo", await g.readline(2))
|
||||
self.assertEqual(b" world,\n", await g.readline())
|
||||
|
||||
# Only readline().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"H", await g.readline(1))
|
||||
self.assertEqual(b"e", await g.readline(1))
|
||||
self.assertEqual(b"llo world,\n", await g.readline())
|
||||
|
||||
async def test_readlines(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=5)
|
||||
await f.write(
|
||||
b"""Hello world,
|
||||
How are you?
|
||||
Hope all is well.
|
||||
Bye"""
|
||||
)
|
||||
await f.close()
|
||||
|
||||
# Try read(), then readlines().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"He", await g.read(2))
|
||||
self.assertEqual([b"llo world,\n", b"How are you?\n"], await g.readlines(11))
|
||||
self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines())
|
||||
self.assertEqual([], await g.readlines())
|
||||
|
||||
# Try readline(), then readlines().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"Hello world,\n", await g.readline())
|
||||
self.assertEqual([b"How are you?\n", b"Hope all is well.\n"], await g.readlines(13))
|
||||
self.assertEqual(b"Bye", await g.readline())
|
||||
self.assertEqual([], await g.readlines())
|
||||
|
||||
# Only readlines().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
|
||||
await g.readlines(),
|
||||
)
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
|
||||
await g.readlines(0),
|
||||
)
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
|
||||
self.assertEqual([b"How are you?\n"], await g.readlines(12))
|
||||
self.assertEqual([b"Hope all is well.\n", b"Bye"], await g.readlines(18))
|
||||
|
||||
# Try readlines() first, then read().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
|
||||
self.assertEqual(b"H", await g.read(1))
|
||||
self.assertEqual([b"ow are you?\n", b"Hope all is well.\n"], await g.readlines(29))
|
||||
self.assertEqual([b"Bye"], await g.readlines(1))
|
||||
|
||||
# Try readlines() first, then readline().
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual([b"Hello world,\n"], await g.readlines(1))
|
||||
self.assertEqual(b"How are you?\n", await g.readline())
|
||||
self.assertEqual([b"Hope all is well.\n"], await g.readlines(17))
|
||||
self.assertEqual(b"Bye", await g.readline())
|
||||
|
||||
async def test_iterator(self):
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
await f.close()
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([], list(g))
|
||||
else:
|
||||
self.assertEqual([], await g.to_list())
|
||||
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
await f.write(b"hello world\nhere are\nsome lines.")
|
||||
await f.close()
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
|
||||
else:
|
||||
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], await g.to_list())
|
||||
|
||||
self.assertEqual(b"", await g.read(5))
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([], list(g))
|
||||
else:
|
||||
self.assertEqual([], await g.to_list())
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"hello world\n", await anext(aiter(g)))
|
||||
self.assertEqual(b"here", await g.read(4))
|
||||
self.assertEqual(b" are\n", await anext(aiter(g)))
|
||||
self.assertEqual(b"some lines", await g.read(10))
|
||||
self.assertEqual(b".", await anext(aiter(g)))
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await aiter(g).__anext__()
|
||||
|
||||
f = AsyncGridIn(self.db.fs, chunk_size=2)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([b"hello world"], list(g))
|
||||
else:
|
||||
self.assertEqual([b"hello world"], await g.to_list())
|
||||
|
||||
async def test_read_unaligned_buffer_size(self):
|
||||
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=16)
|
||||
await f.write(in_data)
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
out_data = b""
|
||||
while 1:
|
||||
s = await g.read(13)
|
||||
if not s:
|
||||
break
|
||||
out_data += s
|
||||
|
||||
self.assertEqual(in_data, out_data)
|
||||
|
||||
async def test_readchunk(self):
|
||||
in_data = b"a" * 10
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=3)
|
||||
await f.write(in_data)
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(3, len(await g.readchunk()))
|
||||
|
||||
self.assertEqual(2, len(await g.read(2)))
|
||||
self.assertEqual(1, len(await g.readchunk()))
|
||||
|
||||
self.assertEqual(3, len(await g.read(3)))
|
||||
|
||||
self.assertEqual(1, len(await g.readchunk()))
|
||||
|
||||
self.assertEqual(0, len(await g.readchunk()))
|
||||
|
||||
async def test_write_unicode(self):
|
||||
f = AsyncGridIn(self.db.fs)
|
||||
with self.assertRaises(TypeError):
|
||||
await f.write("foo")
|
||||
|
||||
f = AsyncGridIn(self.db.fs, encoding="utf-8")
|
||||
await f.write("foo")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"foo", await g.read())
|
||||
|
||||
f = AsyncGridIn(self.db.fs, encoding="iso-8859-1")
|
||||
await f.write("aé")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
self.assertEqual("aé".encode("iso-8859-1"), await g.read())
|
||||
|
||||
async def test_set_after_close(self):
|
||||
f = AsyncGridIn(self.db.fs, _id="foo", bar="baz")
|
||||
|
||||
self.assertEqual("foo", f._id)
|
||||
self.assertEqual("baz", f.bar)
|
||||
self.assertRaises(AttributeError, getattr, f, "baz")
|
||||
self.assertRaises(AttributeError, getattr, f, "uploadDate")
|
||||
|
||||
self.assertRaises(AttributeError, setattr, f, "_id", 5)
|
||||
if _IS_SYNC:
|
||||
f.bar = "foo"
|
||||
f.baz = 5
|
||||
else:
|
||||
await f.set("bar", "foo")
|
||||
await f.set("baz", 5)
|
||||
|
||||
self.assertEqual("foo", f._id)
|
||||
self.assertEqual("foo", f.bar)
|
||||
self.assertEqual(5, f.baz)
|
||||
self.assertRaises(AttributeError, getattr, f, "uploadDate")
|
||||
|
||||
await f.close()
|
||||
|
||||
self.assertEqual("foo", f._id)
|
||||
self.assertEqual("foo", f.bar)
|
||||
self.assertEqual(5, f.baz)
|
||||
self.assertTrue(f.uploadDate)
|
||||
|
||||
self.assertRaises(AttributeError, setattr, f, "_id", 5)
|
||||
if _IS_SYNC:
|
||||
f.bar = "a"
|
||||
f.baz = "b"
|
||||
else:
|
||||
await f.set("bar", "a")
|
||||
await f.set("baz", "b")
|
||||
self.assertRaises(AttributeError, setattr, f, "upload_date", 5)
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
if not _IS_SYNC:
|
||||
await g.open()
|
||||
self.assertEqual("a", g.bar)
|
||||
self.assertEqual("b", g.baz)
|
||||
# Versions 2.0.1 and older saved a _closed field for some reason.
|
||||
self.assertRaises(AttributeError, getattr, g, "_closed")
|
||||
|
||||
async def test_context_manager(self):
|
||||
contents = b"Imagine this is some important data..."
|
||||
|
||||
async with AsyncGridIn(self.db.fs, filename="important") as infile:
|
||||
await infile.write(contents)
|
||||
|
||||
async with AsyncGridOut(self.db.fs, infile._id) as outfile:
|
||||
self.assertEqual(contents, await outfile.read())
|
||||
|
||||
async def test_exception_file_non_existence(self):
|
||||
contents = b"Imagine this is some important data..."
|
||||
|
||||
with self.assertRaises(ConnectionError):
|
||||
async with AsyncGridIn(self.db.fs, filename="important") as infile:
|
||||
await infile.write(contents)
|
||||
raise ConnectionError("Test exception")
|
||||
|
||||
# Expectation: File chunks are written, entry in files doesn't appear.
|
||||
self.assertEqual(
|
||||
await self.db.fs.chunks.count_documents({"files_id": infile._id}), infile._chunk_number
|
||||
)
|
||||
|
||||
self.assertIsNone(await self.db.fs.files.find_one({"_id": infile._id}))
|
||||
self.assertTrue(infile.closed)
|
||||
|
||||
async def test_prechunked_string(self):
|
||||
async def write_me(s, chunk_size):
|
||||
buf = BytesIO(s)
|
||||
infile = AsyncGridIn(self.db.fs)
|
||||
while True:
|
||||
to_write = buf.read(chunk_size)
|
||||
if to_write == b"":
|
||||
break
|
||||
await infile.write(to_write)
|
||||
await infile.close()
|
||||
buf.close()
|
||||
|
||||
outfile = AsyncGridOut(self.db.fs, infile._id)
|
||||
data = await outfile.read()
|
||||
self.assertEqual(s, data)
|
||||
|
||||
s = b"x" * DEFAULT_CHUNK_SIZE * 4
|
||||
# Test with default chunk size
|
||||
await write_me(s, DEFAULT_CHUNK_SIZE)
|
||||
# Multiple
|
||||
await write_me(s, DEFAULT_CHUNK_SIZE * 3)
|
||||
# Custom
|
||||
await write_me(s, 262300)
|
||||
|
||||
async def test_grid_out_lazy_connect(self):
|
||||
fs = self.db.fs
|
||||
outfile = AsyncGridOut(fs, file_id=-1)
|
||||
with self.assertRaises(NoFile):
|
||||
await outfile.read()
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
await outfile.open()
|
||||
outfile.filename
|
||||
|
||||
infile = AsyncGridIn(fs, filename=1)
|
||||
await infile.close()
|
||||
|
||||
outfile = AsyncGridOut(fs, infile._id)
|
||||
await outfile.read()
|
||||
outfile.filename
|
||||
|
||||
outfile = AsyncGridOut(fs, infile._id)
|
||||
await outfile.readchunk()
|
||||
|
||||
async def test_grid_in_lazy_connect(self):
|
||||
client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
fs = client.db.fs
|
||||
infile = AsyncGridIn(fs, file_id=-1, chunk_size=1)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await infile.write(b"data")
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await infile.close()
|
||||
|
||||
async def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
|
||||
async def test_survive_cursor_not_found(self):
|
||||
# By default the find command returns 101 documents in the first batch.
|
||||
# Use 102 batches to cause a single getMore.
|
||||
chunk_size = 1024
|
||||
data = b"d" * (102 * chunk_size)
|
||||
listener = EventListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
db = client.pymongo_test
|
||||
async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile:
|
||||
await infile.write(data)
|
||||
|
||||
async with AsyncGridOut(db.fs, infile._id) as outfile:
|
||||
self.assertEqual(len(await outfile.readchunk()), chunk_size)
|
||||
|
||||
# Kill the cursor to simulate the cursor timing out on the server
|
||||
# when an application spends a long time between two calls to
|
||||
# readchunk().
|
||||
assert await client.address is not None
|
||||
await client._close_cursor_now(
|
||||
outfile._chunk_iter._cursor.cursor_id,
|
||||
_CursorAddress(await client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Read the rest of the file without error.
|
||||
self.assertEqual(len(await outfile.read()), len(data) - chunk_size)
|
||||
|
||||
# Paranoid, ensure that a getMore was actually sent.
|
||||
self.assertIn("getMore", listener.started_command_names())
|
||||
|
||||
@async_client_context.require_sync
|
||||
async def test_zip(self):
|
||||
zf = BytesIO()
|
||||
z = zipfile.ZipFile(zf, "w")
|
||||
z.writestr("test.txt", b"hello world")
|
||||
z.close()
|
||||
zf.seek(0)
|
||||
|
||||
f = AsyncGridIn(self.db.fs, filename="test.zip")
|
||||
await f.write(zf)
|
||||
await f.close()
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(1, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
z = zipfile.ZipFile(g)
|
||||
self.assertSequenceEqual(z.namelist(), ["test.txt"])
|
||||
self.assertEqual(z.read("test.txt"), b"hello world")
|
||||
|
||||
async def test_grid_out_unsupported_operations(self):
|
||||
f = AsyncGridIn(self.db.fs, chunkSize=3)
|
||||
await f.write(b"hello world")
|
||||
await f.close()
|
||||
|
||||
g = AsyncGridOut(self.db.fs, f._id)
|
||||
|
||||
self.assertRaises(io.UnsupportedOperation, g.writelines, [b"some", b"lines"])
|
||||
self.assertRaises(io.UnsupportedOperation, g.write, b"some text")
|
||||
self.assertRaises(io.UnsupportedOperation, g.fileno)
|
||||
self.assertRaises(io.UnsupportedOperation, g.truncate)
|
||||
|
||||
self.assertFalse(g.writable())
|
||||
self.assertFalse(g.isatty())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -35,23 +35,13 @@ try:
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator, no_type_check
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo import common, message
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
if HAVE_SSL:
|
||||
|
||||
@ -44,14 +44,16 @@ from pymongo.operations import *
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class BulkTestBase(IntegrationTest):
|
||||
coll: Collection
|
||||
coll_w0: Collection
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.coll = cls.db.test
|
||||
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
|
||||
|
||||
@ -135,7 +137,8 @@ class BulkTestBase(IntegrationTest):
|
||||
|
||||
class TestBulk(BulkTestBase):
|
||||
def test_empty(self):
|
||||
self.assertRaises(InvalidOperation, self.coll.bulk_write, [])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.coll.bulk_write([])
|
||||
|
||||
def test_insert(self):
|
||||
expected = {
|
||||
@ -180,15 +183,19 @@ class TestBulk(BulkTestBase):
|
||||
self._test_update_many([{"$set": {"foo": "bar"}}])
|
||||
|
||||
def test_array_filters_validation(self):
|
||||
self.assertRaises(TypeError, UpdateMany, {}, {}, array_filters={})
|
||||
self.assertRaises(TypeError, UpdateOne, {}, {}, array_filters={})
|
||||
with self.assertRaises(TypeError):
|
||||
UpdateMany({}, {}, array_filters={}) # type: ignore[arg-type]
|
||||
with self.assertRaises(TypeError):
|
||||
UpdateOne({}, {}, array_filters={}) # type: ignore[arg-type]
|
||||
|
||||
def test_array_filters_unacknowledged(self):
|
||||
coll = self.coll_w0
|
||||
update_one = UpdateOne({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}])
|
||||
update_many = UpdateMany({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}])
|
||||
self.assertRaises(ConfigurationError, coll.bulk_write, [update_one])
|
||||
self.assertRaises(ConfigurationError, coll.bulk_write, [update_many])
|
||||
with self.assertRaises(ConfigurationError):
|
||||
coll.bulk_write([update_one])
|
||||
with self.assertRaises(ConfigurationError):
|
||||
coll.bulk_write([update_many])
|
||||
|
||||
def _test_update_one(self, update):
|
||||
expected = {
|
||||
@ -790,8 +797,8 @@ class BulkAuthorizationTestBase(BulkTestBase):
|
||||
@classmethod
|
||||
@client_context.require_auth
|
||||
@client_context.require_no_api_version
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -828,8 +835,16 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
]
|
||||
result = self.coll_w0.bulk_write(requests)
|
||||
self.assertFalse(result.acknowledged)
|
||||
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
|
||||
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
|
||||
|
||||
def predicate():
|
||||
return self.coll.count_documents({}) == 2
|
||||
|
||||
wait_until(predicate, "insert 2 documents")
|
||||
|
||||
def predicate():
|
||||
return self.coll.find_one({"_id": 1}) is None
|
||||
|
||||
wait_until(predicate, 'removed {"_id": 1}')
|
||||
|
||||
def test_no_results_ordered_failure(self):
|
||||
requests: list = [
|
||||
@ -843,7 +858,11 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
]
|
||||
result = self.coll_w0.bulk_write(requests)
|
||||
self.assertFalse(result.acknowledged)
|
||||
wait_until(lambda: self.coll.count_documents({}) == 3, "insert 3 documents")
|
||||
|
||||
def predicate():
|
||||
return self.coll.count_documents({}) == 3
|
||||
|
||||
wait_until(predicate, "insert 3 documents")
|
||||
self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1}))
|
||||
|
||||
def test_no_results_unordered_success(self):
|
||||
@ -855,8 +874,16 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
]
|
||||
result = self.coll_w0.bulk_write(requests, ordered=False)
|
||||
self.assertFalse(result.acknowledged)
|
||||
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
|
||||
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
|
||||
|
||||
def predicate():
|
||||
return self.coll.count_documents({}) == 2
|
||||
|
||||
wait_until(predicate, "insert 2 documents")
|
||||
|
||||
def predicate():
|
||||
return self.coll.find_one({"_id": 1}) is None
|
||||
|
||||
wait_until(predicate, 'removed {"_id": 1}')
|
||||
|
||||
def test_no_results_unordered_failure(self):
|
||||
requests: list = [
|
||||
@ -870,8 +897,16 @@ class TestBulkUnacknowledged(BulkTestBase):
|
||||
]
|
||||
result = self.coll_w0.bulk_write(requests, ordered=False)
|
||||
self.assertFalse(result.acknowledged)
|
||||
wait_until(lambda: self.coll.count_documents({}) == 2, "insert 2 documents")
|
||||
wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}')
|
||||
|
||||
def predicate():
|
||||
return self.coll.count_documents({}) == 2
|
||||
|
||||
wait_until(predicate, "insert 2 documents")
|
||||
|
||||
def predicate():
|
||||
return self.coll.find_one({"_id": 1}) is None
|
||||
|
||||
wait_until(predicate, 'removed {"_id": 1}')
|
||||
|
||||
|
||||
class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
@ -883,7 +918,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
)
|
||||
coll = cli.pymongo_test.test
|
||||
coll.find_one()
|
||||
self.assertRaises(OperationFailure, coll.bulk_write, [InsertOne({"x": 1})])
|
||||
with self.assertRaises(OperationFailure):
|
||||
coll.bulk_write([InsertOne({"x": 1})])
|
||||
|
||||
def test_no_remove(self):
|
||||
# We test that an authorization failure aborts the batch and is raised
|
||||
@ -899,7 +935,8 @@ class TestBulkAuthorization(BulkAuthorizationTestBase):
|
||||
DeleteMany({}), # Prohibited.
|
||||
InsertOne({"x": 3}), # Never attempted.
|
||||
]
|
||||
self.assertRaises(OperationFailure, coll.bulk_write, requests)
|
||||
with self.assertRaises(OperationFailure):
|
||||
coll.bulk_write(requests) # type: ignore[arg-type]
|
||||
self.assertEqual({1, 2}, set(self.coll.distinct("x")))
|
||||
|
||||
|
||||
@ -908,18 +945,18 @@ class TestBulkWriteConcern(BulkTestBase):
|
||||
secondary: MongoClient
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.w = client_context.w
|
||||
cls.secondary = None
|
||||
if cls.w is not None and cls.w > 1:
|
||||
for member in client_context.hello["hosts"]:
|
||||
if member != client_context.hello["primary"]:
|
||||
for member in (client_context.hello)["hosts"]:
|
||||
if member != (client_context.hello)["primary"]:
|
||||
cls.secondary = single_client(*partition_node(member))
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def async_tearDownClass(cls):
|
||||
if cls.secondary:
|
||||
cls.secondary.close()
|
||||
|
||||
|
||||
@ -18,10 +18,12 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import SkipTest, client_context, unittest
|
||||
from test import SkipTest, UnitTest, client_context, unittest
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestClientContext(unittest.TestCase):
|
||||
class TestClientContext(UnitTest):
|
||||
def test_must_connect(self):
|
||||
if "PYMONGO_MUST_CONNECT" not in os.environ:
|
||||
raise SkipTest("PYMONGO_MUST_CONNECT is not set")
|
||||
|
||||
@ -21,6 +21,7 @@ import io
|
||||
import sys
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from test import IntegrationTest, UnitTest, client_context
|
||||
|
||||
from pymongo.synchronous.database import Database
|
||||
|
||||
@ -36,16 +37,20 @@ from gridfs.synchronous.grid_file import (
|
||||
_SEEK_CUR,
|
||||
_SEEK_END,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
GridFS,
|
||||
GridIn,
|
||||
GridOut,
|
||||
GridOutCursor,
|
||||
)
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, ServerSelectionTimeoutError
|
||||
from pymongo.message import _CursorAddress
|
||||
from pymongo.synchronous.helpers import iter, next
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestGridFileNoConnect(unittest.TestCase):
|
||||
class TestGridFileNoConnect(UnitTest):
|
||||
"""Test GridFile features on a client that does not connect."""
|
||||
|
||||
db: Database
|
||||
@ -151,6 +156,7 @@ class TestGridFile(IntegrationTest):
|
||||
|
||||
self.assertEqual(None, a.content_type)
|
||||
a.content_type = "text/html"
|
||||
|
||||
self.assertEqual("text/html", a.content_type)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "length")
|
||||
@ -164,17 +170,24 @@ class TestGridFile(IntegrationTest):
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "aliases")
|
||||
a.aliases = ["foo"]
|
||||
|
||||
self.assertEqual(["foo"], a.aliases)
|
||||
|
||||
self.assertRaises(AttributeError, getattr, a, "metadata")
|
||||
a.metadata = {"foo": 1}
|
||||
|
||||
self.assertEqual({"foo": 1}, a.metadata)
|
||||
|
||||
self.assertRaises(AttributeError, setattr, a, "md5", 5)
|
||||
|
||||
a.close()
|
||||
|
||||
a.forty_two = 42
|
||||
if _IS_SYNC:
|
||||
a.forty_two = 42
|
||||
else:
|
||||
self.assertRaises(AttributeError, setattr, a, "forty_two", 42)
|
||||
a.set("forty_two", 42)
|
||||
|
||||
self.assertEqual(42, a.forty_two)
|
||||
|
||||
self.assertTrue(isinstance(a._id, ObjectId))
|
||||
@ -213,12 +226,16 @@ class TestGridFile(IntegrationTest):
|
||||
|
||||
gout = GridOut(self.db.fs, 5)
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
gout.open()
|
||||
gout.name
|
||||
|
||||
a = GridIn(self.db.fs)
|
||||
a.close()
|
||||
|
||||
b = GridOut(self.db.fs, a._id)
|
||||
if not _IS_SYNC:
|
||||
b.open()
|
||||
|
||||
self.assertEqual(a._id, b._id)
|
||||
self.assertEqual(0, b.length)
|
||||
@ -278,6 +295,9 @@ class TestGridFile(IntegrationTest):
|
||||
|
||||
two = GridOut(self.db.fs, 5)
|
||||
|
||||
if not _IS_SYNC:
|
||||
two.open()
|
||||
|
||||
self.assertEqual("my_file", two.name)
|
||||
self.assertEqual("my_file", two.filename)
|
||||
self.assertEqual(5, two._id)
|
||||
@ -316,6 +336,8 @@ class TestGridFile(IntegrationTest):
|
||||
|
||||
four = GridOut(self.db.fs, file_document={})
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
four.open()
|
||||
four.name
|
||||
|
||||
def test_write_file_like(self):
|
||||
@ -350,7 +372,8 @@ class TestGridFile(IntegrationTest):
|
||||
def test_close(self):
|
||||
f = GridIn(self.db.fs)
|
||||
f.close()
|
||||
self.assertRaises(ValueError, f.write, "test")
|
||||
with self.assertRaises(ValueError):
|
||||
f.write("test")
|
||||
f.close()
|
||||
|
||||
def test_closed(self):
|
||||
@ -359,6 +382,8 @@ class TestGridFile(IntegrationTest):
|
||||
f.close()
|
||||
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
if not _IS_SYNC:
|
||||
g.open()
|
||||
self.assertFalse(g.closed)
|
||||
g.read(1)
|
||||
self.assertFalse(g.closed)
|
||||
@ -380,6 +405,8 @@ class TestGridFile(IntegrationTest):
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual(random_string, g.read())
|
||||
|
||||
# TODO: https://jira.mongodb.org/browse/PYTHON-4708
|
||||
@client_context.require_sync
|
||||
def test_small_chunks(self):
|
||||
self.files = 0
|
||||
self.chunks = 0
|
||||
@ -415,18 +442,21 @@ class TestGridFile(IntegrationTest):
|
||||
self.assertEqual(b"hello world", g.read())
|
||||
g.seek(1)
|
||||
self.assertEqual(b"ello world", g.read())
|
||||
self.assertRaises(IOError, g.seek, -1)
|
||||
with self.assertRaises(IOError):
|
||||
g.seek(-1)
|
||||
|
||||
g.seek(-3, _SEEK_END)
|
||||
self.assertEqual(b"rld", g.read())
|
||||
g.seek(0, _SEEK_END)
|
||||
self.assertEqual(b"", g.read())
|
||||
self.assertRaises(IOError, g.seek, -100, _SEEK_END)
|
||||
with self.assertRaises(IOError):
|
||||
g.seek(-100, _SEEK_END)
|
||||
|
||||
g.seek(3)
|
||||
g.seek(3, _SEEK_CUR)
|
||||
self.assertEqual(b"world", g.read())
|
||||
self.assertRaises(IOError, g.seek, -100, _SEEK_CUR)
|
||||
with self.assertRaises(IOError):
|
||||
g.seek(-100, _SEEK_CUR)
|
||||
|
||||
def test_tell(self):
|
||||
f = GridIn(self.db.fs, chunkSize=3)
|
||||
@ -519,12 +549,14 @@ Bye"""
|
||||
# Only readlines().
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual(
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines()
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
|
||||
g.readlines(),
|
||||
)
|
||||
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual(
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines(0)
|
||||
[b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"],
|
||||
g.readlines(0),
|
||||
)
|
||||
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
@ -550,15 +582,25 @@ Bye"""
|
||||
f = GridIn(self.db.fs)
|
||||
f.close()
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual([], list(g))
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([], list(g))
|
||||
else:
|
||||
self.assertEqual([], g.to_list())
|
||||
|
||||
f = GridIn(self.db.fs)
|
||||
f.write(b"hello world\nhere are\nsome lines.")
|
||||
f.close()
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
|
||||
else:
|
||||
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], g.to_list())
|
||||
|
||||
self.assertEqual(b"", g.read(5))
|
||||
self.assertEqual([], list(g))
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([], list(g))
|
||||
else:
|
||||
self.assertEqual([], g.to_list())
|
||||
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual(b"hello world\n", next(iter(g)))
|
||||
@ -566,13 +608,17 @@ Bye"""
|
||||
self.assertEqual(b" are\n", next(iter(g)))
|
||||
self.assertEqual(b"some lines", g.read(10))
|
||||
self.assertEqual(b".", next(iter(g)))
|
||||
self.assertRaises(StopIteration, iter(g).__next__)
|
||||
with self.assertRaises(StopIteration):
|
||||
iter(g).__next__()
|
||||
|
||||
f = GridIn(self.db.fs, chunk_size=2)
|
||||
f.write(b"hello world")
|
||||
f.close()
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
self.assertEqual([b"hello world"], list(g))
|
||||
if _IS_SYNC:
|
||||
self.assertEqual([b"hello world"], list(g))
|
||||
else:
|
||||
self.assertEqual([b"hello world"], g.to_list())
|
||||
|
||||
def test_read_unaligned_buffer_size(self):
|
||||
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
|
||||
@ -610,7 +656,8 @@ Bye"""
|
||||
|
||||
def test_write_unicode(self):
|
||||
f = GridIn(self.db.fs)
|
||||
self.assertRaises(TypeError, f.write, "foo")
|
||||
with self.assertRaises(TypeError):
|
||||
f.write("foo")
|
||||
|
||||
f = GridIn(self.db.fs, encoding="utf-8")
|
||||
f.write("foo")
|
||||
@ -635,8 +682,12 @@ Bye"""
|
||||
self.assertRaises(AttributeError, getattr, f, "uploadDate")
|
||||
|
||||
self.assertRaises(AttributeError, setattr, f, "_id", 5)
|
||||
f.bar = "foo"
|
||||
f.baz = 5
|
||||
if _IS_SYNC:
|
||||
f.bar = "foo"
|
||||
f.baz = 5
|
||||
else:
|
||||
f.set("bar", "foo")
|
||||
f.set("baz", 5)
|
||||
|
||||
self.assertEqual("foo", f._id)
|
||||
self.assertEqual("foo", f.bar)
|
||||
@ -651,11 +702,17 @@ Bye"""
|
||||
self.assertTrue(f.uploadDate)
|
||||
|
||||
self.assertRaises(AttributeError, setattr, f, "_id", 5)
|
||||
f.bar = "a"
|
||||
f.baz = "b"
|
||||
if _IS_SYNC:
|
||||
f.bar = "a"
|
||||
f.baz = "b"
|
||||
else:
|
||||
f.set("bar", "a")
|
||||
f.set("baz", "b")
|
||||
self.assertRaises(AttributeError, setattr, f, "upload_date", 5)
|
||||
|
||||
g = GridOut(self.db.fs, f._id)
|
||||
if not _IS_SYNC:
|
||||
g.open()
|
||||
self.assertEqual("a", g.bar)
|
||||
self.assertEqual("b", g.baz)
|
||||
# Versions 2.0.1 and older saved a _closed field for some reason.
|
||||
@ -713,8 +770,12 @@ Bye"""
|
||||
def test_grid_out_lazy_connect(self):
|
||||
fs = self.db.fs
|
||||
outfile = GridOut(fs, file_id=-1)
|
||||
self.assertRaises(NoFile, outfile.read)
|
||||
self.assertRaises(NoFile, getattr, outfile, "filename")
|
||||
with self.assertRaises(NoFile):
|
||||
outfile.read()
|
||||
with self.assertRaises(NoFile):
|
||||
if not _IS_SYNC:
|
||||
outfile.open()
|
||||
outfile.filename
|
||||
|
||||
infile = GridIn(fs, filename=1)
|
||||
infile.close()
|
||||
@ -730,13 +791,15 @@ Bye"""
|
||||
client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10)
|
||||
fs = client.db.fs
|
||||
infile = GridIn(fs, file_id=-1, chunk_size=1)
|
||||
self.assertRaises(ServerSelectionTimeoutError, infile.write, b"data")
|
||||
self.assertRaises(ServerSelectionTimeoutError, infile.close)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
infile.write(b"data")
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
infile.close()
|
||||
|
||||
def test_unacknowledged(self):
|
||||
# w=0 is prohibited.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
GridIn(rs_or_single_client(w=0).pymongo_test.fs)
|
||||
GridIn((rs_or_single_client(w=0)).pymongo_test.fs)
|
||||
|
||||
def test_survive_cursor_not_found(self):
|
||||
# By default the find command returns 101 documents in the first batch.
|
||||
@ -758,7 +821,7 @@ Bye"""
|
||||
assert client.address is not None
|
||||
client._close_cursor_now(
|
||||
outfile._chunk_iter._cursor.cursor_id,
|
||||
_CursorAddress(client.address, db.fs.chunks.full_name),
|
||||
_CursorAddress(client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Read the rest of the file without error.
|
||||
@ -767,6 +830,7 @@ Bye"""
|
||||
# Paranoid, ensure that a getMore was actually sent.
|
||||
self.assertIn("getMore", listener.started_command_names())
|
||||
|
||||
@client_context.require_sync
|
||||
def test_zip(self):
|
||||
zf = BytesIO()
|
||||
z = zipfile.ZipFile(zf, "w")
|
||||
|
||||
@ -1954,7 +1954,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
if client.get("ignoreExtraMessages", False):
|
||||
actual_logs = actual_logs[: len(client["messages"])]
|
||||
self.assertEqual(len(client["messages"]), len(actual_logs))
|
||||
self.assertEqual(
|
||||
len(client["messages"]),
|
||||
len(actual_logs),
|
||||
f"expected {client['messages']} but got {actual_logs}",
|
||||
)
|
||||
for expected_msg, actual_msg in zip(client["messages"], actual_logs):
|
||||
expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data")
|
||||
|
||||
|
||||
141
tools/convert_test_to_async.py
Normal file
141
tools/convert_test_to_async.py
Normal file
@ -0,0 +1,141 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
replacements = {
|
||||
"Collection": "AsyncCollection",
|
||||
"Database": "AsyncDatabase",
|
||||
"Cursor": "AsyncCursor",
|
||||
"MongoClient": "AsyncMongoClient",
|
||||
"CommandCursor": "AsyncCommandCursor",
|
||||
"RawBatchCursor": "AsyncRawBatchCursor",
|
||||
"RawBatchCommandCursor": "AsyncRawBatchCommandCursor",
|
||||
"ClientSession": "AsyncClientSession",
|
||||
"ChangeStream": "AsyncChangeStream",
|
||||
"CollectionChangeStream": "AsyncCollectionChangeStream",
|
||||
"DatabaseChangeStream": "AsyncDatabaseChangeStream",
|
||||
"ClusterChangeStream": "AsyncClusterChangeStream",
|
||||
"_Bulk": "_AsyncBulk",
|
||||
"_ClientBulk": "_AsyncClientBulk",
|
||||
"Connection": "AsyncConnection",
|
||||
"synchronous": "asynchronous",
|
||||
"Synchronous": "Asynchronous",
|
||||
"next": "await anext",
|
||||
"_Lock": "_ALock",
|
||||
"_Condition": "_ACondition",
|
||||
"GridFS": "AsyncGridFS",
|
||||
"GridFSBucket": "AsyncGridFSBucket",
|
||||
"GridIn": "AsyncGridIn",
|
||||
"GridOut": "AsyncGridOut",
|
||||
"GridOutCursor": "AsyncGridOutCursor",
|
||||
"GridOutIterator": "AsyncGridOutIterator",
|
||||
"GridOutChunkIterator": "_AsyncGridOutChunkIterator",
|
||||
"_grid_in_property": "_a_grid_in_property",
|
||||
"_grid_out_property": "_a_grid_out_property",
|
||||
"ClientEncryption": "AsyncClientEncryption",
|
||||
"MongoCryptCallback": "AsyncMongoCryptCallback",
|
||||
"ExplicitEncrypter": "AsyncExplicitEncrypter",
|
||||
"AutoEncrypter": "AsyncAutoEncrypter",
|
||||
"ContextManager": "AsyncContextManager",
|
||||
"ClientContext": "AsyncClientContext",
|
||||
"TestCollection": "AsyncTestCollection",
|
||||
"IntegrationTest": "AsyncIntegrationTest",
|
||||
"PyMongoTestCase": "AsyncPyMongoTestCase",
|
||||
"MockClientTest": "AsyncMockClientTest",
|
||||
"client_context": "async_client_context",
|
||||
"setUp": "asyncSetUp",
|
||||
"tearDown": "asyncTearDown",
|
||||
"wait_until": "await async_wait_until",
|
||||
"addCleanup": "addAsyncCleanup",
|
||||
"TestCase": "IsolatedAsyncioTestCase",
|
||||
"UnitTest": "AsyncUnitTest",
|
||||
"MockClient": "AsyncMockClient",
|
||||
"SpecRunner": "AsyncSpecRunner",
|
||||
"TransactionsBase": "AsyncTransactionsBase",
|
||||
"get_pool": "await async_get_pool",
|
||||
"is_mongos": "await async_is_mongos",
|
||||
"rs_or_single_client": "await async_rs_or_single_client",
|
||||
"rs_or_single_client_noauth": "await async_rs_or_single_client_noauth",
|
||||
"rs_client": "await async_rs_client",
|
||||
"single_client": "await async_single_client",
|
||||
"from_client": "await async_from_client",
|
||||
"closing": "aclosing",
|
||||
"assertRaisesExactly": "asyncAssertRaisesExactly",
|
||||
"get_mock_client": "await get_async_mock_client",
|
||||
"close": "await aclose",
|
||||
}
|
||||
|
||||
async_classes = [AsyncMongoClient, AsyncDatabase, AsyncCollection, AsyncCursor, AsyncCommandCursor]
|
||||
|
||||
|
||||
def get_async_methods() -> set[str]:
|
||||
result: set[str] = set()
|
||||
for x in async_classes:
|
||||
methods = {
|
||||
k
|
||||
for k, v in vars(x).items()
|
||||
if callable(v)
|
||||
and not isinstance(v, classmethod)
|
||||
and asyncio.iscoroutinefunction(v)
|
||||
and v.__name__[0] != "_"
|
||||
}
|
||||
result = result | methods
|
||||
return result
|
||||
|
||||
|
||||
async_methods = get_async_methods()
|
||||
|
||||
|
||||
def apply_replacements(lines: list[str]) -> list[str]:
|
||||
for i in range(len(lines)):
|
||||
if "_IS_SYNC = True" in lines[i]:
|
||||
lines[i] = "_IS_SYNC = False"
|
||||
if "def test" in lines[i]:
|
||||
lines[i] = lines[i].replace("def test", "async def test")
|
||||
for k in replacements:
|
||||
if k in lines[i]:
|
||||
lines[i] = lines[i].replace(k, replacements[k])
|
||||
for k in async_methods:
|
||||
if k + "(" in lines[i]:
|
||||
tokens = lines[i].split(" ")
|
||||
for j in range(len(tokens)):
|
||||
if k + "(" in tokens[j]:
|
||||
if j < 2:
|
||||
tokens.insert(0, "await")
|
||||
else:
|
||||
tokens.insert(j, "await")
|
||||
break
|
||||
new_line = " ".join(tokens)
|
||||
|
||||
lines[i] = new_line
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def process_file(input_file: str, output_file: str) -> None:
|
||||
with open(input_file, "r+") as f:
|
||||
lines = f.readlines()
|
||||
lines = apply_replacements(lines)
|
||||
|
||||
with open(output_file, "w+") as f2:
|
||||
f2.seek(0)
|
||||
f2.writelines(lines)
|
||||
f2.truncate()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = sys.argv[1:]
|
||||
sync_file = "./test/" + args[0]
|
||||
async_file = "./" + args[0]
|
||||
|
||||
process_file(sync_file, async_file)
|
||||
|
||||
|
||||
main()
|
||||
@ -46,7 +46,10 @@ replacements = {
|
||||
"async_sendall": "sendall",
|
||||
"asynchronous": "synchronous",
|
||||
"Asynchronous": "Synchronous",
|
||||
"AsyncBulkTestBase": "BulkTestBase",
|
||||
"AsyncBulkAuthorizationTestBase": "BulkAuthorizationTestBase",
|
||||
"anext": "next",
|
||||
"aiter": "iter",
|
||||
"_ALock": "_Lock",
|
||||
"_ACondition": "_Condition",
|
||||
"AsyncGridFS": "GridFS",
|
||||
@ -98,6 +101,8 @@ replacements = {
|
||||
"default_async": "default",
|
||||
"aclose": "close",
|
||||
"PyMongo|async": "PyMongo",
|
||||
"AsyncTestGridFile": "TestGridFile",
|
||||
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
@ -154,15 +159,18 @@ converted_tests = [
|
||||
"conftest.py",
|
||||
"pymongo_mocks.py",
|
||||
"utils_spec_runner.py",
|
||||
"test_bulk.py",
|
||||
"test_client.py",
|
||||
"test_client_bulk_write.py",
|
||||
"test_collection.py",
|
||||
"test_cursor.py",
|
||||
"test_database.py",
|
||||
"test_encryption.py",
|
||||
"test_grid_file.py",
|
||||
"test_logger.py",
|
||||
"test_session.py",
|
||||
"test_transactions.py",
|
||||
"test_client_context.py",
|
||||
]
|
||||
|
||||
sync_test_files = [
|
||||
@ -294,6 +302,8 @@ def translate_docstrings(lines: list[str]) -> list[str]:
|
||||
lines[i] = lines[i].replace(k, replacements[k])
|
||||
if "Sync" in lines[i] and "Synchronous" not in lines[i] and replacements[k] in lines[i]:
|
||||
lines[i] = lines[i].replace("Sync", "")
|
||||
if "rsApplyStop" in lines[i]:
|
||||
lines[i] = lines[i].replace("rsApplyStop", "rsSyncApplyStop")
|
||||
if "async for" in lines[i] or "async with" in lines[i] or "async def" in lines[i]:
|
||||
lines[i] = lines[i].replace("async ", "")
|
||||
if "await " in lines[i] and "tailable" not in lines[i]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user