PYTHON-3809 add types to monitoring.py (#1332)

This commit is contained in:
Iris 2023-08-02 12:09:06 -07:00 committed by GitHub
parent b7796e1794
commit dc59eb86c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 79 deletions

View File

@ -15,7 +15,7 @@
"""Tools to parse mongo client options."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast
from bson.codec_options import _parse_codec_options
from pymongo import common
@ -23,7 +23,7 @@ from pymongo.auth import MongoCredential, _build_credentials_tuple
from pymongo.common import validate_boolean
from pymongo.compression_support import CompressionSettings
from pymongo.errors import ConfigurationError
from pymongo.monitoring import _EventListeners
from pymongo.monitoring import _EventListener, _EventListeners
from pymongo.pool import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (
@ -152,7 +152,7 @@ def _parse_pool_options(
connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT)
socket_timeout = options.get("sockettimeoutms")
wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT)
event_listeners = options.get("event_listeners")
event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners"))
appname = options.get("appname")
driver = options.get("driver")
server_api = options.get("server_api")

View File

@ -187,7 +187,7 @@ from __future__ import annotations
import datetime
from collections import abc, namedtuple
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from bson.objectid import ObjectId
from pymongo.hello import Hello, HelloCompat
@ -195,6 +195,8 @@ from pymongo.helpers import _handle_exception
from pymongo.typings import _Address, _DocumentOut
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TopologyDescription
@ -480,12 +482,14 @@ class ServerListener(_EventListener):
raise NotImplementedError
def _to_micros(dur):
def _to_micros(dur: timedelta) -> int:
"""Convert duration 'dur' to microseconds."""
return int(dur.total_seconds() * 10e5)
def _validate_event_listeners(option, listeners):
def _validate_event_listeners(
option: str, listeners: Sequence[_EventListeners]
) -> Sequence[_EventListeners]:
"""Validate event listeners"""
if not isinstance(listeners, abc.Sequence):
raise TypeError(f"{option} must be a list or tuple")
@ -545,7 +549,7 @@ _SENSITIVE_COMMANDS: set = {
# The "hello" command is also deemed sensitive when attempting speculative
# authentication.
def _is_speculative_authenticate(command_name, doc):
def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool:
if (
command_name.lower() in ("hello", HelloCompat.LEGACY_CMD)
and "speculativeAuthenticate" in doc
@ -650,7 +654,7 @@ class CommandStartedEvent(_CommandEvent):
"""The name of the database this command was run against."""
return self.__db
def __repr__(self):
def __repr__(self) -> str:
return ("<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}>").format(
self.__class__.__name__,
self.connection_id,
@ -707,7 +711,7 @@ class CommandSucceededEvent(_CommandEvent):
"""The server failure document for this operation."""
return self.__reply
def __repr__(self):
def __repr__(self) -> str:
return (
"<{} {} command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}>"
).format(
@ -762,7 +766,7 @@ class CommandFailedEvent(_CommandEvent):
"""The server failure document for this operation."""
return self.__failure
def __repr__(self):
def __repr__(self) -> str:
return (
"<{} {} command: {!r}, operation_id: {}, duration_micros: {}, "
"failure: {!r}, service_id: {}>"
@ -792,7 +796,7 @@ class _PoolEvent:
"""
return self.__address
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__address!r})"
@ -817,7 +821,7 @@ class PoolCreatedEvent(_PoolEvent):
"""Any non-default pool options that were set on this Connection Pool."""
return self.__options
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})"
@ -861,7 +865,7 @@ class PoolClearedEvent(_PoolEvent):
"""
return self.__service_id
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r})"
@ -933,7 +937,7 @@ class _ConnectionEvent:
"""
return self.__address
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__address!r})"
@ -951,7 +955,7 @@ class _ConnectionIdEvent(_ConnectionEvent):
"""The ID of the connection."""
return self.__connection_id
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})"
@ -1000,12 +1004,12 @@ class ConnectionClosedEvent(_ConnectionIdEvent):
__slots__ = ("__reason",)
def __init__(self, address, connection_id, reason):
def __init__(self, address: _Address, connection_id: int, reason: str):
super().__init__(address, connection_id)
self.__reason = reason
@property
def reason(self):
def reason(self) -> str:
"""A reason explaining why this connection was closed.
The reason must be one of the strings from the
@ -1013,7 +1017,7 @@ class ConnectionClosedEvent(_ConnectionIdEvent):
"""
return self.__reason
def __repr__(self):
def __repr__(self) -> str:
return "{}({!r}, {!r}, {!r})".format(
self.__class__.__name__,
self.address,
@ -1061,7 +1065,7 @@ class ConnectionCheckOutFailedEvent(_ConnectionEvent):
"""
return self.__reason
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r})"
@ -1112,7 +1116,7 @@ class _ServerEvent:
"""A unique identifier for the topology this server is a part of."""
return self.__topology_id
def __repr__(self):
def __repr__(self) -> str:
return "<{} {} topology_id: {}>".format(
self.__class__.__name__,
self.server_address,
@ -1152,7 +1156,7 @@ class ServerDescriptionChangedEvent(_ServerEvent):
"""
return self.__new_description
def __repr__(self):
def __repr__(self) -> str:
return "<{} {} changed from: {}, to: {}>".format(
self.__class__.__name__,
self.server_address,
@ -1192,7 +1196,7 @@ class TopologyEvent:
"""A unique identifier for the topology this server is a part of."""
return self.__topology_id
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} topology_id: {self.topology_id}>"
@ -1228,7 +1232,7 @@ class TopologyDescriptionChangedEvent(TopologyEvent):
"""
return self.__new_description
def __repr__(self):
def __repr__(self) -> str:
return "<{} topology_id: {} changed from: {}, to: {}>".format(
self.__class__.__name__,
self.topology_id,
@ -1270,7 +1274,7 @@ class _ServerHeartbeatEvent:
"""
return self.__connection_id
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.connection_id}>"
@ -1319,7 +1323,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent):
"""
return self.__awaited
def __repr__(self):
def __repr__(self) -> str:
return "<{} {} duration: {}, awaited: {}, reply: {}>".format(
self.__class__.__name__,
self.connection_id,
@ -1366,7 +1370,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent):
"""
return self.__awaited
def __repr__(self):
def __repr__(self) -> str:
return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format(
self.__class__.__name__,
self.connection_id,
@ -1385,7 +1389,7 @@ class _EventListeners:
- `listeners`: A list of event listeners.
"""
def __init__(self, listeners):
def __init__(self, listeners: Optional[Sequence[_EventListener]]):
self.__command_listeners = _LISTENERS.command_listeners[:]
self.__server_listeners = _LISTENERS.server_listeners[:]
lst = _LISTENERS.server_heartbeat_listeners
@ -1411,31 +1415,31 @@ class _EventListeners:
self.__enabled_for_cmap = bool(self.__cmap_listeners)
@property
def enabled_for_commands(self):
def enabled_for_commands(self) -> bool:
"""Are any CommandListener instances registered?"""
return self.__enabled_for_commands
@property
def enabled_for_server(self):
def enabled_for_server(self) -> bool:
"""Are any ServerListener instances registered?"""
return self.__enabled_for_server
@property
def enabled_for_server_heartbeat(self):
def enabled_for_server_heartbeat(self) -> bool:
"""Are any ServerHeartbeatListener instances registered?"""
return self.__enabled_for_server_heartbeat
@property
def enabled_for_topology(self):
def enabled_for_topology(self) -> bool:
"""Are any TopologyListener instances registered?"""
return self.__enabled_for_topology
@property
def enabled_for_cmap(self):
def enabled_for_cmap(self) -> bool:
"""Are any ConnectionPoolListener instances registered?"""
return self.__enabled_for_cmap
def event_listeners(self):
def event_listeners(self) -> List[_EventListeners]:
"""List of registered event listeners."""
return (
self.__command_listeners
@ -1446,8 +1450,14 @@ class _EventListeners:
)
def publish_command_start(
self, command, database_name, request_id, connection_id, op_id=None, service_id=None
):
self,
command: _DocumentOut,
database_name: str,
request_id: int,
connection_id: _Address,
op_id: Optional[int] = None,
service_id: Optional[ObjectId] = None,
) -> None:
"""Publish a CommandStartedEvent to all command listeners.
:Parameters:
@ -1473,15 +1483,15 @@ class _EventListeners:
def publish_command_success(
self,
duration,
reply,
command_name,
request_id,
connection_id,
op_id=None,
service_id=None,
speculative_hello=False,
):
duration: timedelta,
reply: _DocumentOut,
command_name: str,
request_id: int,
connection_id: _Address,
op_id: Optional[int] = None,
service_id: Optional[ObjectId] = None,
speculative_hello: bool = False,
) -> None:
"""Publish a CommandSucceededEvent to all command listeners.
:Parameters:
@ -1512,14 +1522,14 @@ class _EventListeners:
def publish_command_failure(
self,
duration,
failure,
command_name,
request_id,
connection_id,
op_id=None,
service_id=None,
):
duration: timedelta,
failure: _DocumentOut,
command_name: str,
request_id: int,
connection_id: _Address,
op_id: Optional[int] = None,
service_id: Optional[ObjectId] = None,
) -> None:
"""Publish a CommandFailedEvent to all command listeners.
:Parameters:
@ -1544,7 +1554,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_server_heartbeat_started(self, connection_id):
def publish_server_heartbeat_started(self, connection_id: _Address) -> None:
"""Publish a ServerHeartbeatStartedEvent to all server heartbeat
listeners.
@ -1558,7 +1568,9 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_server_heartbeat_succeeded(self, connection_id, duration, reply, awaited):
def publish_server_heartbeat_succeeded(
self, connection_id: _Address, duration: float, reply: Hello, awaited: bool
) -> None:
"""Publish a ServerHeartbeatSucceededEvent to all server heartbeat
listeners.
@ -1576,7 +1588,9 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_server_heartbeat_failed(self, connection_id, duration, reply, awaited):
def publish_server_heartbeat_failed(
self, connection_id: _Address, duration: float, reply: Exception, awaited: bool
) -> None:
"""Publish a ServerHeartbeatFailedEvent to all server heartbeat
listeners.
@ -1594,7 +1608,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_server_opened(self, server_address, topology_id):
def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None:
"""Publish a ServerOpeningEvent to all server listeners.
:Parameters:
@ -1609,7 +1623,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_server_closed(self, server_address, topology_id):
def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None:
"""Publish a ServerClosedEvent to all server listeners.
:Parameters:
@ -1625,8 +1639,12 @@ class _EventListeners:
_handle_exception()
def publish_server_description_changed(
self, previous_description, new_description, server_address, topology_id
):
self,
previous_description: ServerDescription,
new_description: ServerDescription,
server_address: _Address,
topology_id: ObjectId,
) -> None:
"""Publish a ServerDescriptionChangedEvent to all server listeners.
:Parameters:
@ -1645,7 +1663,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_topology_opened(self, topology_id):
def publish_topology_opened(self, topology_id: ObjectId) -> None:
"""Publish a TopologyOpenedEvent to all topology listeners.
:Parameters:
@ -1659,7 +1677,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_topology_closed(self, topology_id):
def publish_topology_closed(self, topology_id: ObjectId) -> None:
"""Publish a TopologyClosedEvent to all topology listeners.
:Parameters:
@ -1674,8 +1692,11 @@ class _EventListeners:
_handle_exception()
def publish_topology_description_changed(
self, previous_description, new_description, topology_id
):
self,
previous_description: TopologyDescription,
new_description: TopologyDescription,
topology_id: ObjectId,
) -> None:
"""Publish a TopologyDescriptionChangedEvent to all topology listeners.
:Parameters:
@ -1691,7 +1712,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_pool_created(self, address, options):
def publish_pool_created(self, address: _Address, options: Dict[str, Any]) -> None:
"""Publish a :class:`PoolCreatedEvent` to all pool listeners."""
event = PoolCreatedEvent(address, options)
for subscriber in self.__cmap_listeners:
@ -1700,7 +1721,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_pool_ready(self, address):
def publish_pool_ready(self, address: _Address) -> None:
"""Publish a :class:`PoolReadyEvent` to all pool listeners."""
event = PoolReadyEvent(address)
for subscriber in self.__cmap_listeners:
@ -1709,7 +1730,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_pool_cleared(self, address, service_id):
def publish_pool_cleared(self, address: _Address, service_id: Optional[ObjectId]) -> None:
"""Publish a :class:`PoolClearedEvent` to all pool listeners."""
event = PoolClearedEvent(address, service_id)
for subscriber in self.__cmap_listeners:
@ -1718,7 +1739,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_pool_closed(self, address):
def publish_pool_closed(self, address: _Address) -> None:
"""Publish a :class:`PoolClosedEvent` to all pool listeners."""
event = PoolClosedEvent(address)
for subscriber in self.__cmap_listeners:
@ -1727,7 +1748,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_created(self, address, connection_id):
def publish_connection_created(self, address: _Address, connection_id: int) -> None:
"""Publish a :class:`ConnectionCreatedEvent` to all connection
listeners.
"""
@ -1738,7 +1759,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_ready(self, address, connection_id):
def publish_connection_ready(self, address: _Address, connection_id: int) -> None:
"""Publish a :class:`ConnectionReadyEvent` to all connection listeners."""
event = ConnectionReadyEvent(address, connection_id)
for subscriber in self.__cmap_listeners:
@ -1747,7 +1768,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_closed(self, address, connection_id, reason):
def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None:
"""Publish a :class:`ConnectionClosedEvent` to all connection
listeners.
"""
@ -1758,7 +1779,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_check_out_started(self, address):
def publish_connection_check_out_started(self, address: _Address) -> None:
"""Publish a :class:`ConnectionCheckOutStartedEvent` to all connection
listeners.
"""
@ -1769,7 +1790,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_check_out_failed(self, address, reason):
def publish_connection_check_out_failed(self, address: _Address, reason: str) -> None:
"""Publish a :class:`ConnectionCheckOutFailedEvent` to all connection
listeners.
"""
@ -1780,7 +1801,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_checked_out(self, address, connection_id):
def publish_connection_checked_out(self, address: _Address, connection_id: int) -> None:
"""Publish a :class:`ConnectionCheckedOutEvent` to all connection
listeners.
"""
@ -1791,7 +1812,7 @@ class _EventListeners:
except Exception:
_handle_exception()
def publish_connection_checked_in(self, address, connection_id):
def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None:
"""Publish a :class:`ConnectionCheckedInEvent` to all connection
listeners.
"""

View File

@ -55,7 +55,7 @@ if TYPE_CHECKING:
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address
from pymongo.typings import _Address, _DocumentOut
from pymongo.write_concern import WriteConcern
_UNPACK_HEADER = struct.Struct("<iiii").unpack
@ -166,6 +166,7 @@ def command(
if publish:
encoding_duration = datetime.datetime.now() - start
assert listeners is not None
assert address is not None
listeners.publish_command_start(
orig, dbname, request_id, address, service_id=conn.service_id
)
@ -198,10 +199,11 @@ def command(
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure = exc.details
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = message._convert_exception(exc)
assert listeners is not None
assert address is not None
listeners.publish_command_failure(
duration, failure, name, request_id, address, service_id=conn.service_id
)
@ -209,6 +211,7 @@ def command(
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
assert listeners is not None
assert address is not None
listeners.publish_command_success(
duration,
response_doc,

View File

@ -45,6 +45,7 @@ if TYPE_CHECKING:
from pymongo.pool import Connection, Pool
from pymongo.read_preferences import _ServerMode
from pymongo.server_description import ServerDescription
from pymongo.typings import _DocumentOut
_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}}
@ -174,7 +175,7 @@ class Server:
if publish:
duration = datetime.now() - start
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure = exc.details
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
listeners.publish_command_failure(
@ -192,9 +193,9 @@ class Server:
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res = docs[0]
res: _DocumentOut = docs[0] # type: ignore[assignment]
elif operation.name == "explain":
res = docs[0] if docs else {}
res = docs[0] if docs else {} # type: ignore[assignment]
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1}
if operation.name == "find":