From dc59eb86c7b465d67d46a02fb0da81db963788d1 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 2 Aug 2023 12:09:06 -0700 Subject: [PATCH] PYTHON-3809 add types to monitoring.py (#1332) --- pymongo/client_options.py | 6 +- pymongo/monitoring.py | 163 +++++++++++++++++++++----------------- pymongo/network.py | 7 +- pymongo/server.py | 7 +- 4 files changed, 104 insertions(+), 79 deletions(-) diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 91ef51a52..b8a228fa4 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -15,7 +15,7 @@ """Tools to parse mongo client options.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast from bson.codec_options import _parse_codec_options from pymongo import common @@ -23,7 +23,7 @@ from pymongo.auth import MongoCredential, _build_credentials_tuple from pymongo.common import validate_boolean from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError -from pymongo.monitoring import _EventListeners +from pymongo.monitoring import _EventListener, _EventListeners from pymongo.pool import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ( @@ -152,7 +152,7 @@ def _parse_pool_options( connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) socket_timeout = options.get("sockettimeoutms") wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) - event_listeners = options.get("event_listeners") + event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) appname = options.get("appname") driver = options.get("driver") server_api = options.get("server_api") diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 2a3a662c9..73e15821d 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -187,7 +187,7 @@ from __future__ import annotations import datetime from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence from bson.objectid import ObjectId from pymongo.hello import Hello, HelloCompat @@ -195,6 +195,8 @@ from pymongo.helpers import _handle_exception from pymongo.typings import _Address, _DocumentOut if TYPE_CHECKING: + from datetime import timedelta + from pymongo.server_description import ServerDescription from pymongo.topology_description import TopologyDescription @@ -480,12 +482,14 @@ class ServerListener(_EventListener): raise NotImplementedError -def _to_micros(dur): +def _to_micros(dur: timedelta) -> int: """Convert duration 'dur' to microseconds.""" return int(dur.total_seconds() * 10e5) -def _validate_event_listeners(option, listeners): +def _validate_event_listeners( + option: str, listeners: Sequence[_EventListeners] +) -> Sequence[_EventListeners]: """Validate event listeners""" if not isinstance(listeners, abc.Sequence): raise TypeError(f"{option} must be a list or tuple") @@ -545,7 +549,7 @@ _SENSITIVE_COMMANDS: set = { # The "hello" command is also deemed sensitive when attempting speculative # authentication. -def _is_speculative_authenticate(command_name, doc): +def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: if ( command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) and "speculativeAuthenticate" in doc @@ -650,7 +654,7 @@ class CommandStartedEvent(_CommandEvent): """The name of the database this command was run against.""" return self.__db - def __repr__(self): + def __repr__(self) -> str: return ("<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}>").format( self.__class__.__name__, self.connection_id, @@ -707,7 +711,7 @@ class CommandSucceededEvent(_CommandEvent): """The server failure document for this operation.""" return self.__reply - def __repr__(self): + def __repr__(self) -> str: return ( "<{} {} command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}>" ).format( @@ -762,7 +766,7 @@ class CommandFailedEvent(_CommandEvent): """The server failure document for this operation.""" return self.__failure - def __repr__(self): + def __repr__(self) -> str: return ( "<{} {} command: {!r}, operation_id: {}, duration_micros: {}, " "failure: {!r}, service_id: {}>" @@ -792,7 +796,7 @@ class _PoolEvent: """ return self.__address - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__address!r})" @@ -817,7 +821,7 @@ class PoolCreatedEvent(_PoolEvent): """Any non-default pool options that were set on this Connection Pool.""" return self.__options - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" @@ -861,7 +865,7 @@ class PoolClearedEvent(_PoolEvent): """ return self.__service_id - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r})" @@ -933,7 +937,7 @@ class _ConnectionEvent: """ return self.__address - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__address!r})" @@ -951,7 +955,7 @@ class _ConnectionIdEvent(_ConnectionEvent): """The ID of the connection.""" return self.__connection_id - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" @@ -1000,12 +1004,12 @@ class ConnectionClosedEvent(_ConnectionIdEvent): __slots__ = ("__reason",) - def __init__(self, address, connection_id, reason): + def __init__(self, address: _Address, connection_id: int, reason: str): super().__init__(address, connection_id) self.__reason = reason @property - def reason(self): + def reason(self) -> str: """A reason explaining why this connection was closed. The reason must be one of the strings from the @@ -1013,7 +1017,7 @@ class ConnectionClosedEvent(_ConnectionIdEvent): """ return self.__reason - def __repr__(self): + def __repr__(self) -> str: return "{}({!r}, {!r}, {!r})".format( self.__class__.__name__, self.address, @@ -1061,7 +1065,7 @@ class ConnectionCheckOutFailedEvent(_ConnectionEvent): """ return self.__reason - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r})" @@ -1112,7 +1116,7 @@ class _ServerEvent: """A unique identifier for the topology this server is a part of.""" return self.__topology_id - def __repr__(self): + def __repr__(self) -> str: return "<{} {} topology_id: {}>".format( self.__class__.__name__, self.server_address, @@ -1152,7 +1156,7 @@ class ServerDescriptionChangedEvent(_ServerEvent): """ return self.__new_description - def __repr__(self): + def __repr__(self) -> str: return "<{} {} changed from: {}, to: {}>".format( self.__class__.__name__, self.server_address, @@ -1192,7 +1196,7 @@ class TopologyEvent: """A unique identifier for the topology this server is a part of.""" return self.__topology_id - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" @@ -1228,7 +1232,7 @@ class TopologyDescriptionChangedEvent(TopologyEvent): """ return self.__new_description - def __repr__(self): + def __repr__(self) -> str: return "<{} topology_id: {} changed from: {}, to: {}>".format( self.__class__.__name__, self.topology_id, @@ -1270,7 +1274,7 @@ class _ServerHeartbeatEvent: """ return self.__connection_id - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.connection_id}>" @@ -1319,7 +1323,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): """ return self.__awaited - def __repr__(self): + def __repr__(self) -> str: return "<{} {} duration: {}, awaited: {}, reply: {}>".format( self.__class__.__name__, self.connection_id, @@ -1366,7 +1370,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): """ return self.__awaited - def __repr__(self): + def __repr__(self) -> str: return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( self.__class__.__name__, self.connection_id, @@ -1385,7 +1389,7 @@ class _EventListeners: - `listeners`: A list of event listeners. """ - def __init__(self, listeners): + def __init__(self, listeners: Optional[Sequence[_EventListener]]): self.__command_listeners = _LISTENERS.command_listeners[:] self.__server_listeners = _LISTENERS.server_listeners[:] lst = _LISTENERS.server_heartbeat_listeners @@ -1411,31 +1415,31 @@ class _EventListeners: self.__enabled_for_cmap = bool(self.__cmap_listeners) @property - def enabled_for_commands(self): + def enabled_for_commands(self) -> bool: """Are any CommandListener instances registered?""" return self.__enabled_for_commands @property - def enabled_for_server(self): + def enabled_for_server(self) -> bool: """Are any ServerListener instances registered?""" return self.__enabled_for_server @property - def enabled_for_server_heartbeat(self): + def enabled_for_server_heartbeat(self) -> bool: """Are any ServerHeartbeatListener instances registered?""" return self.__enabled_for_server_heartbeat @property - def enabled_for_topology(self): + def enabled_for_topology(self) -> bool: """Are any TopologyListener instances registered?""" return self.__enabled_for_topology @property - def enabled_for_cmap(self): + def enabled_for_cmap(self) -> bool: """Are any ConnectionPoolListener instances registered?""" return self.__enabled_for_cmap - def event_listeners(self): + def event_listeners(self) -> List[_EventListeners]: """List of registered event listeners.""" return ( self.__command_listeners @@ -1446,8 +1450,14 @@ class _EventListeners: ) def publish_command_start( - self, command, database_name, request_id, connection_id, op_id=None, service_id=None - ): + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + ) -> None: """Publish a CommandStartedEvent to all command listeners. :Parameters: @@ -1473,15 +1483,15 @@ class _EventListeners: def publish_command_success( self, - duration, - reply, - command_name, - request_id, - connection_id, - op_id=None, - service_id=None, - speculative_hello=False, - ): + duration: timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + speculative_hello: bool = False, + ) -> None: """Publish a CommandSucceededEvent to all command listeners. :Parameters: @@ -1512,14 +1522,14 @@ class _EventListeners: def publish_command_failure( self, - duration, - failure, - command_name, - request_id, - connection_id, - op_id=None, - service_id=None, - ): + duration: timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + ) -> None: """Publish a CommandFailedEvent to all command listeners. :Parameters: @@ -1544,7 +1554,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_server_heartbeat_started(self, connection_id): + def publish_server_heartbeat_started(self, connection_id: _Address) -> None: """Publish a ServerHeartbeatStartedEvent to all server heartbeat listeners. @@ -1558,7 +1568,9 @@ class _EventListeners: except Exception: _handle_exception() - def publish_server_heartbeat_succeeded(self, connection_id, duration, reply, awaited): + def publish_server_heartbeat_succeeded( + self, connection_id: _Address, duration: float, reply: Hello, awaited: bool + ) -> None: """Publish a ServerHeartbeatSucceededEvent to all server heartbeat listeners. @@ -1576,7 +1588,9 @@ class _EventListeners: except Exception: _handle_exception() - def publish_server_heartbeat_failed(self, connection_id, duration, reply, awaited): + def publish_server_heartbeat_failed( + self, connection_id: _Address, duration: float, reply: Exception, awaited: bool + ) -> None: """Publish a ServerHeartbeatFailedEvent to all server heartbeat listeners. @@ -1594,7 +1608,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_server_opened(self, server_address, topology_id): + def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: """Publish a ServerOpeningEvent to all server listeners. :Parameters: @@ -1609,7 +1623,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_server_closed(self, server_address, topology_id): + def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: """Publish a ServerClosedEvent to all server listeners. :Parameters: @@ -1625,8 +1639,12 @@ class _EventListeners: _handle_exception() def publish_server_description_changed( - self, previous_description, new_description, server_address, topology_id - ): + self, + previous_description: ServerDescription, + new_description: ServerDescription, + server_address: _Address, + topology_id: ObjectId, + ) -> None: """Publish a ServerDescriptionChangedEvent to all server listeners. :Parameters: @@ -1645,7 +1663,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_topology_opened(self, topology_id): + def publish_topology_opened(self, topology_id: ObjectId) -> None: """Publish a TopologyOpenedEvent to all topology listeners. :Parameters: @@ -1659,7 +1677,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_topology_closed(self, topology_id): + def publish_topology_closed(self, topology_id: ObjectId) -> None: """Publish a TopologyClosedEvent to all topology listeners. :Parameters: @@ -1674,8 +1692,11 @@ class _EventListeners: _handle_exception() def publish_topology_description_changed( - self, previous_description, new_description, topology_id - ): + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + topology_id: ObjectId, + ) -> None: """Publish a TopologyDescriptionChangedEvent to all topology listeners. :Parameters: @@ -1691,7 +1712,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_pool_created(self, address, options): + def publish_pool_created(self, address: _Address, options: Dict[str, Any]) -> None: """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" event = PoolCreatedEvent(address, options) for subscriber in self.__cmap_listeners: @@ -1700,7 +1721,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_pool_ready(self, address): + def publish_pool_ready(self, address: _Address) -> None: """Publish a :class:`PoolReadyEvent` to all pool listeners.""" event = PoolReadyEvent(address) for subscriber in self.__cmap_listeners: @@ -1709,7 +1730,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_pool_cleared(self, address, service_id): + def publish_pool_cleared(self, address: _Address, service_id: Optional[ObjectId]) -> None: """Publish a :class:`PoolClearedEvent` to all pool listeners.""" event = PoolClearedEvent(address, service_id) for subscriber in self.__cmap_listeners: @@ -1718,7 +1739,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_pool_closed(self, address): + def publish_pool_closed(self, address: _Address) -> None: """Publish a :class:`PoolClosedEvent` to all pool listeners.""" event = PoolClosedEvent(address) for subscriber in self.__cmap_listeners: @@ -1727,7 +1748,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_created(self, address, connection_id): + def publish_connection_created(self, address: _Address, connection_id: int) -> None: """Publish a :class:`ConnectionCreatedEvent` to all connection listeners. """ @@ -1738,7 +1759,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_ready(self, address, connection_id): + def publish_connection_ready(self, address: _Address, connection_id: int) -> None: """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" event = ConnectionReadyEvent(address, connection_id) for subscriber in self.__cmap_listeners: @@ -1747,7 +1768,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_closed(self, address, connection_id, reason): + def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: """Publish a :class:`ConnectionClosedEvent` to all connection listeners. """ @@ -1758,7 +1779,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_check_out_started(self, address): + def publish_connection_check_out_started(self, address: _Address) -> None: """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection listeners. """ @@ -1769,7 +1790,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_check_out_failed(self, address, reason): + def publish_connection_check_out_failed(self, address: _Address, reason: str) -> None: """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection listeners. """ @@ -1780,7 +1801,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_checked_out(self, address, connection_id): + def publish_connection_checked_out(self, address: _Address, connection_id: int) -> None: """Publish a :class:`ConnectionCheckedOutEvent` to all connection listeners. """ @@ -1791,7 +1812,7 @@ class _EventListeners: except Exception: _handle_exception() - def publish_connection_checked_in(self, address, connection_id): + def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: """Publish a :class:`ConnectionCheckedInEvent` to all connection listeners. """ diff --git a/pymongo/network.py b/pymongo/network.py index 139f7b2ae..0d40955be 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -55,7 +55,7 @@ if TYPE_CHECKING: from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address + from pymongo.typings import _Address, _DocumentOut from pymongo.write_concern import WriteConcern _UNPACK_HEADER = struct.Struct("