PYTHON-5257 - Turn on mypy disallow_any_generics (#2456)
This commit is contained in:
parent
d7074ba9ee
commit
bbb6f88fae
@ -844,7 +844,7 @@ def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any:
|
||||
return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}}
|
||||
|
||||
|
||||
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
|
||||
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
if (
|
||||
json_options.datetime_representation == DatetimeRepresentation.ISO8601
|
||||
and 0 <= int(obj) <= _MAX_UTC_MS
|
||||
@ -855,7 +855,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
|
||||
return {"$date": {"$numberLong": str(int(obj))}}
|
||||
|
||||
|
||||
def _encode_code(obj: Code, json_options: JSONOptions) -> dict:
|
||||
def _encode_code(obj: Code, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
if obj.scope is None:
|
||||
return {"$code": str(obj)}
|
||||
else:
|
||||
@ -873,7 +873,7 @@ def _encode_noop(obj: Any, dummy0: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict:
|
||||
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
flags = ""
|
||||
if obj.flags & re.IGNORECASE:
|
||||
flags += "i"
|
||||
@ -918,7 +918,7 @@ def _encode_float(obj: float, json_options: JSONOptions) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
|
||||
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
|
||||
if not obj.tzinfo:
|
||||
obj = obj.replace(tzinfo=utc)
|
||||
@ -941,15 +941,15 @@ def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
|
||||
return {"$date": {"$numberLong": str(millis)}}
|
||||
|
||||
|
||||
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict:
|
||||
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
return _encode_binary(obj, 0, json_options)
|
||||
|
||||
|
||||
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict:
|
||||
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
return _encode_binary(obj, obj.subtype, json_options)
|
||||
|
||||
|
||||
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
|
||||
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
if json_options.strict_uuid:
|
||||
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
|
||||
return _encode_binary(binval, binval.subtype, json_options)
|
||||
@ -957,27 +957,27 @@ def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
|
||||
return {"$uuid": obj.hex}
|
||||
|
||||
|
||||
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict:
|
||||
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: # type: ignore[type-arg]
|
||||
return {"$oid": str(obj)}
|
||||
|
||||
|
||||
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict:
|
||||
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
|
||||
return {"$timestamp": {"t": obj.time, "i": obj.inc}}
|
||||
|
||||
|
||||
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict:
|
||||
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
|
||||
return {"$numberDecimal": str(obj)}
|
||||
|
||||
|
||||
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict:
|
||||
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
|
||||
return _json_convert(obj.as_doc(), json_options=json_options)
|
||||
|
||||
|
||||
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict:
|
||||
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
|
||||
return {"$minKey": 1}
|
||||
|
||||
|
||||
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
|
||||
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
|
||||
return {"$maxKey": 1}
|
||||
|
||||
|
||||
@ -985,7 +985,7 @@ def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
|
||||
# Each encoder function's signature is:
|
||||
# - obj: a Python data type, e.g. a Python int for _encode_int
|
||||
# - json_options: a JSONOptions
|
||||
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = {
|
||||
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { # type: ignore[type-arg]
|
||||
bool: _encode_noop,
|
||||
bytes: _encode_bytes,
|
||||
datetime.datetime: _encode_datetime,
|
||||
@ -1056,7 +1056,7 @@ def _get_datetime_size(obj: datetime.datetime) -> int:
|
||||
return 5 + len(str(obj.time()))
|
||||
|
||||
|
||||
def _get_regex_size(obj: Regex) -> int:
|
||||
def _get_regex_size(obj: Regex) -> int: # type: ignore[type-arg]
|
||||
return 18 + len(obj.pattern)
|
||||
|
||||
|
||||
|
||||
@ -28,4 +28,4 @@ if TYPE_CHECKING:
|
||||
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
|
||||
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
|
||||
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
|
||||
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
|
||||
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]
|
||||
|
||||
@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
|
||||
class AsyncGridFS:
|
||||
"""An instance of GridFS on top of a single Database."""
|
||||
|
||||
def __init__(self, database: AsyncDatabase, collection: str = "fs"):
|
||||
def __init__(self, database: AsyncDatabase[Any], collection: str = "fs"):
|
||||
"""Create a new instance of :class:`GridFS`.
|
||||
|
||||
Raises :class:`TypeError` if `database` is not an instance of
|
||||
@ -463,7 +463,7 @@ class AsyncGridFSBucket:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: AsyncDatabase,
|
||||
db: AsyncDatabase[Any],
|
||||
bucket_name: str = "fs",
|
||||
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
@ -513,11 +513,11 @@ class AsyncGridFSBucket:
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._collection = db[bucket_name]
|
||||
self._chunks: AsyncCollection = self._collection.chunks.with_options(
|
||||
self._chunks: AsyncCollection[Any] = self._collection.chunks.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
self._files: AsyncCollection = self._collection.files.with_options(
|
||||
self._files: AsyncCollection[Any] = self._collection.files.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
@ -1085,7 +1085,7 @@ class AsyncGridIn:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: AsyncCollection,
|
||||
root_collection: AsyncCollection[Any],
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -1172,7 +1172,7 @@ class AsyncGridIn:
|
||||
object.__setattr__(self, "_buffered_docs_size", 0)
|
||||
|
||||
async def _create_index(
|
||||
self, collection: AsyncCollection, index_key: Any, unique: bool
|
||||
self, collection: AsyncCollection[Any], index_key: Any, unique: bool
|
||||
) -> None:
|
||||
doc = await collection.find_one(projection={"_id": 1}, session=self._session)
|
||||
if doc is None:
|
||||
@ -1456,7 +1456,7 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: AsyncCollection,
|
||||
root_collection: AsyncCollection[Any],
|
||||
file_id: Optional[int] = None,
|
||||
file_document: Optional[Any] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
@ -1829,7 +1829,7 @@ class _AsyncGridOutChunkIterator:
|
||||
def __init__(
|
||||
self,
|
||||
grid_out: AsyncGridOut,
|
||||
chunks: AsyncCollection,
|
||||
chunks: AsyncCollection[Any],
|
||||
session: Optional[AsyncClientSession],
|
||||
next_chunk: Any,
|
||||
) -> None:
|
||||
@ -1842,7 +1842,7 @@ class _AsyncGridOutChunkIterator:
|
||||
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
|
||||
self._cursor = None
|
||||
|
||||
_cursor: Optional[AsyncCursor]
|
||||
_cursor: Optional[AsyncCursor[Any]]
|
||||
|
||||
def expected_chunk_length(self, chunk_n: int) -> int:
|
||||
if chunk_n < self._num_chunks - 1:
|
||||
@ -1921,7 +1921,7 @@ class _AsyncGridOutChunkIterator:
|
||||
|
||||
class AsyncGridOutIterator:
|
||||
def __init__(
|
||||
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
|
||||
self, grid_out: AsyncGridOut, chunks: AsyncCollection[Any], session: AsyncClientSession
|
||||
):
|
||||
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)
|
||||
|
||||
@ -1935,14 +1935,14 @@ class AsyncGridOutIterator:
|
||||
__anext__ = next
|
||||
|
||||
|
||||
class AsyncGridOutCursor(AsyncCursor):
|
||||
class AsyncGridOutCursor(AsyncCursor): # type: ignore[type-arg]
|
||||
"""A cursor / iterator for returning GridOut objects as the result
|
||||
of an arbitrary query against the GridFS files collection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: AsyncCollection,
|
||||
collection: AsyncCollection[Any],
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 0,
|
||||
|
||||
@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None:
|
||||
class GridFS:
|
||||
"""An instance of GridFS on top of a single Database."""
|
||||
|
||||
def __init__(self, database: Database, collection: str = "fs"):
|
||||
def __init__(self, database: Database[Any], collection: str = "fs"):
|
||||
"""Create a new instance of :class:`GridFS`.
|
||||
|
||||
Raises :class:`TypeError` if `database` is not an instance of
|
||||
@ -461,7 +461,7 @@ class GridFSBucket:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
db: Database[Any],
|
||||
bucket_name: str = "fs",
|
||||
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
@ -511,11 +511,11 @@ class GridFSBucket:
|
||||
|
||||
self._bucket_name = bucket_name
|
||||
self._collection = db[bucket_name]
|
||||
self._chunks: Collection = self._collection.chunks.with_options(
|
||||
self._chunks: Collection[Any] = self._collection.chunks.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
self._files: Collection = self._collection.files.with_options(
|
||||
self._files: Collection[Any] = self._collection.files.with_options(
|
||||
write_concern=write_concern, read_preference=read_preference
|
||||
)
|
||||
|
||||
@ -1077,7 +1077,7 @@ class GridIn:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: Collection,
|
||||
root_collection: Collection[Any],
|
||||
session: Optional[ClientSession] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -1163,7 +1163,7 @@ class GridIn:
|
||||
object.__setattr__(self, "_buffered_docs", [])
|
||||
object.__setattr__(self, "_buffered_docs_size", 0)
|
||||
|
||||
def _create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
|
||||
def _create_index(self, collection: Collection[Any], index_key: Any, unique: bool) -> None:
|
||||
doc = collection.find_one(projection={"_id": 1}, session=self._session)
|
||||
if doc is None:
|
||||
try:
|
||||
@ -1444,7 +1444,7 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_collection: Collection,
|
||||
root_collection: Collection[Any],
|
||||
file_id: Optional[int] = None,
|
||||
file_document: Optional[Any] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
@ -1817,7 +1817,7 @@ class GridOutChunkIterator:
|
||||
def __init__(
|
||||
self,
|
||||
grid_out: GridOut,
|
||||
chunks: Collection,
|
||||
chunks: Collection[Any],
|
||||
session: Optional[ClientSession],
|
||||
next_chunk: Any,
|
||||
) -> None:
|
||||
@ -1830,7 +1830,7 @@ class GridOutChunkIterator:
|
||||
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
|
||||
self._cursor = None
|
||||
|
||||
_cursor: Optional[Cursor]
|
||||
_cursor: Optional[Cursor[Any]]
|
||||
|
||||
def expected_chunk_length(self, chunk_n: int) -> int:
|
||||
if chunk_n < self._num_chunks - 1:
|
||||
@ -1908,7 +1908,7 @@ class GridOutChunkIterator:
|
||||
|
||||
|
||||
class GridOutIterator:
|
||||
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
|
||||
def __init__(self, grid_out: GridOut, chunks: Collection[Any], session: ClientSession):
|
||||
self._chunk_iter = GridOutChunkIterator(grid_out, chunks, session, 0)
|
||||
|
||||
def __iter__(self) -> GridOutIterator:
|
||||
@ -1921,14 +1921,14 @@ class GridOutIterator:
|
||||
__next__ = next
|
||||
|
||||
|
||||
class GridOutCursor(Cursor):
|
||||
class GridOutCursor(Cursor): # type: ignore[type-arg]
|
||||
"""A cursor / iterator for returning GridOut objects as the result
|
||||
of an arbitrary query against the GridFS files collection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection: Collection,
|
||||
collection: Collection[Any],
|
||||
filter: Optional[Mapping[str, Any]] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 0,
|
||||
|
||||
@ -93,7 +93,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._waiters: Optional[collections.deque] = None
|
||||
self._waiters: Optional[collections.deque[Any]] = None
|
||||
self._locked = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -196,7 +196,7 @@ class Condition(_ContextManagerMixin, _LoopBoundMixin):
|
||||
self.acquire = lock.acquire
|
||||
self.release = lock.release
|
||||
|
||||
self._waiters: collections.deque = collections.deque()
|
||||
self._waiters: collections.deque[Any] = collections.deque()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
res = super().__repr__()
|
||||
@ -260,7 +260,7 @@ class Condition(_ContextManagerMixin, _LoopBoundMixin):
|
||||
self._notify(1)
|
||||
raise
|
||||
|
||||
async def wait_for(self, predicate: Any) -> Coroutine:
|
||||
async def wait_for(self, predicate: Any) -> Coroutine[Any, Any, Any]:
|
||||
"""Wait until a predicate becomes true.
|
||||
|
||||
The predicate should be a callable whose result will be
|
||||
|
||||
@ -24,7 +24,7 @@ from typing import Any, Coroutine, Optional
|
||||
|
||||
|
||||
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
|
||||
class _Task(asyncio.Task):
|
||||
class _Task(asyncio.Task[Any]):
|
||||
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
|
||||
super().__init__(coro, name=name)
|
||||
self._cancel_requests = 0
|
||||
@ -43,7 +43,7 @@ class _Task(asyncio.Task):
|
||||
return self._cancel_requests
|
||||
|
||||
|
||||
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
|
||||
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task[Any]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return asyncio.create_task(coro, name=name)
|
||||
return _Task(coro, name=name)
|
||||
|
||||
@ -68,7 +68,7 @@ def clamp_remaining(max_timeout: float) -> float:
|
||||
return min(timeout, max_timeout)
|
||||
|
||||
|
||||
class _TimeoutContext(AbstractContextManager):
|
||||
class _TimeoutContext(AbstractContextManager[Any]):
|
||||
"""Internal timeout context manager.
|
||||
|
||||
Use :func:`pymongo.timeout` instead::
|
||||
|
||||
@ -46,8 +46,8 @@ class _AggregationCommand:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[AsyncDatabase, AsyncCollection],
|
||||
cursor_class: type[AsyncCommandCursor],
|
||||
target: Union[AsyncDatabase[Any], AsyncCollection[Any]],
|
||||
cursor_class: type[AsyncCommandCursor[Any]],
|
||||
pipeline: _Pipeline,
|
||||
options: MutableMapping[str, Any],
|
||||
explicit_session: bool,
|
||||
@ -111,12 +111,12 @@ class _AggregationCommand:
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection[Any]:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
def _database(self) -> AsyncDatabase[Any]:
|
||||
"""The database against which the aggregation command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -205,7 +205,7 @@ class _AggregationCommand:
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncCollection
|
||||
_target: AsyncCollection[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> str:
|
||||
@ -215,12 +215,12 @@ class _CollectionAggregationCommand(_AggregationCommand):
|
||||
def _cursor_namespace(self) -> str:
|
||||
return self._target.full_name
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
def _database(self) -> AsyncDatabase[Any]:
|
||||
return self._target.database
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
|
||||
|
||||
|
||||
class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
_target: AsyncDatabase
|
||||
_target: AsyncDatabase[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> int:
|
||||
@ -245,10 +245,10 @@ class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
return f"{self._target.name}.$cmd.aggregate"
|
||||
|
||||
@property
|
||||
def _database(self) -> AsyncDatabase:
|
||||
def _database(self) -> AsyncDatabase[Any]:
|
||||
return self._target
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
|
||||
"""The AsyncCollection used for the aggregate command cursor."""
|
||||
# AsyncCollection level aggregate may not always return the "ns" field
|
||||
# according to our MockupDB tests. Let's handle that case for db level
|
||||
|
||||
@ -259,7 +259,7 @@ class _OIDCAuthenticator:
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = await self._get_access_token()
|
||||
|
||||
@ -248,7 +248,7 @@ class _AsyncBulk:
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
@ -334,7 +334,7 @@ class _AsyncBulk:
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
@ -419,7 +419,7 @@ class _AsyncBulk:
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
@ -446,7 +446,7 @@ class _AsyncBulk:
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
|
||||
@ -164,7 +164,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncMongoClient:
|
||||
def _client(self) -> AsyncMongoClient: # type: ignore[type-arg]
|
||||
"""The client against which the aggregation commands for
|
||||
this AsyncChangeStream will be run.
|
||||
"""
|
||||
@ -206,7 +206,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this AsyncChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
@ -237,7 +237,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
|
||||
|
||||
async def _run_aggregation_cmd(
|
||||
self, session: Optional[AsyncClientSession], explicit_session: bool
|
||||
) -> AsyncCommandCursor:
|
||||
) -> AsyncCommandCursor: # type: ignore[type-arg]
|
||||
"""Run the full aggregation pipeline for this AsyncChangeStream and return
|
||||
the corresponding AsyncCommandCursor.
|
||||
"""
|
||||
@ -257,7 +257,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
async def _create_cursor(self) -> AsyncCommandCursor:
|
||||
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
|
||||
async with self._client._tmp_session(self._session, close=False) as s:
|
||||
return await self._run_aggregation_cmd(
|
||||
session=s, explicit_session=self._session is not None
|
||||
|
||||
@ -88,7 +88,7 @@ class _AsyncClientBulk:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
write_concern: WriteConcern,
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
@ -233,7 +233,7 @@ class _AsyncClientBulk:
|
||||
msg: Union[bytes, dict[str, Any]],
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for AsyncConnection.write_command that handles event publishing."""
|
||||
cmd["ops"] = op_docs
|
||||
@ -324,7 +324,7 @@ class _AsyncClientBulk:
|
||||
msg: bytes,
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
|
||||
@ -396,7 +396,7 @@ class _TxnState:
|
||||
class _Transaction:
|
||||
"""Internal class to hold transaction information in a AsyncClientSession."""
|
||||
|
||||
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient):
|
||||
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[Any]):
|
||||
self.opts = opts
|
||||
self.state = _TxnState.NONE
|
||||
self.sharded = False
|
||||
@ -459,7 +459,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
|
||||
|
||||
# From the transactions spec, all the retryable writes errors plus
|
||||
# WriteConcernTimeout.
|
||||
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset(
|
||||
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg]
|
||||
[
|
||||
64, # WriteConcernTimeout
|
||||
50, # MaxTimeMSExpired
|
||||
@ -499,13 +499,13 @@ class AsyncClientSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
server_session: Any,
|
||||
options: SessionOptions,
|
||||
implicit: bool,
|
||||
) -> None:
|
||||
# An AsyncMongoClient, a _ServerSession, a SessionOptions, and a set.
|
||||
self._client: AsyncMongoClient = client
|
||||
self._client: AsyncMongoClient[Any] = client
|
||||
self._server_session = server_session
|
||||
self._options = options
|
||||
self._cluster_time: Optional[Mapping[str, Any]] = None
|
||||
@ -551,7 +551,7 @@ class AsyncClientSession:
|
||||
await self._end_session(lock=True)
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncMongoClient:
|
||||
def client(self) -> AsyncMongoClient[Any]:
|
||||
"""The :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` this session was
|
||||
created from.
|
||||
"""
|
||||
@ -751,7 +751,7 @@ class AsyncClientSession:
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
max_commit_time_ms: Optional[int] = None,
|
||||
) -> AsyncContextManager:
|
||||
) -> AsyncContextManager[Any]:
|
||||
"""Start a multi-statement transaction.
|
||||
|
||||
Takes the same arguments as :class:`TransactionOptions`.
|
||||
@ -1123,7 +1123,7 @@ class _ServerSession:
|
||||
self._transaction_id += 1
|
||||
|
||||
|
||||
class _ServerSessionPool(collections.deque):
|
||||
class _ServerSessionPool(collections.deque): # type: ignore[type-arg]
|
||||
"""Pool of _ServerSession objects.
|
||||
|
||||
This class is thread-safe.
|
||||
|
||||
@ -581,7 +581,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
conn: AsyncConnection,
|
||||
command: MutableMapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
@ -704,7 +704,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
) -> BulkWriteResult:
|
||||
"""Send a batch of write operations to the server.
|
||||
|
||||
@ -2525,7 +2525,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
|
||||
codec_options: CodecOptions = CodecOptions(SON)
|
||||
codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON)
|
||||
coll = cast(
|
||||
AsyncCollection[MutableMapping[str, Any]],
|
||||
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
|
||||
@ -2871,7 +2871,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
aggregation_command: Type[_AggregationCommand],
|
||||
pipeline: _Pipeline,
|
||||
cursor_class: Type[AsyncCommandCursor],
|
||||
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
|
||||
session: Optional[AsyncClientSession],
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
@ -3114,7 +3114,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
comment: Optional[Any] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
**kwargs: Any,
|
||||
) -> list:
|
||||
) -> list[str]:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in this collection.
|
||||
|
||||
@ -3177,7 +3177,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
_server: Server,
|
||||
conn: AsyncConnection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> list:
|
||||
) -> list: # type: ignore[type-arg]
|
||||
return (
|
||||
await self._command(
|
||||
conn,
|
||||
@ -3202,7 +3202,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Internal findAndModify helper."""
|
||||
|
||||
@ -350,7 +350,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
await self._refresh()
|
||||
@ -457,7 +457,7 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[dict[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
|
||||
@ -216,7 +216,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
# it anytime we change __limit.
|
||||
self._empty = False
|
||||
|
||||
self._data: deque = deque()
|
||||
self._data: deque = deque() # type: ignore[type-arg]
|
||||
self._address: Optional[_Address] = None
|
||||
self._retrieved = 0
|
||||
|
||||
@ -280,7 +280,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
"""
|
||||
return self._clone(True)
|
||||
|
||||
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor:
|
||||
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
|
||||
"""Internal clone helper."""
|
||||
if not base:
|
||||
if self._explicit_session:
|
||||
@ -322,7 +322,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
base.__dict__.update(data)
|
||||
return base
|
||||
|
||||
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor:
|
||||
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: # type: ignore[type-arg]
|
||||
"""Creates an empty AsyncCursor object for information to be copied into."""
|
||||
return self.__class__(self._collection, session=session)
|
||||
|
||||
@ -864,7 +864,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
if self._has_filter:
|
||||
spec = dict(self._spec)
|
||||
else:
|
||||
spec = cast(dict, self._spec)
|
||||
spec = cast(dict, self._spec) # type: ignore[type-arg]
|
||||
spec["$where"] = code
|
||||
self._spec = spec
|
||||
return self
|
||||
@ -888,7 +888,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions, # type: ignore[type-arg]
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
@ -964,29 +964,33 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
return self._clone(deepcopy=True)
|
||||
|
||||
@overload
|
||||
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list:
|
||||
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg]
|
||||
...
|
||||
|
||||
@overload
|
||||
def _deepcopy(
|
||||
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
|
||||
) -> dict:
|
||||
self,
|
||||
x: SupportsItems, # type: ignore[type-arg]
|
||||
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
|
||||
) -> dict: # type: ignore[type-arg]
|
||||
...
|
||||
|
||||
def _deepcopy(
|
||||
self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None
|
||||
) -> Union[list, dict]:
|
||||
self,
|
||||
x: Union[Iterable, SupportsItems], # type: ignore[type-arg]
|
||||
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
|
||||
) -> Union[list[Any], dict[str, Any]]:
|
||||
"""Deepcopy helper for the data dictionary or list.
|
||||
|
||||
Regular expressions cannot be deep copied but as they are immutable we
|
||||
don't have to copy them when cloning.
|
||||
"""
|
||||
y: Union[list, dict]
|
||||
y: Union[list[Any], dict[str, Any]]
|
||||
iterator: Iterable[tuple[Any, Any]]
|
||||
if not hasattr(x, "items"):
|
||||
y, is_list, iterator = [], True, enumerate(x)
|
||||
else:
|
||||
y, is_list, iterator = {}, False, cast("SupportsItems", x).items()
|
||||
y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg]
|
||||
if memo is None:
|
||||
memo = {}
|
||||
val_id = id(x)
|
||||
@ -1060,7 +1064,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
"""Explicitly close / kill this cursor."""
|
||||
await self._die_lock()
|
||||
|
||||
async def distinct(self, key: str) -> list:
|
||||
async def distinct(self, key: str) -> list[str]:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in the result set of this query.
|
||||
|
||||
@ -1265,7 +1269,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
@ -1325,7 +1329,7 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
return res
|
||||
|
||||
|
||||
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]):
|
||||
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): # type: ignore[type-arg]
|
||||
"""An asynchronous cursor / iterator over raw batches of BSON data from a query result."""
|
||||
|
||||
_query_class = _RawBatchQuery
|
||||
|
||||
@ -771,7 +771,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
|
||||
self._name,
|
||||
command,
|
||||
read_preference,
|
||||
codec_options,
|
||||
codec_options, # type: ignore[arg-type]
|
||||
check,
|
||||
allowable_errors,
|
||||
write_concern=write_concern,
|
||||
|
||||
@ -161,10 +161,10 @@ _ReadCall = Callable[
|
||||
_IS_SYNC = False
|
||||
|
||||
_WriteOp = Union[
|
||||
InsertOne,
|
||||
InsertOne, # type: ignore[type-arg]
|
||||
DeleteOne,
|
||||
DeleteMany,
|
||||
ReplaceOne,
|
||||
ReplaceOne, # type: ignore[type-arg]
|
||||
UpdateOne,
|
||||
UpdateMany,
|
||||
]
|
||||
@ -176,7 +176,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Define order to retrieve options from ClientOptions for __repr__.
|
||||
# No host/port; these are retrieved from TopologySettings.
|
||||
_constructor_args = ("document_class", "tz_aware", "connect")
|
||||
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -847,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _async_create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
self._kill_cursors_queue: list = [] # type: ignore[type-arg]
|
||||
|
||||
self._encrypter: Optional[_Encrypter] = None
|
||||
|
||||
@ -1064,7 +1064,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Reset the session pool to avoid duplicate sessions in the child process.
|
||||
self._topology._session_pool.reset()
|
||||
|
||||
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient:
|
||||
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient: # type: ignore[type-arg]
|
||||
args = self._init_kwargs.copy()
|
||||
args.update(kwargs)
|
||||
return AsyncMongoClient(**args)
|
||||
@ -1548,7 +1548,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self, name, codec_options, read_preference, write_concern, read_concern
|
||||
)
|
||||
|
||||
def _database_default_options(self, name: str) -> database.AsyncDatabase:
|
||||
def _database_default_options(self, name: str) -> database.AsyncDatabase: # type: ignore[type-arg]
|
||||
"""Get a AsyncDatabase instance with the default settings."""
|
||||
return self.get_database(
|
||||
name,
|
||||
@ -1887,7 +1887,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _run_operation(
|
||||
self,
|
||||
operation: Union[_Query, _GetMore],
|
||||
unpack_res: Callable,
|
||||
unpack_res: Callable, # type: ignore[type-arg]
|
||||
address: Optional[_Address] = None,
|
||||
) -> Response:
|
||||
"""Run a _Query/_GetMore operation and return a Response.
|
||||
@ -2261,7 +2261,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
@contextlib.asynccontextmanager
|
||||
async def _tmp_session(
|
||||
self, session: Optional[client_session.AsyncClientSession], close: bool = True
|
||||
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]:
|
||||
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.AsyncClientSession):
|
||||
@ -2308,8 +2308,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
return cast(
|
||||
dict,
|
||||
return cast( # type: ignore[redundant-cast]
|
||||
dict[str, Any],
|
||||
await self.admin.command(
|
||||
"buildinfo", read_preference=ReadPreference.PRIMARY, session=session
|
||||
),
|
||||
@ -2438,13 +2438,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
@_csot.apply
|
||||
async def bulk_write(
|
||||
self,
|
||||
models: Sequence[_WriteOp[_DocumentType]],
|
||||
models: Sequence[_WriteOp],
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
ordered: bool = True,
|
||||
verbose_results: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> ClientBulkWriteResult:
|
||||
"""Send a batch of write operations, potentially across multiple namespaces, to the server.
|
||||
@ -2631,7 +2631,10 @@ class _MongoClientErrorHandler:
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
|
||||
self,
|
||||
client: AsyncMongoClient, # type: ignore[type-arg]
|
||||
server: Server,
|
||||
session: Optional[AsyncClientSession],
|
||||
):
|
||||
if not isinstance(client, AsyncMongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
@ -2704,7 +2707,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mongo_client: AsyncMongoClient,
|
||||
mongo_client: AsyncMongoClient, # type: ignore[type-arg]
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]],
|
||||
operation: str,
|
||||
|
||||
@ -351,7 +351,7 @@ class Monitor(MonitorBase):
|
||||
)
|
||||
return sd
|
||||
|
||||
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
|
||||
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: # type: ignore[type-arg]
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
@ -66,7 +66,7 @@ async def command(
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[AsyncClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
client: Optional[AsyncMongoClient[Any]],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
address: Optional[_Address] = None,
|
||||
|
||||
@ -201,7 +201,7 @@ class AsyncConnection:
|
||||
self.conn.get_conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(
|
||||
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
|
||||
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
|
||||
) -> Optional[float]:
|
||||
# CSOT: use remaining timeout when set.
|
||||
timeout = _csot.remaining()
|
||||
@ -255,7 +255,7 @@ class AsyncConnection:
|
||||
else:
|
||||
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}
|
||||
|
||||
async def hello(self) -> Hello:
|
||||
async def hello(self) -> Hello[dict[str, Any]]:
|
||||
return await self._hello(None, None)
|
||||
|
||||
async def _hello(
|
||||
@ -357,7 +357,7 @@ class AsyncConnection:
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
|
||||
codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment]
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
@ -365,7 +365,7 @@ class AsyncConnection:
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
client: Optional[AsyncMongoClient] = None,
|
||||
client: Optional[AsyncMongoClient[Any]] = None,
|
||||
retryable_write: bool = False,
|
||||
publish_events: bool = True,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
@ -417,7 +417,7 @@ class AsyncConnection:
|
||||
spec,
|
||||
self.is_mongos,
|
||||
read_preference,
|
||||
codec_options,
|
||||
codec_options, # type: ignore[arg-type]
|
||||
session,
|
||||
client,
|
||||
check,
|
||||
@ -489,7 +489,7 @@ class AsyncConnection:
|
||||
await self.send_message(msg, max_doc_size)
|
||||
|
||||
async def write_command(
|
||||
self, request_id: int, msg: bytes, codec_options: CodecOptions
|
||||
self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Send "insert" etc. command, returning response as a dict.
|
||||
|
||||
@ -541,7 +541,7 @@ class AsyncConnection:
|
||||
)
|
||||
|
||||
def validate_session(
|
||||
self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession]
|
||||
self, client: Optional[AsyncMongoClient[Any]], session: Optional[AsyncClientSession]
|
||||
) -> None:
|
||||
"""Validate this session before use with client.
|
||||
|
||||
@ -598,7 +598,7 @@ class AsyncConnection:
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
session: Optional[AsyncClientSession],
|
||||
client: Optional[AsyncMongoClient],
|
||||
client: Optional[AsyncMongoClient[Any]],
|
||||
) -> None:
|
||||
"""Add $clusterTime."""
|
||||
if client:
|
||||
@ -732,7 +732,7 @@ class Pool:
|
||||
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
|
||||
# and returned to pool from the left side. Stale sockets removed
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.conns: collections.deque[AsyncConnection] = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _async_create_lock()
|
||||
self._max_connecting_cond = _async_create_condition(self.lock)
|
||||
@ -839,8 +839,8 @@ class Pool:
|
||||
if service_id is None:
|
||||
sockets, self.conns = self.conns, collections.deque()
|
||||
else:
|
||||
discard: collections.deque = collections.deque()
|
||||
keep: collections.deque = collections.deque()
|
||||
discard: collections.deque = collections.deque() # type: ignore[type-arg]
|
||||
keep: collections.deque = collections.deque() # type: ignore[type-arg]
|
||||
for conn in self.conns:
|
||||
if conn.service_id == service_id:
|
||||
discard.append(conn)
|
||||
@ -866,7 +866,7 @@ class Pool:
|
||||
if close:
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -903,7 +903,7 @@ class Pool:
|
||||
)
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -917,7 +917,7 @@ class Pool:
|
||||
self.is_writable = is_writable
|
||||
async with self.lock:
|
||||
for _socket in self.conns:
|
||||
_socket.update_is_writable(self.is_writable)
|
||||
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
|
||||
|
||||
async def reset(
|
||||
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
|
||||
@ -956,7 +956,7 @@ class Pool:
|
||||
close_conns.append(self.conns.pop())
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -1477,4 +1477,4 @@ class Pool:
|
||||
# not safe to acquire a lock in __del__.
|
||||
if _IS_SYNC:
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
conn.close_conn(None) # type: ignore[unused-coroutine]
|
||||
|
||||
@ -66,7 +66,7 @@ class Server:
|
||||
monitor: Monitor,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
events: Optional[ReferenceType[Queue]] = None,
|
||||
events: Optional[ReferenceType[Queue[Any]]] = None,
|
||||
) -> None:
|
||||
"""Represent one MongoDB server."""
|
||||
self._description = server_description
|
||||
@ -142,7 +142,7 @@ class Server:
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., list[_DocumentOut]],
|
||||
client: AsyncMongoClient,
|
||||
client: AsyncMongoClient[Any],
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ _IS_SYNC = False
|
||||
_pymongo_dir = str(Path(__file__).parent)
|
||||
|
||||
|
||||
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
|
||||
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg]
|
||||
q = queue_ref()
|
||||
if not q:
|
||||
return False # Cancel PeriodicExecutor.
|
||||
@ -186,7 +186,7 @@ class Topology:
|
||||
|
||||
if self._publish_server or self._publish_tp:
|
||||
assert self._events is not None
|
||||
weak: weakref.ReferenceType[queue.Queue]
|
||||
weak: weakref.ReferenceType[queue.Queue[Any]]
|
||||
|
||||
async def target() -> bool:
|
||||
return process_events_queue(weak)
|
||||
|
||||
@ -247,7 +247,7 @@ class ClientOptions:
|
||||
return self.__connect
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
def codec_options(self) -> CodecOptions[Any]:
|
||||
"""A :class:`~bson.codec_options.CodecOptions` instance."""
|
||||
return self.__codec_options
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.typings import _AgnosticClientSession
|
||||
|
||||
|
||||
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
|
||||
ORDERED_TYPES: Sequence[Type[Any]] = (SON, OrderedDict)
|
||||
|
||||
# Defaults until we connect to a server and get updated limits.
|
||||
MAX_BSON_SIZE = 16 * (1024**2)
|
||||
@ -166,7 +166,7 @@ def clean_node(node: str) -> tuple[str, int]:
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn:
|
||||
def raise_config_error(key: str, suggestions: Optional[list[str]] = None) -> NoReturn:
|
||||
"""Raise ConfigurationError with the given key name."""
|
||||
msg = f"Unknown option: {key}."
|
||||
if suggestions:
|
||||
@ -411,7 +411,7 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
|
||||
tag_sets: list = []
|
||||
tag_sets: list[dict[str, Any]] = []
|
||||
for tag_set in value:
|
||||
if tag_set == "":
|
||||
tag_sets.append({})
|
||||
@ -497,7 +497,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
|
||||
|
||||
def validate_document_class(
|
||||
option: str, value: Any
|
||||
) -> Union[Type[MutableMapping], Type[RawBSONDocument]]:
|
||||
) -> Union[Type[MutableMapping[str, Any]], Type[RawBSONDocument]]:
|
||||
"""Validate the document_class option."""
|
||||
# issubclass can raise TypeError for generic aliases like SON[str, Any].
|
||||
# In that case we can use the base class for the comparison.
|
||||
@ -523,14 +523,14 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
|
||||
return value
|
||||
|
||||
|
||||
def validate_list(option: str, value: Any) -> list:
|
||||
def validate_list(option: str, value: Any) -> list[Any]:
|
||||
"""Validates that 'value' is a list."""
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"{option} must be a list, not {type(value)}")
|
||||
return value
|
||||
|
||||
|
||||
def validate_list_or_none(option: Any, value: Any) -> Optional[list]:
|
||||
def validate_list_or_none(option: Any, value: Any) -> Optional[list[Any]]:
|
||||
"""Validates that 'value' is a list or None."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -597,7 +597,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
|
||||
return value
|
||||
|
||||
|
||||
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
|
||||
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable[..., Any]]:
|
||||
"""Validates that 'value' is a callable."""
|
||||
if value is None:
|
||||
return value
|
||||
@ -829,7 +829,7 @@ def validate_auth_option(option: str, value: Any) -> tuple[str, Any]:
|
||||
|
||||
def _get_validator(
|
||||
key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None
|
||||
) -> Callable:
|
||||
) -> Callable[[Any, Any], Any]:
|
||||
normed_key = normed_key or key
|
||||
try:
|
||||
return validators[normed_key]
|
||||
@ -917,7 +917,7 @@ class BaseObject:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[Any],
|
||||
read_preference: _ServerMode,
|
||||
write_concern: WriteConcern,
|
||||
read_concern: ReadConcern,
|
||||
@ -947,7 +947,7 @@ class BaseObject:
|
||||
self._read_concern = read_concern
|
||||
|
||||
@property
|
||||
def codec_options(self) -> CodecOptions:
|
||||
def codec_options(self) -> CodecOptions[Any]:
|
||||
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
|
||||
of this instance.
|
||||
"""
|
||||
|
||||
@ -37,7 +37,7 @@ from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
||||
from pymongo.typings import _AgnosticMongoClient
|
||||
|
||||
|
||||
class AutoEncryptionOpts:
|
||||
@ -47,7 +47,7 @@ class AutoEncryptionOpts:
|
||||
self,
|
||||
kms_providers: Mapping[str, Any],
|
||||
key_vault_namespace: str,
|
||||
key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None,
|
||||
key_vault_client: Optional[_AgnosticMongoClient] = None,
|
||||
schema_map: Optional[Mapping[str, Any]] = None,
|
||||
bypass_auto_encryption: bool = False,
|
||||
mongocryptd_uri: str = "mongodb://localhost:27020",
|
||||
|
||||
@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
||||
|
||||
# From the SDAM spec, the "node is shutting down" codes.
|
||||
|
||||
_SHUTDOWN_CODES: frozenset = frozenset(
|
||||
_SHUTDOWN_CODES: frozenset[int] = frozenset(
|
||||
[
|
||||
11600, # InterruptedAtShutdown
|
||||
91, # ShutdownInProgress
|
||||
@ -61,7 +61,7 @@ _SHUTDOWN_CODES: frozenset = frozenset(
|
||||
# From the SDAM spec, the "not primary" error codes are combined with the
|
||||
# "node is recovering" error codes (of which the "node is shutting down"
|
||||
# errors are a subset).
|
||||
_NOT_PRIMARY_CODES: frozenset = (
|
||||
_NOT_PRIMARY_CODES: frozenset[int] = (
|
||||
frozenset(
|
||||
[
|
||||
10058, # LegacyNotPrimary <=3.2 "not primary" error code
|
||||
@ -75,7 +75,7 @@ _NOT_PRIMARY_CODES: frozenset = (
|
||||
| _SHUTDOWN_CODES
|
||||
)
|
||||
# From the retryable writes spec.
|
||||
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
|
||||
_RETRYABLE_ERROR_CODES: frozenset[int] = _NOT_PRIMARY_CODES | frozenset(
|
||||
[
|
||||
7, # HostNotFound
|
||||
6, # HostUnreachable
|
||||
@ -95,7 +95,7 @@ _AUTHENTICATION_FAILURE_CODE: int = 18
|
||||
# Note - to avoid bugs from forgetting which if these is all lowercase and
|
||||
# which are camelCase, and at the same time avoid having to add a test for
|
||||
# every command, use all lowercase here and test against command_name.lower().
|
||||
_SENSITIVE_COMMANDS: set = {
|
||||
_SENSITIVE_COMMANDS: set[str] = {
|
||||
"authenticate",
|
||||
"saslstart",
|
||||
"saslcontinue",
|
||||
|
||||
@ -333,7 +333,7 @@ def _op_msg_no_header(
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[list[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
) -> tuple[bytes, int, int]:
|
||||
"""Get a OP_MSG message.
|
||||
|
||||
@ -365,7 +365,7 @@ def _op_msg_compressed(
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[list[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
||||
) -> tuple[int, bytes, int, int]:
|
||||
"""Internal OP_MSG message helper."""
|
||||
@ -379,7 +379,7 @@ def _op_msg_uncompressed(
|
||||
command: Mapping[str, Any],
|
||||
identifier: str,
|
||||
docs: Optional[list[Mapping[str, Any]]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
) -> tuple[int, bytes, int, int]:
|
||||
"""Internal compressed OP_MSG message helper."""
|
||||
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
|
||||
@ -396,7 +396,7 @@ def _op_msg(
|
||||
command: MutableMapping[str, Any],
|
||||
dbname: str,
|
||||
read_preference: Optional[_ServerMode],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
) -> tuple[int, bytes, int, int]:
|
||||
"""Get a OP_MSG message."""
|
||||
@ -430,7 +430,7 @@ def _query_impl(
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
) -> tuple[bytes, int]:
|
||||
"""Get an OP_QUERY message."""
|
||||
encoded = _dict_to_bson(query, False, opts)
|
||||
@ -461,7 +461,7 @@ def _query_compressed(
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
|
||||
) -> tuple[int, bytes, int]:
|
||||
"""Internal compressed query message helper."""
|
||||
@ -479,7 +479,7 @@ def _query_uncompressed(
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
) -> tuple[int, bytes, int]:
|
||||
"""Internal query message helper."""
|
||||
op_query, max_bson_size = _query_impl(
|
||||
@ -500,7 +500,7 @@ def _query(
|
||||
num_to_return: int,
|
||||
query: Mapping[str, Any],
|
||||
field_selector: Optional[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
) -> tuple[int, bytes, int]:
|
||||
"""Get a **query** message."""
|
||||
@ -598,7 +598,7 @@ class _BulkWriteContextBase:
|
||||
listeners: _EventListeners,
|
||||
session: Optional[_AgnosticClientSession],
|
||||
op_type: int,
|
||||
codec: CodecOptions,
|
||||
codec: CodecOptions[Any],
|
||||
):
|
||||
self.db_name = database_name
|
||||
self.conn = conn
|
||||
@ -679,7 +679,7 @@ class _BulkWriteContext(_BulkWriteContextBase):
|
||||
listeners: _EventListeners,
|
||||
session: Optional[_AgnosticClientSession],
|
||||
op_type: int,
|
||||
codec: CodecOptions,
|
||||
codec: CodecOptions[Any],
|
||||
):
|
||||
super().__init__(
|
||||
database_name,
|
||||
@ -771,7 +771,7 @@ def _batched_op_msg_impl(
|
||||
command: Mapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
buf: _BytesIO,
|
||||
) -> tuple[list[Mapping[str, Any]], int]:
|
||||
@ -839,7 +839,7 @@ def _encode_batched_op_msg(
|
||||
command: Mapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
) -> tuple[bytes, list[Mapping[str, Any]]]:
|
||||
"""Encode the next batched insert, update, or delete operation
|
||||
@ -860,7 +860,7 @@ def _batched_op_msg_compressed(
|
||||
command: Mapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
||||
"""Create the next batched insert, update, or delete operation
|
||||
@ -878,7 +878,7 @@ def _batched_op_msg(
|
||||
command: Mapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
||||
"""OP_MSG implementation entry point."""
|
||||
@ -910,7 +910,7 @@ def _do_batched_op_msg(
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
|
||||
"""Create the next batched insert, update, or delete operation
|
||||
@ -939,7 +939,7 @@ class _ClientBulkWriteContext(_BulkWriteContextBase):
|
||||
operation_id: int,
|
||||
listeners: _EventListeners,
|
||||
session: Optional[_AgnosticClientSession],
|
||||
codec: CodecOptions,
|
||||
codec: CodecOptions[Any],
|
||||
):
|
||||
super().__init__(
|
||||
database_name,
|
||||
@ -1043,7 +1043,7 @@ def _client_batched_op_msg_impl(
|
||||
operations: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _ClientBulkWriteContext,
|
||||
buf: _BytesIO,
|
||||
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], int]:
|
||||
@ -1161,7 +1161,7 @@ def _client_encode_batched_op_msg(
|
||||
operations: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _ClientBulkWriteContext,
|
||||
) -> tuple[bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Encode the next batched client-level bulkWrite
|
||||
@ -1180,7 +1180,7 @@ def _client_batched_op_msg_compressed(
|
||||
operations: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _ClientBulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Create the next batched client-level bulkWrite operation
|
||||
@ -1200,7 +1200,7 @@ def _client_batched_op_msg(
|
||||
operations: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
ack: bool,
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _ClientBulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""OP_MSG implementation entry point for client-level bulkWrite."""
|
||||
@ -1229,7 +1229,7 @@ def _client_do_batched_op_msg(
|
||||
command: MutableMapping[str, Any],
|
||||
operations: list[tuple[str, Mapping[str, Any]]],
|
||||
namespaces: list[str],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _ClientBulkWriteContext,
|
||||
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
|
||||
"""Create the next batched client-level bulkWrite
|
||||
@ -1253,7 +1253,7 @@ def _encode_batched_write_command(
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
) -> tuple[bytes, list[Mapping[str, Any]]]:
|
||||
"""Encode the next batched insert, update, or delete command."""
|
||||
@ -1272,7 +1272,7 @@ def _batched_write_command_impl(
|
||||
operation: int,
|
||||
command: MutableMapping[str, Any],
|
||||
docs: list[Mapping[str, Any]],
|
||||
opts: CodecOptions,
|
||||
opts: CodecOptions[Any],
|
||||
ctx: _BulkWriteContext,
|
||||
buf: _BytesIO,
|
||||
) -> tuple[list[Mapping[str, Any]], int]:
|
||||
@ -1383,7 +1383,7 @@ class _OpReply:
|
||||
errobj = {"ok": 0, "errmsg": msg, "code": 43}
|
||||
raise CursorNotFound(msg, 43, errobj)
|
||||
elif self.flags & 2:
|
||||
error_object: dict = bson.BSON(self.documents).decode()
|
||||
error_object: dict[str, Any] = bson.BSON(self.documents).decode()
|
||||
# Fake the ok field if it doesn't exist.
|
||||
error_object.setdefault("ok", 0)
|
||||
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
|
||||
@ -1405,7 +1405,7 @@ class _OpReply:
|
||||
def unpack_response(
|
||||
self,
|
||||
cursor_id: Optional[int] = None,
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
@ -1431,7 +1431,7 @@ class _OpReply:
|
||||
return bson.decode_all(self.documents, codec_options)
|
||||
return bson._decode_all_selective(self.documents, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
|
||||
def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
docs = self.unpack_response(codec_options=codec_options)
|
||||
assert self.number_returned == 1
|
||||
@ -1491,7 +1491,7 @@ class _OpMsg:
|
||||
def unpack_response(
|
||||
self,
|
||||
cursor_id: Optional[int] = None,
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
@ -1508,7 +1508,7 @@ class _OpMsg:
|
||||
assert not legacy_response
|
||||
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
|
||||
def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
return self.unpack_response(codec_options=codec_options)[0]
|
||||
|
||||
@ -1583,7 +1583,7 @@ class _Query:
|
||||
ntoskip: int,
|
||||
spec: Mapping[str, Any],
|
||||
fields: Optional[Mapping[str, Any]],
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[Any],
|
||||
read_preference: _ServerMode,
|
||||
limit: int,
|
||||
batch_size: int,
|
||||
@ -1757,7 +1757,7 @@ class _GetMore:
|
||||
coll: str,
|
||||
ntoreturn: int,
|
||||
cursor_id: int,
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[Any],
|
||||
read_preference: _ServerMode,
|
||||
session: Optional[_AgnosticClientSession],
|
||||
client: _AgnosticMongoClient,
|
||||
@ -1871,7 +1871,7 @@ class _RawBatchGetMore(_GetMore):
|
||||
return False
|
||||
|
||||
|
||||
class _CursorAddress(tuple):
|
||||
class _CursorAddress(tuple[Any, ...]):
|
||||
"""The server address (host, port) of a cursor, with namespace property."""
|
||||
|
||||
__namespace: Any
|
||||
|
||||
@ -1347,7 +1347,11 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
|
||||
__slots__ = ("__duration", "__reply")
|
||||
|
||||
def __init__(
|
||||
self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False
|
||||
self,
|
||||
duration: float,
|
||||
reply: Hello[dict[str, Any]],
|
||||
connection_id: _Address,
|
||||
awaited: bool = False,
|
||||
) -> None:
|
||||
super().__init__(connection_id, awaited)
|
||||
self.__duration = duration
|
||||
@ -1359,7 +1363,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
|
||||
return self.__duration
|
||||
|
||||
@property
|
||||
def reply(self) -> Hello:
|
||||
def reply(self) -> Hello[dict[str, Any]]:
|
||||
"""An instance of :class:`~pymongo.hello.Hello`."""
|
||||
return self.__reply
|
||||
|
||||
@ -1647,7 +1651,7 @@ class _EventListeners:
|
||||
_handle_exception()
|
||||
|
||||
def publish_server_heartbeat_succeeded(
|
||||
self, connection_id: _Address, duration: float, reply: Hello, awaited: bool
|
||||
self, connection_id: _Address, duration: float, reply: Hello[dict[str, Any]], awaited: bool
|
||||
) -> None:
|
||||
"""Publish a ServerHeartbeatSucceededEvent to all server heartbeat
|
||||
listeners.
|
||||
|
||||
@ -96,7 +96,7 @@ if sys.platform != "win32":
|
||||
view = memoryview(buf)
|
||||
sent = 0
|
||||
|
||||
def _is_ready(fut: Future) -> None:
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
@ -139,7 +139,7 @@ if sys.platform != "win32":
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
|
||||
def _is_ready(fut: Future) -> None:
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
@ -486,15 +486,15 @@ class PyMongoProtocol(BufferedProtocol):
|
||||
self._message_size = 0
|
||||
self._op_code = 0
|
||||
self._connection_lost = False
|
||||
self._read_waiter: Optional[Future] = None
|
||||
self._read_waiter: Optional[Future[Any]] = None
|
||||
self._timeout = timeout
|
||||
self._is_compressed = False
|
||||
self._compressor_id: Optional[int] = None
|
||||
self._max_message_size = MAX_MESSAGE_SIZE
|
||||
self._response_to: Optional[int] = None
|
||||
self._closed = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages: collections.deque[Future] = collections.deque()
|
||||
self._done_messages: collections.deque[Future] = collections.deque()
|
||||
self._pending_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
self._done_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self._timeout = timeout
|
||||
|
||||
@ -53,7 +53,7 @@ class AsyncPeriodicExecutor:
|
||||
self._min_interval = min_interval
|
||||
self._target = target
|
||||
self._stopped = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: Optional[asyncio.Task[Any]] = None
|
||||
self._name = name
|
||||
self._skip_sleep = False
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class ServerDescription:
|
||||
def __init__(
|
||||
self,
|
||||
address: _Address,
|
||||
hello: Optional[Hello] = None,
|
||||
hello: Optional[Hello[dict[str, Any]]] = None,
|
||||
round_trip_time: Optional[float] = None,
|
||||
error: Optional[Exception] = None,
|
||||
min_round_trip_time: float = 0.0,
|
||||
@ -299,4 +299,4 @@ class ServerDescription:
|
||||
)
|
||||
|
||||
# For unittesting only. Use under no circumstances!
|
||||
_host_to_round_trip_time: dict = {}
|
||||
_host_to_round_trip_time: dict = {} # type: ignore[type-arg]
|
||||
|
||||
@ -56,17 +56,22 @@ if HAVE_SSL:
|
||||
|
||||
if HAVE_PYSSL:
|
||||
PYSSLError: Any = _pyssl.SSLError
|
||||
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR)
|
||||
BLOCKING_IO_WRITE_ERROR: tuple = (
|
||||
BLOCKING_IO_ERRORS: tuple = ( # type: ignore[type-arg]
|
||||
_ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
|
||||
)
|
||||
BLOCKING_IO_READ_ERROR: tuple = ( # type: ignore[type-arg]
|
||||
_pyssl.BLOCKING_IO_READ_ERROR,
|
||||
_ssl.BLOCKING_IO_READ_ERROR,
|
||||
)
|
||||
BLOCKING_IO_WRITE_ERROR: tuple = ( # type: ignore[type-arg]
|
||||
_pyssl.BLOCKING_IO_WRITE_ERROR,
|
||||
_ssl.BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
else:
|
||||
PYSSLError = _ssl.SSLError
|
||||
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,)
|
||||
BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,)
|
||||
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS # type: ignore[type-arg, no-redef]
|
||||
BLOCKING_IO_READ_ERROR: tuple = (_ssl.BLOCKING_IO_READ_ERROR,) # type: ignore[type-arg, no-redef]
|
||||
BLOCKING_IO_WRITE_ERROR: tuple = (_ssl.BLOCKING_IO_WRITE_ERROR,) # type: ignore[type-arg, no-redef]
|
||||
SSLError = _ssl.SSLError
|
||||
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
|
||||
|
||||
@ -131,7 +136,7 @@ else:
|
||||
pass
|
||||
|
||||
IPADDR_SAFE = False
|
||||
BLOCKING_IO_ERRORS = ()
|
||||
BLOCKING_IO_ERRORS: tuple = () # type: ignore[type-arg, no-redef]
|
||||
|
||||
def _has_sni(is_sync: bool) -> bool: # noqa: ARG001
|
||||
return False
|
||||
|
||||
@ -46,8 +46,8 @@ class _AggregationCommand:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Database, Collection],
|
||||
cursor_class: type[CommandCursor],
|
||||
target: Union[Database[Any], Collection[Any]],
|
||||
cursor_class: type[CommandCursor[Any]],
|
||||
pipeline: _Pipeline,
|
||||
options: MutableMapping[str, Any],
|
||||
explicit_session: bool,
|
||||
@ -111,12 +111,12 @@ class _AggregationCommand:
|
||||
"""The namespace in which the aggregate command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection:
|
||||
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _database(self) -> Database:
|
||||
def _database(self) -> Database[Any]:
|
||||
"""The database against which the aggregation command is run."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -205,7 +205,7 @@ class _AggregationCommand:
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
_target: Collection
|
||||
_target: Collection[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> str:
|
||||
@ -215,12 +215,12 @@ class _CollectionAggregationCommand(_AggregationCommand):
|
||||
def _cursor_namespace(self) -> str:
|
||||
return self._target.full_name
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def _database(self) -> Database:
|
||||
def _database(self) -> Database[Any]:
|
||||
return self._target.database
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
|
||||
|
||||
|
||||
class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
_target: Database
|
||||
_target: Database[Any]
|
||||
|
||||
@property
|
||||
def _aggregation_target(self) -> int:
|
||||
@ -245,10 +245,10 @@ class _DatabaseAggregationCommand(_AggregationCommand):
|
||||
return f"{self._target.name}.$cmd.aggregate"
|
||||
|
||||
@property
|
||||
def _database(self) -> Database:
|
||||
def _database(self) -> Database[Any]:
|
||||
return self._target
|
||||
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
|
||||
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
|
||||
"""The Collection used for the aggregate command cursor."""
|
||||
# Collection level aggregate may not always return the "ns" field
|
||||
# according to our MockupDB tests. Let's handle that case for db level
|
||||
|
||||
@ -257,7 +257,7 @@ class _OIDCAuthenticator:
|
||||
) -> Mapping[str, Any]:
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
start_payload: dict = bson.decode(start_resp["payload"])
|
||||
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
|
||||
if "issuer" in start_payload:
|
||||
self.idp_info = OIDCIdPInfo(**start_payload)
|
||||
access_token = self._get_access_token()
|
||||
|
||||
@ -248,7 +248,7 @@ class _Bulk:
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
cmd[bwc.field] = docs
|
||||
@ -334,7 +334,7 @@ class _Bulk:
|
||||
msg: bytes,
|
||||
max_doc_size: int,
|
||||
docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
@ -419,7 +419,7 @@ class _Bulk:
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
@ -446,7 +446,7 @@ class _Bulk:
|
||||
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
|
||||
cmd: dict[str, Any],
|
||||
ops: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
|
||||
if self.is_encrypted:
|
||||
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
|
||||
|
||||
@ -164,7 +164,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _client(self) -> MongoClient:
|
||||
def _client(self) -> MongoClient: # type: ignore[type-arg]
|
||||
"""The client against which the aggregation commands for
|
||||
this ChangeStream will be run.
|
||||
"""
|
||||
@ -206,7 +206,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
|
||||
"""Return the full aggregation pipeline for this ChangeStream."""
|
||||
options = self._change_stream_options()
|
||||
full_pipeline: list = [{"$changeStream": options}]
|
||||
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
|
||||
full_pipeline.extend(self._pipeline)
|
||||
return full_pipeline
|
||||
|
||||
@ -237,7 +237,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
|
||||
def _run_aggregation_cmd(
|
||||
self, session: Optional[ClientSession], explicit_session: bool
|
||||
) -> CommandCursor:
|
||||
) -> CommandCursor: # type: ignore[type-arg]
|
||||
"""Run the full aggregation pipeline for this ChangeStream and return
|
||||
the corresponding CommandCursor.
|
||||
"""
|
||||
@ -257,7 +257,7 @@ class ChangeStream(Generic[_DocumentType]):
|
||||
operation=_Op.AGGREGATE,
|
||||
)
|
||||
|
||||
def _create_cursor(self) -> CommandCursor:
|
||||
def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg]
|
||||
with self._client._tmp_session(self._session, close=False) as s:
|
||||
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)
|
||||
|
||||
|
||||
@ -88,7 +88,7 @@ class _ClientBulk:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
write_concern: WriteConcern,
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
@ -233,7 +233,7 @@ class _ClientBulk:
|
||||
msg: Union[bytes, dict[str, Any]],
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> dict[str, Any]:
|
||||
"""A proxy for Connection.write_command that handles event publishing."""
|
||||
cmd["ops"] = op_docs
|
||||
@ -324,7 +324,7 @@ class _ClientBulk:
|
||||
msg: bytes,
|
||||
op_docs: list[Mapping[str, Any]],
|
||||
ns_docs: list[Mapping[str, Any]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
"""A proxy for Connection.unack_write that handles event publishing."""
|
||||
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
|
||||
|
||||
@ -395,7 +395,7 @@ class _TxnState:
|
||||
class _Transaction:
|
||||
"""Internal class to hold transaction information in a ClientSession."""
|
||||
|
||||
def __init__(self, opts: Optional[TransactionOptions], client: MongoClient):
|
||||
def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]):
|
||||
self.opts = opts
|
||||
self.state = _TxnState.NONE
|
||||
self.sharded = False
|
||||
@ -458,7 +458,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
|
||||
|
||||
# From the transactions spec, all the retryable writes errors plus
|
||||
# WriteConcernTimeout.
|
||||
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset(
|
||||
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg]
|
||||
[
|
||||
64, # WriteConcernTimeout
|
||||
50, # MaxTimeMSExpired
|
||||
@ -498,13 +498,13 @@ class ClientSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
server_session: Any,
|
||||
options: SessionOptions,
|
||||
implicit: bool,
|
||||
) -> None:
|
||||
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
|
||||
self._client: MongoClient = client
|
||||
self._client: MongoClient[Any] = client
|
||||
self._server_session = server_session
|
||||
self._options = options
|
||||
self._cluster_time: Optional[Mapping[str, Any]] = None
|
||||
@ -550,7 +550,7 @@ class ClientSession:
|
||||
self._end_session(lock=True)
|
||||
|
||||
@property
|
||||
def client(self) -> MongoClient:
|
||||
def client(self) -> MongoClient[Any]:
|
||||
"""The :class:`~pymongo.mongo_client.MongoClient` this session was
|
||||
created from.
|
||||
"""
|
||||
@ -748,7 +748,7 @@ class ClientSession:
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
max_commit_time_ms: Optional[int] = None,
|
||||
) -> ContextManager:
|
||||
) -> ContextManager[Any]:
|
||||
"""Start a multi-statement transaction.
|
||||
|
||||
Takes the same arguments as :class:`TransactionOptions`.
|
||||
@ -1118,7 +1118,7 @@ class _ServerSession:
|
||||
self._transaction_id += 1
|
||||
|
||||
|
||||
class _ServerSessionPool(collections.deque):
|
||||
class _ServerSessionPool(collections.deque): # type: ignore[type-arg]
|
||||
"""Pool of _ServerSession objects.
|
||||
|
||||
This class is thread-safe.
|
||||
|
||||
@ -582,7 +582,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
conn: Connection,
|
||||
command: MutableMapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
@ -703,7 +703,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
) -> BulkWriteResult:
|
||||
"""Send a batch of write operations to the server.
|
||||
|
||||
@ -2522,7 +2522,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> CommandCursor[MutableMapping[str, Any]]:
|
||||
codec_options: CodecOptions = CodecOptions(SON)
|
||||
codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON)
|
||||
coll = cast(
|
||||
Collection[MutableMapping[str, Any]],
|
||||
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
|
||||
@ -2864,7 +2864,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
aggregation_command: Type[_AggregationCommand],
|
||||
pipeline: _Pipeline,
|
||||
cursor_class: Type[CommandCursor],
|
||||
cursor_class: Type[CommandCursor], # type: ignore[type-arg]
|
||||
session: Optional[ClientSession],
|
||||
explicit_session: bool,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
@ -3107,7 +3107,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
comment: Optional[Any] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
**kwargs: Any,
|
||||
) -> list:
|
||||
) -> list[str]:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in this collection.
|
||||
|
||||
@ -3170,7 +3170,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
_server: Server,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
) -> list:
|
||||
) -> list: # type: ignore[type-arg]
|
||||
return (
|
||||
self._command(
|
||||
conn,
|
||||
@ -3195,7 +3195,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Internal findAndModify helper."""
|
||||
|
||||
@ -350,7 +350,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
self._refresh()
|
||||
@ -457,7 +457,7 @@ class RawBatchCommandCursor(CommandCursor[_DocumentType]):
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions[dict[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
|
||||
@ -216,7 +216,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
# it anytime we change __limit.
|
||||
self._empty = False
|
||||
|
||||
self._data: deque = deque()
|
||||
self._data: deque = deque() # type: ignore[type-arg]
|
||||
self._address: Optional[_Address] = None
|
||||
self._retrieved = 0
|
||||
|
||||
@ -280,7 +280,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
"""
|
||||
return self._clone(True)
|
||||
|
||||
def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor:
|
||||
def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: # type: ignore[type-arg]
|
||||
"""Internal clone helper."""
|
||||
if not base:
|
||||
if self._explicit_session:
|
||||
@ -322,7 +322,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
base.__dict__.update(data)
|
||||
return base
|
||||
|
||||
def _clone_base(self, session: Optional[ClientSession]) -> Cursor:
|
||||
def _clone_base(self, session: Optional[ClientSession]) -> Cursor: # type: ignore[type-arg]
|
||||
"""Creates an empty Cursor object for information to be copied into."""
|
||||
return self.__class__(self._collection, session=session)
|
||||
|
||||
@ -862,7 +862,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
if self._has_filter:
|
||||
spec = dict(self._spec)
|
||||
else:
|
||||
spec = cast(dict, self._spec)
|
||||
spec = cast(dict, self._spec) # type: ignore[type-arg]
|
||||
spec["$where"] = code
|
||||
self._spec = spec
|
||||
return self
|
||||
@ -886,7 +886,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
self,
|
||||
response: Union[_OpReply, _OpMsg],
|
||||
cursor_id: Optional[int],
|
||||
codec_options: CodecOptions,
|
||||
codec_options: CodecOptions, # type: ignore[type-arg]
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> Sequence[_DocumentOut]:
|
||||
@ -962,29 +962,33 @@ class Cursor(Generic[_DocumentType]):
|
||||
return self._clone(deepcopy=True)
|
||||
|
||||
@overload
|
||||
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list:
|
||||
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg]
|
||||
...
|
||||
|
||||
@overload
|
||||
def _deepcopy(
|
||||
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
|
||||
) -> dict:
|
||||
self,
|
||||
x: SupportsItems, # type: ignore[type-arg]
|
||||
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
|
||||
) -> dict: # type: ignore[type-arg]
|
||||
...
|
||||
|
||||
def _deepcopy(
|
||||
self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None
|
||||
) -> Union[list, dict]:
|
||||
self,
|
||||
x: Union[Iterable, SupportsItems], # type: ignore[type-arg]
|
||||
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
|
||||
) -> Union[list[Any], dict[str, Any]]:
|
||||
"""Deepcopy helper for the data dictionary or list.
|
||||
|
||||
Regular expressions cannot be deep copied but as they are immutable we
|
||||
don't have to copy them when cloning.
|
||||
"""
|
||||
y: Union[list, dict]
|
||||
y: Union[list[Any], dict[str, Any]]
|
||||
iterator: Iterable[tuple[Any, Any]]
|
||||
if not hasattr(x, "items"):
|
||||
y, is_list, iterator = [], True, enumerate(x)
|
||||
else:
|
||||
y, is_list, iterator = {}, False, cast("SupportsItems", x).items()
|
||||
y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg]
|
||||
if memo is None:
|
||||
memo = {}
|
||||
val_id = id(x)
|
||||
@ -1058,7 +1062,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
"""Explicitly close / kill this cursor."""
|
||||
self._die_lock()
|
||||
|
||||
def distinct(self, key: str) -> list:
|
||||
def distinct(self, key: str) -> list[str]:
|
||||
"""Get a list of distinct values for `key` among all documents
|
||||
in the result set of this query.
|
||||
|
||||
@ -1263,7 +1267,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
@ -1323,7 +1327,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
return res
|
||||
|
||||
|
||||
class RawBatchCursor(Cursor, Generic[_DocumentType]):
|
||||
class RawBatchCursor(Cursor, Generic[_DocumentType]): # type: ignore[type-arg]
|
||||
"""A cursor / iterator over raw batches of BSON data from a query result."""
|
||||
|
||||
_query_class = _RawBatchQuery
|
||||
|
||||
@ -771,7 +771,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
self._name,
|
||||
command,
|
||||
read_preference,
|
||||
codec_options,
|
||||
codec_options, # type: ignore[arg-type]
|
||||
check,
|
||||
allowable_errors,
|
||||
write_concern=write_concern,
|
||||
|
||||
@ -158,10 +158,10 @@ _ReadCall = Callable[
|
||||
_IS_SYNC = True
|
||||
|
||||
_WriteOp = Union[
|
||||
InsertOne,
|
||||
InsertOne, # type: ignore[type-arg]
|
||||
DeleteOne,
|
||||
DeleteMany,
|
||||
ReplaceOne,
|
||||
ReplaceOne, # type: ignore[type-arg]
|
||||
UpdateOne,
|
||||
UpdateMany,
|
||||
]
|
||||
@ -173,7 +173,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Define order to retrieve options from ClientOptions for __repr__.
|
||||
# No host/port; these are retrieved from TopologySettings.
|
||||
_constructor_args = ("document_class", "tz_aware", "connect")
|
||||
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -847,7 +847,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
self._kill_cursors_queue: list = [] # type: ignore[type-arg]
|
||||
|
||||
self._encrypter: Optional[_Encrypter] = None
|
||||
|
||||
@ -1064,7 +1064,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Reset the session pool to avoid duplicate sessions in the child process.
|
||||
self._topology._session_pool.reset()
|
||||
|
||||
def _duplicate(self, **kwargs: Any) -> MongoClient:
|
||||
def _duplicate(self, **kwargs: Any) -> MongoClient: # type: ignore[type-arg]
|
||||
args = self._init_kwargs.copy()
|
||||
args.update(kwargs)
|
||||
return MongoClient(**args)
|
||||
@ -1546,7 +1546,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self, name, codec_options, read_preference, write_concern, read_concern
|
||||
)
|
||||
|
||||
def _database_default_options(self, name: str) -> database.Database:
|
||||
def _database_default_options(self, name: str) -> database.Database: # type: ignore[type-arg]
|
||||
"""Get a Database instance with the default settings."""
|
||||
return self.get_database(
|
||||
name,
|
||||
@ -1883,7 +1883,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _run_operation(
|
||||
self,
|
||||
operation: Union[_Query, _GetMore],
|
||||
unpack_res: Callable,
|
||||
unpack_res: Callable, # type: ignore[type-arg]
|
||||
address: Optional[_Address] = None,
|
||||
) -> Response:
|
||||
"""Run a _Query/_GetMore operation and return a Response.
|
||||
@ -2257,7 +2257,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
@contextlib.contextmanager
|
||||
def _tmp_session(
|
||||
self, session: Optional[client_session.ClientSession], close: bool = True
|
||||
) -> Generator[Optional[client_session.ClientSession], None, None]:
|
||||
) -> Generator[Optional[client_session.ClientSession], None]:
|
||||
"""If provided session is None, lend a temporary session."""
|
||||
if session is not None:
|
||||
if not isinstance(session, client_session.ClientSession):
|
||||
@ -2300,8 +2300,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionchanged:: 3.6
|
||||
Added ``session`` parameter.
|
||||
"""
|
||||
return cast(
|
||||
dict,
|
||||
return cast( # type: ignore[redundant-cast]
|
||||
dict[str, Any],
|
||||
self.admin.command(
|
||||
"buildinfo", read_preference=ReadPreference.PRIMARY, session=session
|
||||
),
|
||||
@ -2428,13 +2428,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
@_csot.apply
|
||||
def bulk_write(
|
||||
self,
|
||||
models: Sequence[_WriteOp[_DocumentType]],
|
||||
models: Sequence[_WriteOp],
|
||||
session: Optional[ClientSession] = None,
|
||||
ordered: bool = True,
|
||||
verbose_results: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
let: Optional[Mapping[str, Any]] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
) -> ClientBulkWriteResult:
|
||||
"""Send a batch of write operations, potentially across multiple namespaces, to the server.
|
||||
@ -2620,7 +2620,12 @@ class _MongoClientErrorHandler:
|
||||
"handled",
|
||||
)
|
||||
|
||||
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
|
||||
def __init__(
|
||||
self,
|
||||
client: MongoClient, # type: ignore[type-arg]
|
||||
server: Server,
|
||||
session: Optional[ClientSession],
|
||||
):
|
||||
if not isinstance(client, MongoClient):
|
||||
# This is for compatibility with mocked and subclassed types, such as in Motor.
|
||||
if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__):
|
||||
@ -2692,7 +2697,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mongo_client: MongoClient,
|
||||
mongo_client: MongoClient, # type: ignore[type-arg]
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
bulk: Optional[Union[_Bulk, _ClientBulk]],
|
||||
operation: str,
|
||||
|
||||
@ -349,7 +349,7 @@ class Monitor(MonitorBase):
|
||||
)
|
||||
return sd
|
||||
|
||||
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
|
||||
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: # type: ignore[type-arg]
|
||||
"""Return (Hello, round_trip_time).
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
|
||||
@ -66,7 +66,7 @@ def command(
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient],
|
||||
client: Optional[MongoClient[Any]],
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
address: Optional[_Address] = None,
|
||||
|
||||
@ -201,7 +201,7 @@ class Connection:
|
||||
self.conn.get_conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(
|
||||
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
|
||||
self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
|
||||
) -> Optional[float]:
|
||||
# CSOT: use remaining timeout when set.
|
||||
timeout = _csot.remaining()
|
||||
@ -255,7 +255,7 @@ class Connection:
|
||||
else:
|
||||
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}
|
||||
|
||||
def hello(self) -> Hello:
|
||||
def hello(self) -> Hello[dict[str, Any]]:
|
||||
return self._hello(None, None)
|
||||
|
||||
def _hello(
|
||||
@ -357,7 +357,7 @@ class Connection:
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
|
||||
codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment]
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
@ -365,7 +365,7 @@ class Connection:
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
client: Optional[MongoClient] = None,
|
||||
client: Optional[MongoClient[Any]] = None,
|
||||
retryable_write: bool = False,
|
||||
publish_events: bool = True,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
@ -417,7 +417,7 @@ class Connection:
|
||||
spec,
|
||||
self.is_mongos,
|
||||
read_preference,
|
||||
codec_options,
|
||||
codec_options, # type: ignore[arg-type]
|
||||
session,
|
||||
client,
|
||||
check,
|
||||
@ -489,7 +489,7 @@ class Connection:
|
||||
self.send_message(msg, max_doc_size)
|
||||
|
||||
def write_command(
|
||||
self, request_id: int, msg: bytes, codec_options: CodecOptions
|
||||
self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Send "insert" etc. command, returning response as a dict.
|
||||
|
||||
@ -541,7 +541,7 @@ class Connection:
|
||||
)
|
||||
|
||||
def validate_session(
|
||||
self, client: Optional[MongoClient], session: Optional[ClientSession]
|
||||
self, client: Optional[MongoClient[Any]], session: Optional[ClientSession]
|
||||
) -> None:
|
||||
"""Validate this session before use with client.
|
||||
|
||||
@ -596,7 +596,7 @@ class Connection:
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient],
|
||||
client: Optional[MongoClient[Any]],
|
||||
) -> None:
|
||||
"""Add $clusterTime."""
|
||||
if client:
|
||||
@ -730,7 +730,7 @@ class Pool:
|
||||
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
|
||||
# and returned to pool from the left side. Stale sockets removed
|
||||
# from the right side.
|
||||
self.conns: collections.deque = collections.deque()
|
||||
self.conns: collections.deque[Connection] = collections.deque()
|
||||
self.active_contexts: set[_CancellationContext] = set()
|
||||
self.lock = _create_lock()
|
||||
self._max_connecting_cond = _create_condition(self.lock)
|
||||
@ -837,8 +837,8 @@ class Pool:
|
||||
if service_id is None:
|
||||
sockets, self.conns = self.conns, collections.deque()
|
||||
else:
|
||||
discard: collections.deque = collections.deque()
|
||||
keep: collections.deque = collections.deque()
|
||||
discard: collections.deque = collections.deque() # type: ignore[type-arg]
|
||||
keep: collections.deque = collections.deque() # type: ignore[type-arg]
|
||||
for conn in self.conns:
|
||||
if conn.service_id == service_id:
|
||||
discard.append(conn)
|
||||
@ -864,7 +864,7 @@ class Pool:
|
||||
if close:
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -901,7 +901,7 @@ class Pool:
|
||||
)
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -915,7 +915,7 @@ class Pool:
|
||||
self.is_writable = is_writable
|
||||
with self.lock:
|
||||
for _socket in self.conns:
|
||||
_socket.update_is_writable(self.is_writable)
|
||||
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
|
||||
|
||||
def reset(
|
||||
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
|
||||
@ -952,7 +952,7 @@ class Pool:
|
||||
close_conns.append(self.conns.pop())
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
@ -1473,4 +1473,4 @@ class Pool:
|
||||
# not safe to acquire a lock in __del__.
|
||||
if _IS_SYNC:
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
conn.close_conn(None) # type: ignore[unused-coroutine]
|
||||
|
||||
@ -66,7 +66,7 @@ class Server:
|
||||
monitor: Monitor,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
listeners: Optional[_EventListeners] = None,
|
||||
events: Optional[ReferenceType[Queue]] = None,
|
||||
events: Optional[ReferenceType[Queue[Any]]] = None,
|
||||
) -> None:
|
||||
"""Represent one MongoDB server."""
|
||||
self._description = server_description
|
||||
@ -142,7 +142,7 @@ class Server:
|
||||
read_preference: _ServerMode,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., list[_DocumentOut]],
|
||||
client: MongoClient,
|
||||
client: MongoClient[Any],
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ _IS_SYNC = True
|
||||
_pymongo_dir = str(Path(__file__).parent)
|
||||
|
||||
|
||||
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
|
||||
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg]
|
||||
q = queue_ref()
|
||||
if not q:
|
||||
return False # Cancel PeriodicExecutor.
|
||||
@ -186,7 +186,7 @@ class Topology:
|
||||
|
||||
if self._publish_server or self._publish_tp:
|
||||
assert self._events is not None
|
||||
weak: weakref.ReferenceType[queue.Queue]
|
||||
weak: weakref.ReferenceType[queue.Queue[Any]]
|
||||
|
||||
def target() -> bool:
|
||||
return process_events_queue(weak)
|
||||
|
||||
@ -569,8 +569,8 @@ def _update_rs_from_primary(
|
||||
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
|
||||
|
||||
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
|
||||
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
|
||||
max_election_tuple: tuple = (max_set_version, max_election_id)
|
||||
new_election_tuple: tuple = (server_description.set_version, server_description.election_id) # type: ignore[type-arg]
|
||||
max_election_tuple: tuple = (max_set_version, max_election_id) # type: ignore[type-arg]
|
||||
if None not in new_election_tuple:
|
||||
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
|
||||
# Stale primary, set to type Unknown.
|
||||
|
||||
@ -51,7 +51,7 @@ ClusterTime = Mapping[str, Any]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Type hinting types for compatibility between async and sync classes
|
||||
_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"]
|
||||
_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] # type: ignore[type-arg]
|
||||
_AgnosticConnection = Union["AsyncConnection", "Connection"]
|
||||
_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"]
|
||||
_AgnosticBulk = Union["_AsyncBulk", "_Bulk"]
|
||||
|
||||
@ -149,11 +149,12 @@ markers = [
|
||||
strict = true
|
||||
show_error_codes = true
|
||||
pretty = true
|
||||
disable_error_code = ["type-arg", "no-any-return"]
|
||||
disable_error_code = ["no-any-return"]
|
||||
disallow_any_generics = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["test.*"]
|
||||
disable_error_code = ["no-untyped-def", "no-untyped-call"]
|
||||
disable_error_code = ["type-arg", "no-untyped-def", "no-untyped-call"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["service_identity.*"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user