PYTHON-3813 add types to pool.py (#1318)

This commit is contained in:
Iris 2023-08-04 17:28:30 -07:00 committed by GitHub
parent 54840752d2
commit e0b8b36f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 316 additions and 207 deletions

View File

@ -239,12 +239,15 @@ def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanis
ctx = conn.auth_ctx
if ctx and ctx.speculate_succeeded():
assert isinstance(ctx, _ScramContext)
assert ctx.scram_data is not None
nonce, first_bare = ctx.scram_data
res = ctx.speculative_authenticate
else:
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
res = conn.command(source, cmd)
assert res is not None
server_first = res["payload"]
parsed = _parse_scram_response(server_first)
iterations = int(parsed[b"i"])
@ -575,7 +578,7 @@ class _ScramContext(_AuthContext):
class _X509Context(_AuthContext):
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
def speculate_command(self) -> MutableMapping[str, Any]:
cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")])
if self.credentials.username is not None:
cmd["user"] = self.credentials.username

View File

@ -19,7 +19,17 @@ import os
import threading
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
)
import bson
from bson.binary import Binary
@ -242,7 +252,9 @@ class _OIDCAuthenticator:
self.idp_resp = None
self.token_exp_utc = None
def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
def run_command(
self, conn: Connection, cmd: MutableMapping[str, Any]
) -> Optional[Mapping[str, Any]]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure as exc:
@ -276,6 +288,7 @@ class _OIDCAuthenticator:
assert cmd is not None
resp = self.run_command(conn, cmd)
assert resp is not None
if resp["done"]:
conn.oidc_token_gen_id = self.token_gen_id
return None
@ -297,6 +310,7 @@ class _OIDCAuthenticator:
]
)
resp = self.run_command(conn, cmd)
assert resp is not None
if not resp["done"]:
self.clear()
raise OperationFailure("SASL conversation failed to complete.")

View File

@ -15,7 +15,7 @@
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast
from bson.codec_options import _parse_codec_options
from pymongo import common
@ -309,11 +309,12 @@ class ClientOptions:
return self.__load_balanced
@property
def event_listeners(self) -> _EventListeners:
def event_listeners(self) -> List[_EventListeners]:
"""The event listeners registered for this client.
See :mod:`~pymongo.monitoring` for details.
.. versionadded:: 4.0
"""
assert self.__pool_options._event_listeners is not None
return self.__pool_options._event_listeners.event_listeners()

View File

@ -20,7 +20,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Container,
ContextManager,
Generic,
Iterable,
@ -268,11 +267,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _command(
self,
conn: Connection,
command: Mapping[str, Any],
command: MutableMapping[str, Any],
read_preference: Optional[_ServerMode] = None,
codec_options: Optional[CodecOptions] = None,
check: bool = True,
allowable_errors: Optional[Container[Any]] = None,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
collation: Optional[_CollationIn] = None,
@ -1753,7 +1752,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
session: Optional[ClientSession],
conn: Connection,
read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any],
cmd: SON[str, Any],
collation: Optional[Collation],
) -> int:
"""Internal count command helper."""
@ -1777,7 +1776,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self,
conn: Connection,
read_preference: Optional[_ServerMode],
cmd: Mapping[str, Any],
cmd: SON[str, Any],
collation: Optional[_CollationIn],
session: Optional[ClientSession],
) -> Optional[Mapping[str, Any]]:

View File

@ -25,6 +25,7 @@ from typing import (
Mapping,
NoReturn,
Optional,
Sequence,
Union,
)
@ -220,7 +221,7 @@ class CommandCursor(Generic[_DocumentType]):
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _refresh(self) -> int:

View File

@ -14,7 +14,7 @@
from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Union
from typing import Any, Iterable, List, Optional, Union
try:
import snappy
@ -96,7 +96,7 @@ class CompressionSettings:
self.zlib_compression_level = zlib_compression_level
def get_compression_context(
self, compressors: List[str]
self, compressors: Optional[List[str]]
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
if compressors:
chosen = compressors[0]

View File

@ -1130,7 +1130,7 @@ class Cursor(Generic[_DocumentType]):
codec_options: CodecOptions,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
) -> Sequence[_DocumentOut]:
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
def _read_preference(self) -> _ServerMode:

View File

@ -694,7 +694,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
@ -711,7 +711,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions[_CodecDocumentType] = ...,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
@ -727,7 +727,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
value: int = 1,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: Union[
CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType]
] = DEFAULT_CODEC_OPTIONS,

View File

@ -1018,7 +1018,7 @@ class _BulkWriteContext:
cmd = self._start(cmd, request_id, docs)
start = datetime.datetime.now()
try:
result = self.conn.unack_write(msg, max_doc_size)
result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value]
if self.publish:
duration = (datetime.datetime.now() - start) + duration
if result is not None:
@ -1050,7 +1050,7 @@ class _BulkWriteContext:
request_id: int,
msg: bytes,
docs: List[Mapping[str, Any]],
) -> Mapping[str, Any]:
) -> Dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
if self.publish:
assert self.start_time is not None
@ -1127,7 +1127,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
) -> Tuple[Dict[str, Any], List[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
msg, to_send = _encode_batched_write_command(
namespace, self.op_type, cmd, docs, self.codec, self
@ -1517,7 +1517,7 @@ class _OpReply:
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
) -> List[Dict[str, Any]]:
"""Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary
@ -1541,7 +1541,7 @@ class _OpReply:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(self.documents, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
"""Unpack a command response."""
docs = self.unpack_response(codec_options=codec_options)
assert self.number_returned == 1
@ -1604,7 +1604,7 @@ class _OpMsg:
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> List[_DocumentOut]:
) -> List[Dict[str, Any]]:
"""Unpack a OP_MSG command response.
:Parameters:
@ -1619,7 +1619,7 @@ class _OpMsg:
assert not legacy_response
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]

View File

@ -116,6 +116,7 @@ if TYPE_CHECKING:
import sys
from types import TracebackType
from bson.objectid import ObjectId
from pymongo.bulk import _Bulk
from pymongo.client_session import ClientSession, _ServerSession
from pymongo.cursor import _ConnectionManager
@ -1898,7 +1899,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
else:
yield None
def _send_cluster_time(self, command: MutableMapping[str, Any], session: ClientSession) -> None:
def _send_cluster_time(
self, command: MutableMapping[str, Any], session: Optional[ClientSession]
) -> None:
topology_time = self._topology.max_cluster_time()
session_time = session.cluster_time if session else None
if topology_time and session_time:
@ -2255,7 +2258,7 @@ class _MongoClientErrorHandler:
# of the pool at the time the connection attempt was started."
self.sock_generation = server.pool.gen.get_overall()
self.completed_handshake = False
self.service_id = None
self.service_id: Optional[ObjectId] = None
self.handled = False
def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None:

View File

@ -32,7 +32,7 @@ from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.pool import Connection, Pool
from pymongo.pool import Connection, Pool, _CancellationContext
from pymongo.settings import TopologySettings
from pymongo.topology import Topology
@ -131,9 +131,8 @@ class Monitor(MonitorBase):
self._pool = pool
self._settings = topology_settings
self._listeners = self._settings._pool_options._event_listeners
pub = self._listeners is not None
self._publish = pub and self._listeners.enabled_for_server_heartbeat
self._cancel_context = None
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
self._cancel_context: Optional[_CancellationContext] = None
self._rtt_monitor = _RttMonitor(
topology,
topology_settings,
@ -238,7 +237,8 @@ class Monitor(MonitorBase):
address = sd.address
duration = time.monotonic() - start
if self._publish:
awaited = sd.is_server_type_known and sd.topology_version
awaited = bool(sd.is_server_type_known and sd.topology_version)
assert self._listeners is not None
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
self._reset_connection()
if isinstance(error, _OperationCancelled):
@ -254,6 +254,7 @@ class Monitor(MonitorBase):
"""
address = self._server_description.address
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_started(address)
if self._cancel_context and self._cancel_context.cancelled:
@ -267,6 +268,7 @@ class Monitor(MonitorBase):
avg_rtt, min_rtt = self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
self._listeners.publish_server_heartbeat_succeeded(
address, round_trip_time, response, response.awaitable
)

View File

@ -47,14 +47,13 @@ from pymongo.socket_checker import _errno_from_exception
if TYPE_CHECKING:
from bson import CodecOptions
from pymongo.client_session import ClientSession
from pymongo.collation import Collation
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.mongo_client import MongoClient
from pymongo.monitoring import _EventListeners
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _DocumentOut, _DocumentType
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
from pymongo.write_concern import WriteConcern
_UNPACK_HEADER = struct.Struct("<iiii").unpack
@ -65,7 +64,7 @@ def command(
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
read_preference: _ServerMode,
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[ClientSession],
client: Optional[MongoClient],
@ -76,7 +75,7 @@ def command(
max_bson_size: Optional[int] = None,
read_concern: Optional[ReadConcern] = None,
parse_write_concern_error: bool = False,
collation: Optional[Collation] = None,
collation: Optional[_CollationIn] = None,
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
use_op_msg: bool = False,
unacknowledged: bool = False,
@ -119,6 +118,7 @@ def command(
# Publish the original command document, perhaps with lsid and $clusterTime.
orig = spec
if is_mongos and not use_op_msg:
assert read_preference is not None
spec = message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session.in_transaction):
if read_concern.level:
@ -232,7 +232,7 @@ _UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message(
conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():

View File

@ -12,6 +12,8 @@
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import collections
import contextlib
import copy
@ -23,7 +25,21 @@ import sys
import threading
import time
import weakref
from typing import Any, Dict, NoReturn, Optional
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import bson
from bson import DEFAULT_CODEC_OPTIONS
@ -59,7 +75,11 @@ from pymongo.errors import (
from pymongo.hello import Hello, HelloCompat
from pymongo.helpers import _handle_reauth
from pymongo.lock import _create_lock
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
_EventListeners,
)
from pymongo.network import command, receive_message
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
@ -67,10 +87,31 @@ from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
if TYPE_CHECKING:
from bson import CodecOptions
from bson.objectid import ObjectId
from pymongo.auth import MongoCredential, _AuthContext
from pymongo.client_session import ClientSession
from pymongo.compression_support import (
CompressionSettings,
SnappyContext,
ZlibContext,
ZstdContext,
)
from pymongo.driver_info import DriverInfo
from pymongo.message import _OpMsg, _OpReply
from pymongo.mongo_client import MongoClient, _MongoClientErrorHandler
from pymongo.pyopenssl_context import SSLContext, _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.server_api import ServerApi
from pymongo.typings import _Address, _CollationIn
from pymongo.write_concern import WriteConcern
try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
def _set_non_inheritable_non_atomic(fd):
def _set_non_inheritable_non_atomic(fd: int) -> None:
"""Set the close-on-exec flag on the given file descriptor."""
flags = fcntl(fd, F_GETFD)
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
@ -79,7 +120,7 @@ except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ...), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd):
def _set_non_inheritable_non_atomic(fd: int) -> None:
"""Dummy function for platforms that don't provide fcntl."""
@ -123,7 +164,7 @@ if sys.platform == "win32":
else:
def _set_tcp_option(sock, tcp_option, max_value):
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
@ -136,7 +177,7 @@ else:
except OSError:
pass
def _set_keepalive_times(sock):
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
@ -302,7 +343,7 @@ _MAX_METADATA_SIZE = 512
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
def _truncate_metadata(metadata):
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
"""Perform metadata truncation."""
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
return
@ -365,7 +406,7 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error
def _cond_wait(condition, deadline):
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return condition.wait(timeout)
@ -406,23 +447,23 @@ class PoolOptions:
def __init__(
self,
max_pool_size=MAX_POOL_SIZE,
min_pool_size=MIN_POOL_SIZE,
max_idle_time_seconds=MAX_IDLE_TIME_SEC,
connect_timeout=None,
socket_timeout=None,
wait_queue_timeout=WAIT_QUEUE_TIMEOUT,
ssl_context=None,
tls_allow_invalid_hostnames=False,
event_listeners=None,
appname=None,
driver=None,
compression_settings=None,
max_connecting=MAX_CONNECTING,
pause_enabled=True,
server_api=None,
load_balanced=None,
credentials=None,
max_pool_size: int = MAX_POOL_SIZE,
min_pool_size: int = MIN_POOL_SIZE,
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
connect_timeout: Optional[float] = None,
socket_timeout: Optional[float] = None,
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
ssl_context: Optional[SSLContext] = None,
tls_allow_invalid_hostnames: bool = False,
event_listeners: Optional[_EventListeners] = None,
appname: Optional[str] = None,
driver: Optional[DriverInfo] = None,
compression_settings: Optional[CompressionSettings] = None,
max_connecting: int = MAX_CONNECTING,
pause_enabled: bool = True,
server_api: Optional[ServerApi] = None,
load_balanced: Optional[bool] = None,
credentials: Optional[MongoCredential] = None,
):
self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
@ -474,12 +515,12 @@ class PoolOptions:
_truncate_metadata(self.__metadata)
@property
def _credentials(self):
def _credentials(self) -> Optional[MongoCredential]:
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
return self.__credentials
@property
def non_default_options(self):
def non_default_options(self) -> Dict[str, Any]:
"""The non-default options this pool was created with.
Added for CMAP's :class:`PoolCreatedEvent`.
@ -490,15 +531,17 @@ class PoolOptions:
if self.__min_pool_size != MIN_POOL_SIZE:
opts["minPoolSize"] = self.__min_pool_size
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
assert self.__max_idle_time_seconds is not None
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
assert self.__wait_queue_timeout is not None
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
if self.__max_connecting != MAX_CONNECTING:
opts["maxConnecting"] = self.__max_connecting
return opts
@property
def max_pool_size(self):
def max_pool_size(self) -> float:
"""The maximum allowable number of concurrent connections to each
connected server. Requests to a server will block if there are
`maxPoolSize` outstanding connections to the requested server.
@ -513,25 +556,25 @@ class PoolOptions:
return self.__max_pool_size
@property
def min_pool_size(self):
def min_pool_size(self) -> int:
"""The minimum required number of concurrent connections that the pool
will maintain to each connected server. Default is 0.
"""
return self.__min_pool_size
@property
def max_connecting(self):
def max_connecting(self) -> int:
"""The maximum number of concurrent connection creation attempts per
pool. Defaults to 2.
"""
return self.__max_connecting
@property
def pause_enabled(self):
def pause_enabled(self) -> bool:
return self.__pause_enabled
@property
def max_idle_time_seconds(self):
def max_idle_time_seconds(self) -> Optional[int]:
"""The maximum number of seconds that a connection can remain
idle in the pool before being removed and replaced. Defaults to
`None` (no limit).
@ -539,77 +582,77 @@ class PoolOptions:
return self.__max_idle_time_seconds
@property
def connect_timeout(self):
def connect_timeout(self) -> Optional[float]:
"""How long a connection can take to be opened before timing out."""
return self.__connect_timeout
@property
def socket_timeout(self):
def socket_timeout(self) -> Optional[float]:
"""How long a send or receive on a socket can take before timing out."""
return self.__socket_timeout
@property
def wait_queue_timeout(self):
def wait_queue_timeout(self) -> Optional[int]:
"""How long a thread will wait for a socket from the pool if the pool
has no free sockets.
"""
return self.__wait_queue_timeout
@property
def _ssl_context(self):
def _ssl_context(self) -> Optional[SSLContext]:
"""An SSLContext instance or None."""
return self.__ssl_context
@property
def tls_allow_invalid_hostnames(self):
def tls_allow_invalid_hostnames(self) -> bool:
"""If True skip ssl.match_hostname."""
return self.__tls_allow_invalid_hostnames
@property
def _event_listeners(self):
def _event_listeners(self) -> Optional[_EventListeners]:
"""An instance of pymongo.monitoring._EventListeners."""
return self.__event_listeners
@property
def appname(self):
def appname(self) -> Optional[str]:
"""The application name, for sending with hello in server handshake."""
return self.__appname
@property
def driver(self):
def driver(self) -> Optional[DriverInfo]:
"""Driver name and version, for sending with hello in handshake."""
return self.__driver
@property
def _compression_settings(self):
def _compression_settings(self) -> Optional[CompressionSettings]:
return self.__compression_settings
@property
def metadata(self):
def metadata(self) -> SON[str, Any]:
"""A dict of metadata about the application, driver, os, and platform."""
return self.__metadata.copy()
@property
def server_api(self):
def server_api(self) -> Optional[ServerApi]:
"""A pymongo.server_api.ServerApi or None."""
return self.__server_api
@property
def load_balanced(self):
def load_balanced(self) -> Optional[bool]:
"""True if this Pool is configured in load balanced mode."""
return self.__load_balanced
class _CancellationContext:
def __init__(self):
def __init__(self) -> None:
self._cancelled = False
def cancel(self):
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self):
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
@ -624,47 +667,48 @@ class Connection:
- `id`: the id of this socket in it's pool
"""
def __init__(self, conn, pool, address, id):
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: Tuple[str, int], id: int
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
self.address = address
self.id = id
self.authed = set()
self.closed = False
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size = MAX_BSON_SIZE
self.max_message_size = MAX_MESSAGE_SIZE
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok = None
self.hello_ok: bool = False
self.is_mongos = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.compression_settings = pool.opts._compression_settings
self.compression_context = None
self.socket_checker = SocketChecker()
self.oidc_token_gen_id = None
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.socket_checker: SocketChecker = SocketChecker()
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs = None
self.auth_ctx = None
self.negotiated_mechs: Optional[List[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
self.cancel_context = None
self.cancel_context: Optional[_CancellationContext] = None
if not pool.handshake:
# This is a Monitor connection.
self.cancel_context = _CancellationContext()
self.opts = pool.opts
self.more_to_come = False
self.more_to_come: bool = False
# For load balancer support.
self.service_id = None
self.service_id: Optional[ObjectId] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
@ -673,14 +717,16 @@ class Connection:
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
def set_conn_timeout(self, timeout):
def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.conn.settimeout(timeout)
def apply_timeout(self, client, cmd):
def apply_timeout(
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
) -> Optional[float]:
# CSOT: use remaining timeout when set.
timeout = _csot.remaining()
if timeout is None:
@ -704,22 +750,22 @@ class Connection:
self.set_conn_timeout(timeout)
return timeout
def pin_txn(self):
def pin_txn(self) -> None:
self.pinned_txn = True
assert not self.pinned_cursor
def pin_cursor(self):
def pin_cursor(self) -> None:
self.pinned_cursor = True
assert not self.pinned_txn
def unpin(self):
def unpin(self) -> None:
pool = self.pool_ref()
if pool:
pool.checkin(self)
else:
self.close_conn(ConnectionClosedReason.STALE)
def hello_cmd(self):
def hello_cmd(self) -> SON[str, Any]:
# Handshake spec requires us to use OP_MSG+hello command for the
# initial handshake in load balanced or stable API mode.
if self.opts.server_api or self.hello_ok or self.opts.load_balanced:
@ -728,10 +774,15 @@ class Connection:
else:
return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)])
def hello(self):
def hello(self) -> Hello[Dict[str, Any]]:
return self._hello(None, None, None)
def _hello(self, cluster_time, topology_version, heartbeat_frequency):
def _hello(
self,
cluster_time: Optional[Mapping[str, Any]],
topology_version: Optional[Any],
heartbeat_frequency: Optional[int],
) -> Hello[Dict[str, Any]]:
cmd = self.hello_cmd()
performing_handshake = not self.performed_handshake
awaitable = False
@ -744,6 +795,7 @@ class Connection:
cmd["loadBalanced"] = True
elif topology_version is not None:
cmd["topologyVersion"] = topology_version
assert heartbeat_frequency is not None
cmd["maxAwaitTimeMS"] = int(heartbeat_frequency * 1000)
awaitable = True
# If connect_timeout is None there is no timeout.
@ -808,7 +860,7 @@ class Connection:
self.generation = self.pool_gen.get(self.service_id)
return hello
def _next_reply(self):
def _next_reply(self) -> Dict[str, Any]:
reply = self.receive_message(None)
self.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response()
@ -819,23 +871,23 @@ class Connection:
@_handle_reauth
def command(
self,
dbname,
spec,
read_preference=ReadPreference.PRIMARY,
codec_options=DEFAULT_CODEC_OPTIONS,
check=True,
allowable_errors=None,
read_concern=None,
write_concern=None,
parse_write_concern_error=False,
collation=None,
session=None,
client=None,
retryable_write=False,
publish_events=True,
user_fields=None,
exhaust_allowed=False,
):
dbname: str,
spec: MutableMapping[str, Any],
read_preference: _ServerMode = ReadPreference.PRIMARY,
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
check: bool = True,
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
read_concern: Optional[ReadConcern] = None,
write_concern: Optional[WriteConcern] = None,
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
session: Optional[ClientSession] = None,
client: Optional[MongoClient] = None,
retryable_write: bool = False,
publish_events: bool = True,
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
) -> Dict[str, Any]:
"""Execute a command or raise an error.
:Parameters:
@ -873,7 +925,7 @@ class Connection:
session._apply_to(spec, retryable_write, read_preference, self)
self.send_cluster_time(spec, session, client)
listeners = self.listeners if publish_events else None
unacknowledged = write_concern and not write_concern.acknowledged
unacknowledged = bool(write_concern and not write_concern.acknowledged)
if self.op_msg_enabled:
self._raise_if_not_writable(unacknowledged)
try:
@ -907,7 +959,7 @@ class Connection:
except BaseException as error:
self._raise_connection_failure(error)
def send_message(self, message, max_doc_size):
def send_message(self, message: bytes, max_doc_size: int) -> None:
"""Send a raw BSON message or raise ConnectionFailure.
If a network exception is raised, the socket is closed.
@ -923,7 +975,7 @@ class Connection:
except BaseException as error:
self._raise_connection_failure(error)
def receive_message(self, request_id):
def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.
If any exception is raised, the socket is closed.
@ -933,7 +985,7 @@ class Connection:
except BaseException as error:
self._raise_connection_failure(error)
def _raise_if_not_writable(self, unacknowledged):
def _raise_if_not_writable(self, unacknowledged: bool) -> None:
"""Raise NotPrimaryError on unacknowledged write if this socket is not
writable.
"""
@ -941,7 +993,7 @@ class Connection:
# Write won't succeed, bail as if we'd received a not primary error.
raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107})
def unack_write(self, msg, max_doc_size):
def unack_write(self, msg: bytes, max_doc_size: int) -> None:
"""Send unack OP_MSG.
Can raise ConnectionFailure or InvalidDocument.
@ -953,7 +1005,9 @@ class Connection:
self._raise_if_not_writable(True)
self.send_message(msg, max_doc_size)
def write_command(self, request_id, msg, codec_options):
def write_command(
self, request_id: int, msg: bytes, codec_options: CodecOptions
) -> Dict[str, Any]:
"""Send "insert" etc. command, returning response as a dict.
Can raise ConnectionFailure or OperationFailure.
@ -970,7 +1024,7 @@ class Connection:
helpers._check_command_response(result, self.max_wire_version)
return result
def authenticate(self, reauthenticate=False):
def authenticate(self, reauthenticate: bool = False) -> None:
"""Authenticate to the server if needed.
Can raise ConnectionFailure or OperationFailure.
@ -988,9 +1042,12 @@ class Connection:
auth.authenticate(creds, self, reauthenticate=reauthenticate)
self.ready = True
if self.enabled_for_cmap:
assert self.listeners is not None
self.listeners.publish_connection_ready(self.address, self.id)
def validate_session(self, client, session):
def validate_session(
self, client: Optional[MongoClient], session: Optional[ClientSession]
) -> None:
"""Validate this session before use with client.
Raises error if the client is not the one that created the session.
@ -999,15 +1056,16 @@ class Connection:
if session._client is not client:
raise InvalidOperation("Can only use session with the MongoClient that started it")
def close_conn(self, reason):
def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
self._close_conn()
if reason and self.enabled_for_cmap:
assert self.listeners is not None
self.listeners.publish_connection_closed(self.address, self.id, reason)
def _close_conn(self):
def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
@ -1021,31 +1079,36 @@ class Connection:
except Exception:
pass
def conn_closed(self):
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.conn)
def send_cluster_time(self, command, session, client):
def send_cluster_time(
self,
command: MutableMapping[str, Any],
session: Optional[ClientSession],
client: Optional[MongoClient],
) -> None:
"""Add $clusterTime."""
if client:
client._send_cluster_time(command, session)
def add_server_api(self, command):
def add_server_api(self, command: MutableMapping[str, Any]) -> None:
"""Add server_api parameters."""
if self.opts.server_api:
_add_to_command(command, self.opts.server_api)
def update_last_checkin_time(self):
def update_last_checkin_time(self) -> None:
self.last_checkin_time = time.monotonic()
def update_is_writable(self, is_writable):
def update_is_writable(self, is_writable: bool) -> None:
self.is_writable = is_writable
def idle_time_seconds(self):
def idle_time_seconds(self) -> float:
"""Seconds since this socket was last checked into its pool."""
return time.monotonic() - self.last_checkin_time
def _raise_connection_failure(self, error):
def _raise_connection_failure(self, error: BaseException) -> NoReturn:
# Catch *all* exceptions from socket methods and close the socket. In
# regular Python, socket operations only raise socket.error, even if
# the underlying cause was a Ctrl-C: a signal raised during socket.recv
@ -1072,16 +1135,16 @@ class Connection:
else:
raise
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return self.conn == other.conn
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self):
def __hash__(self) -> int:
return hash(self.conn)
def __repr__(self):
def __repr__(self) -> str:
return "Connection({}){} at {}".format(
repr(self.conn),
self.closed and " CLOSED" or "",
@ -1089,7 +1152,7 @@ class Connection:
)
def _create_connection(address, options):
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.
Can raise socket.error.
@ -1160,7 +1223,7 @@ def _create_connection(address, options):
raise OSError("getaddrinfo failed")
def _configured_socket(address, options):
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
@ -1170,39 +1233,42 @@ def _configured_socket(address, options):
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is not None:
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
sock = ssl_context.wrap_socket(sock)
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc: # noqa: B014
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
_raise_connection_failure(address, exc, "SSL handshake failed: ")
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(sock.getpeercert(), hostname=host)
except _CertificateError:
sock.close()
raise
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
ssl_sock = ssl_context.wrap_socket(sock)
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc: # noqa: B014
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
_raise_connection_failure(address, exc, "SSL handshake failed: ")
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
class _PoolClosedError(PyMongoError):
@ -1212,23 +1278,23 @@ class _PoolClosedError(PyMongoError):
class _PoolGeneration:
def __init__(self):
def __init__(self) -> None:
# Maps service_id to generation.
self._generations = collections.defaultdict(int)
self._generations: Dict[ObjectId, int] = collections.defaultdict(int)
# Overall pool generation.
self._generation = 0
def get(self, service_id):
def get(self, service_id: Optional[ObjectId]) -> int:
"""Get the generation for the given service_id."""
if service_id is None:
return self._generation
return self._generations[service_id]
def get_overall(self):
def get_overall(self) -> int:
"""Get the Pool's overall generation."""
return self._generation
def inc(self, service_id):
def inc(self, service_id: Optional[ObjectId]) -> None:
"""Increment the generation for the given service_id."""
self._generation += 1
if service_id is None:
@ -1237,7 +1303,7 @@ class _PoolGeneration:
else:
self._generations[service_id] += 1
def stale(self, gen, service_id):
def stale(self, gen: int, service_id: Optional[ObjectId]) -> bool:
"""Return if the given generation for a given service_id is stale."""
return gen != self.get(service_id)
@ -1251,7 +1317,7 @@ class PoolState:
# Do *not* explicitly inherit from object or Jython won't call __del__
# http://bugs.jython.org/issue1057
class Pool:
def __init__(self, address, options, handshake=True):
def __init__(self, address: _Address, options: PoolOptions, handshake: bool = True):
"""
:Parameters:
- `address`: a (hostname, port) tuple
@ -1274,7 +1340,7 @@ class Pool:
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
# Track whether the sockets in this pool are writeable or not.
self.is_writable = None
self.is_writable: Optional[bool] = None
# Keep track of resets, so we notice sockets created before the most
# recent reset and close them.
@ -1306,6 +1372,7 @@ class Pool:
self._max_connecting = self.opts.max_connecting
self._pending = 0
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_created(
self.address, self.opts.non_default_options
)
@ -1314,23 +1381,26 @@ class Pool:
# Retain references to pinned connections to prevent the CPython GC
# from thinking that a cursor's pinned connection can be GC'd when the
# cursor is GC'd (see PYTHON-2751).
self.__pinned_sockets = set()
self.__pinned_sockets: Set[Connection] = set()
self.ncursors = 0
self.ntxns = 0
def ready(self):
def ready(self) -> None:
# Take the lock to avoid the race condition described in PYTHON-2699.
with self.lock:
if self.state != PoolState.READY:
self.state = PoolState.READY
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_pool_ready(self.address)
@property
def closed(self):
def closed(self) -> bool:
return self.state == PoolState.CLOSED
def _reset(self, close, pause=True, service_id=None):
def _reset(
self, close: bool, pause: bool = True, service_id: Optional[ObjectId] = None
) -> None:
old_state = self.state
with self.size_cond:
if self.closed:
@ -1370,14 +1440,16 @@ class Pool:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
else:
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_cleared(self.address, service_id=service_id)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable):
def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
Pool.
"""
@ -1386,19 +1458,19 @@ class Pool:
for _socket in self.conns:
_socket.update_is_writable(self.is_writable)
def reset(self, service_id=None):
def reset(self, service_id: Optional[ObjectId] = None) -> None:
self._reset(close=False, service_id=service_id)
def reset_without_pause(self):
def reset_without_pause(self) -> None:
self._reset(close=False, pause=False)
def close(self):
def close(self) -> None:
self._reset(close=True)
def stale_generation(self, gen, service_id):
def stale_generation(self, gen: int, service_id: Optional[ObjectId]) -> bool:
return self.gen.stale(gen, service_id)
def remove_stale_sockets(self, reference_generation):
def remove_stale_sockets(self, reference_generation: int) -> None:
"""Removes stale sockets then adds new ones if pool is too small and
has not been reset. The `reference_generation` argument specifies the
`generation` at the point in time this operation was requested on the
@ -1454,7 +1526,7 @@ class Pool:
self.requests -= 1
self.size_cond.notify()
def connect(self, handler=None):
def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Connect to Mongo and return a new Connection.
Can raise ConnectionFailure.
@ -1468,12 +1540,14 @@ class Pool:
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_created(self.address, conn_id)
try:
sock = _configured_socket(self.address, self.opts)
except BaseException as error:
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_closed(
self.address, conn_id, ConnectionClosedReason.ERROR
)
@ -1483,7 +1557,7 @@ class Pool:
raise
conn = Connection(sock, self, self.address, conn_id)
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
try:
if self.handshake:
conn.hello()
@ -1499,7 +1573,7 @@ class Pool:
return conn
@contextlib.contextmanager
def checkout(self, handler=None):
def checkout(self, handler: Optional[_MongoClientErrorHandler] = None) -> Iterator[Connection]:
"""Get a connection from the pool. Use with a "with" statement.
Returns a :class:`Connection` object wrapping a connected
@ -1518,11 +1592,13 @@ class Pool:
"""
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_check_out_started(self.address)
conn = self._get_conn(handler=handler)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_checked_out(self.address, conn.id)
try:
yield conn
@ -1551,15 +1627,16 @@ class Pool:
elif conn.active:
self.checkin(conn)
def _raise_if_not_ready(self, emit_event):
def _raise_if_not_ready(self, emit_event: bool) -> None:
if self.state != PoolState.READY:
if self.enabled_for_cmap and emit_event:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.CONN_ERROR
)
_raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
def _get_conn(self, handler=None):
def _get_conn(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
"""Get or create a Connection. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
@ -1569,6 +1646,7 @@ class Pool:
if self.closed:
if self.enabled_for_cmap:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.POOL_CLOSED
)
@ -1649,6 +1727,7 @@ class Pool:
self.size_cond.notify()
if self.enabled_for_cmap and not emitted_event:
assert self.opts._event_listeners is not None
self.opts._event_listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.CONN_ERROR
)
@ -1657,7 +1736,7 @@ class Pool:
conn.active = True
return conn
def checkin(self, conn):
def checkin(self, conn: Connection) -> None:
"""Return the connection to the pool, or if it's closed discard it.
:Parameters:
@ -1671,6 +1750,7 @@ class Pool:
self.__pinned_sockets.discard(conn)
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_checked_in(self.address, conn.id)
if self.pid != os.getpid():
self.reset_without_pause()
@ -1680,6 +1760,7 @@ class Pool:
elif conn.closed:
# CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_closed(
self.address, conn.id, ConnectionClosedReason.ERROR
)
@ -1691,7 +1772,7 @@ class Pool:
conn.close_conn(ConnectionClosedReason.STALE)
else:
conn.update_last_checkin_time()
conn.update_is_writable(self.is_writable)
conn.update_is_writable(bool(self.is_writable))
self.conns.appendleft(conn)
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()
@ -1706,7 +1787,7 @@ class Pool:
self.operation_count -= 1
self.size_cond.notify()
def _perished(self, conn):
def _perished(self, conn: Connection) -> bool:
"""Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for
@ -1745,6 +1826,7 @@ class Pool:
def _raise_wait_queue_timeout(self) -> NoReturn:
listeners = self.opts._event_listeners
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.TIMEOUT
)
@ -1768,7 +1850,7 @@ class Pool:
"maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout)
)
def __del__(self):
def __del__(self) -> None:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.

View File

@ -109,7 +109,7 @@ class Server:
conn: Connection,
operation: Union[_Query, _GetMore],
read_preference: _ServerMode,
listeners: _EventListeners,
listeners: Optional[_EventListeners],
unpack_res: Callable[..., List[_DocumentOut]],
) -> Response:
"""Run a _Query or _GetMore operation and return a Response object.
@ -126,6 +126,7 @@ class Server:
- `unpack_res`: A callable that decodes the wire protocol response.
"""
duration = None
assert listeners is not None
publish = listeners.enabled_for_commands
if publish:
start = datetime.now()
@ -140,6 +141,7 @@ class Server:
if publish:
cmd, dbn = operation.as_command(conn)
assert listeners is not None
listeners.publish_command_start(
cmd, dbn, request_id, conn.address, service_id=conn.service_id
)
@ -177,6 +179,7 @@ class Server:
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
assert listeners is not None
listeners.publish_command_failure(
duration,
failure,
@ -196,11 +199,12 @@ class Server:
elif operation.name == "explain":
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1}
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
if operation.name == "find":
res["cursor"]["firstBatch"] = docs
else:
res["cursor"]["nextBatch"] = docs
assert listeners is not None
listeners.publish_command_success(
duration,
res,

View File

@ -1084,10 +1084,10 @@ class TestCommandMonitoring(IntegrationTest):
self.listener.reset()
cmd = SON([("getnonce", 1)])
listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address)
listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address) # type: ignore[arg-type]
delta = datetime.timedelta(milliseconds=100)
listeners.publish_command_success(
delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address
delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address # type: ignore[arg-type]
)
started = self.listener.started_events[0]
succeeded = self.listener.succeeded_events[0]