PYTHON-3867 add types to topology.py (#1346)
This commit is contained in:
parent
5d6d8ca68e
commit
34da931b3a
@ -179,7 +179,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.server import Server
|
||||
from pymongo.typings import _Address
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
|
||||
class SessionOptions:
|
||||
@ -562,7 +562,7 @@ class ClientSession:
|
||||
return self._server_session.session_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[Mapping[str, Any]]:
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
"""The cluster time returned by the last operation executed
|
||||
in this session.
|
||||
"""
|
||||
|
||||
@ -22,7 +22,7 @@ from typing import Any, Generic, List, Mapping, Optional, Set, Tuple
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo import common
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _DocumentType
|
||||
from pymongo.typings import ClusterTime, _DocumentType
|
||||
|
||||
|
||||
class HelloCompat:
|
||||
@ -155,7 +155,7 @@ class Hello(Generic[_DocumentType]):
|
||||
return self._doc.get("electionId")
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[Mapping[str, Any]]:
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._doc.get("$clusterTime")
|
||||
|
||||
@property
|
||||
|
||||
@ -98,6 +98,7 @@ from pymongo.settings import TopologySettings
|
||||
from pymongo.topology import Topology, _ErrorContext
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
from pymongo.typings import (
|
||||
ClusterTime,
|
||||
_Address,
|
||||
_CollationIn,
|
||||
_DocumentType,
|
||||
@ -1106,10 +1107,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
MongoClient gained this property in version 3.0.
|
||||
"""
|
||||
return self._topology.get_primary()
|
||||
return self._topology.get_primary() # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def secondaries(self) -> Set[Tuple[str, int]]:
|
||||
def secondaries(self) -> Set[_Address]:
|
||||
"""The secondary members known to this client.
|
||||
|
||||
A sequence of (host, port) pairs. Empty if this client is not
|
||||
@ -1122,7 +1123,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
def arbiters(self) -> Set[Tuple[str, int]]:
|
||||
def arbiters(self) -> Set[_Address]:
|
||||
"""Arbiters in the replica set.
|
||||
|
||||
A sequence of (host, port) pairs. Empty if this client is not
|
||||
@ -1729,7 +1730,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if address:
|
||||
# address could be a tuple or _CursorAddress, but
|
||||
# select_server_by_address needs (host, port).
|
||||
server = topology.select_server_by_address(tuple(address))
|
||||
server = topology.select_server_by_address(tuple(address)) # type: ignore[arg-type]
|
||||
else:
|
||||
# Application called close_cursor() with no address.
|
||||
server = topology.select_server(writable_server_selector)
|
||||
@ -1906,7 +1907,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
session_time = session.cluster_time if session else None
|
||||
if topology_time and session_time:
|
||||
if topology_time["clusterTime"] > session_time["clusterTime"]:
|
||||
cluster_time = topology_time
|
||||
cluster_time: Optional[ClusterTime] = topology_time
|
||||
else:
|
||||
cluster_time = session_time
|
||||
else:
|
||||
@ -2271,7 +2272,7 @@ class _MongoClientErrorHandler:
|
||||
def handle(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException]
|
||||
) -> None:
|
||||
if self.handled or exc_type is None:
|
||||
if self.handled or exc_val is None:
|
||||
return
|
||||
self.handled = True
|
||||
if self.session:
|
||||
@ -2285,7 +2286,6 @@ class _MongoClientErrorHandler:
|
||||
"RetryableWriteError"
|
||||
):
|
||||
self.session._unpin()
|
||||
|
||||
err_ctx = _ErrorContext(
|
||||
exc_val,
|
||||
self.max_wire_version,
|
||||
@ -2300,8 +2300,8 @@ class _MongoClientErrorHandler:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_type: Optional[Type[Exception]],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
return self.handle(exc_type, exc_val)
|
||||
|
||||
@ -105,7 +105,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.typings import _Address, _CollationIn
|
||||
from pymongo.typings import ClusterTime, _Address, _CollationIn
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
try:
|
||||
@ -779,7 +779,7 @@ class Connection:
|
||||
|
||||
def _hello(
|
||||
self,
|
||||
cluster_time: Optional[Mapping[str, Any]],
|
||||
cluster_time: Optional[ClusterTime],
|
||||
topology_version: Optional[Any],
|
||||
heartbeat_frequency: Optional[int],
|
||||
) -> Hello[Dict[str, Any]]:
|
||||
|
||||
@ -22,7 +22,7 @@ from bson import EPOCH_NAIVE
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.typings import _Address
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
|
||||
class ServerDescription:
|
||||
@ -176,7 +176,7 @@ class ServerDescription:
|
||||
return self._election_id
|
||||
|
||||
@property
|
||||
def cluster_time(self) -> Optional[Mapping[str, Any]]:
|
||||
def cluster_time(self) -> Optional[ClusterTime]:
|
||||
return self._cluster_time
|
||||
|
||||
@property
|
||||
|
||||
@ -95,11 +95,11 @@ class TopologySettings:
|
||||
return self._pool_options
|
||||
|
||||
@property
|
||||
def monitor_class(self) -> Optional[Type[monitor.Monitor]]:
|
||||
def monitor_class(self) -> Type[monitor.Monitor]:
|
||||
return self._monitor_class
|
||||
|
||||
@property
|
||||
def condition_class(self) -> Optional[Type[threading.Condition]]:
|
||||
def condition_class(self) -> Type[threading.Condition]:
|
||||
return self._condition_class
|
||||
|
||||
@property
|
||||
|
||||
@ -14,16 +14,29 @@
|
||||
|
||||
"""Internal class to monitor a topology of one or more servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Any
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pymongo import _csot, common, helpers, periodic_executor
|
||||
from pymongo.client_session import _ServerSessionPool
|
||||
from pymongo.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.errors import (
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
@ -38,7 +51,7 @@ from pymongo.errors import (
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.monitor import SrvMonitor
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.pool import Pool, PoolOptions
|
||||
from pymongo.server import Server
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import (
|
||||
@ -57,8 +70,13 @@ from pymongo.topology_description import (
|
||||
updated_topology_description,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import ObjectId
|
||||
from pymongo.settings import TopologySettings
|
||||
from pymongo.typings import ClusterTime, _Address
|
||||
|
||||
def process_events_queue(queue_ref):
|
||||
|
||||
def process_events_queue(queue_ref: weakref.ReferenceType[queue.Queue]) -> bool:
|
||||
q = queue_ref()
|
||||
if not q:
|
||||
return False # Cancel PeriodicExecutor.
|
||||
@ -78,12 +96,11 @@ def process_events_queue(queue_ref):
|
||||
class Topology:
|
||||
"""Monitor a topology of one or more servers."""
|
||||
|
||||
def __init__(self, topology_settings):
|
||||
def __init__(self, topology_settings: TopologySettings):
|
||||
self._topology_id = topology_settings._topology_id
|
||||
self._listeners = topology_settings._pool_options._event_listeners
|
||||
pub = self._listeners is not None
|
||||
self._publish_server = pub and self._listeners.enabled_for_server
|
||||
self._publish_tp = pub and self._listeners.enabled_for_topology
|
||||
self._publish_server = self._listeners is not None and self._listeners.enabled_for_server
|
||||
self._publish_tp = self._listeners is not None and self._listeners.enabled_for_topology
|
||||
|
||||
# Create events queue if there are publishers.
|
||||
self._events = None
|
||||
@ -129,14 +146,16 @@ class Topology:
|
||||
self._closed = False
|
||||
self._lock = _create_lock()
|
||||
self._condition = self._settings.condition_class(self._lock)
|
||||
self._servers = {}
|
||||
self._pid = None
|
||||
self._max_cluster_time = None
|
||||
self._servers: Dict[_Address, Server] = {}
|
||||
self._pid: Optional[int] = None
|
||||
self._max_cluster_time: Optional[ClusterTime] = None
|
||||
self._session_pool = _ServerSessionPool()
|
||||
|
||||
if self._publish_server or self._publish_tp:
|
||||
assert self._events is not None
|
||||
weak: weakref.ReferenceType[queue.Queue]
|
||||
|
||||
def target():
|
||||
def target() -> bool:
|
||||
return process_events_queue(weak)
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
@ -157,7 +176,7 @@ class Topology:
|
||||
if self._settings.fqdn is not None and not self._settings.load_balanced:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
No effect if called multiple times.
|
||||
@ -191,14 +210,19 @@ class Topology:
|
||||
with self._lock:
|
||||
self._ensure_opened()
|
||||
|
||||
def get_server_selection_timeout(self):
|
||||
def get_server_selection_timeout(self) -> float:
|
||||
# CSOT: use remaining timeout when set.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
return self._settings.server_selection_timeout
|
||||
return timeout
|
||||
|
||||
def select_servers(self, selector, server_selection_timeout=None, address=None):
|
||||
def select_servers(
|
||||
self,
|
||||
selector: Callable[[Selection], Selection],
|
||||
server_selection_timeout: Optional[float] = None,
|
||||
address: Optional[_Address] = None,
|
||||
) -> List[Server]:
|
||||
"""Return a list of Servers matching selector, or time out.
|
||||
|
||||
:Parameters:
|
||||
@ -222,9 +246,16 @@ class Topology:
|
||||
with self._lock:
|
||||
server_descriptions = self._select_servers_loop(selector, server_timeout, address)
|
||||
|
||||
return [self.get_server_by_address(sd.address) for sd in server_descriptions]
|
||||
return [
|
||||
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
|
||||
]
|
||||
|
||||
def _select_servers_loop(self, selector, timeout, address):
|
||||
def _select_servers_loop(
|
||||
self,
|
||||
selector: Callable[[Selection], Selection],
|
||||
timeout: float,
|
||||
address: Optional[_Address],
|
||||
) -> List[ServerDescription]:
|
||||
"""select_servers() guts. Hold the lock when calling this."""
|
||||
now = time.monotonic()
|
||||
end_time = now + timeout
|
||||
@ -256,7 +287,12 @@ class Topology:
|
||||
self._description.check_compatible()
|
||||
return server_descriptions
|
||||
|
||||
def _select_server(self, selector, server_selection_timeout=None, address=None):
|
||||
def _select_server(
|
||||
self,
|
||||
selector: Callable[[Selection], Selection],
|
||||
server_selection_timeout: Optional[float] = None,
|
||||
address: Optional[_Address] = None,
|
||||
) -> Server:
|
||||
servers = self.select_servers(selector, server_selection_timeout, address)
|
||||
if len(servers) == 1:
|
||||
return servers[0]
|
||||
@ -266,14 +302,21 @@ class Topology:
|
||||
else:
|
||||
return server2
|
||||
|
||||
def select_server(self, selector, server_selection_timeout=None, address=None):
|
||||
def select_server(
|
||||
self,
|
||||
selector: Callable[[Selection], Selection],
|
||||
server_selection_timeout: Optional[float] = None,
|
||||
address: Optional[_Address] = None,
|
||||
) -> Server:
|
||||
"""Like select_servers, but choose a random server if several match."""
|
||||
server = self._select_server(selector, server_selection_timeout, address)
|
||||
if _csot.get_timeout():
|
||||
_csot.set_rtt(server.description.min_round_trip_time)
|
||||
return server
|
||||
|
||||
def select_server_by_address(self, address, server_selection_timeout=None):
|
||||
def select_server_by_address(
|
||||
self, address: _Address, server_selection_timeout: Optional[int] = None
|
||||
) -> Server:
|
||||
"""Return a Server for "address", reconnecting if necessary.
|
||||
|
||||
If the server's type is not known, request an immediate check of all
|
||||
@ -293,7 +336,9 @@ class Topology:
|
||||
"""
|
||||
return self.select_server(any_server_selector, server_selection_timeout, address)
|
||||
|
||||
def _process_change(self, server_description, reset_pool=False):
|
||||
def _process_change(
|
||||
self, server_description: ServerDescription, reset_pool: bool = False
|
||||
) -> None:
|
||||
"""Process a new ServerDescription on an opened topology.
|
||||
|
||||
Hold the lock when calling this.
|
||||
@ -354,7 +399,7 @@ class Topology:
|
||||
# Wake waiters in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
def on_change(self, server_description, reset_pool=False):
|
||||
def on_change(self, server_description: ServerDescription, reset_pool: bool = False) -> None:
|
||||
"""Process a new ServerDescription after an hello call completes."""
|
||||
# We do no I/O holding the lock.
|
||||
with self._lock:
|
||||
@ -369,7 +414,7 @@ class Topology:
|
||||
if self._opened and self._description.has_server(server_description.address):
|
||||
self._process_change(server_description, reset_pool)
|
||||
|
||||
def _process_srv_update(self, seedlist):
|
||||
def _process_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
|
||||
"""Process a new seedlist on an opened topology.
|
||||
Hold the lock when calling this.
|
||||
"""
|
||||
@ -389,14 +434,14 @@ class Topology:
|
||||
)
|
||||
)
|
||||
|
||||
def on_srv_update(self, seedlist):
|
||||
def on_srv_update(self, seedlist: List[Tuple[str, Any]]) -> None:
|
||||
"""Process a new list of nodes obtained from scanning SRV records."""
|
||||
# We do no I/O holding the lock.
|
||||
with self._lock:
|
||||
if self._opened:
|
||||
self._process_srv_update(seedlist)
|
||||
|
||||
def get_server_by_address(self, address):
|
||||
def get_server_by_address(self, address: _Address) -> Optional[Server]:
|
||||
"""Get a Server or None.
|
||||
|
||||
Returns the current version of the server immediately, even if it's
|
||||
@ -406,10 +451,10 @@ class Topology:
|
||||
"""
|
||||
return self._servers.get(address)
|
||||
|
||||
def has_server(self, address):
|
||||
def has_server(self, address: _Address) -> bool:
|
||||
return address in self._servers
|
||||
|
||||
def get_primary(self):
|
||||
def get_primary(self) -> Optional[_Address]:
|
||||
"""Return primary's address or None."""
|
||||
# Implemented here in Topology instead of MongoClient, so it can lock.
|
||||
with self._lock:
|
||||
@ -419,7 +464,7 @@ class Topology:
|
||||
|
||||
return writable_server_selector(self._new_selection())[0].address
|
||||
|
||||
def _get_replica_set_members(self, selector):
|
||||
def _get_replica_set_members(self, selector: Callable[[Selection], Selection]) -> Set[_Address]:
|
||||
"""Return set of replica set member addresses."""
|
||||
# Implemented here in Topology instead of MongoClient, so it can lock.
|
||||
with self._lock:
|
||||
@ -430,21 +475,21 @@ class Topology:
|
||||
):
|
||||
return set()
|
||||
|
||||
return {sd.address for sd in selector(self._new_selection())}
|
||||
return {sd.address for sd in iter(selector(self._new_selection()))}
|
||||
|
||||
def get_secondaries(self):
|
||||
def get_secondaries(self) -> Set[_Address]:
|
||||
"""Return set of secondary addresses."""
|
||||
return self._get_replica_set_members(secondary_server_selector)
|
||||
|
||||
def get_arbiters(self):
|
||||
def get_arbiters(self) -> Set[_Address]:
|
||||
"""Return set of arbiter addresses."""
|
||||
return self._get_replica_set_members(arbiter_server_selector)
|
||||
|
||||
def max_cluster_time(self):
|
||||
def max_cluster_time(self) -> Optional[ClusterTime]:
|
||||
"""Return a document, the highest seen $clusterTime."""
|
||||
return self._max_cluster_time
|
||||
|
||||
def _receive_cluster_time_no_lock(self, cluster_time):
|
||||
def _receive_cluster_time_no_lock(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
|
||||
# Driver Sessions Spec: "Whenever a driver receives a cluster time from
|
||||
# a server it MUST compare it to the current highest seen cluster time
|
||||
# for the deployment. If the new cluster time is higher than the
|
||||
@ -459,17 +504,17 @@ class Topology:
|
||||
):
|
||||
self._max_cluster_time = cluster_time
|
||||
|
||||
def receive_cluster_time(self, cluster_time):
|
||||
def receive_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None:
|
||||
with self._lock:
|
||||
self._receive_cluster_time_no_lock(cluster_time)
|
||||
|
||||
def request_check_all(self, wait_time=5):
|
||||
def request_check_all(self, wait_time: int = 5) -> None:
|
||||
"""Wake all monitors, wait for at least one to check its server."""
|
||||
with self._lock:
|
||||
self._request_check_all()
|
||||
self._condition.wait(wait_time)
|
||||
|
||||
def data_bearing_servers(self):
|
||||
def data_bearing_servers(self) -> List[ServerDescription]:
|
||||
"""Return a list of all data-bearing servers.
|
||||
|
||||
This includes any server that might be selected for an operation.
|
||||
@ -478,7 +523,7 @@ class Topology:
|
||||
return self._description.known_servers
|
||||
return self._description.readable_servers
|
||||
|
||||
def update_pool(self):
|
||||
def update_pool(self) -> None:
|
||||
# Remove any stale sockets and add new sockets if pool is too small.
|
||||
servers = []
|
||||
with self._lock:
|
||||
@ -495,7 +540,7 @@ class Topology:
|
||||
self.handle_error(server.description.address, ctx)
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Clear pools and terminate monitors. Topology does not reopen on
|
||||
demand. Any further operations will raise
|
||||
:exc:`~.errors.InvalidOperation`.
|
||||
@ -525,19 +570,19 @@ class Topology:
|
||||
self.__events_executor.close()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
def description(self) -> TopologyDescription:
|
||||
return self._description
|
||||
|
||||
def pop_all_sessions(self):
|
||||
def pop_all_sessions(self) -> List[_ServerSession]:
|
||||
"""Pop all session ids from the pool."""
|
||||
with self._lock:
|
||||
return self._session_pool.pop_all()
|
||||
|
||||
def _check_implicit_session_support(self):
|
||||
def _check_implicit_session_support(self) -> None:
|
||||
with self._lock:
|
||||
self._check_session_support()
|
||||
|
||||
def _check_session_support(self):
|
||||
def _check_session_support(self) -> float:
|
||||
"""Internal check for session support on clusters."""
|
||||
if self._settings.load_balanced:
|
||||
# Sessions never time out in load balanced mode.
|
||||
@ -560,13 +605,13 @@ class Topology:
|
||||
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
|
||||
return session_timeout
|
||||
|
||||
def get_server_session(self):
|
||||
def get_server_session(self) -> _ServerSession:
|
||||
"""Start or resume a server session, or raise ConfigurationError."""
|
||||
with self._lock:
|
||||
session_timeout = self._check_session_support()
|
||||
return self._session_pool.get_server_session(session_timeout)
|
||||
|
||||
def return_server_session(self, server_session, lock):
|
||||
def return_server_session(self, server_session: _ServerSession, lock: bool) -> None:
|
||||
if lock:
|
||||
with self._lock:
|
||||
self._session_pool.return_server_session(
|
||||
@ -576,14 +621,14 @@ class Topology:
|
||||
# Called from a __del__ method, can't use a lock.
|
||||
self._session_pool.return_server_session_no_lock(server_session)
|
||||
|
||||
def _new_selection(self):
|
||||
def _new_selection(self) -> Selection:
|
||||
"""A Selection object, initially including all known servers.
|
||||
|
||||
Hold the lock when calling this.
|
||||
"""
|
||||
return Selection.from_topology_description(self._description)
|
||||
|
||||
def _ensure_opened(self):
|
||||
def _ensure_opened(self) -> None:
|
||||
"""Start monitors, or restart after a fork.
|
||||
|
||||
Hold the lock when calling this.
|
||||
@ -616,7 +661,7 @@ class Topology:
|
||||
for server in self._servers.values():
|
||||
server.open()
|
||||
|
||||
def _is_stale_error(self, address, err_ctx):
|
||||
def _is_stale_error(self, address: _Address, err_ctx: _ErrorContext) -> bool:
|
||||
server = self._servers.get(address)
|
||||
if server is None:
|
||||
# Another thread removed this server from the topology.
|
||||
@ -636,13 +681,12 @@ class Topology:
|
||||
|
||||
return _is_stale_error_topology_version(cur_tv, error_tv)
|
||||
|
||||
def _handle_error(self, address, err_ctx):
|
||||
def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
|
||||
if self._is_stale_error(address, err_ctx):
|
||||
return
|
||||
|
||||
server = self._servers[address]
|
||||
error = err_ctx.error
|
||||
exc_type = type(error)
|
||||
service_id = err_ctx.service_id
|
||||
|
||||
# Ignore a handshake error if the server is behind a load balancer but
|
||||
@ -652,16 +696,16 @@ class Topology:
|
||||
if self._settings.load_balanced and not service_id and not err_ctx.completed_handshake:
|
||||
return
|
||||
|
||||
if issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake:
|
||||
if isinstance(error, NetworkTimeout) and err_ctx.completed_handshake:
|
||||
# The socket has been closed. Don't reset the server.
|
||||
# Server Discovery And Monitoring Spec: "When an application
|
||||
# operation fails because of any network error besides a socket
|
||||
# timeout...."
|
||||
return
|
||||
elif issubclass(exc_type, WriteError):
|
||||
elif isinstance(error, WriteError):
|
||||
# Ignore writeErrors.
|
||||
return
|
||||
elif issubclass(exc_type, (NotPrimaryError, OperationFailure)):
|
||||
elif isinstance(error, (NotPrimaryError, OperationFailure)):
|
||||
# As per the SDAM spec if:
|
||||
# - the server sees a "not primary" error, and
|
||||
# - the server is not shutting down, and
|
||||
@ -675,7 +719,7 @@ class Topology:
|
||||
else:
|
||||
# Default error code if one does not exist.
|
||||
default = 10107 if isinstance(error, NotPrimaryError) else None
|
||||
err_code = error.details.get("code", default)
|
||||
err_code = error.details.get("code", default) # type: ignore[union-attr]
|
||||
if err_code in helpers._NOT_PRIMARY_CODES:
|
||||
is_shutting_down = err_code in helpers._SHUTDOWN_CODES
|
||||
# Mark server Unknown, clear the pool, and request check.
|
||||
@ -691,7 +735,7 @@ class Topology:
|
||||
self._process_change(ServerDescription(address, error=error))
|
||||
# Clear the pool.
|
||||
server.reset(service_id)
|
||||
elif issubclass(exc_type, ConnectionFailure):
|
||||
elif isinstance(error, ConnectionFailure):
|
||||
# "Client MUST replace the server's description with type Unknown
|
||||
# ... MUST NOT request an immediate check of the server."
|
||||
if not self._settings.load_balanced:
|
||||
@ -703,7 +747,7 @@ class Topology:
|
||||
# that server and close the current monitoring connection."
|
||||
server._monitor.cancel_check()
|
||||
|
||||
def handle_error(self, address, err_ctx):
|
||||
def handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None:
|
||||
"""Handle an application error.
|
||||
|
||||
May reset the server to Unknown, clear the pool, and request an
|
||||
@ -712,12 +756,12 @@ class Topology:
|
||||
with self._lock:
|
||||
self._handle_error(address, err_ctx)
|
||||
|
||||
def _request_check_all(self):
|
||||
def _request_check_all(self) -> None:
|
||||
"""Wake all monitors. Hold the lock when calling this."""
|
||||
for server in self._servers.values():
|
||||
server.request_check()
|
||||
|
||||
def _update_servers(self):
|
||||
def _update_servers(self) -> None:
|
||||
"""Sync our Servers from TopologyDescription.server_descriptions.
|
||||
|
||||
Hold the lock while calling this.
|
||||
@ -759,10 +803,10 @@ class Topology:
|
||||
server.close()
|
||||
self._servers.pop(address)
|
||||
|
||||
def _create_pool_for_server(self, address):
|
||||
def _create_pool_for_server(self, address: _Address) -> Pool:
|
||||
return self._settings.pool_class(address, self._settings.pool_options)
|
||||
|
||||
def _create_pool_for_monitor(self, address):
|
||||
def _create_pool_for_monitor(self, address: _Address) -> Pool:
|
||||
options = self._settings.pool_options
|
||||
|
||||
# According to the Server Discovery And Monitoring Spec, monitors use
|
||||
@ -782,7 +826,7 @@ class Topology:
|
||||
|
||||
return self._settings.pool_class(address, monitor_pool_options, handshake=False)
|
||||
|
||||
def _error_message(self, selector):
|
||||
def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
|
||||
"""Format an error message if server selection fails.
|
||||
|
||||
Hold the lock when calling this.
|
||||
@ -840,7 +884,7 @@ class Topology:
|
||||
else:
|
||||
return ",".join(str(server.error) for server in servers if server.error)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
msg = ""
|
||||
if not self._opened:
|
||||
msg = "CLOSED "
|
||||
@ -851,19 +895,26 @@ class Topology:
|
||||
ts = self._settings
|
||||
return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self.eq_props() == other.eq_props()
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.eq_props())
|
||||
|
||||
|
||||
class _ErrorContext:
|
||||
"""An error with context for SDAM error handling."""
|
||||
|
||||
def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id):
|
||||
def __init__(
|
||||
self,
|
||||
error: BaseException,
|
||||
max_wire_version: int,
|
||||
sock_generation: int,
|
||||
completed_handshake: bool,
|
||||
service_id: Optional[ObjectId],
|
||||
):
|
||||
self.error = error
|
||||
self.max_wire_version = max_wire_version
|
||||
self.sock_generation = sock_generation
|
||||
@ -871,7 +922,9 @@ class _ErrorContext:
|
||||
self.service_id = service_id
|
||||
|
||||
|
||||
def _is_stale_error_topology_version(current_tv, error_tv):
|
||||
def _is_stale_error_topology_version(
|
||||
current_tv: Optional[Mapping[str, Any]], error_tv: Optional[Mapping[str, Any]]
|
||||
) -> bool:
|
||||
"""Return True if the error's topologyVersion is <= current."""
|
||||
if current_tv is None or error_tv is None:
|
||||
return False
|
||||
@ -880,7 +933,7 @@ def _is_stale_error_topology_version(current_tv, error_tv):
|
||||
return current_tv["counter"] >= error_tv["counter"]
|
||||
|
||||
|
||||
def _is_stale_server_description(current_sd, new_sd):
|
||||
def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDescription) -> bool:
|
||||
"""Return True if the new topologyVersion is < current."""
|
||||
current_tv, new_tv = current_sd.topology_version, new_sd.topology_version
|
||||
if current_tv is None or new_tv is None:
|
||||
|
||||
@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
||||
_Address = Tuple[str, Optional[int]]
|
||||
_CollationIn = Union[Mapping[str, Any], "Collation"]
|
||||
_Pipeline = Sequence[Mapping[str, Any]]
|
||||
ClusterTime = Mapping[str, Any]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@ -123,6 +123,9 @@ class TestMaxStaleness(unittest.TestCase):
|
||||
time.sleep(1)
|
||||
server = client._topology.select_server(writable_server_selector)
|
||||
second = server.description.last_write_date
|
||||
assert first is not None
|
||||
|
||||
assert second is not None
|
||||
self.assertGreater(second, first)
|
||||
self.assertLess(second, first + 10)
|
||||
|
||||
|
||||
@ -265,7 +265,7 @@ class TestReadPreferences(TestReadPreferencesBase):
|
||||
|
||||
not_used = data_members.difference(used)
|
||||
latencies = ", ".join(
|
||||
"%s: %dms" % (server.description.address, server.description.round_trip_time)
|
||||
"%s: %sms" % (server.description.address, server.description.round_trip_time)
|
||||
for server in c._get_topology().select_servers(readable_server_selector)
|
||||
)
|
||||
|
||||
|
||||
@ -122,6 +122,7 @@ class TestStreamingProtocol(IntegrationTest):
|
||||
# XXX: Add a public TopologyDescription getter to MongoClient?
|
||||
topology = client._topology
|
||||
sd = topology.description.server_descriptions()[address]
|
||||
assert sd.round_trip_time is not None
|
||||
return sd.round_trip_time > 0.250
|
||||
|
||||
wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user