PYTHON-3867 add types to topology.py (#1346)

This commit is contained in:
Iris 2023-08-09 14:21:43 -07:00 committed by GitHub
parent 5d6d8ca68e
commit 34da931b3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 144 additions and 86 deletions

View File

@ -179,7 +179,7 @@ if TYPE_CHECKING:
from pymongo.pool import Connection
from pymongo.server import Server
from pymongo.typings import _Address
from pymongo.typings import ClusterTime, _Address
class SessionOptions:
@ -562,7 +562,7 @@ class ClientSession:
return self._server_session.session_id
@property
def cluster_time(self) -> Optional[Mapping[str, Any]]:
def cluster_time(self) -> Optional[ClusterTime]:
"""The cluster time returned by the last operation executed
in this session.
"""

View File

@ -22,7 +22,7 @@ from typing import Any, Generic, List, Mapping, Optional, Set, Tuple
from bson.objectid import ObjectId
from pymongo import common
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _DocumentType
from pymongo.typings import ClusterTime, _DocumentType
class HelloCompat:
@ -155,7 +155,7 @@ class Hello(Generic[_DocumentType]):
return self._doc.get("electionId")
@property
def cluster_time(self) -> Optional[Mapping[str, Any]]:
def cluster_time(self) -> Optional[ClusterTime]:
return self._doc.get("$clusterTime")
@property

View File

@ -98,6 +98,7 @@ from pymongo.settings import TopologySettings
from pymongo.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
from pymongo.typings import (
ClusterTime,
_Address,
_CollationIn,
_DocumentType,
@ -1106,10 +1107,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
return self._topology.get_primary()
return self._topology.get_primary() # type: ignore[return-value]
@property
def secondaries(self) -> Set[Tuple[str, int]]:
def secondaries(self) -> Set[_Address]:
"""The secondary members known to this client.
A sequence of (host, port) pairs. Empty if this client is not
@ -1122,7 +1123,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return self._topology.get_secondaries()
@property
def arbiters(self) -> Set[Tuple[str, int]]:
def arbiters(self) -> Set[_Address]:
"""Arbiters in the replica set.
A sequence of (host, port) pairs. Empty if this client is not
@ -1729,7 +1730,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if address:
# address could be a tuple or _CursorAddress, but
# select_server_by_address needs (host, port).
server = topology.select_server_by_address(tuple(address))
server = topology.select_server_by_address(tuple(address)) # type: ignore[arg-type]
else:
# Application called close_cursor() with no address.
server = topology.select_server(writable_server_selector)
@ -1906,7 +1907,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
session_time = session.cluster_time if session else None
if topology_time and session_time:
if topology_time["clusterTime"] > session_time["clusterTime"]:
cluster_time = topology_time
cluster_time: Optional[ClusterTime] = topology_time
else:
cluster_time = session_time
else:
@ -2271,7 +2272,7 @@ class _MongoClientErrorHandler:
def handle(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException]
) -> None:
if self.handled or exc_type is None:
if self.handled or exc_val is None:
return
self.handled = True
if self.session:
@ -2285,7 +2286,6 @@ class _MongoClientErrorHandler:
"RetryableWriteError"
):
self.session._unpin()
err_ctx = _ErrorContext(
exc_val,
self.max_wire_version,
@ -2300,8 +2300,8 @@ class _MongoClientErrorHandler:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_type: Optional[Type[Exception]],
exc_val: Optional[Exception],
exc_tb: Optional[TracebackType],
) -> None:
return self.handle(exc_type, exc_val)

View File

@ -105,7 +105,7 @@ if TYPE_CHECKING:
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.server_api import ServerApi
from pymongo.typings import _Address, _CollationIn
from pymongo.typings import ClusterTime, _Address, _CollationIn
from pymongo.write_concern import WriteConcern
try:
@ -779,7 +779,7 @@ class Connection:
def _hello(
self,
cluster_time: Optional[Mapping[str, Any]],
cluster_time: Optional[ClusterTime],
topology_version: Optional[Any],
heartbeat_frequency: Optional[int],
) -> Hello[Dict[str, Any]]:

View File

@ -22,7 +22,7 @@ from bson import EPOCH_NAIVE
from bson.objectid import ObjectId
from pymongo.hello import Hello
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import _Address
from pymongo.typings import ClusterTime, _Address
class ServerDescription:
@ -176,7 +176,7 @@ class ServerDescription:
return self._election_id
@property
def cluster_time(self) -> Optional[Mapping[str, Any]]:
def cluster_time(self) -> Optional[ClusterTime]:
return self._cluster_time
@property

View File

@ -95,11 +95,11 @@ class TopologySettings:
return self._pool_options
@property
def monitor_class(self) -> Optional[Type[monitor.Monitor]]:
def monitor_class(self) -> Type[monitor.Monitor]:
return self._monitor_class
@property
def condition_class(self) -> Optional[Type[threading.Condition]]:
def condition_class(self) -> Type[threading.Condition]:
return self._condition_class
@property

View File

@ -14,16 +14,29 @@
"""Internal class to monitor a topology of one or more servers."""
from __future__ import annotations
import os
import queue
import random
import time
import warnings
import weakref
from typing import Any
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
cast,
)
from pymongo import _csot, common, helpers, periodic_executor
from pymongo.client_session import _ServerSessionPool
from pymongo.client_session import _ServerSession, _ServerSessionPool
from pymongo.errors import (
ConfigurationError,
ConnectionFailure,
@ -38,7 +51,7 @@ from pymongo.errors import (
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.monitor import SrvMonitor
from pymongo.pool import PoolOptions
from pymongo.pool import Pool, PoolOptions
from pymongo.server import Server
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import (
@ -57,8 +70,13 @@ from pymongo.topology_description import (
updated_topology_description,
)
if TYPE_CHECKING:
from bson import ObjectId
from pymongo.settings import TopologySettings
from pymongo.typings import ClusterTime, _Address
def process_events_queue(queue_ref):
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
q = queue_ref()
if not q:
return False # Cancel PeriodicExecutor.
@ -78,12 +96,11 @@ def process_events_queue(queue_ref):
class Topology:
"""Monitor a topology of one or more servers."""
def __init__(self, topology_settings):
def __init__(self, topology_settings: TopologySettings):
self._topology_id = topology_settings._topology_id
self._listeners = topology_settings._pool_options._event_listeners
pub = self._listeners is not None
self._publish_server = pub and self._listeners.enabled_for_server
self._publish_tp = pub and self._listeners.enabled_for_topology
self._publish_server = self._listeners is not None and self._listeners.enabled_for_server
self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology
# Create events queue if there are publishers.
self._events = None
@ -129,14 +146,16 @@ class Topology:
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock)
self._servers = {}
self._pid = None
self._max_cluster_time = None
self._servers: Dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
self._session_pool = _ServerSessionPool()
if self._publish_server or self._publish_tp:
assert self._events is not None
weak: weakref.ReferenceType[queue.Queue]
def target():
def target() -> bool:
return process_events_queue(weak)
executor = periodic_executor.PeriodicExecutor(
@ -157,7 +176,7 @@ class Topology:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
def open(self):
def open(self) -> None:
"""Start monitoring, or restart after a fork.
No effect if called multiple times.
@ -191,14 +210,19 @@ class Topology:
with self._lock:
self._ensure_opened()
def get_server_selection_timeout(self):
def get_server_selection_timeout(self) -> float:
# CSOT: use remaining timeout when set.
timeout = _csot.remaining()
if timeout is None:
return self._settings.server_selection_timeout
return timeout
def select_servers(self, selector, server_selection_timeout=None, address=None):
def select_servers(
self,
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
) -> List[Server]:
"""Return a list of Servers matching selector, or time out.
:Parameters:
@ -222,9 +246,16 @@ class Topology:
with self._lock:
server_descriptions = self._select_servers_loop(selector, server_timeout, address)
return [self.get_server_by_address(sd.address) for sd in server_descriptions]
return [
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
]
def _select_servers_loop(self, selector, timeout, address):
def _select_servers_loop(
self,
selector: Callable[[Selection], Selection],
timeout: float,
address: Optional[_Address],
) -> List[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
end_time = now + timeout
@ -256,7 +287,12 @@ class Topology:
self._description.check_compatible()
return server_descriptions
def _select_server(self, selector, server_selection_timeout=None, address=None):
def _select_server(
self,
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
) -> Server:
servers = self.select_servers(selector, server_selection_timeout, address)
if len(servers) == 1:
return servers[0]
@ -266,14 +302,21 @@ class Topology:
else:
return server2
def select_server(self, selector, server_selection_timeout=None, address=None):
def select_server(
self,
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
) -> Server:
"""Like select_servers, but choose a random server if several match."""
server = self._select_server(selector, server_selection_timeout, address)
if _csot.get_timeout():
_csot.set_rtt(server.description.min_round_trip_time)
return server
def select_server_by_address(self, address, server_selection_timeout=None):
def select_server_by_address(
self, address: _Address, server_selection_timeout: Optional[int] = None
) -> Server:
"""Return a Server for "address", reconnecting if necessary.
If the server's type is not known, request an immediate check of all
@ -293,7 +336,9 @@ class Topology:
"""
return self.select_server(any_server_selector, server_selection_timeout, address)
def _process_change(self, server_description, reset_pool=False):
def _process_change(
self, server_description: ServerDescription, reset_pool: bool = False
) -> None:
"""Process a new ServerDescription on an opened topology.
Hold the lock when calling this.
@ -354,7 +399,7 @@ class Topology:
# Wake waiters in select_servers().
self._condition.notify_all()
def on_change(self, server_description, reset_pool=False):
def on_change(self, server_description: ServerDescription, reset_pool: bool = False) -> None:
"""Process a new ServerDescription after an hello call completes."""
# We do no I/O holding the lock.
with self._lock:
@ -369,7 +414,7 @@ class Topology:
if self._opened and self._description.has_server(server_description.address):
self._process_change(server_description, reset_pool)
def _process_srv_update(self, seedlist):
def _process_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.
Hold the lock when calling this.
"""
@ -389,14 +434,14 @@ class Topology:
)
)
def on_srv_update(self, seedlist):
def on_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
"""Process a new list of nodes obtained from scanning SRV records."""
# We do no I/O holding the lock.
with self._lock:
if self._opened:
self._process_srv_update(seedlist)
def get_server_by_address(self, address):
def get_server_by_address(self, address: _Address) -> Optional[Server]:
"""Get a Server or None.
Returns the current version of the server immediately, even if it's
@ -406,10 +451,10 @@ class Topology:
"""
return self._servers.get(address)
def has_server(self, address):
def has_server(self, address: _Address) -> bool:
return address in self._servers
def get_primary(self):
def get_primary(self) -> Optional[_Address]:
"""Return primary's address or None."""
# Implemented here in Topology instead of MongoClient, so it can lock.
with self._lock:
@ -419,7 +464,7 @@ class Topology:
return writable_server_selector(self._new_selection())[0].address
def _get_replica_set_members(self, selector):
def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> Set[_Address]:
"""Return set of replica set member addresses."""
# Implemented here in Topology instead of MongoClient, so it can lock.
with self._lock:
@ -430,21 +475,21 @@ class Topology:
):
return set()
return {sd.address for sd in selector(self._new_selection())}
return {sd.address for sd in iter(selector(self._new_selection()))}
def get_secondaries(self):
def get_secondaries(self) -> Set[_Address]:
"""Return set of secondary addresses."""
return self._get_replica_set_members(secondary_server_selector)
def get_arbiters(self):
def get_arbiters(self) -> Set[_Address]:
"""Return set of arbiter addresses."""
return self._get_replica_set_members(arbiter_server_selector)
def max_cluster_time(self):
def max_cluster_time(self) -> Optional[ClusterTime]:
"""Return a document, the highest seen $clusterTime."""
return self._max_cluster_time
def _receive_cluster_time_no_lock(self, cluster_time):
def _receive_cluster_time_no_lock(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
# Driver Sessions Spec: "Whenever a driver receives a cluster time from
# a server it MUST compare it to the current highest seen cluster time
# for the deployment. If the new cluster time is higher than the
@ -459,17 +504,17 @@ class Topology:
):
self._max_cluster_time = cluster_time
def receive_cluster_time(self, cluster_time):
def receive_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
with self._lock:
self._receive_cluster_time_no_lock(cluster_time)
def request_check_all(self, wait_time=5):
def request_check_all(self, wait_time: int = 5) -> None:
"""Wake all monitors, wait for at least one to check its server."""
with self._lock:
self._request_check_all()
self._condition.wait(wait_time)
def data_bearing_servers(self):
def data_bearing_servers(self) -> List[ServerDescription]:
"""Return a list of all data-bearing servers.
This includes any server that might be selected for an operation.
@ -478,7 +523,7 @@ class Topology:
return self._description.known_servers
return self._description.readable_servers
def update_pool(self):
def update_pool(self) -> None:
# Remove any stale sockets and add new sockets if pool is too small.
servers = []
with self._lock:
@ -495,7 +540,7 @@ class Topology:
self.handle_error(server.description.address, ctx)
raise
def close(self):
def close(self) -> None:
"""Clear pools and terminate monitors. Topology does not reopen on
demand. Any further operations will raise
:exc:`~.errors.InvalidOperation`.
@ -525,19 +570,19 @@ class Topology:
self.__events_executor.close()
@property
def description(self):
def description(self) -> TopologyDescription:
return self._description
def pop_all_sessions(self):
def pop_all_sessions(self) -> List[_ServerSession]:
"""Pop all session ids from the pool."""
with self._lock:
return self._session_pool.pop_all()
def _check_implicit_session_support(self):
def _check_implicit_session_support(self) -> None:
with self._lock:
self._check_session_support()
def _check_session_support(self):
def _check_session_support(self) -> float:
"""Internal check for session support on clusters."""
if self._settings.load_balanced:
# Sessions never time out in load balanced mode.
@ -560,13 +605,13 @@ class Topology:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return session_timeout
def get_server_session(self):
def get_server_session(self) -> _ServerSession:
"""Start or resume a server session, or raise ConfigurationError."""
with self._lock:
session_timeout = self._check_session_support()
return self._session_pool.get_server_session(session_timeout)
def return_server_session(self, server_session, lock):
def return_server_session(self, server_session: _ServerSession, lock: bool) -> None:
if lock:
with self._lock:
self._session_pool.return_server_session(
@ -576,14 +621,14 @@ class Topology:
# Called from a __del__ method, can't use a lock.
self._session_pool.return_server_session_no_lock(server_session)
def _new_selection(self):
def _new_selection(self) -> Selection:
"""A Selection object, initially including all known servers.
Hold the lock when calling this.
"""
return Selection.from_topology_description(self._description)
def _ensure_opened(self):
def _ensure_opened(self) -> None:
"""Start monitors, or restart after a fork.
Hold the lock when calling this.
@ -616,7 +661,7 @@ class Topology:
for server in self._servers.values():
server.open()
def _is_stale_error(self, address, err_ctx):
def _is_stale_error(self, address: _Address, err_ctx: _ErrorContext) -> bool:
server = self._servers.get(address)
if server is None:
# Another thread removed this server from the topology.
@ -636,13 +681,12 @@ class Topology:
return _is_stale_error_topology_version(cur_tv, error_tv)
def _handle_error(self, address, err_ctx):
def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
if self._is_stale_error(address, err_ctx):
return
server = self._servers[address]
error = err_ctx.error
exc_type = type(error)
service_id = err_ctx.service_id
# Ignore a handshake error if the server is behind a load balancer but
@ -652,16 +696,16 @@ class Topology:
if self._settings.load_balanced and not service_id and not err_ctx.completed_handshake:
return
if issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake:
if isinstance(error, NetworkTimeout) and err_ctx.completed_handshake:
# The socket has been closed. Don't reset the server.
# Server Discovery And Monitoring Spec: "When an application
# operation fails because of any network error besides a socket
# timeout...."
return
elif issubclass(exc_type, WriteError):
elif isinstance(error, WriteError):
# Ignore writeErrors.
return
elif issubclass(exc_type, (NotPrimaryError, OperationFailure)):
elif isinstance(error, (NotPrimaryError, OperationFailure)):
# As per the SDAM spec if:
# - the server sees a "not primary" error, and
# - the server is not shutting down, and
@ -675,7 +719,7 @@ class Topology:
else:
# Default error code if one does not exist.
default = 10107 if isinstance(error, NotPrimaryError) else None
err_code = error.details.get("code", default)
err_code = error.details.get("code", default) # type: ignore[union-attr]
if err_code in helpers._NOT_PRIMARY_CODES:
is_shutting_down = err_code in helpers._SHUTDOWN_CODES
# Mark server Unknown, clear the pool, and request check.
@ -691,7 +735,7 @@ class Topology:
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset(service_id)
elif issubclass(exc_type, ConnectionFailure):
elif isinstance(error, ConnectionFailure):
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."
if not self._settings.load_balanced:
@ -703,7 +747,7 @@ class Topology:
# that server and close the current monitoring connection."
server._monitor.cancel_check()
def handle_error(self, address, err_ctx):
def handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
"""Handle an application error.
May reset the server to Unknown, clear the pool, and request an
@ -712,12 +756,12 @@ class Topology:
with self._lock:
self._handle_error(address, err_ctx)
def _request_check_all(self):
def _request_check_all(self) -> None:
"""Wake all monitors. Hold the lock when calling this."""
for server in self._servers.values():
server.request_check()
def _update_servers(self):
def _update_servers(self) -> None:
"""Sync our Servers from TopologyDescription.server_descriptions.
Hold the lock while calling this.
@ -759,10 +803,10 @@ class Topology:
server.close()
self._servers.pop(address)
def _create_pool_for_server(self, address):
def _create_pool_for_server(self, address: _Address) -> Pool:
return self._settings.pool_class(address, self._settings.pool_options)
def _create_pool_for_monitor(self, address):
def _create_pool_for_monitor(self, address: _Address) -> Pool:
options = self._settings.pool_options
# According to the Server Discovery And Monitoring Spec, monitors use
@ -782,7 +826,7 @@ class Topology:
return self._settings.pool_class(address, monitor_pool_options, handshake=False)
def _error_message(self, selector):
def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
"""Format an error message if server selection fails.
Hold the lock when calling this.
@ -840,7 +884,7 @@ class Topology:
else:
return ",".join(str(server.error) for server in servers if server.error)
def __repr__(self):
def __repr__(self) -> str:
msg = ""
if not self._opened:
msg = "CLOSED "
@ -851,19 +895,26 @@ class Topology:
ts = self._settings
return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.eq_props() == other.eq_props()
return NotImplemented
def __hash__(self):
def __hash__(self) -> int:
return hash(self.eq_props())
class _ErrorContext:
"""An error with context for SDAM error handling."""
def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id):
def __init__(
self,
error: BaseException,
max_wire_version: int,
sock_generation: int,
completed_handshake: bool,
service_id: Optional[ObjectId],
):
self.error = error
self.max_wire_version = max_wire_version
self.sock_generation = sock_generation
@ -871,7 +922,9 @@ class _ErrorContext:
self.service_id = service_id
def _is_stale_error_topology_version(current_tv, error_tv):
def _is_stale_error_topology_version(
current_tv: Optional[Mapping[str, Any]], error_tv: Optional[Mapping[str, Any]]
) -> bool:
"""Return True if the error's topologyVersion is <= current."""
if current_tv is None or error_tv is None:
return False
@ -880,7 +933,7 @@ def _is_stale_error_topology_version(current_tv, error_tv):
return current_tv["counter"] >= error_tv["counter"]
def _is_stale_server_description(current_sd, new_sd):
def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDescription) -> bool:
"""Return True if the new topologyVersion is < current."""
current_tv, new_tv = current_sd.topology_version, new_sd.topology_version
if current_tv is None or new_tv is None:

View File

@ -34,6 +34,7 @@ if TYPE_CHECKING:
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_Pipeline = Sequence[Mapping[str, Any]]
ClusterTime = Mapping[str, Any]
_T = TypeVar("_T")

View File

@ -123,6 +123,9 @@ class TestMaxStaleness(unittest.TestCase):
time.sleep(1)
server = client._topology.select_server(writable_server_selector)
second = server.description.last_write_date
assert first is not None
assert second is not None
self.assertGreater(second, first)
self.assertLess(second, first + 10)

View File

@ -265,7 +265,7 @@ class TestReadPreferences(TestReadPreferencesBase):
not_used = data_members.difference(used)
latencies = ", ".join(
"%s: %dms" % (server.description.address, server.description.round_trip_time)
"%s: %sms" % (server.description.address, server.description.round_trip_time)
for server in c._get_topology().select_servers(readable_server_selector)
)

View File

@ -122,6 +122,7 @@ class TestStreamingProtocol(IntegrationTest):
# XXX: Add a public TopologyDescription getter to MongoClient?
topology = client._topology
sd = topology.description.server_descriptions()[address]
assert sd.round_trip_time is not None
return sd.round_trip_time > 0.250
wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT")