PYTHON-5257 - Turn on mypy disallow_any_generics (#2456)

This commit is contained in:
Noah Stapp 2025-08-06 14:21:53 -04:00 committed by GitHub
parent d7074ba9ee
commit bbb6f88fae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 323 additions and 297 deletions

View File

@ -844,7 +844,7 @@ def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any:
return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}}
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if (
json_options.datetime_representation == DatetimeRepresentation.ISO8601
and 0 <= int(obj) <= _MAX_UTC_MS
@ -855,7 +855,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
return {"$date": {"$numberLong": str(int(obj))}}
def _encode_code(obj: Code, json_options: JSONOptions) -> dict:
def _encode_code(obj: Code, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if obj.scope is None:
return {"$code": str(obj)}
else:
@ -873,7 +873,7 @@ def _encode_noop(obj: Any, dummy0: Any) -> Any:
return obj
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict:
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
flags = ""
if obj.flags & re.IGNORECASE:
flags += "i"
@ -918,7 +918,7 @@ def _encode_float(obj: float, json_options: JSONOptions) -> Any:
return obj
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
if not obj.tzinfo:
obj = obj.replace(tzinfo=utc)
@ -941,15 +941,15 @@ def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
return {"$date": {"$numberLong": str(millis)}}
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict:
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _encode_binary(obj, 0, json_options)
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict:
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _encode_binary(obj, obj.subtype, json_options)
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
if json_options.strict_uuid:
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
return _encode_binary(binval, binval.subtype, json_options)
@ -957,27 +957,27 @@ def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
return {"$uuid": obj.hex}
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict:
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$oid": str(obj)}
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict:
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$timestamp": {"t": obj.time, "i": obj.inc}}
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict:
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
return {"$numberDecimal": str(obj)}
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict:
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
return _json_convert(obj.as_doc(), json_options=json_options)
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict:
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
return {"$minKey": 1}
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
return {"$maxKey": 1}
@ -985,7 +985,7 @@ def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
# Each encoder function's signature is:
# - obj: a Python data type, e.g. a Python int for _encode_int
# - json_options: a JSONOptions
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = {
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { # type: ignore[type-arg]
bool: _encode_noop,
bytes: _encode_bytes,
datetime.datetime: _encode_datetime,
@ -1056,7 +1056,7 @@ def _get_datetime_size(obj: datetime.datetime) -> int:
return 5 + len(str(obj.time()))
def _get_regex_size(obj: Regex) -> int:
def _get_regex_size(obj: Regex) -> int: # type: ignore[type-arg]
return 18 + len(obj.pattern)

View File

@ -28,4 +28,4 @@ if TYPE_CHECKING:
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]

View File

@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
class AsyncGridFS:
"""An instance of GridFS on top of a single Database."""
def __init__(self, database: AsyncDatabase, collection: str = "fs"):
def __init__(self, database: AsyncDatabase[Any], collection: str = "fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
@ -463,7 +463,7 @@ class AsyncGridFSBucket:
def __init__(
self,
db: AsyncDatabase,
db: AsyncDatabase[Any],
bucket_name: str = "fs",
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
write_concern: Optional[WriteConcern] = None,
@ -513,11 +513,11 @@ class AsyncGridFSBucket:
self._bucket_name = bucket_name
self._collection = db[bucket_name]
self._chunks: AsyncCollection = self._collection.chunks.with_options(
self._chunks: AsyncCollection[Any] = self._collection.chunks.with_options(
write_concern=write_concern, read_preference=read_preference
)
self._files: AsyncCollection = self._collection.files.with_options(
self._files: AsyncCollection[Any] = self._collection.files.with_options(
write_concern=write_concern, read_preference=read_preference
)
@ -1085,7 +1085,7 @@ class AsyncGridIn:
def __init__(
self,
root_collection: AsyncCollection,
root_collection: AsyncCollection[Any],
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> None:
@ -1172,7 +1172,7 @@ class AsyncGridIn:
object.__setattr__(self, "_buffered_docs_size", 0)
async def _create_index(
self, collection: AsyncCollection, index_key: Any, unique: bool
self, collection: AsyncCollection[Any], index_key: Any, unique: bool
) -> None:
doc = await collection.find_one(projection={"_id": 1}, session=self._session)
if doc is None:
@ -1456,7 +1456,7 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
def __init__(
self,
root_collection: AsyncCollection,
root_collection: AsyncCollection[Any],
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[AsyncClientSession] = None,
@ -1829,7 +1829,7 @@ class _AsyncGridOutChunkIterator:
def __init__(
self,
grid_out: AsyncGridOut,
chunks: AsyncCollection,
chunks: AsyncCollection[Any],
session: Optional[AsyncClientSession],
next_chunk: Any,
) -> None:
@ -1842,7 +1842,7 @@ class _AsyncGridOutChunkIterator:
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None
_cursor: Optional[AsyncCursor]
_cursor: Optional[AsyncCursor[Any]]
def expected_chunk_length(self, chunk_n: int) -> int:
if chunk_n < self._num_chunks - 1:
@ -1921,7 +1921,7 @@ class _AsyncGridOutChunkIterator:
class AsyncGridOutIterator:
def __init__(
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
self, grid_out: AsyncGridOut, chunks: AsyncCollection[Any], session: AsyncClientSession
):
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)
@ -1935,14 +1935,14 @@ class AsyncGridOutIterator:
__anext__ = next
class AsyncGridOutCursor(AsyncCursor):
class AsyncGridOutCursor(AsyncCursor): # type: ignore[type-arg]
"""A cursor / iterator for returning GridOut objects as the result
of an arbitrary query against the GridFS files collection.
"""
def __init__(
self,
collection: AsyncCollection,
collection: AsyncCollection[Any],
filter: Optional[Mapping[str, Any]] = None,
skip: int = 0,
limit: int = 0,

View File

@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None:
class GridFS:
"""An instance of GridFS on top of a single Database."""
def __init__(self, database: Database, collection: str = "fs"):
def __init__(self, database: Database[Any], collection: str = "fs"):
"""Create a new instance of :class:`GridFS`.
Raises :class:`TypeError` if `database` is not an instance of
@ -461,7 +461,7 @@ class GridFSBucket:
def __init__(
self,
db: Database,
db: Database[Any],
bucket_name: str = "fs",
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
write_concern: Optional[WriteConcern] = None,
@ -511,11 +511,11 @@ class GridFSBucket:
self._bucket_name = bucket_name
self._collection = db[bucket_name]
self._chunks: Collection = self._collection.chunks.with_options(
self._chunks: Collection[Any] = self._collection.chunks.with_options(
write_concern=write_concern, read_preference=read_preference
)
self._files: Collection = self._collection.files.with_options(
self._files: Collection[Any] = self._collection.files.with_options(
write_concern=write_concern, read_preference=read_preference
)
@ -1077,7 +1077,7 @@ class GridIn:
def __init__(
self,
root_collection: Collection,
root_collection: Collection[Any],
session: Optional[ClientSession] = None,
**kwargs: Any,
) -> None:
@ -1163,7 +1163,7 @@ class GridIn:
object.__setattr__(self, "_buffered_docs", [])
object.__setattr__(self, "_buffered_docs_size", 0)
def _create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
def _create_index(self, collection: Collection[Any], index_key: Any, unique: bool) -> None:
doc = collection.find_one(projection={"_id": 1}, session=self._session)
if doc is None:
try:
@ -1444,7 +1444,7 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
def __init__(
self,
root_collection: Collection,
root_collection: Collection[Any],
file_id: Optional[int] = None,
file_document: Optional[Any] = None,
session: Optional[ClientSession] = None,
@ -1817,7 +1817,7 @@ class GridOutChunkIterator:
def __init__(
self,
grid_out: GridOut,
chunks: Collection,
chunks: Collection[Any],
session: Optional[ClientSession],
next_chunk: Any,
) -> None:
@ -1830,7 +1830,7 @@ class GridOutChunkIterator:
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
self._cursor = None
_cursor: Optional[Cursor]
_cursor: Optional[Cursor[Any]]
def expected_chunk_length(self, chunk_n: int) -> int:
if chunk_n < self._num_chunks - 1:
@ -1908,7 +1908,7 @@ class GridOutChunkIterator:
class GridOutIterator:
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
def __init__(self, grid_out: GridOut, chunks: Collection[Any], session: ClientSession):
self._chunk_iter = GridOutChunkIterator(grid_out, chunks, session, 0)
def __iter__(self) -> GridOutIterator:
@ -1921,14 +1921,14 @@ class GridOutIterator:
__next__ = next
class GridOutCursor(Cursor):
class GridOutCursor(Cursor): # type: ignore[type-arg]
"""A cursor / iterator for returning GridOut objects as the result
of an arbitrary query against the GridFS files collection.
"""
def __init__(
self,
collection: Collection,
collection: Collection[Any],
filter: Optional[Mapping[str, Any]] = None,
skip: int = 0,
limit: int = 0,

View File

@ -93,7 +93,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
"""
def __init__(self) -> None:
self._waiters: Optional[collections.deque] = None
self._waiters: Optional[collections.deque[Any]] = None
self._locked = False
def __repr__(self) -> str:
@ -196,7 +196,7 @@ class Condition(_ContextManagerMixin, _LoopBoundMixin):
self.acquire = lock.acquire
self.release = lock.release
self._waiters: collections.deque = collections.deque()
self._waiters: collections.deque[Any] = collections.deque()
def __repr__(self) -> str:
res = super().__repr__()
@ -260,7 +260,7 @@ class Condition(_ContextManagerMixin, _LoopBoundMixin):
self._notify(1)
raise
async def wait_for(self, predicate: Any) -> Coroutine:
async def wait_for(self, predicate: Any) -> Coroutine[Any, Any, Any]:
"""Wait until a predicate becomes true.
The predicate should be a callable whose result will be

View File

@ -24,7 +24,7 @@ from typing import Any, Coroutine, Optional
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
class _Task(asyncio.Task):
class _Task(asyncio.Task[Any]):
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
super().__init__(coro, name=name)
self._cancel_requests = 0
@ -43,7 +43,7 @@ class _Task(asyncio.Task):
return self._cancel_requests
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task[Any]:
if sys.version_info >= (3, 11):
return asyncio.create_task(coro, name=name)
return _Task(coro, name=name)

View File

@ -68,7 +68,7 @@ def clamp_remaining(max_timeout: float) -> float:
return min(timeout, max_timeout)
class _TimeoutContext(AbstractContextManager):
class _TimeoutContext(AbstractContextManager[Any]):
"""Internal timeout context manager.
Use :func:`pymongo.timeout` instead::

View File

@ -46,8 +46,8 @@ class _AggregationCommand:
def __init__(
self,
target: Union[AsyncDatabase, AsyncCollection],
cursor_class: type[AsyncCommandCursor],
target: Union[AsyncDatabase[Any], AsyncCollection[Any]],
cursor_class: type[AsyncCommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
@ -111,12 +111,12 @@ class _AggregationCommand:
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection[Any]:
"""The AsyncCollection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self) -> AsyncDatabase:
def _database(self) -> AsyncDatabase[Any]:
"""The database against which the aggregation command is run."""
raise NotImplementedError
@ -205,7 +205,7 @@ class _AggregationCommand:
class _CollectionAggregationCommand(_AggregationCommand):
_target: AsyncCollection
_target: AsyncCollection[Any]
@property
def _aggregation_target(self) -> str:
@ -215,12 +215,12 @@ class _CollectionAggregationCommand(_AggregationCommand):
def _cursor_namespace(self) -> str:
return self._target.full_name
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
"""The AsyncCollection used for the aggregate command cursor."""
return self._target
@property
def _database(self) -> AsyncDatabase:
def _database(self) -> AsyncDatabase[Any]:
return self._target.database
@ -234,7 +234,7 @@ class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
class _DatabaseAggregationCommand(_AggregationCommand):
_target: AsyncDatabase
_target: AsyncDatabase[Any]
@property
def _aggregation_target(self) -> int:
@ -245,10 +245,10 @@ class _DatabaseAggregationCommand(_AggregationCommand):
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self) -> AsyncDatabase:
def _database(self) -> AsyncDatabase[Any]:
return self._target
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
"""The AsyncCollection used for the aggregate command cursor."""
# AsyncCollection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level

View File

@ -259,7 +259,7 @@ class _OIDCAuthenticator:
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = await self._get_access_token()

View File

@ -248,7 +248,7 @@ class _AsyncBulk:
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[bwc.field] = docs
@ -334,7 +334,7 @@ class _AsyncBulk:
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
@ -419,7 +419,7 @@ class _AsyncBulk:
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> list[Mapping[str, Any]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
@ -446,7 +446,7 @@ class _AsyncBulk:
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)

View File

@ -164,7 +164,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
raise NotImplementedError
@property
def _client(self) -> AsyncMongoClient:
def _client(self) -> AsyncMongoClient: # type: ignore[type-arg]
"""The client against which the aggregation commands for
this AsyncChangeStream will be run.
"""
@ -206,7 +206,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this AsyncChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
@ -237,7 +237,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
async def _run_aggregation_cmd(
self, session: Optional[AsyncClientSession], explicit_session: bool
) -> AsyncCommandCursor:
) -> AsyncCommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this AsyncChangeStream and return
the corresponding AsyncCommandCursor.
"""
@ -257,7 +257,7 @@ class AsyncChangeStream(Generic[_DocumentType]):
operation=_Op.AGGREGATE,
)
async def _create_cursor(self) -> AsyncCommandCursor:
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
async with self._client._tmp_session(self._session, close=False) as s:
return await self._run_aggregation_cmd(
session=s, explicit_session=self._session is not None

View File

@ -88,7 +88,7 @@ class _AsyncClientBulk:
def __init__(
self,
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
write_concern: WriteConcern,
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
@ -233,7 +233,7 @@ class _AsyncClientBulk:
msg: Union[bytes, dict[str, Any]],
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> dict[str, Any]:
"""A proxy for AsyncConnection.write_command that handles event publishing."""
cmd["ops"] = op_docs
@ -324,7 +324,7 @@ class _AsyncClientBulk:
msg: bytes,
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):

View File

@ -396,7 +396,7 @@ class _TxnState:
class _Transaction:
"""Internal class to hold transaction information in a AsyncClientSession."""
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient):
def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[Any]):
self.opts = opts
self.state = _TxnState.NONE
self.sharded = False
@ -459,7 +459,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
# From the transactions spec, all the retryable writes errors plus
# WriteConcernTimeout.
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset(
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg]
[
64, # WriteConcernTimeout
50, # MaxTimeMSExpired
@ -499,13 +499,13 @@ class AsyncClientSession:
def __init__(
self,
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
server_session: Any,
options: SessionOptions,
implicit: bool,
) -> None:
# An AsyncMongoClient, a _ServerSession, a SessionOptions, and a set.
self._client: AsyncMongoClient = client
self._client: AsyncMongoClient[Any] = client
self._server_session = server_session
self._options = options
self._cluster_time: Optional[Mapping[str, Any]] = None
@ -551,7 +551,7 @@ class AsyncClientSession:
await self._end_session(lock=True)
@property
def client(self) -> AsyncMongoClient:
def client(self) -> AsyncMongoClient[Any]:
"""The :class:`~pymongo.asynchronous.mongo_client.AsyncMongoClient` this session was
created from.
"""
@ -751,7 +751,7 @@ class AsyncClientSession:
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
max_commit_time_ms: Optional[int] = None,
) -> AsyncContextManager:
) -> AsyncContextManager[Any]:
"""Start a multi-statement transaction.
Takes the same arguments as :class:`TransactionOptions`.
@ -1123,7 +1123,7 @@ class _ServerSession:
self._transaction_id += 1
class _ServerSessionPool(collections.deque):
class _ServerSessionPool(collections.deque): # type: ignore[type-arg]
"""Pool of _ServerSession objects.
This class is thread-safe.

View File

@ -581,7 +581,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
conn: AsyncConnection,
command: MutableMapping[str, Any],
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None,
codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
@ -704,7 +704,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
) -> BulkWriteResult:
"""Send a batch of write operations to the server.
@ -2525,7 +2525,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
codec_options: CodecOptions = CodecOptions(SON)
codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON)
coll = cast(
AsyncCollection[MutableMapping[str, Any]],
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
@ -2871,7 +2871,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
aggregation_command: Type[_AggregationCommand],
pipeline: _Pipeline,
cursor_class: Type[AsyncCommandCursor],
cursor_class: Type[AsyncCommandCursor], # type: ignore[type-arg]
session: Optional[AsyncClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
@ -3114,7 +3114,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
comment: Optional[Any] = None,
hint: Optional[_IndexKeyHint] = None,
**kwargs: Any,
) -> list:
) -> list[str]:
"""Get a list of distinct values for `key` among all documents
in this collection.
@ -3177,7 +3177,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
_server: Server,
conn: AsyncConnection,
read_preference: Optional[_ServerMode],
) -> list:
) -> list: # type: ignore[type-arg]
return (
await self._command(
conn,
@ -3202,7 +3202,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[AsyncClientSession] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Internal findAndModify helper."""

View File

@ -350,7 +350,7 @@ class AsyncCommandCursor(Generic[_DocumentType]):
else:
return None
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
await self._refresh()
@ -457,7 +457,7 @@ class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
codec_options: CodecOptions[dict[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:

View File

@ -216,7 +216,7 @@ class AsyncCursor(Generic[_DocumentType]):
# it anytime we change __limit.
self._empty = False
self._data: deque = deque()
self._data: deque = deque() # type: ignore[type-arg]
self._address: Optional[_Address] = None
self._retrieved = 0
@ -280,7 +280,7 @@ class AsyncCursor(Generic[_DocumentType]):
"""
return self._clone(True)
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor:
def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> AsyncCursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
@ -322,7 +322,7 @@ class AsyncCursor(Generic[_DocumentType]):
base.__dict__.update(data)
return base
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor:
def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: # type: ignore[type-arg]
"""Creates an empty AsyncCursor object for information to be copied into."""
return self.__class__(self._collection, session=session)
@ -864,7 +864,7 @@ class AsyncCursor(Generic[_DocumentType]):
if self._has_filter:
spec = dict(self._spec)
else:
spec = cast(dict, self._spec)
spec = cast(dict, self._spec) # type: ignore[type-arg]
spec["$where"] = code
self._spec = spec
return self
@ -888,7 +888,7 @@ class AsyncCursor(Generic[_DocumentType]):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
codec_options: CodecOptions, # type: ignore[type-arg]
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
@ -964,29 +964,33 @@ class AsyncCursor(Generic[_DocumentType]):
return self._clone(deepcopy=True)
@overload
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list:
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg]
...
@overload
def _deepcopy(
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
) -> dict:
self,
x: SupportsItems, # type: ignore[type-arg]
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
) -> dict: # type: ignore[type-arg]
...
def _deepcopy(
self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None
) -> Union[list, dict]:
self,
x: Union[Iterable, SupportsItems], # type: ignore[type-arg]
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
) -> Union[list[Any], dict[str, Any]]:
"""Deepcopy helper for the data dictionary or list.
Regular expressions cannot be deep copied but as they are immutable we
don't have to copy them when cloning.
"""
y: Union[list, dict]
y: Union[list[Any], dict[str, Any]]
iterator: Iterable[tuple[Any, Any]]
if not hasattr(x, "items"):
y, is_list, iterator = [], True, enumerate(x)
else:
y, is_list, iterator = {}, False, cast("SupportsItems", x).items()
y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg]
if memo is None:
memo = {}
val_id = id(x)
@ -1060,7 +1064,7 @@ class AsyncCursor(Generic[_DocumentType]):
"""Explicitly close / kill this cursor."""
await self._die_lock()
async def distinct(self, key: str) -> list:
async def distinct(self, key: str) -> list[str]:
"""Get a list of distinct values for `key` among all documents
in the result set of this query.
@ -1265,7 +1269,7 @@ class AsyncCursor(Generic[_DocumentType]):
else:
raise StopAsyncIteration
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
@ -1325,7 +1329,7 @@ class AsyncCursor(Generic[_DocumentType]):
return res
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]):
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): # type: ignore[type-arg]
"""An asynchronous cursor / iterator over raw batches of BSON data from a query result."""
_query_class = _RawBatchQuery

View File

@ -771,7 +771,7 @@ class AsyncDatabase(common.BaseObject, Generic[_DocumentType]):
self._name,
command,
read_preference,
codec_options,
codec_options, # type: ignore[arg-type]
check,
allowable_errors,
write_concern=write_concern,

View File

@ -161,10 +161,10 @@ _ReadCall = Callable[
_IS_SYNC = False
_WriteOp = Union[
InsertOne,
InsertOne, # type: ignore[type-arg]
DeleteOne,
DeleteMany,
ReplaceOne,
ReplaceOne, # type: ignore[type-arg]
UpdateOne,
UpdateMany,
]
@ -176,7 +176,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# Define order to retrieve options from ClientOptions for __repr__.
# No host/port; these are retrieved from TopologySettings.
_constructor_args = ("document_class", "tz_aware", "connect")
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg]
def __init__(
self,
@ -847,7 +847,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._default_database_name = dbase
self._lock = _async_create_lock()
self._kill_cursors_queue: list = []
self._kill_cursors_queue: list = [] # type: ignore[type-arg]
self._encrypter: Optional[_Encrypter] = None
@ -1064,7 +1064,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# Reset the session pool to avoid duplicate sessions in the child process.
self._topology._session_pool.reset()
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient:
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient: # type: ignore[type-arg]
args = self._init_kwargs.copy()
args.update(kwargs)
return AsyncMongoClient(**args)
@ -1548,7 +1548,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self, name, codec_options, read_preference, write_concern, read_concern
)
def _database_default_options(self, name: str) -> database.AsyncDatabase:
def _database_default_options(self, name: str) -> database.AsyncDatabase: # type: ignore[type-arg]
"""Get a AsyncDatabase instance with the default settings."""
return self.get_database(
name,
@ -1887,7 +1887,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
async def _run_operation(
self,
operation: Union[_Query, _GetMore],
unpack_res: Callable,
unpack_res: Callable, # type: ignore[type-arg]
address: Optional[_Address] = None,
) -> Response:
"""Run a _Query/_GetMore operation and return a Response.
@ -2261,7 +2261,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.asynccontextmanager
async def _tmp_session(
self, session: Optional[client_session.AsyncClientSession], close: bool = True
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]:
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.AsyncClientSession):
@ -2308,8 +2308,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
return cast(
dict,
return cast( # type: ignore[redundant-cast]
dict[str, Any],
await self.admin.command(
"buildinfo", read_preference=ReadPreference.PRIMARY, session=session
),
@ -2438,13 +2438,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
@_csot.apply
async def bulk_write(
self,
models: Sequence[_WriteOp[_DocumentType]],
models: Sequence[_WriteOp],
session: Optional[AsyncClientSession] = None,
ordered: bool = True,
verbose_results: bool = False,
bypass_document_validation: Optional[bool] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
write_concern: Optional[WriteConcern] = None,
) -> ClientBulkWriteResult:
"""Send a batch of write operations, potentially across multiple namespaces, to the server.
@ -2631,7 +2631,10 @@ class _MongoClientErrorHandler:
)
def __init__(
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
self,
client: AsyncMongoClient, # type: ignore[type-arg]
server: Server,
session: Optional[AsyncClientSession],
):
if not isinstance(client, AsyncMongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.
@ -2704,7 +2707,7 @@ class _ClientConnectionRetryable(Generic[T]):
def __init__(
self,
mongo_client: AsyncMongoClient,
mongo_client: AsyncMongoClient, # type: ignore[type-arg]
func: _WriteCall[T] | _ReadCall[T],
bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]],
operation: str,

View File

@ -351,7 +351,7 @@ class Monitor(MonitorBase):
)
return sd
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]:
async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: # type: ignore[type-arg]
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.

View File

@ -66,7 +66,7 @@ async def command(
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[AsyncClientSession],
client: Optional[AsyncMongoClient],
client: Optional[AsyncMongoClient[Any]],
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
address: Optional[_Address] = None,

View File

@ -201,7 +201,7 @@ class AsyncConnection:
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
) -> Optional[float]:
# CSOT: use remaining timeout when set.
timeout = _csot.remaining()
@ -255,7 +255,7 @@ class AsyncConnection:
else:
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}
async def hello(self) -> Hello:
async def hello(self) -> Hello[dict[str, Any]]:
return await self._hello(None, None)
async def _hello(
@ -357,7 +357,7 @@ class AsyncConnection:
dbname: str,
spec: MutableMapping[str, Any],
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment]
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
@ -365,7 +365,7 @@ class AsyncConnection:
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
session: Optional[AsyncClientSession] = None,
client: Optional[AsyncMongoClient] = None,
client: Optional[AsyncMongoClient[Any]] = None,
retryable_write: bool = False,
publish_events: bool = True,
user_fields: Optional[Mapping[str, Any]] = None,
@ -417,7 +417,7 @@ class AsyncConnection:
spec,
self.is_mongos,
read_preference,
codec_options,
codec_options, # type: ignore[arg-type]
session,
client,
check,
@ -489,7 +489,7 @@ class AsyncConnection:
await self.send_message(msg, max_doc_size)
async def write_command(
self, request_id: int, msg: bytes, codec_options: CodecOptions
self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]]
) -> dict[str, Any]:
"""Send "insert" etc. command, returning response as a dict.
@ -541,7 +541,7 @@ class AsyncConnection:
)
def validate_session(
self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession]
self, client: Optional[AsyncMongoClient[Any]], session: Optional[AsyncClientSession]
) -> None:
"""Validate this session before use with client.
@ -598,7 +598,7 @@ class AsyncConnection:
self,
command: MutableMapping[str, Any],
session: Optional[AsyncClientSession],
client: Optional[AsyncMongoClient],
client: Optional[AsyncMongoClient[Any]],
) -> None:
"""Add $clusterTime."""
if client:
@ -732,7 +732,7 @@ class Pool:
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
# and returned to pool from the left side. Stale sockets removed
# from the right side.
self.conns: collections.deque = collections.deque()
self.conns: collections.deque[AsyncConnection] = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _async_create_lock()
self._max_connecting_cond = _async_create_condition(self.lock)
@ -839,8 +839,8 @@ class Pool:
if service_id is None:
sockets, self.conns = self.conns, collections.deque()
else:
discard: collections.deque = collections.deque()
keep: collections.deque = collections.deque()
discard: collections.deque = collections.deque() # type: ignore[type-arg]
keep: collections.deque = collections.deque() # type: ignore[type-arg]
for conn in self.conns:
if conn.service_id == service_id:
discard.append(conn)
@ -866,7 +866,7 @@ class Pool:
if close:
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -903,7 +903,7 @@ class Pool:
)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -917,7 +917,7 @@ class Pool:
self.is_writable = is_writable
async with self.lock:
for _socket in self.conns:
_socket.update_is_writable(self.is_writable)
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
async def reset(
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
@ -956,7 +956,7 @@ class Pool:
close_conns.append(self.conns.pop())
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -1477,4 +1477,4 @@ class Pool:
# not safe to acquire a lock in __del__.
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)
conn.close_conn(None) # type: ignore[unused-coroutine]

View File

@ -66,7 +66,7 @@ class Server:
monitor: Monitor,
topology_id: Optional[ObjectId] = None,
listeners: Optional[_EventListeners] = None,
events: Optional[ReferenceType[Queue]] = None,
events: Optional[ReferenceType[Queue[Any]]] = None,
) -> None:
"""Represent one MongoDB server."""
self._description = server_description
@ -142,7 +142,7 @@ class Server:
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., list[_DocumentOut]],
client: AsyncMongoClient,
client: AsyncMongoClient[Any],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.

View File

@ -84,7 +84,7 @@ _IS_SYNC = False
_pymongo_dir = str(Path(__file__).parent)
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg]
q = queue_ref()
if not q:
return False # Cancel PeriodicExecutor.
@ -186,7 +186,7 @@ class Topology:
if self._publish_server or self._publish_tp:
assert self._events is not None
weak: weakref.ReferenceType[queue.Queue]
weak: weakref.ReferenceType[queue.Queue[Any]]
async def target() -> bool:
return process_events_queue(weak)

View File

@ -247,7 +247,7 @@ class ClientOptions:
return self.__connect
@property
def codec_options(self) -> CodecOptions:
def codec_options(self) -> CodecOptions[Any]:
"""A :class:`~bson.codec_options.CodecOptions` instance."""
return self.__codec_options

View File

@ -56,7 +56,7 @@ if TYPE_CHECKING:
from pymongo.typings import _AgnosticClientSession
ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict)
ORDERED_TYPES: Sequence[Type[Any]] = (SON, OrderedDict)
# Defaults until we connect to a server and get updated limits.
MAX_BSON_SIZE = 16 * (1024**2)
@ -166,7 +166,7 @@ def clean_node(node: str) -> tuple[str, int]:
return host.lower(), port
def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn:
def raise_config_error(key: str, suggestions: Optional[list[str]] = None) -> NoReturn:
"""Raise ConfigurationError with the given key name."""
msg = f"Unknown option: {key}."
if suggestions:
@ -411,7 +411,7 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]
if not isinstance(value, list):
value = [value]
tag_sets: list = []
tag_sets: list[dict[str, Any]] = []
for tag_set in value:
if tag_set == "":
tag_sets.append({})
@ -497,7 +497,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni
def validate_document_class(
option: str, value: Any
) -> Union[Type[MutableMapping], Type[RawBSONDocument]]:
) -> Union[Type[MutableMapping[str, Any]], Type[RawBSONDocument]]:
"""Validate the document_class option."""
# issubclass can raise TypeError for generic aliases like SON[str, Any].
# In that case we can use the base class for the comparison.
@ -523,14 +523,14 @@ def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]:
return value
def validate_list(option: str, value: Any) -> list:
def validate_list(option: str, value: Any) -> list[Any]:
"""Validates that 'value' is a list."""
if not isinstance(value, list):
raise TypeError(f"{option} must be a list, not {type(value)}")
return value
def validate_list_or_none(option: Any, value: Any) -> Optional[list]:
def validate_list_or_none(option: Any, value: Any) -> Optional[list[Any]]:
"""Validates that 'value' is a list or None."""
if value is None:
return value
@ -597,7 +597,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]:
return value
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]:
def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable[..., Any]]:
"""Validates that 'value' is a callable."""
if value is None:
return value
@ -829,7 +829,7 @@ def validate_auth_option(option: str, value: Any) -> tuple[str, Any]:
def _get_validator(
key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None
) -> Callable:
) -> Callable[[Any, Any], Any]:
normed_key = normed_key or key
try:
return validators[normed_key]
@ -917,7 +917,7 @@ class BaseObject:
def __init__(
self,
codec_options: CodecOptions,
codec_options: CodecOptions[Any],
read_preference: _ServerMode,
write_concern: WriteConcern,
read_concern: ReadConcern,
@ -947,7 +947,7 @@ class BaseObject:
self._read_concern = read_concern
@property
def codec_options(self) -> CodecOptions:
def codec_options(self) -> CodecOptions[Any]:
"""Read only access to the :class:`~bson.codec_options.CodecOptions`
of this instance.
"""

View File

@ -37,7 +37,7 @@ from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
from pymongo.typings import _AgnosticMongoClient
class AutoEncryptionOpts:
@ -47,7 +47,7 @@ class AutoEncryptionOpts:
self,
kms_providers: Mapping[str, Any],
key_vault_namespace: str,
key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None,
key_vault_client: Optional[_AgnosticMongoClient] = None,
schema_map: Optional[Mapping[str, Any]] = None,
bypass_auto_encryption: bool = False,
mongocryptd_uri: str = "mongodb://localhost:27020",

View File

@ -52,7 +52,7 @@ if TYPE_CHECKING:
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES: frozenset = frozenset(
_SHUTDOWN_CODES: frozenset[int] = frozenset(
[
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
@ -61,7 +61,7 @@ _SHUTDOWN_CODES: frozenset = frozenset(
# From the SDAM spec, the "not primary" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_PRIMARY_CODES: frozenset = (
_NOT_PRIMARY_CODES: frozenset[int] = (
frozenset(
[
10058, # LegacyNotPrimary <=3.2 "not primary" error code
@ -75,7 +75,7 @@ _NOT_PRIMARY_CODES: frozenset = (
| _SHUTDOWN_CODES
)
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
_RETRYABLE_ERROR_CODES: frozenset[int] = _NOT_PRIMARY_CODES | frozenset(
[
7, # HostNotFound
6, # HostUnreachable
@ -95,7 +95,7 @@ _AUTHENTICATION_FAILURE_CODE: int = 18
# Note - to avoid bugs from forgetting which if these is all lowercase and
# which are camelCase, and at the same time avoid having to add a test for
# every command, use all lowercase here and test against command_name.lower().
_SENSITIVE_COMMANDS: set = {
_SENSITIVE_COMMANDS: set[str] = {
"authenticate",
"saslstart",
"saslcontinue",

View File

@ -333,7 +333,7 @@ def _op_msg_no_header(
command: Mapping[str, Any],
identifier: str,
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
opts: CodecOptions[Any],
) -> tuple[bytes, int, int]:
"""Get a OP_MSG message.
@ -365,7 +365,7 @@ def _op_msg_compressed(
command: Mapping[str, Any],
identifier: str,
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int, int]:
"""Internal OP_MSG message helper."""
@ -379,7 +379,7 @@ def _op_msg_uncompressed(
command: Mapping[str, Any],
identifier: str,
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
opts: CodecOptions[Any],
) -> tuple[int, bytes, int, int]:
"""Internal compressed OP_MSG message helper."""
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
@ -396,7 +396,7 @@ def _op_msg(
command: MutableMapping[str, Any],
dbname: str,
read_preference: Optional[_ServerMode],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int, int]:
"""Get a OP_MSG message."""
@ -430,7 +430,7 @@ def _query_impl(
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
) -> tuple[bytes, int]:
"""Get an OP_QUERY message."""
encoded = _dict_to_bson(query, False, opts)
@ -461,7 +461,7 @@ def _query_compressed(
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int]:
"""Internal compressed query message helper."""
@ -479,7 +479,7 @@ def _query_uncompressed(
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
) -> tuple[int, bytes, int]:
"""Internal query message helper."""
op_query, max_bson_size = _query_impl(
@ -500,7 +500,7 @@ def _query(
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int]:
"""Get a **query** message."""
@ -598,7 +598,7 @@ class _BulkWriteContextBase:
listeners: _EventListeners,
session: Optional[_AgnosticClientSession],
op_type: int,
codec: CodecOptions,
codec: CodecOptions[Any],
):
self.db_name = database_name
self.conn = conn
@ -679,7 +679,7 @@ class _BulkWriteContext(_BulkWriteContextBase):
listeners: _EventListeners,
session: Optional[_AgnosticClientSession],
op_type: int,
codec: CodecOptions,
codec: CodecOptions[Any],
):
super().__init__(
database_name,
@ -771,7 +771,7 @@ def _batched_op_msg_impl(
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
@ -839,7 +839,7 @@ def _encode_batched_op_msg(
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete operation
@ -860,7 +860,7 @@ def _batched_op_msg_compressed(
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
@ -878,7 +878,7 @@ def _batched_op_msg(
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""OP_MSG implementation entry point."""
@ -910,7 +910,7 @@ def _do_batched_op_msg(
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
@ -939,7 +939,7 @@ class _ClientBulkWriteContext(_BulkWriteContextBase):
operation_id: int,
listeners: _EventListeners,
session: Optional[_AgnosticClientSession],
codec: CodecOptions,
codec: CodecOptions[Any],
):
super().__init__(
database_name,
@ -1043,7 +1043,7 @@ def _client_batched_op_msg_impl(
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _ClientBulkWriteContext,
buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]], int]:
@ -1161,7 +1161,7 @@ def _client_encode_batched_op_msg(
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _ClientBulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Encode the next batched client-level bulkWrite
@ -1180,7 +1180,7 @@ def _client_batched_op_msg_compressed(
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Create the next batched client-level bulkWrite operation
@ -1200,7 +1200,7 @@ def _client_batched_op_msg(
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
ack: bool,
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""OP_MSG implementation entry point for client-level bulkWrite."""
@ -1229,7 +1229,7 @@ def _client_do_batched_op_msg(
command: MutableMapping[str, Any],
operations: list[tuple[str, Mapping[str, Any]]],
namespaces: list[str],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _ClientBulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
"""Create the next batched client-level bulkWrite
@ -1253,7 +1253,7 @@ def _encode_batched_write_command(
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete command."""
@ -1272,7 +1272,7 @@ def _batched_write_command_impl(
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
opts: CodecOptions[Any],
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
@ -1383,7 +1383,7 @@ class _OpReply:
errobj = {"ok": 0, "errmsg": msg, "code": 43}
raise CursorNotFound(msg, 43, errobj)
elif self.flags & 2:
error_object: dict = bson.BSON(self.documents).decode()
error_object: dict[str, Any] = bson.BSON(self.documents).decode()
# Fake the ok field if it doesn't exist.
error_object.setdefault("ok", 0)
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
@ -1405,7 +1405,7 @@ class _OpReply:
def unpack_response(
self,
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[dict[str, Any]]:
@ -1431,7 +1431,7 @@ class _OpReply:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(self.documents, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]:
"""Unpack a command response."""
docs = self.unpack_response(codec_options=codec_options)
assert self.number_returned == 1
@ -1491,7 +1491,7 @@ class _OpMsg:
def unpack_response(
self,
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
codec_options: CodecOptions[Any] = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[dict[str, Any]]:
@ -1508,7 +1508,7 @@ class _OpMsg:
assert not legacy_response
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
def command_response(self, codec_options: CodecOptions[Any]) -> dict[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]
@ -1583,7 +1583,7 @@ class _Query:
ntoskip: int,
spec: Mapping[str, Any],
fields: Optional[Mapping[str, Any]],
codec_options: CodecOptions,
codec_options: CodecOptions[Any],
read_preference: _ServerMode,
limit: int,
batch_size: int,
@ -1757,7 +1757,7 @@ class _GetMore:
coll: str,
ntoreturn: int,
cursor_id: int,
codec_options: CodecOptions,
codec_options: CodecOptions[Any],
read_preference: _ServerMode,
session: Optional[_AgnosticClientSession],
client: _AgnosticMongoClient,
@ -1871,7 +1871,7 @@ class _RawBatchGetMore(_GetMore):
return False
class _CursorAddress(tuple):
class _CursorAddress(tuple[Any, ...]):
"""The server address (host, port) of a cursor, with namespace property."""
__namespace: Any

View File

@ -1347,7 +1347,11 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
__slots__ = ("__duration", "__reply")
def __init__(
self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False
self,
duration: float,
reply: Hello[dict[str, Any]],
connection_id: _Address,
awaited: bool = False,
) -> None:
super().__init__(connection_id, awaited)
self.__duration = duration
@ -1359,7 +1363,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
return self.__duration
@property
def reply(self) -> Hello:
def reply(self) -> Hello[dict[str, Any]]:
"""An instance of :class:`~pymongo.hello.Hello`."""
return self.__reply
@ -1647,7 +1651,7 @@ class _EventListeners:
_handle_exception()
def publish_server_heartbeat_succeeded(
self, connection_id: _Address, duration: float, reply: Hello, awaited: bool
self, connection_id: _Address, duration: float, reply: Hello[dict[str, Any]], awaited: bool
) -> None:
"""Publish a ServerHeartbeatSucceededEvent to all server heartbeat
listeners.

View File

@ -96,7 +96,7 @@ if sys.platform != "win32":
view = memoryview(buf)
sent = 0
def _is_ready(fut: Future) -> None:
def _is_ready(fut: Future[Any]) -> None:
if fut.done():
return
fut.set_result(None)
@ -139,7 +139,7 @@ if sys.platform != "win32":
mv = memoryview(bytearray(length))
total_read = 0
def _is_ready(fut: Future) -> None:
def _is_ready(fut: Future[Any]) -> None:
if fut.done():
return
fut.set_result(None)
@ -486,15 +486,15 @@ class PyMongoProtocol(BufferedProtocol):
self._message_size = 0
self._op_code = 0
self._connection_lost = False
self._read_waiter: Optional[Future] = None
self._read_waiter: Optional[Future[Any]] = None
self._timeout = timeout
self._is_compressed = False
self._compressor_id: Optional[int] = None
self._max_message_size = MAX_MESSAGE_SIZE
self._response_to: Optional[int] = None
self._closed = asyncio.get_running_loop().create_future()
self._pending_messages: collections.deque[Future] = collections.deque()
self._done_messages: collections.deque[Future] = collections.deque()
self._pending_messages: collections.deque[Future[Any]] = collections.deque()
self._done_messages: collections.deque[Future[Any]] = collections.deque()
def settimeout(self, timeout: float | None) -> None:
self._timeout = timeout

View File

@ -53,7 +53,7 @@ class AsyncPeriodicExecutor:
self._min_interval = min_interval
self._target = target
self._stopped = False
self._task: Optional[asyncio.Task] = None
self._task: Optional[asyncio.Task[Any]] = None
self._name = name
self._skip_sleep = False

View File

@ -69,7 +69,7 @@ class ServerDescription:
def __init__(
self,
address: _Address,
hello: Optional[Hello] = None,
hello: Optional[Hello[dict[str, Any]]] = None,
round_trip_time: Optional[float] = None,
error: Optional[Exception] = None,
min_round_trip_time: float = 0.0,
@ -299,4 +299,4 @@ class ServerDescription:
)
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time: dict = {}
_host_to_round_trip_time: dict = {} # type: ignore[type-arg]

View File

@ -56,17 +56,22 @@ if HAVE_SSL:
if HAVE_PYSSL:
PYSSLError: Any = _pyssl.SSLError
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR)
BLOCKING_IO_WRITE_ERROR: tuple = (
BLOCKING_IO_ERRORS: tuple = ( # type: ignore[type-arg]
_ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
)
BLOCKING_IO_READ_ERROR: tuple = ( # type: ignore[type-arg]
_pyssl.BLOCKING_IO_READ_ERROR,
_ssl.BLOCKING_IO_READ_ERROR,
)
BLOCKING_IO_WRITE_ERROR: tuple = ( # type: ignore[type-arg]
_pyssl.BLOCKING_IO_WRITE_ERROR,
_ssl.BLOCKING_IO_WRITE_ERROR,
)
else:
PYSSLError = _ssl.SSLError
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,)
BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,)
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS # type: ignore[type-arg, no-redef]
BLOCKING_IO_READ_ERROR: tuple = (_ssl.BLOCKING_IO_READ_ERROR,) # type: ignore[type-arg, no-redef]
BLOCKING_IO_WRITE_ERROR: tuple = (_ssl.BLOCKING_IO_WRITE_ERROR,) # type: ignore[type-arg, no-redef]
SSLError = _ssl.SSLError
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
@ -131,7 +136,7 @@ else:
pass
IPADDR_SAFE = False
BLOCKING_IO_ERRORS = ()
BLOCKING_IO_ERRORS: tuple = () # type: ignore[type-arg, no-redef]
def _has_sni(is_sync: bool) -> bool: # noqa: ARG001
return False

View File

@ -46,8 +46,8 @@ class _AggregationCommand:
def __init__(
self,
target: Union[Database, Collection],
cursor_class: type[CommandCursor],
target: Union[Database[Any], Collection[Any]],
cursor_class: type[CommandCursor[Any]],
pipeline: _Pipeline,
options: MutableMapping[str, Any],
explicit_session: bool,
@ -111,12 +111,12 @@ class _AggregationCommand:
"""The namespace in which the aggregate command is run."""
raise NotImplementedError
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection:
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> Collection[Any]:
"""The Collection used for the aggregate command cursor."""
raise NotImplementedError
@property
def _database(self) -> Database:
def _database(self) -> Database[Any]:
"""The database against which the aggregation command is run."""
raise NotImplementedError
@ -205,7 +205,7 @@ class _AggregationCommand:
class _CollectionAggregationCommand(_AggregationCommand):
_target: Collection
_target: Collection[Any]
@property
def _aggregation_target(self) -> str:
@ -215,12 +215,12 @@ class _CollectionAggregationCommand(_AggregationCommand):
def _cursor_namespace(self) -> str:
return self._target.full_name
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
"""The Collection used for the aggregate command cursor."""
return self._target
@property
def _database(self) -> Database:
def _database(self) -> Database[Any]:
return self._target.database
@ -234,7 +234,7 @@ class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
class _DatabaseAggregationCommand(_AggregationCommand):
_target: Database
_target: Database[Any]
@property
def _aggregation_target(self) -> int:
@ -245,10 +245,10 @@ class _DatabaseAggregationCommand(_AggregationCommand):
return f"{self._target.name}.$cmd.aggregate"
@property
def _database(self) -> Database:
def _database(self) -> Database[Any]:
return self._target
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection:
def _cursor_collection(self, cursor: Mapping[str, Any]) -> Collection[Any]:
"""The Collection used for the aggregate command cursor."""
# Collection level aggregate may not always return the "ns" field
# according to our MockupDB tests. Let's handle that case for db level

View File

@ -257,7 +257,7 @@ class _OIDCAuthenticator:
) -> Mapping[str, Any]:
self.access_token = None
self.refresh_token = None
start_payload: dict = bson.decode(start_resp["payload"])
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
if "issuer" in start_payload:
self.idp_info = OIDCIdPInfo(**start_payload)
access_token = self._get_access_token()

View File

@ -248,7 +248,7 @@ class _Bulk:
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[bwc.field] = docs
@ -334,7 +334,7 @@ class _Bulk:
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
@ -419,7 +419,7 @@ class _Bulk:
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> list[Mapping[str, Any]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
@ -446,7 +446,7 @@ class _Bulk:
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
cmd: dict[str, Any],
ops: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
if self.is_encrypted:
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)

View File

@ -164,7 +164,7 @@ class ChangeStream(Generic[_DocumentType]):
raise NotImplementedError
@property
def _client(self) -> MongoClient:
def _client(self) -> MongoClient: # type: ignore[type-arg]
"""The client against which the aggregation commands for
this ChangeStream will be run.
"""
@ -206,7 +206,7 @@ class ChangeStream(Generic[_DocumentType]):
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
"""Return the full aggregation pipeline for this ChangeStream."""
options = self._change_stream_options()
full_pipeline: list = [{"$changeStream": options}]
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
full_pipeline.extend(self._pipeline)
return full_pipeline
@ -237,7 +237,7 @@ class ChangeStream(Generic[_DocumentType]):
def _run_aggregation_cmd(
self, session: Optional[ClientSession], explicit_session: bool
) -> CommandCursor:
) -> CommandCursor: # type: ignore[type-arg]
"""Run the full aggregation pipeline for this ChangeStream and return
the corresponding CommandCursor.
"""
@ -257,7 +257,7 @@ class ChangeStream(Generic[_DocumentType]):
operation=_Op.AGGREGATE,
)
def _create_cursor(self) -> CommandCursor:
def _create_cursor(self) -> CommandCursor: # type: ignore[type-arg]
with self._client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None)

View File

@ -88,7 +88,7 @@ class _ClientBulk:
def __init__(
self,
client: MongoClient,
client: MongoClient[Any],
write_concern: WriteConcern,
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
@ -233,7 +233,7 @@ class _ClientBulk:
msg: Union[bytes, dict[str, Any]],
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> dict[str, Any]:
"""A proxy for Connection.write_command that handles event publishing."""
cmd["ops"] = op_docs
@ -324,7 +324,7 @@ class _ClientBulk:
msg: bytes,
op_docs: list[Mapping[str, Any]],
ns_docs: list[Mapping[str, Any]],
client: MongoClient,
client: MongoClient[Any],
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):

View File

@ -395,7 +395,7 @@ class _TxnState:
class _Transaction:
"""Internal class to hold transaction information in a ClientSession."""
def __init__(self, opts: Optional[TransactionOptions], client: MongoClient):
def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any]):
self.opts = opts
self.state = _TxnState.NONE
self.sharded = False
@ -458,7 +458,7 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
# From the transactions spec, all the retryable writes errors plus
# WriteConcernTimeout.
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset(
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset( # type: ignore[type-arg]
[
64, # WriteConcernTimeout
50, # MaxTimeMSExpired
@ -498,13 +498,13 @@ class ClientSession:
def __init__(
self,
client: MongoClient,
client: MongoClient[Any],
server_session: Any,
options: SessionOptions,
implicit: bool,
) -> None:
# A MongoClient, a _ServerSession, a SessionOptions, and a set.
self._client: MongoClient = client
self._client: MongoClient[Any] = client
self._server_session = server_session
self._options = options
self._cluster_time: Optional[Mapping[str, Any]] = None
@ -550,7 +550,7 @@ class ClientSession:
self._end_session(lock=True)
@property
def client(self) -> MongoClient:
def client(self) -> MongoClient[Any]:
"""The :class:`~pymongo.mongo_client.MongoClient` this session was
created from.
"""
@ -748,7 +748,7 @@ class ClientSession:
write_concern: Optional[WriteConcern] = None,
read_preference: Optional[_ServerMode] = None,
max_commit_time_ms: Optional[int] = None,
) -> ContextManager:
) -> ContextManager[Any]:
"""Start a multi-statement transaction.
Takes the same arguments as :class:`TransactionOptions`.
@ -1118,7 +1118,7 @@ class _ServerSession:
self._transaction_id += 1
class _ServerSessionPool(collections.deque):
class _ServerSessionPool(collections.deque): # type: ignore[type-arg]
"""Pool of _ServerSession objects.
This class is thread-safe.

View File

@ -582,7 +582,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
conn: Connection,
command: MutableMapping[str, Any],
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None,
codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
@ -703,7 +703,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
bypass_document_validation: Optional[bool] = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
) -> BulkWriteResult:
"""Send a batch of write operations to the server.
@ -2522,7 +2522,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
) -> CommandCursor[MutableMapping[str, Any]]:
codec_options: CodecOptions = CodecOptions(SON)
codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON)
coll = cast(
Collection[MutableMapping[str, Any]],
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
@ -2864,7 +2864,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self,
aggregation_command: Type[_AggregationCommand],
pipeline: _Pipeline,
cursor_class: Type[CommandCursor],
cursor_class: Type[CommandCursor], # type: ignore[type-arg]
session: Optional[ClientSession],
explicit_session: bool,
let: Optional[Mapping[str, Any]] = None,
@ -3107,7 +3107,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
comment: Optional[Any] = None,
hint: Optional[_IndexKeyHint] = None,
**kwargs: Any,
) -> list:
) -> list[str]:
"""Get a list of distinct values for `key` among all documents
in this collection.
@ -3170,7 +3170,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
_server: Server,
conn: Connection,
read_preference: Optional[_ServerMode],
) -> list:
) -> list: # type: ignore[type-arg]
return (
self._command(
conn,
@ -3195,7 +3195,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Internal findAndModify helper."""

View File

@ -350,7 +350,7 @@ class CommandCursor(Generic[_DocumentType]):
else:
return None
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
self._refresh()
@ -457,7 +457,7 @@ class RawBatchCommandCursor(CommandCursor[_DocumentType]):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
codec_options: CodecOptions[dict[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[Mapping[str, Any]]:

View File

@ -216,7 +216,7 @@ class Cursor(Generic[_DocumentType]):
# it anytime we change __limit.
self._empty = False
self._data: deque = deque()
self._data: deque = deque() # type: ignore[type-arg]
self._address: Optional[_Address] = None
self._retrieved = 0
@ -280,7 +280,7 @@ class Cursor(Generic[_DocumentType]):
"""
return self._clone(True)
def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor:
def _clone(self, deepcopy: bool = True, base: Optional[Cursor] = None) -> Cursor: # type: ignore[type-arg]
"""Internal clone helper."""
if not base:
if self._explicit_session:
@ -322,7 +322,7 @@ class Cursor(Generic[_DocumentType]):
base.__dict__.update(data)
return base
def _clone_base(self, session: Optional[ClientSession]) -> Cursor:
def _clone_base(self, session: Optional[ClientSession]) -> Cursor: # type: ignore[type-arg]
"""Creates an empty Cursor object for information to be copied into."""
return self.__class__(self._collection, session=session)
@ -862,7 +862,7 @@ class Cursor(Generic[_DocumentType]):
if self._has_filter:
spec = dict(self._spec)
else:
spec = cast(dict, self._spec)
spec = cast(dict, self._spec) # type: ignore[type-arg]
spec["$where"] = code
self._spec = spec
return self
@ -886,7 +886,7 @@ class Cursor(Generic[_DocumentType]):
self,
response: Union[_OpReply, _OpMsg],
cursor_id: Optional[int],
codec_options: CodecOptions,
codec_options: CodecOptions, # type: ignore[type-arg]
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> Sequence[_DocumentOut]:
@ -962,29 +962,33 @@ class Cursor(Generic[_DocumentType]):
return self._clone(deepcopy=True)
@overload
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list:
def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] = None) -> list: # type: ignore[type-arg]
...
@overload
def _deepcopy(
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
) -> dict:
self,
x: SupportsItems, # type: ignore[type-arg]
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
) -> dict: # type: ignore[type-arg]
...
def _deepcopy(
self, x: Union[Iterable, SupportsItems], memo: Optional[dict[int, Union[list, dict]]] = None
) -> Union[list, dict]:
self,
x: Union[Iterable, SupportsItems], # type: ignore[type-arg]
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
) -> Union[list[Any], dict[str, Any]]:
"""Deepcopy helper for the data dictionary or list.
Regular expressions cannot be deep copied but as they are immutable we
don't have to copy them when cloning.
"""
y: Union[list, dict]
y: Union[list[Any], dict[str, Any]]
iterator: Iterable[tuple[Any, Any]]
if not hasattr(x, "items"):
y, is_list, iterator = [], True, enumerate(x)
else:
y, is_list, iterator = {}, False, cast("SupportsItems", x).items()
y, is_list, iterator = {}, False, cast("SupportsItems", x).items() # type: ignore[type-arg]
if memo is None:
memo = {}
val_id = id(x)
@ -1058,7 +1062,7 @@ class Cursor(Generic[_DocumentType]):
"""Explicitly close / kill this cursor."""
self._die_lock()
def distinct(self, key: str) -> list:
def distinct(self, key: str) -> list[str]:
"""Get a list of distinct values for `key` among all documents
in the result set of this query.
@ -1263,7 +1267,7 @@ class Cursor(Generic[_DocumentType]):
else:
raise StopIteration
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
def _next_batch(self, result: list, total: Optional[int] = None) -> bool: # type: ignore[type-arg]
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
@ -1323,7 +1327,7 @@ class Cursor(Generic[_DocumentType]):
return res
class RawBatchCursor(Cursor, Generic[_DocumentType]):
class RawBatchCursor(Cursor, Generic[_DocumentType]): # type: ignore[type-arg]
"""A cursor / iterator over raw batches of BSON data from a query result."""
_query_class = _RawBatchQuery

View File

@ -771,7 +771,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
self._name,
command,
read_preference,
codec_options,
codec_options, # type: ignore[arg-type]
check,
allowable_errors,
write_concern=write_concern,

View File

@ -158,10 +158,10 @@ _ReadCall = Callable[
_IS_SYNC = True
_WriteOp = Union[
InsertOne,
InsertOne, # type: ignore[type-arg]
DeleteOne,
DeleteMany,
ReplaceOne,
ReplaceOne, # type: ignore[type-arg]
UpdateOne,
UpdateMany,
]
@ -173,7 +173,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Define order to retrieve options from ClientOptions for __repr__.
# No host/port; these are retrieved from TopologySettings.
_constructor_args = ("document_class", "tz_aware", "connect")
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg]
def __init__(
self,
@ -847,7 +847,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._default_database_name = dbase
self._lock = _create_lock()
self._kill_cursors_queue: list = []
self._kill_cursors_queue: list = [] # type: ignore[type-arg]
self._encrypter: Optional[_Encrypter] = None
@ -1064,7 +1064,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Reset the session pool to avoid duplicate sessions in the child process.
self._topology._session_pool.reset()
def _duplicate(self, **kwargs: Any) -> MongoClient:
def _duplicate(self, **kwargs: Any) -> MongoClient: # type: ignore[type-arg]
args = self._init_kwargs.copy()
args.update(kwargs)
return MongoClient(**args)
@ -1546,7 +1546,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self, name, codec_options, read_preference, write_concern, read_concern
)
def _database_default_options(self, name: str) -> database.Database:
def _database_default_options(self, name: str) -> database.Database: # type: ignore[type-arg]
"""Get a Database instance with the default settings."""
return self.get_database(
name,
@ -1883,7 +1883,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _run_operation(
self,
operation: Union[_Query, _GetMore],
unpack_res: Callable,
unpack_res: Callable, # type: ignore[type-arg]
address: Optional[_Address] = None,
) -> Response:
"""Run a _Query/_GetMore operation and return a Response.
@ -2257,7 +2257,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
@contextlib.contextmanager
def _tmp_session(
self, session: Optional[client_session.ClientSession], close: bool = True
) -> Generator[Optional[client_session.ClientSession], None, None]:
) -> Generator[Optional[client_session.ClientSession], None]:
"""If provided session is None, lend a temporary session."""
if session is not None:
if not isinstance(session, client_session.ClientSession):
@ -2300,8 +2300,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionchanged:: 3.6
Added ``session`` parameter.
"""
return cast(
dict,
return cast( # type: ignore[redundant-cast]
dict[str, Any],
self.admin.command(
"buildinfo", read_preference=ReadPreference.PRIMARY, session=session
),
@ -2428,13 +2428,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
@_csot.apply
def bulk_write(
self,
models: Sequence[_WriteOp[_DocumentType]],
models: Sequence[_WriteOp],
session: Optional[ClientSession] = None,
ordered: bool = True,
verbose_results: bool = False,
bypass_document_validation: Optional[bool] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
let: Optional[Mapping[str, Any]] = None,
write_concern: Optional[WriteConcern] = None,
) -> ClientBulkWriteResult:
"""Send a batch of write operations, potentially across multiple namespaces, to the server.
@ -2620,7 +2620,12 @@ class _MongoClientErrorHandler:
"handled",
)
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
def __init__(
self,
client: MongoClient, # type: ignore[type-arg]
server: Server,
session: Optional[ClientSession],
):
if not isinstance(client, MongoClient):
# This is for compatibility with mocked and subclassed types, such as in Motor.
if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__):
@ -2692,7 +2697,7 @@ class _ClientConnectionRetryable(Generic[T]):
def __init__(
self,
mongo_client: MongoClient,
mongo_client: MongoClient, # type: ignore[type-arg]
func: _WriteCall[T] | _ReadCall[T],
bulk: Optional[Union[_Bulk, _ClientBulk]],
operation: str,

View File

@ -349,7 +349,7 @@ class Monitor(MonitorBase):
)
return sd
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]:
def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: # type: ignore[type-arg]
"""Return (Hello, round_trip_time).
Can raise ConnectionFailure or OperationFailure.

View File

@ -66,7 +66,7 @@ def command(
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[ClientSession],
client: Optional[MongoClient],
client: Optional[MongoClient[Any]],
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
address: Optional[_Address] = None,

View File

@ -201,7 +201,7 @@ class Connection:
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]]
) -> Optional[float]:
# CSOT: use remaining timeout when set.
timeout = _csot.remaining()
@ -255,7 +255,7 @@ class Connection:
else:
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}
def hello(self) -> Hello:
def hello(self) -> Hello[dict[str, Any]]:
return self._hello(None, None)
def _hello(
@ -357,7 +357,7 @@ class Connection:
dbname: str,
spec: MutableMapping[str, Any],
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
codec_options: CodecOptions[Mapping[str, Any]] = DEFAULT_CODEC_OPTIONS, # type: ignore[assignment]
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
@ -365,7 +365,7 @@ class Connection:
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
session: Optional[ClientSession] = None,
client: Optional[MongoClient] = None,
client: Optional[MongoClient[Any]] = None,
retryable_write: bool = False,
publish_events: bool = True,
user_fields: Optional[Mapping[str, Any]] = None,
@ -417,7 +417,7 @@ class Connection:
spec,
self.is_mongos,
read_preference,
codec_options,
codec_options, # type: ignore[arg-type]
session,
client,
check,
@ -489,7 +489,7 @@ class Connection:
self.send_message(msg, max_doc_size)
def write_command(
self, request_id: int, msg: bytes, codec_options: CodecOptions
self, request_id: int, msg: bytes, codec_options: CodecOptions[Mapping[str, Any]]
) -> dict[str, Any]:
"""Send "insert" etc. command, returning response as a dict.
@ -541,7 +541,7 @@ class Connection:
)
def validate_session(
self, client: Optional[MongoClient], session: Optional[ClientSession]
self, client: Optional[MongoClient[Any]], session: Optional[ClientSession]
) -> None:
"""Validate this session before use with client.
@ -596,7 +596,7 @@ class Connection:
self,
command: MutableMapping[str, Any],
session: Optional[ClientSession],
client: Optional[MongoClient],
client: Optional[MongoClient[Any]],
) -> None:
"""Add $clusterTime."""
if client:
@ -730,7 +730,7 @@ class Pool:
# LIFO pool. Sockets are ordered on idle time. Sockets claimed
# and returned to pool from the left side. Stale sockets removed
# from the right side.
self.conns: collections.deque = collections.deque()
self.conns: collections.deque[Connection] = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _create_lock()
self._max_connecting_cond = _create_condition(self.lock)
@ -837,8 +837,8 @@ class Pool:
if service_id is None:
sockets, self.conns = self.conns, collections.deque()
else:
discard: collections.deque = collections.deque()
keep: collections.deque = collections.deque()
discard: collections.deque = collections.deque() # type: ignore[type-arg]
keep: collections.deque = collections.deque() # type: ignore[type-arg]
for conn in self.conns:
if conn.service_id == service_id:
discard.append(conn)
@ -864,7 +864,7 @@ class Pool:
if close:
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -901,7 +901,7 @@ class Pool:
)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -915,7 +915,7 @@ class Pool:
self.is_writable = is_writable
with self.lock:
for _socket in self.conns:
_socket.update_is_writable(self.is_writable)
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
def reset(
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
@ -952,7 +952,7 @@ class Pool:
close_conns.append(self.conns.pop())
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
return_exceptions=True,
)
else:
@ -1473,4 +1473,4 @@ class Pool:
# not safe to acquire a lock in __del__.
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)
conn.close_conn(None) # type: ignore[unused-coroutine]

View File

@ -66,7 +66,7 @@ class Server:
monitor: Monitor,
topology_id: Optional[ObjectId] = None,
listeners: Optional[_EventListeners] = None,
events: Optional[ReferenceType[Queue]] = None,
events: Optional[ReferenceType[Queue[Any]]] = None,
) -> None:
"""Represent one MongoDB server."""
self._description = server_description
@ -142,7 +142,7 @@ class Server:
read_preference: _ServerMode,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., list[_DocumentOut]],
client: MongoClient,
client: MongoClient[Any],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.

View File

@ -84,7 +84,7 @@ _IS_SYNC = True
_pymongo_dir = str(Path(__file__).parent)
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool: # type: ignore[type-arg]
q = queue_ref()
if not q:
return False # Cancel PeriodicExecutor.
@ -186,7 +186,7 @@ class Topology:
if self._publish_server or self._publish_tp:
assert self._events is not None
weak: weakref.ReferenceType[queue.Queue]
weak: weakref.ReferenceType[queue.Queue[Any]]
def target() -> bool:
return process_events_queue(weak)

View File

@ -569,8 +569,8 @@ def _update_rs_from_primary(
return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id
if server_description.max_wire_version is None or server_description.max_wire_version < 17:
new_election_tuple: tuple = (server_description.set_version, server_description.election_id)
max_election_tuple: tuple = (max_set_version, max_election_id)
new_election_tuple: tuple = (server_description.set_version, server_description.election_id) # type: ignore[type-arg]
max_election_tuple: tuple = (max_set_version, max_election_id) # type: ignore[type-arg]
if None not in new_election_tuple:
if None not in max_election_tuple and new_election_tuple < max_election_tuple:
# Stale primary, set to type Unknown.

View File

@ -51,7 +51,7 @@ ClusterTime = Mapping[str, Any]
_T = TypeVar("_T")
# Type hinting types for compatibility between async and sync classes
_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"]
_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] # type: ignore[type-arg]
_AgnosticConnection = Union["AsyncConnection", "Connection"]
_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"]
_AgnosticBulk = Union["_AsyncBulk", "_Bulk"]

View File

@ -149,11 +149,12 @@ markers = [
strict = true
show_error_codes = true
pretty = true
disable_error_code = ["type-arg", "no-any-return"]
disable_error_code = ["no-any-return"]
disallow_any_generics = true
[[tool.mypy.overrides]]
module = ["test.*"]
disable_error_code = ["no-untyped-def", "no-untyped-call"]
disable_error_code = ["type-arg", "no-untyped-def", "no-untyped-call"]
[[tool.mypy.overrides]]
module = ["service_identity.*"]