From e0b8b36f4162eccd0b0fbe00563a749e90ee37d6 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:28:30 -0700 Subject: [PATCH] PYTHON-3813 add types to pool.py (#1318) --- pymongo/auth.py | 5 +- pymongo/auth_oidc.py | 18 +- pymongo/client_options.py | 5 +- pymongo/collection.py | 9 +- pymongo/command_cursor.py | 3 +- pymongo/compression_support.py | 4 +- pymongo/cursor.py | 2 +- pymongo/database.py | 6 +- pymongo/message.py | 14 +- pymongo/mongo_client.py | 7 +- pymongo/monitor.py | 12 +- pymongo/network.py | 10 +- pymongo/pool.py | 416 ++++++++++++++++++++------------- pymongo/server.py | 8 +- test/test_monitoring.py | 4 +- 15 files changed, 316 insertions(+), 207 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 063df41e4..f8e35f9c1 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -239,12 +239,15 @@ def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanis ctx = conn.auth_ctx if ctx and ctx.speculate_succeeded(): + assert isinstance(ctx, _ScramContext) + assert ctx.scram_data is not None nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) res = conn.command(source, cmd) + assert res is not None server_first = res["payload"] parsed = _parse_scram_response(server_first) iterations = int(parsed[b"i"]) @@ -575,7 +578,7 @@ class _ScramContext(_AuthContext): class _X509Context(_AuthContext): - def speculate_command(self) -> Optional[MutableMapping[str, Any]]: + def speculate_command(self) -> MutableMapping[str, Any]: cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")]) if self.credentials.username is not None: cmd["user"] = self.credentials.username diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index e5d0afeb8..0ca74fc49 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -19,7 +19,17 @@ import os import threading from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Tuple, +) import bson from bson.binary import Binary @@ -242,7 +252,9 @@ class _OIDCAuthenticator: self.idp_resp = None self.token_exp_utc = None - def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: + def run_command( + self, conn: Connection, cmd: MutableMapping[str, Any] + ) -> Optional[Mapping[str, Any]]: try: return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] except OperationFailure as exc: @@ -276,6 +288,7 @@ class _OIDCAuthenticator: assert cmd is not None resp = self.run_command(conn, cmd) + assert resp is not None if resp["done"]: conn.oidc_token_gen_id = self.token_gen_id return None @@ -297,6 +310,7 @@ class _OIDCAuthenticator: ] ) resp = self.run_command(conn, cmd) + assert resp is not None if not resp["done"]: self.clear() raise OperationFailure("SASL conversation failed to complete.") diff --git a/pymongo/client_options.py b/pymongo/client_options.py index b8a228fa4..a83216e9d 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -15,7 +15,7 @@ """Tools to parse mongo client options.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast +from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast from bson.codec_options import _parse_codec_options from pymongo import common @@ -309,11 +309,12 @@ class ClientOptions: return self.__load_balanced @property - def event_listeners(self) -> _EventListeners: + def event_listeners(self) -> List[_EventListeners]: """The event listeners registered for this client. See :mod:`~pymongo.monitoring` for details. .. versionadded:: 4.0 """ + assert self.__pool_options._event_listeners is not None return self.__pool_options._event_listeners.event_listeners() diff --git a/pymongo/collection.py b/pymongo/collection.py index fff3fcf18..fbbe7fb59 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -20,7 +20,6 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Container, ContextManager, Generic, Iterable, @@ -268,11 +267,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]): def _command( self, conn: Connection, - command: Mapping[str, Any], + command: MutableMapping[str, Any], read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions] = None, check: bool = True, - allowable_errors: Optional[Container[Any]] = None, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, collation: Optional[_CollationIn] = None, @@ -1753,7 +1752,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): session: Optional[ClientSession], conn: Connection, read_preference: Optional[_ServerMode], - cmd: Mapping[str, Any], + cmd: SON[str, Any], collation: Optional[Collation], ) -> int: """Internal count command helper.""" @@ -1777,7 +1776,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]): self, conn: Connection, read_preference: Optional[_ServerMode], - cmd: Mapping[str, Any], + cmd: SON[str, Any], collation: Optional[_CollationIn], session: Optional[ClientSession], ) -> Optional[Mapping[str, Any]]: diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index ddd81acec..777b88b08 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -25,6 +25,7 @@ from typing import ( Mapping, NoReturn, Optional, + Sequence, Union, ) @@ -220,7 +221,7 @@ class CommandCursor(Generic[_DocumentType]): codec_options: CodecOptions[Mapping[str, Any]], user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[_DocumentOut]: + ) -> Sequence[_DocumentOut]: return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _refresh(self) -> int: diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index a9083b546..27fc3cdf2 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -14,7 +14,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, Union +from typing import Any, Iterable, List, Optional, Union try: import snappy @@ -96,7 +96,7 @@ class CompressionSettings: self.zlib_compression_level = zlib_compression_level def get_compression_context( - self, compressors: List[str] + self, compressors: Optional[List[str]] ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: if compressors: chosen = compressors[0] diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 4dab6858f..2bf420ac1 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -1130,7 +1130,7 @@ class Cursor(Generic[_DocumentType]): codec_options: CodecOptions, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[_DocumentOut]: + ) -> Sequence[_DocumentOut]: return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _read_preference(self) -> _ServerMode: diff --git a/pymongo/database.py b/pymongo/database.py index b30d8fb3b..133061424 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -694,7 +694,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): value: int = 1, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, - read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + read_preference: _ServerMode = ReadPreference.PRIMARY, codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, @@ -711,7 +711,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): value: int = 1, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, - read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + read_preference: _ServerMode = ReadPreference.PRIMARY, codec_options: CodecOptions[_CodecDocumentType] = ..., write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, @@ -727,7 +727,7 @@ class Database(common.BaseObject, Generic[_DocumentType]): value: int = 1, check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, - read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY, + read_preference: _ServerMode = ReadPreference.PRIMARY, codec_options: Union[ CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType] ] = DEFAULT_CODEC_OPTIONS, diff --git a/pymongo/message.py b/pymongo/message.py index 7565384e3..5a4b1753f 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -1018,7 +1018,7 @@ class _BulkWriteContext: cmd = self._start(cmd, request_id, docs) start = datetime.datetime.now() try: - result = self.conn.unack_write(msg, max_doc_size) + result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] if self.publish: duration = (datetime.datetime.now() - start) + duration if result is not None: @@ -1050,7 +1050,7 @@ class _BulkWriteContext: request_id: int, msg: bytes, docs: List[Mapping[str, Any]], - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" if self.publish: assert self.start_time is not None @@ -1127,7 +1127,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext): def __batch_command( self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]] - ) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]: + ) -> Tuple[Dict[str, Any], List[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" msg, to_send = _encode_batched_write_command( namespace, self.op_type, cmd, docs, self.codec, self @@ -1517,7 +1517,7 @@ class _OpReply: codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[_DocumentOut]: + ) -> List[Dict[str, Any]]: """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary @@ -1541,7 +1541,7 @@ class _OpReply: return bson.decode_all(self.documents, codec_options) return bson._decode_all_selective(self.documents, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]: + def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]: """Unpack a command response.""" docs = self.unpack_response(codec_options=codec_options) assert self.number_returned == 1 @@ -1604,7 +1604,7 @@ class _OpMsg: codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, - ) -> List[_DocumentOut]: + ) -> List[Dict[str, Any]]: """Unpack a OP_MSG command response. :Parameters: @@ -1619,7 +1619,7 @@ class _OpMsg: assert not legacy_response return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]: + def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]: """Unpack a command response.""" return self.unpack_response(codec_options=codec_options)[0] diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 63fdab04a..9655fde28 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -116,6 +116,7 @@ if TYPE_CHECKING: import sys from types import TracebackType + from bson.objectid import ObjectId from pymongo.bulk import _Bulk from pymongo.client_session import ClientSession, _ServerSession from pymongo.cursor import _ConnectionManager @@ -1898,7 +1899,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): else: yield None - def _send_cluster_time(self, command: MutableMapping[str, Any], session: ClientSession) -> None: + def _send_cluster_time( + self, command: MutableMapping[str, Any], session: Optional[ClientSession] + ) -> None: topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: @@ -2255,7 +2258,7 @@ class _MongoClientErrorHandler: # of the pool at the time the connection attempt was started." self.sock_generation = server.pool.gen.get_overall() self.completed_handshake = False - self.service_id = None + self.service_id: Optional[ObjectId] = None self.handled = False def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 3fb20d0b1..b2ff3404f 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -32,7 +32,7 @@ from pymongo.server_description import ServerDescription from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.pool import Connection, Pool + from pymongo.pool import Connection, Pool, _CancellationContext from pymongo.settings import TopologySettings from pymongo.topology import Topology @@ -131,9 +131,8 @@ class Monitor(MonitorBase): self._pool = pool self._settings = topology_settings self._listeners = self._settings._pool_options._event_listeners - pub = self._listeners is not None - self._publish = pub and self._listeners.enabled_for_server_heartbeat - self._cancel_context = None + self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat + self._cancel_context: Optional[_CancellationContext] = None self._rtt_monitor = _RttMonitor( topology, topology_settings, @@ -238,7 +237,8 @@ class Monitor(MonitorBase): address = sd.address duration = time.monotonic() - start if self._publish: - awaited = sd.is_server_type_known and sd.topology_version + awaited = bool(sd.is_server_type_known and sd.topology_version) + assert self._listeners is not None self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited) self._reset_connection() if isinstance(error, _OperationCancelled): @@ -254,6 +254,7 @@ class Monitor(MonitorBase): """ address = self._server_description.address if self._publish: + assert self._listeners is not None self._listeners.publish_server_heartbeat_started(address) if self._cancel_context and self._cancel_context.cancelled: @@ -267,6 +268,7 @@ class Monitor(MonitorBase): avg_rtt, min_rtt = self._rtt_monitor.get() sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) if self._publish: + assert self._listeners is not None self._listeners.publish_server_heartbeat_succeeded( address, round_trip_time, response, response.awaitable ) diff --git a/pymongo/network.py b/pymongo/network.py index 5aaceda52..df540f1a3 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -47,14 +47,13 @@ from pymongo.socket_checker import _errno_from_exception if TYPE_CHECKING: from bson import CodecOptions from pymongo.client_session import ClientSession - from pymongo.collation import Collation from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.mongo_client import MongoClient from pymongo.monitoring import _EventListeners from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _UNPACK_HEADER = struct.Struct(" Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): diff --git a/pymongo/pool.py b/pymongo/pool.py index 9baaecc71..c89aa10e3 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -12,6 +12,8 @@ # implied. See the License for the specific language governing # permissions and limitations under the License. +from __future__ import annotations + import collections import contextlib import copy @@ -23,7 +25,21 @@ import sys import threading import time import weakref -from typing import Any, Dict, NoReturn, Optional +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Sequence, + Set, + Tuple, + Union, +) import bson from bson import DEFAULT_CODEC_OPTIONS @@ -59,7 +75,11 @@ from pymongo.errors import ( from pymongo.hello import Hello, HelloCompat from pymongo.helpers import _handle_reauth from pymongo.lock import _create_lock -from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason +from pymongo.monitoring import ( + ConnectionCheckOutFailedReason, + ConnectionClosedReason, + _EventListeners, +) from pymongo.network import command, receive_message from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command @@ -67,10 +87,31 @@ from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker from pymongo.ssl_support import HAS_SNI, SSLError +if TYPE_CHECKING: + from bson import CodecOptions + from bson.objectid import ObjectId + from pymongo.auth import MongoCredential, _AuthContext + from pymongo.client_session import ClientSession + from pymongo.compression_support import ( + CompressionSettings, + SnappyContext, + ZlibContext, + ZstdContext, + ) + from pymongo.driver_info import DriverInfo + from pymongo.message import _OpMsg, _OpReply + from pymongo.mongo_client import MongoClient, _MongoClientErrorHandler + from pymongo.pyopenssl_context import SSLContext, _sslConn + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.server_api import ServerApi + from pymongo.typings import _Address, _CollationIn + from pymongo.write_concern import WriteConcern + try: from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl - def _set_non_inheritable_non_atomic(fd): + def _set_non_inheritable_non_atomic(fd: int) -> None: """Set the close-on-exec flag on the given file descriptor.""" flags = fcntl(fd, F_GETFD) fcntl(fd, F_SETFD, flags | FD_CLOEXEC) @@ -79,7 +120,7 @@ except ImportError: # Windows, various platforms we don't claim to support # (Jython, IronPython, ...), systems that don't provide # everything we need from fcntl, etc. - def _set_non_inheritable_non_atomic(fd): + def _set_non_inheritable_non_atomic(fd: int) -> None: """Dummy function for platforms that don't provide fcntl.""" @@ -123,7 +164,7 @@ if sys.platform == "win32": else: - def _set_tcp_option(sock, tcp_option, max_value): + def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None: if hasattr(socket, tcp_option): sockopt = getattr(socket, tcp_option) try: @@ -136,7 +177,7 @@ else: except OSError: pass - def _set_keepalive_times(sock): + def _set_keepalive_times(sock: socket.socket) -> None: _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) @@ -302,7 +343,7 @@ _MAX_METADATA_SIZE = 512 # See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations -def _truncate_metadata(metadata): +def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: """Perform metadata truncation.""" if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: return @@ -365,7 +406,7 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -def _cond_wait(condition, deadline): +def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: timeout = deadline - time.monotonic() if deadline else None return condition.wait(timeout) @@ -406,23 +447,23 @@ class PoolOptions: def __init__( self, - max_pool_size=MAX_POOL_SIZE, - min_pool_size=MIN_POOL_SIZE, - max_idle_time_seconds=MAX_IDLE_TIME_SEC, - connect_timeout=None, - socket_timeout=None, - wait_queue_timeout=WAIT_QUEUE_TIMEOUT, - ssl_context=None, - tls_allow_invalid_hostnames=False, - event_listeners=None, - appname=None, - driver=None, - compression_settings=None, - max_connecting=MAX_CONNECTING, - pause_enabled=True, - server_api=None, - load_balanced=None, - credentials=None, + max_pool_size: int = MAX_POOL_SIZE, + min_pool_size: int = MIN_POOL_SIZE, + max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, + connect_timeout: Optional[float] = None, + socket_timeout: Optional[float] = None, + wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, + ssl_context: Optional[SSLContext] = None, + tls_allow_invalid_hostnames: bool = False, + event_listeners: Optional[_EventListeners] = None, + appname: Optional[str] = None, + driver: Optional[DriverInfo] = None, + compression_settings: Optional[CompressionSettings] = None, + max_connecting: int = MAX_CONNECTING, + pause_enabled: bool = True, + server_api: Optional[ServerApi] = None, + load_balanced: Optional[bool] = None, + credentials: Optional[MongoCredential] = None, ): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size @@ -474,12 +515,12 @@ class PoolOptions: _truncate_metadata(self.__metadata) @property - def _credentials(self): + def _credentials(self) -> Optional[MongoCredential]: """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" return self.__credentials @property - def non_default_options(self): + def non_default_options(self) -> Dict[str, Any]: """The non-default options this pool was created with. Added for CMAP's :class:`PoolCreatedEvent`. @@ -490,15 +531,17 @@ class PoolOptions: if self.__min_pool_size != MIN_POOL_SIZE: opts["minPoolSize"] = self.__min_pool_size if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: + assert self.__max_idle_time_seconds is not None opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: + assert self.__wait_queue_timeout is not None opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 if self.__max_connecting != MAX_CONNECTING: opts["maxConnecting"] = self.__max_connecting return opts @property - def max_pool_size(self): + def max_pool_size(self) -> float: """The maximum allowable number of concurrent connections to each connected server. Requests to a server will block if there are `maxPoolSize` outstanding connections to the requested server. @@ -513,25 +556,25 @@ class PoolOptions: return self.__max_pool_size @property - def min_pool_size(self): + def min_pool_size(self) -> int: """The minimum required number of concurrent connections that the pool will maintain to each connected server. Default is 0. """ return self.__min_pool_size @property - def max_connecting(self): + def max_connecting(self) -> int: """The maximum number of concurrent connection creation attempts per pool. Defaults to 2. """ return self.__max_connecting @property - def pause_enabled(self): + def pause_enabled(self) -> bool: return self.__pause_enabled @property - def max_idle_time_seconds(self): + def max_idle_time_seconds(self) -> Optional[int]: """The maximum number of seconds that a connection can remain idle in the pool before being removed and replaced. Defaults to `None` (no limit). @@ -539,77 +582,77 @@ class PoolOptions: return self.__max_idle_time_seconds @property - def connect_timeout(self): + def connect_timeout(self) -> Optional[float]: """How long a connection can take to be opened before timing out.""" return self.__connect_timeout @property - def socket_timeout(self): + def socket_timeout(self) -> Optional[float]: """How long a send or receive on a socket can take before timing out.""" return self.__socket_timeout @property - def wait_queue_timeout(self): + def wait_queue_timeout(self) -> Optional[int]: """How long a thread will wait for a socket from the pool if the pool has no free sockets. """ return self.__wait_queue_timeout @property - def _ssl_context(self): + def _ssl_context(self) -> Optional[SSLContext]: """An SSLContext instance or None.""" return self.__ssl_context @property - def tls_allow_invalid_hostnames(self): + def tls_allow_invalid_hostnames(self) -> bool: """If True skip ssl.match_hostname.""" return self.__tls_allow_invalid_hostnames @property - def _event_listeners(self): + def _event_listeners(self) -> Optional[_EventListeners]: """An instance of pymongo.monitoring._EventListeners.""" return self.__event_listeners @property - def appname(self): + def appname(self) -> Optional[str]: """The application name, for sending with hello in server handshake.""" return self.__appname @property - def driver(self): + def driver(self) -> Optional[DriverInfo]: """Driver name and version, for sending with hello in handshake.""" return self.__driver @property - def _compression_settings(self): + def _compression_settings(self) -> Optional[CompressionSettings]: return self.__compression_settings @property - def metadata(self): + def metadata(self) -> SON[str, Any]: """A dict of metadata about the application, driver, os, and platform.""" return self.__metadata.copy() @property - def server_api(self): + def server_api(self) -> Optional[ServerApi]: """A pymongo.server_api.ServerApi or None.""" return self.__server_api @property - def load_balanced(self): + def load_balanced(self) -> Optional[bool]: """True if this Pool is configured in load balanced mode.""" return self.__load_balanced class _CancellationContext: - def __init__(self): + def __init__(self) -> None: self._cancelled = False - def cancel(self): + def cancel(self) -> None: """Cancel this context.""" self._cancelled = True @property - def cancelled(self): + def cancelled(self) -> bool: """Was cancel called?""" return self._cancelled @@ -624,47 +667,48 @@ class Connection: - `id`: the id of this socket in it's pool """ - def __init__(self, conn, pool, address, id): + def __init__( + self, conn: Union[socket.socket, _sslConn], pool: Pool, address: Tuple[str, int], id: int + ): self.pool_ref = weakref.ref(pool) self.conn = conn self.address = address self.id = id - self.authed = set() self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False - self.is_writable = False + self.is_writable: bool = False self.max_wire_version = MAX_WIRE_VERSION self.max_bson_size = MAX_BSON_SIZE self.max_message_size = MAX_MESSAGE_SIZE self.max_write_batch_size = MAX_WRITE_BATCH_SIZE self.supports_sessions = False - self.hello_ok = None + self.hello_ok: bool = False self.is_mongos = False self.op_msg_enabled = False self.listeners = pool.opts._event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.compression_settings = pool.opts._compression_settings - self.compression_context = None - self.socket_checker = SocketChecker() - self.oidc_token_gen_id = None + self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None + self.socket_checker: SocketChecker = SocketChecker() + self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. - self.negotiated_mechs = None - self.auth_ctx = None + self.negotiated_mechs: Optional[List[str]] = None + self.auth_ctx: Optional[_AuthContext] = None # The pool's generation changes with each reset() so we can close # sockets created before the last reset. self.pool_gen = pool.gen self.generation = self.pool_gen.get_overall() self.ready = False - self.cancel_context = None + self.cancel_context: Optional[_CancellationContext] = None if not pool.handshake: # This is a Monitor connection. self.cancel_context = _CancellationContext() self.opts = pool.opts - self.more_to_come = False + self.more_to_come: bool = False # For load balancer support. - self.service_id = None + self.service_id: Optional[ObjectId] = None # When executing a transaction in load balancing mode, this flag is # set to true to indicate that the session now owns the connection. self.pinned_txn = False @@ -673,14 +717,16 @@ class Connection: self.last_timeout = self.opts.socket_timeout self.connect_rtt = 0.0 - def set_conn_timeout(self, timeout): + def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" if timeout == self.last_timeout: return self.last_timeout = timeout self.conn.settimeout(timeout) - def apply_timeout(self, client, cmd): + def apply_timeout( + self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: # CSOT: use remaining timeout when set. timeout = _csot.remaining() if timeout is None: @@ -704,22 +750,22 @@ class Connection: self.set_conn_timeout(timeout) return timeout - def pin_txn(self): + def pin_txn(self) -> None: self.pinned_txn = True assert not self.pinned_cursor - def pin_cursor(self): + def pin_cursor(self) -> None: self.pinned_cursor = True assert not self.pinned_txn - def unpin(self): + def unpin(self) -> None: pool = self.pool_ref() if pool: pool.checkin(self) else: self.close_conn(ConnectionClosedReason.STALE) - def hello_cmd(self): + def hello_cmd(self) -> SON[str, Any]: # Handshake spec requires us to use OP_MSG+hello command for the # initial handshake in load balanced or stable API mode. if self.opts.server_api or self.hello_ok or self.opts.load_balanced: @@ -728,10 +774,15 @@ class Connection: else: return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)]) - def hello(self): + def hello(self) -> Hello[Dict[str, Any]]: return self._hello(None, None, None) - def _hello(self, cluster_time, topology_version, heartbeat_frequency): + def _hello( + self, + cluster_time: Optional[Mapping[str, Any]], + topology_version: Optional[Any], + heartbeat_frequency: Optional[int], + ) -> Hello[Dict[str, Any]]: cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False @@ -744,6 +795,7 @@ class Connection: cmd["loadBalanced"] = True elif topology_version is not None: cmd["topologyVersion"] = topology_version + assert heartbeat_frequency is not None cmd["maxAwaitTimeMS"] = int(heartbeat_frequency * 1000) awaitable = True # If connect_timeout is None there is no timeout. @@ -808,7 +860,7 @@ class Connection: self.generation = self.pool_gen.get(self.service_id) return hello - def _next_reply(self): + def _next_reply(self) -> Dict[str, Any]: reply = self.receive_message(None) self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() @@ -819,23 +871,23 @@ class Connection: @_handle_reauth def command( self, - dbname, - spec, - read_preference=ReadPreference.PRIMARY, - codec_options=DEFAULT_CODEC_OPTIONS, - check=True, - allowable_errors=None, - read_concern=None, - write_concern=None, - parse_write_concern_error=False, - collation=None, - session=None, - client=None, - retryable_write=False, - publish_events=True, - user_fields=None, - exhaust_allowed=False, - ): + dbname: str, + spec: MutableMapping[str, Any], + read_preference: _ServerMode = ReadPreference.PRIMARY, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + read_concern: Optional[ReadConcern] = None, + write_concern: Optional[WriteConcern] = None, + parse_write_concern_error: bool = False, + collation: Optional[_CollationIn] = None, + session: Optional[ClientSession] = None, + client: Optional[MongoClient] = None, + retryable_write: bool = False, + publish_events: bool = True, + user_fields: Optional[Mapping[str, Any]] = None, + exhaust_allowed: bool = False, + ) -> Dict[str, Any]: """Execute a command or raise an error. :Parameters: @@ -873,7 +925,7 @@ class Connection: session._apply_to(spec, retryable_write, read_preference, self) self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None - unacknowledged = write_concern and not write_concern.acknowledged + unacknowledged = bool(write_concern and not write_concern.acknowledged) if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: @@ -907,7 +959,7 @@ class Connection: except BaseException as error: self._raise_connection_failure(error) - def send_message(self, message, max_doc_size): + def send_message(self, message: bytes, max_doc_size: int) -> None: """Send a raw BSON message or raise ConnectionFailure. If a network exception is raised, the socket is closed. @@ -923,7 +975,7 @@ class Connection: except BaseException as error: self._raise_connection_failure(error) - def receive_message(self, request_id): + def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. @@ -933,7 +985,7 @@ class Connection: except BaseException as error: self._raise_connection_failure(error) - def _raise_if_not_writable(self, unacknowledged): + def _raise_if_not_writable(self, unacknowledged: bool) -> None: """Raise NotPrimaryError on unacknowledged write if this socket is not writable. """ @@ -941,7 +993,7 @@ class Connection: # Write won't succeed, bail as if we'd received a not primary error. raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) - def unack_write(self, msg, max_doc_size): + def unack_write(self, msg: bytes, max_doc_size: int) -> None: """Send unack OP_MSG. Can raise ConnectionFailure or InvalidDocument. @@ -953,7 +1005,9 @@ class Connection: self._raise_if_not_writable(True) self.send_message(msg, max_doc_size) - def write_command(self, request_id, msg, codec_options): + def write_command( + self, request_id: int, msg: bytes, codec_options: CodecOptions + ) -> Dict[str, Any]: """Send "insert" etc. command, returning response as a dict. Can raise ConnectionFailure or OperationFailure. @@ -970,7 +1024,7 @@ class Connection: helpers._check_command_response(result, self.max_wire_version) return result - def authenticate(self, reauthenticate=False): + def authenticate(self, reauthenticate: bool = False) -> None: """Authenticate to the server if needed. Can raise ConnectionFailure or OperationFailure. @@ -988,9 +1042,12 @@ class Connection: auth.authenticate(creds, self, reauthenticate=reauthenticate) self.ready = True if self.enabled_for_cmap: + assert self.listeners is not None self.listeners.publish_connection_ready(self.address, self.id) - def validate_session(self, client, session): + def validate_session( + self, client: Optional[MongoClient], session: Optional[ClientSession] + ) -> None: """Validate this session before use with client. Raises error if the client is not the one that created the session. @@ -999,15 +1056,16 @@ class Connection: if session._client is not client: raise InvalidOperation("Can only use session with the MongoClient that started it") - def close_conn(self, reason): + def close_conn(self, reason: Optional[str]) -> None: """Close this connection with a reason.""" if self.closed: return self._close_conn() if reason and self.enabled_for_cmap: + assert self.listeners is not None self.listeners.publish_connection_closed(self.address, self.id, reason) - def _close_conn(self): + def _close_conn(self) -> None: """Close this connection.""" if self.closed: return @@ -1021,31 +1079,36 @@ class Connection: except Exception: pass - def conn_closed(self): + def conn_closed(self) -> bool: """Return True if we know socket has been closed, False otherwise.""" return self.socket_checker.socket_closed(self.conn) - def send_cluster_time(self, command, session, client): + def send_cluster_time( + self, + command: MutableMapping[str, Any], + session: Optional[ClientSession], + client: Optional[MongoClient], + ) -> None: """Add $clusterTime.""" if client: client._send_cluster_time(command, session) - def add_server_api(self, command): + def add_server_api(self, command: MutableMapping[str, Any]) -> None: """Add server_api parameters.""" if self.opts.server_api: _add_to_command(command, self.opts.server_api) - def update_last_checkin_time(self): + def update_last_checkin_time(self) -> None: self.last_checkin_time = time.monotonic() - def update_is_writable(self, is_writable): + def update_is_writable(self, is_writable: bool) -> None: self.is_writable = is_writable - def idle_time_seconds(self): + def idle_time_seconds(self) -> float: """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time - def _raise_connection_failure(self, error): + def _raise_connection_failure(self, error: BaseException) -> NoReturn: # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if # the underlying cause was a Ctrl-C: a signal raised during socket.recv @@ -1072,16 +1135,16 @@ class Connection: else: raise - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.conn == other.conn - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash(self.conn) - def __repr__(self): + def __repr__(self) -> str: return "Connection({}){} at {}".format( repr(self.conn), self.closed and " CLOSED" or "", @@ -1089,7 +1152,7 @@ class Connection: ) -def _create_connection(address, options): +def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: """Given (host, port) and PoolOptions, connect and return a socket object. Can raise socket.error. @@ -1160,7 +1223,7 @@ def _create_connection(address, options): raise OSError("getaddrinfo failed") -def _configured_socket(address, options): +def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: """Given (host, port) and PoolOptions, return a configured socket. Can raise socket.error, ConnectionFailure, or _CertificateError. @@ -1170,39 +1233,42 @@ def _configured_socket(address, options): sock = _create_connection(address, options) ssl_context = options._ssl_context - if ssl_context is not None: - host = address[0] - try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - if HAS_SNI: - sock = ssl_context.wrap_socket(sock, server_hostname=host) - else: - sock = ssl_context.wrap_socket(sock) - except _CertificateError: - sock.close() - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, SSLError) as exc: # noqa: B014 - sock.close() - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - _raise_connection_failure(address, exc, "SSL handshake failed: ") - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: - ssl.match_hostname(sock.getpeercert(), hostname=host) - except _CertificateError: - sock.close() - raise + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock - sock.settimeout(options.socket_timeout) - return sock + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if HAS_SNI: + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) + else: + ssl_sock = ssl_context.wrap_socket(sock) + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, SSLError) as exc: # noqa: B014 + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + _raise_connection_failure(address, exc, "SSL handshake failed: ") + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock class _PoolClosedError(PyMongoError): @@ -1212,23 +1278,23 @@ class _PoolClosedError(PyMongoError): class _PoolGeneration: - def __init__(self): + def __init__(self) -> None: # Maps service_id to generation. - self._generations = collections.defaultdict(int) + self._generations: Dict[ObjectId, int] = collections.defaultdict(int) # Overall pool generation. self._generation = 0 - def get(self, service_id): + def get(self, service_id: Optional[ObjectId]) -> int: """Get the generation for the given service_id.""" if service_id is None: return self._generation return self._generations[service_id] - def get_overall(self): + def get_overall(self) -> int: """Get the Pool's overall generation.""" return self._generation - def inc(self, service_id): + def inc(self, service_id: Optional[ObjectId]) -> None: """Increment the generation for the given service_id.""" self._generation += 1 if service_id is None: @@ -1237,7 +1303,7 @@ class _PoolGeneration: else: self._generations[service_id] += 1 - def stale(self, gen, service_id): + def stale(self, gen: int, service_id: Optional[ObjectId]) -> bool: """Return if the given generation for a given service_id is stale.""" return gen != self.get(service_id) @@ -1251,7 +1317,7 @@ class PoolState: # Do *not* explicitly inherit from object or Jython won't call __del__ # http://bugs.jython.org/issue1057 class Pool: - def __init__(self, address, options, handshake=True): + def __init__(self, address: _Address, options: PoolOptions, handshake: bool = True): """ :Parameters: - `address`: a (hostname, port) tuple @@ -1274,7 +1340,7 @@ class Pool: # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 # Track whether the sockets in this pool are writeable or not. - self.is_writable = None + self.is_writable: Optional[bool] = None # Keep track of resets, so we notice sockets created before the most # recent reset and close them. @@ -1306,6 +1372,7 @@ class Pool: self._max_connecting = self.opts.max_connecting self._pending = 0 if self.enabled_for_cmap: + assert self.opts._event_listeners is not None self.opts._event_listeners.publish_pool_created( self.address, self.opts.non_default_options ) @@ -1314,23 +1381,26 @@ class Pool: # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets = set() + self.__pinned_sockets: Set[Connection] = set() self.ncursors = 0 self.ntxns = 0 - def ready(self): + def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. with self.lock: if self.state != PoolState.READY: self.state = PoolState.READY if self.enabled_for_cmap: + assert self.opts._event_listeners is not None self.opts._event_listeners.publish_pool_ready(self.address) @property - def closed(self): + def closed(self) -> bool: return self.state == PoolState.CLOSED - def _reset(self, close, pause=True, service_id=None): + def _reset( + self, close: bool, pause: bool = True, service_id: Optional[ObjectId] = None + ) -> None: old_state = self.state with self.size_cond: if self.closed: @@ -1370,14 +1440,16 @@ class Pool: for conn in sockets: conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: + assert listeners is not None listeners.publish_pool_closed(self.address) else: if old_state != PoolState.PAUSED and self.enabled_for_cmap: + assert listeners is not None listeners.publish_pool_cleared(self.address, service_id=service_id) for conn in sockets: conn.close_conn(ConnectionClosedReason.STALE) - def update_is_writable(self, is_writable): + def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the Pool. """ @@ -1386,19 +1458,19 @@ class Pool: for _socket in self.conns: _socket.update_is_writable(self.is_writable) - def reset(self, service_id=None): + def reset(self, service_id: Optional[ObjectId] = None) -> None: self._reset(close=False, service_id=service_id) - def reset_without_pause(self): + def reset_without_pause(self) -> None: self._reset(close=False, pause=False) - def close(self): + def close(self) -> None: self._reset(close=True) - def stale_generation(self, gen, service_id): + def stale_generation(self, gen: int, service_id: Optional[ObjectId]) -> bool: return self.gen.stale(gen, service_id) - def remove_stale_sockets(self, reference_generation): + def remove_stale_sockets(self, reference_generation: int) -> None: """Removes stale sockets then adds new ones if pool is too small and has not been reset. The `reference_generation` argument specifies the `generation` at the point in time this operation was requested on the @@ -1454,7 +1526,7 @@ class Pool: self.requests -= 1 self.size_cond.notify() - def connect(self, handler=None): + def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: """Connect to Mongo and return a new Connection. Can raise ConnectionFailure. @@ -1468,12 +1540,14 @@ class Pool: listeners = self.opts._event_listeners if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_created(self.address, conn_id) try: sock = _configured_socket(self.address, self.opts) except BaseException as error: if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_closed( self.address, conn_id, ConnectionClosedReason.ERROR ) @@ -1483,7 +1557,7 @@ class Pool: raise - conn = Connection(sock, self, self.address, conn_id) + conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] try: if self.handshake: conn.hello() @@ -1499,7 +1573,7 @@ class Pool: return conn @contextlib.contextmanager - def checkout(self, handler=None): + def checkout(self, handler: Optional[_MongoClientErrorHandler] = None) -> Iterator[Connection]: """Get a connection from the pool. Use with a "with" statement. Returns a :class:`Connection` object wrapping a connected @@ -1518,11 +1592,13 @@ class Pool: """ listeners = self.opts._event_listeners if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_check_out_started(self.address) conn = self._get_conn(handler=handler) if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_checked_out(self.address, conn.id) try: yield conn @@ -1551,15 +1627,16 @@ class Pool: elif conn.active: self.checkin(conn) - def _raise_if_not_ready(self, emit_event): + def _raise_if_not_ready(self, emit_event: bool) -> None: if self.state != PoolState.READY: if self.enabled_for_cmap and emit_event: + assert self.opts._event_listeners is not None self.opts._event_listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.CONN_ERROR ) _raise_connection_failure(self.address, AutoReconnect("connection pool paused")) - def _get_conn(self, handler=None): + def _get_conn(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: """Get or create a Connection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of @@ -1569,6 +1646,7 @@ class Pool: if self.closed: if self.enabled_for_cmap: + assert self.opts._event_listeners is not None self.opts._event_listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.POOL_CLOSED ) @@ -1649,6 +1727,7 @@ class Pool: self.size_cond.notify() if self.enabled_for_cmap and not emitted_event: + assert self.opts._event_listeners is not None self.opts._event_listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.CONN_ERROR ) @@ -1657,7 +1736,7 @@ class Pool: conn.active = True return conn - def checkin(self, conn): + def checkin(self, conn: Connection) -> None: """Return the connection to the pool, or if it's closed discard it. :Parameters: @@ -1671,6 +1750,7 @@ class Pool: self.__pinned_sockets.discard(conn) listeners = self.opts._event_listeners if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_checked_in(self.address, conn.id) if self.pid != os.getpid(): self.reset_without_pause() @@ -1680,6 +1760,7 @@ class Pool: elif conn.closed: # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_closed( self.address, conn.id, ConnectionClosedReason.ERROR ) @@ -1691,7 +1772,7 @@ class Pool: conn.close_conn(ConnectionClosedReason.STALE) else: conn.update_last_checkin_time() - conn.update_is_writable(self.is_writable) + conn.update_is_writable(bool(self.is_writable)) self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() @@ -1706,7 +1787,7 @@ class Pool: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, conn): + def _perished(self, conn: Connection) -> bool: """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for @@ -1745,6 +1826,7 @@ class Pool: def _raise_wait_queue_timeout(self) -> NoReturn: listeners = self.opts._event_listeners if self.enabled_for_cmap: + assert listeners is not None listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.TIMEOUT ) @@ -1768,7 +1850,7 @@ class Pool: "maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout) ) - def __del__(self): + def __del__(self) -> None: # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. diff --git a/pymongo/server.py b/pymongo/server.py index a23a87911..2fe2443ee 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -109,7 +109,7 @@ class Server: conn: Connection, operation: Union[_Query, _GetMore], read_preference: _ServerMode, - listeners: _EventListeners, + listeners: Optional[_EventListeners], unpack_res: Callable[..., List[_DocumentOut]], ) -> Response: """Run a _Query or _GetMore operation and return a Response object. @@ -126,6 +126,7 @@ class Server: - `unpack_res`: A callable that decodes the wire protocol response. """ duration = None + assert listeners is not None publish = listeners.enabled_for_commands if publish: start = datetime.now() @@ -140,6 +141,7 @@ class Server: if publish: cmd, dbn = operation.as_command(conn) + assert listeners is not None listeners.publish_command_start( cmd, dbn, request_id, conn.address, service_id=conn.service_id ) @@ -177,6 +179,7 @@ class Server: failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) + assert listeners is not None listeners.publish_command_failure( duration, failure, @@ -196,11 +199,12 @@ class Server: elif operation.name == "explain": res = docs[0] if docs else {} else: - res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} + res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] if operation.name == "find": res["cursor"]["firstBatch"] = docs else: res["cursor"]["nextBatch"] = docs + assert listeners is not None listeners.publish_command_success( duration, res, diff --git a/test/test_monitoring.py b/test/test_monitoring.py index c7c793b38..7aa3d67eb 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -1084,10 +1084,10 @@ class TestCommandMonitoring(IntegrationTest): self.listener.reset() cmd = SON([("getnonce", 1)]) - listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address) + listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address) # type: ignore[arg-type] delta = datetime.timedelta(milliseconds=100) listeners.publish_command_success( - delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address + delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address # type: ignore[arg-type] ) started = self.listener.started_events[0] succeeded = self.listener.succeeded_events[0]