PYTHON-3813 add types to pool.py (#1318)
This commit is contained in:
parent
54840752d2
commit
e0b8b36f41
@ -239,12 +239,15 @@ def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanis
|
||||
|
||||
ctx = conn.auth_ctx
|
||||
if ctx and ctx.speculate_succeeded():
|
||||
assert isinstance(ctx, _ScramContext)
|
||||
assert ctx.scram_data is not None
|
||||
nonce, first_bare = ctx.scram_data
|
||||
res = ctx.speculative_authenticate
|
||||
else:
|
||||
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
|
||||
res = conn.command(source, cmd)
|
||||
|
||||
assert res is not None
|
||||
server_first = res["payload"]
|
||||
parsed = _parse_scram_response(server_first)
|
||||
iterations = int(parsed[b"i"])
|
||||
@ -575,7 +578,7 @@ class _ScramContext(_AuthContext):
|
||||
|
||||
|
||||
class _X509Context(_AuthContext):
|
||||
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
|
||||
def speculate_command(self) -> MutableMapping[str, Any]:
|
||||
cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")])
|
||||
if self.credentials.username is not None:
|
||||
cmd["user"] = self.credentials.username
|
||||
|
||||
@ -19,7 +19,17 @@ import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
@ -242,7 +252,9 @@ class _OIDCAuthenticator:
|
||||
self.idp_resp = None
|
||||
self.token_exp_utc = None
|
||||
|
||||
def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
|
||||
def run_command(
|
||||
self, conn: Connection, cmd: MutableMapping[str, Any]
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
try:
|
||||
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
|
||||
except OperationFailure as exc:
|
||||
@ -276,6 +288,7 @@ class _OIDCAuthenticator:
|
||||
assert cmd is not None
|
||||
resp = self.run_command(conn, cmd)
|
||||
|
||||
assert resp is not None
|
||||
if resp["done"]:
|
||||
conn.oidc_token_gen_id = self.token_gen_id
|
||||
return None
|
||||
@ -297,6 +310,7 @@ class _OIDCAuthenticator:
|
||||
]
|
||||
)
|
||||
resp = self.run_command(conn, cmd)
|
||||
assert resp is not None
|
||||
if not resp["done"]:
|
||||
self.clear()
|
||||
raise OperationFailure("SASL conversation failed to complete.")
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
"""Tools to parse mongo client options."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast
|
||||
|
||||
from bson.codec_options import _parse_codec_options
|
||||
from pymongo import common
|
||||
@ -309,11 +309,12 @@ class ClientOptions:
|
||||
return self.__load_balanced
|
||||
|
||||
@property
|
||||
def event_listeners(self) -> _EventListeners:
|
||||
def event_listeners(self) -> List[_EventListeners]:
|
||||
"""The event listeners registered for this client.
|
||||
|
||||
See :mod:`~pymongo.monitoring` for details.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
assert self.__pool_options._event_listeners is not None
|
||||
return self.__pool_options._event_listeners.event_listeners()
|
||||
|
||||
@ -20,7 +20,6 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
ContextManager,
|
||||
Generic,
|
||||
Iterable,
|
||||
@ -268,11 +267,11 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _command(
|
||||
self,
|
||||
conn: Connection,
|
||||
command: Mapping[str, Any],
|
||||
command: MutableMapping[str, Any],
|
||||
read_preference: Optional[_ServerMode] = None,
|
||||
codec_options: Optional[CodecOptions] = None,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Container[Any]] = None,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
@ -1753,7 +1752,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
session: Optional[ClientSession],
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
cmd: SON[str, Any],
|
||||
collation: Optional[Collation],
|
||||
) -> int:
|
||||
"""Internal count command helper."""
|
||||
@ -1777,7 +1776,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
conn: Connection,
|
||||
read_preference: Optional[_ServerMode],
|
||||
cmd: Mapping[str, Any],
|
||||
cmd: SON[str, Any],
|
||||
collation: Optional[_CollationIn],
|
||||
session: Optional[ClientSession],
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
|
||||
@ -25,6 +25,7 @@ from typing import (
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -220,7 +221,7 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
codec_options: CodecOptions[Mapping[str, Any]],
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
def _refresh(self) -> int:
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Union
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
try:
|
||||
import snappy
|
||||
@ -96,7 +96,7 @@ class CompressionSettings:
|
||||
self.zlib_compression_level = zlib_compression_level
|
||||
|
||||
def get_compression_context(
|
||||
self, compressors: List[str]
|
||||
self, compressors: Optional[List[str]]
|
||||
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
|
||||
if compressors:
|
||||
chosen = compressors[0]
|
||||
|
||||
@ -1130,7 +1130,7 @@ class Cursor(Generic[_DocumentType]):
|
||||
codec_options: CodecOptions,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
) -> Sequence[_DocumentOut]:
|
||||
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
|
||||
|
||||
def _read_preference(self) -> _ServerMode:
|
||||
|
||||
@ -694,7 +694,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
@ -711,7 +711,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: CodecOptions[_CodecDocumentType] = ...,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
@ -727,7 +727,7 @@ class Database(common.BaseObject, Generic[_DocumentType]):
|
||||
value: int = 1,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: Union[
|
||||
CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType]
|
||||
] = DEFAULT_CODEC_OPTIONS,
|
||||
|
||||
@ -1018,7 +1018,7 @@ class _BulkWriteContext:
|
||||
cmd = self._start(cmd, request_id, docs)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
result = self.conn.unack_write(msg, max_doc_size)
|
||||
result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value]
|
||||
if self.publish:
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
if result is not None:
|
||||
@ -1050,7 +1050,7 @@ class _BulkWriteContext:
|
||||
request_id: int,
|
||||
msg: bytes,
|
||||
docs: List[Mapping[str, Any]],
|
||||
) -> Mapping[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""A proxy for SocketInfo.write_command that handles event publishing."""
|
||||
if self.publish:
|
||||
assert self.start_time is not None
|
||||
@ -1127,7 +1127,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
|
||||
|
||||
def __batch_command(
|
||||
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
|
||||
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
|
||||
) -> Tuple[Dict[str, Any], List[Mapping[str, Any]]]:
|
||||
namespace = self.db_name + ".$cmd"
|
||||
msg, to_send = _encode_batched_write_command(
|
||||
namespace, self.op_type, cmd, docs, self.codec, self
|
||||
@ -1517,7 +1517,7 @@ class _OpReply:
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Unpack a response from the database and decode the BSON document(s).
|
||||
|
||||
Check the response for errors and unpack, returning a dictionary
|
||||
@ -1541,7 +1541,7 @@ class _OpReply:
|
||||
return bson.decode_all(self.documents, codec_options)
|
||||
return bson._decode_all_selective(self.documents, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
|
||||
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
docs = self.unpack_response(codec_options=codec_options)
|
||||
assert self.number_returned == 1
|
||||
@ -1604,7 +1604,7 @@ class _OpMsg:
|
||||
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
legacy_response: bool = False,
|
||||
) -> List[_DocumentOut]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Unpack a OP_MSG command response.
|
||||
|
||||
:Parameters:
|
||||
@ -1619,7 +1619,7 @@ class _OpMsg:
|
||||
assert not legacy_response
|
||||
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
|
||||
|
||||
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
|
||||
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
|
||||
"""Unpack a command response."""
|
||||
return self.unpack_response(codec_options=codec_options)[0]
|
||||
|
||||
|
||||
@ -116,6 +116,7 @@ if TYPE_CHECKING:
|
||||
import sys
|
||||
from types import TracebackType
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.bulk import _Bulk
|
||||
from pymongo.client_session import ClientSession, _ServerSession
|
||||
from pymongo.cursor import _ConnectionManager
|
||||
@ -1898,7 +1899,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
else:
|
||||
yield None
|
||||
|
||||
def _send_cluster_time(self, command: MutableMapping[str, Any], session: ClientSession) -> None:
|
||||
def _send_cluster_time(
|
||||
self, command: MutableMapping[str, Any], session: Optional[ClientSession]
|
||||
) -> None:
|
||||
topology_time = self._topology.max_cluster_time()
|
||||
session_time = session.cluster_time if session else None
|
||||
if topology_time and session_time:
|
||||
@ -2255,7 +2258,7 @@ class _MongoClientErrorHandler:
|
||||
# of the pool at the time the connection attempt was started."
|
||||
self.sock_generation = server.pool.gen.get_overall()
|
||||
self.completed_handshake = False
|
||||
self.service_id = None
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.handled = False
|
||||
|
||||
def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None:
|
||||
|
||||
@ -32,7 +32,7 @@ from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pool import Connection, Pool
|
||||
from pymongo.pool import Connection, Pool, _CancellationContext
|
||||
from pymongo.settings import TopologySettings
|
||||
from pymongo.topology import Topology
|
||||
|
||||
@ -131,9 +131,8 @@ class Monitor(MonitorBase):
|
||||
self._pool = pool
|
||||
self._settings = topology_settings
|
||||
self._listeners = self._settings._pool_options._event_listeners
|
||||
pub = self._listeners is not None
|
||||
self._publish = pub and self._listeners.enabled_for_server_heartbeat
|
||||
self._cancel_context = None
|
||||
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
|
||||
self._cancel_context: Optional[_CancellationContext] = None
|
||||
self._rtt_monitor = _RttMonitor(
|
||||
topology,
|
||||
topology_settings,
|
||||
@ -238,7 +237,8 @@ class Monitor(MonitorBase):
|
||||
address = sd.address
|
||||
duration = time.monotonic() - start
|
||||
if self._publish:
|
||||
awaited = sd.is_server_type_known and sd.topology_version
|
||||
awaited = bool(sd.is_server_type_known and sd.topology_version)
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
|
||||
self._reset_connection()
|
||||
if isinstance(error, _OperationCancelled):
|
||||
@ -254,6 +254,7 @@ class Monitor(MonitorBase):
|
||||
"""
|
||||
address = self._server_description.address
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_started(address)
|
||||
|
||||
if self._cancel_context and self._cancel_context.cancelled:
|
||||
@ -267,6 +268,7 @@ class Monitor(MonitorBase):
|
||||
avg_rtt, min_rtt = self._rtt_monitor.get()
|
||||
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
|
||||
if self._publish:
|
||||
assert self._listeners is not None
|
||||
self._listeners.publish_server_heartbeat_succeeded(
|
||||
address, round_trip_time, response, response.awaitable
|
||||
)
|
||||
|
||||
@ -47,14 +47,13 @@ from pymongo.socket_checker import _errno_from_exception
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.collation import Collation
|
||||
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.monitoring import _EventListeners
|
||||
from pymongo.pool import Connection
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _DocumentOut, _DocumentType
|
||||
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
@ -65,7 +64,7 @@ def command(
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
is_mongos: bool,
|
||||
read_preference: _ServerMode,
|
||||
read_preference: Optional[_ServerMode],
|
||||
codec_options: CodecOptions[_DocumentType],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient],
|
||||
@ -76,7 +75,7 @@ def command(
|
||||
max_bson_size: Optional[int] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[Collation] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
|
||||
use_op_msg: bool = False,
|
||||
unacknowledged: bool = False,
|
||||
@ -119,6 +118,7 @@ def command(
|
||||
# Publish the original command document, perhaps with lsid and $clusterTime.
|
||||
orig = spec
|
||||
if is_mongos and not use_op_msg:
|
||||
assert read_preference is not None
|
||||
spec = message._maybe_add_read_preference(spec, read_preference)
|
||||
if read_concern and not (session and session.in_transaction):
|
||||
if read_concern.level:
|
||||
@ -232,7 +232,7 @@ _UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
|
||||
|
||||
def receive_message(
|
||||
conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
|
||||
416
pymongo/pool.py
416
pymongo/pool.py
@ -12,6 +12,8 @@
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
@ -23,7 +25,21 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Dict, NoReturn, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import bson
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
@ -59,7 +75,11 @@ from pymongo.errors import (
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.helpers import _handle_reauth
|
||||
from pymongo.lock import _create_lock
|
||||
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
|
||||
from pymongo.monitoring import (
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason,
|
||||
_EventListeners,
|
||||
)
|
||||
from pymongo.network import command, receive_message
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
@ -67,10 +87,31 @@ from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.auth import MongoCredential, _AuthContext
|
||||
from pymongo.client_session import ClientSession
|
||||
from pymongo.compression_support import (
|
||||
CompressionSettings,
|
||||
SnappyContext,
|
||||
ZlibContext,
|
||||
ZstdContext,
|
||||
)
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.mongo_client import MongoClient, _MongoClientErrorHandler
|
||||
from pymongo.pyopenssl_context import SSLContext, _sslConn
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.typings import _Address, _CollationIn
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
try:
|
||||
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
|
||||
|
||||
def _set_non_inheritable_non_atomic(fd):
|
||||
def _set_non_inheritable_non_atomic(fd: int) -> None:
|
||||
"""Set the close-on-exec flag on the given file descriptor."""
|
||||
flags = fcntl(fd, F_GETFD)
|
||||
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
|
||||
@ -79,7 +120,7 @@ except ImportError:
|
||||
# Windows, various platforms we don't claim to support
|
||||
# (Jython, IronPython, ...), systems that don't provide
|
||||
# everything we need from fcntl, etc.
|
||||
def _set_non_inheritable_non_atomic(fd):
|
||||
def _set_non_inheritable_non_atomic(fd: int) -> None:
|
||||
"""Dummy function for platforms that don't provide fcntl."""
|
||||
|
||||
|
||||
@ -123,7 +164,7 @@ if sys.platform == "win32":
|
||||
|
||||
else:
|
||||
|
||||
def _set_tcp_option(sock, tcp_option, max_value):
|
||||
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
|
||||
if hasattr(socket, tcp_option):
|
||||
sockopt = getattr(socket, tcp_option)
|
||||
try:
|
||||
@ -136,7 +177,7 @@ else:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _set_keepalive_times(sock):
|
||||
def _set_keepalive_times(sock: socket.socket) -> None:
|
||||
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
|
||||
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
|
||||
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
||||
@ -302,7 +343,7 @@ _MAX_METADATA_SIZE = 512
|
||||
|
||||
|
||||
# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations
|
||||
def _truncate_metadata(metadata):
|
||||
def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None:
|
||||
"""Perform metadata truncation."""
|
||||
if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE:
|
||||
return
|
||||
@ -365,7 +406,7 @@ def _raise_connection_failure(
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _cond_wait(condition, deadline):
|
||||
def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool:
|
||||
timeout = deadline - time.monotonic() if deadline else None
|
||||
return condition.wait(timeout)
|
||||
|
||||
@ -406,23 +447,23 @@ class PoolOptions:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_pool_size=MAX_POOL_SIZE,
|
||||
min_pool_size=MIN_POOL_SIZE,
|
||||
max_idle_time_seconds=MAX_IDLE_TIME_SEC,
|
||||
connect_timeout=None,
|
||||
socket_timeout=None,
|
||||
wait_queue_timeout=WAIT_QUEUE_TIMEOUT,
|
||||
ssl_context=None,
|
||||
tls_allow_invalid_hostnames=False,
|
||||
event_listeners=None,
|
||||
appname=None,
|
||||
driver=None,
|
||||
compression_settings=None,
|
||||
max_connecting=MAX_CONNECTING,
|
||||
pause_enabled=True,
|
||||
server_api=None,
|
||||
load_balanced=None,
|
||||
credentials=None,
|
||||
max_pool_size: int = MAX_POOL_SIZE,
|
||||
min_pool_size: int = MIN_POOL_SIZE,
|
||||
max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC,
|
||||
connect_timeout: Optional[float] = None,
|
||||
socket_timeout: Optional[float] = None,
|
||||
wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT,
|
||||
ssl_context: Optional[SSLContext] = None,
|
||||
tls_allow_invalid_hostnames: bool = False,
|
||||
event_listeners: Optional[_EventListeners] = None,
|
||||
appname: Optional[str] = None,
|
||||
driver: Optional[DriverInfo] = None,
|
||||
compression_settings: Optional[CompressionSettings] = None,
|
||||
max_connecting: int = MAX_CONNECTING,
|
||||
pause_enabled: bool = True,
|
||||
server_api: Optional[ServerApi] = None,
|
||||
load_balanced: Optional[bool] = None,
|
||||
credentials: Optional[MongoCredential] = None,
|
||||
):
|
||||
self.__max_pool_size = max_pool_size
|
||||
self.__min_pool_size = min_pool_size
|
||||
@ -474,12 +515,12 @@ class PoolOptions:
|
||||
_truncate_metadata(self.__metadata)
|
||||
|
||||
@property
|
||||
def _credentials(self):
|
||||
def _credentials(self) -> Optional[MongoCredential]:
|
||||
"""A :class:`~pymongo.auth.MongoCredentials` instance or None."""
|
||||
return self.__credentials
|
||||
|
||||
@property
|
||||
def non_default_options(self):
|
||||
def non_default_options(self) -> Dict[str, Any]:
|
||||
"""The non-default options this pool was created with.
|
||||
|
||||
Added for CMAP's :class:`PoolCreatedEvent`.
|
||||
@ -490,15 +531,17 @@ class PoolOptions:
|
||||
if self.__min_pool_size != MIN_POOL_SIZE:
|
||||
opts["minPoolSize"] = self.__min_pool_size
|
||||
if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC:
|
||||
assert self.__max_idle_time_seconds is not None
|
||||
opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000
|
||||
if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT:
|
||||
assert self.__wait_queue_timeout is not None
|
||||
opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000
|
||||
if self.__max_connecting != MAX_CONNECTING:
|
||||
opts["maxConnecting"] = self.__max_connecting
|
||||
return opts
|
||||
|
||||
@property
|
||||
def max_pool_size(self):
|
||||
def max_pool_size(self) -> float:
|
||||
"""The maximum allowable number of concurrent connections to each
|
||||
connected server. Requests to a server will block if there are
|
||||
`maxPoolSize` outstanding connections to the requested server.
|
||||
@ -513,25 +556,25 @@ class PoolOptions:
|
||||
return self.__max_pool_size
|
||||
|
||||
@property
|
||||
def min_pool_size(self):
|
||||
def min_pool_size(self) -> int:
|
||||
"""The minimum required number of concurrent connections that the pool
|
||||
will maintain to each connected server. Default is 0.
|
||||
"""
|
||||
return self.__min_pool_size
|
||||
|
||||
@property
|
||||
def max_connecting(self):
|
||||
def max_connecting(self) -> int:
|
||||
"""The maximum number of concurrent connection creation attempts per
|
||||
pool. Defaults to 2.
|
||||
"""
|
||||
return self.__max_connecting
|
||||
|
||||
@property
|
||||
def pause_enabled(self):
|
||||
def pause_enabled(self) -> bool:
|
||||
return self.__pause_enabled
|
||||
|
||||
@property
|
||||
def max_idle_time_seconds(self):
|
||||
def max_idle_time_seconds(self) -> Optional[int]:
|
||||
"""The maximum number of seconds that a connection can remain
|
||||
idle in the pool before being removed and replaced. Defaults to
|
||||
`None` (no limit).
|
||||
@ -539,77 +582,77 @@ class PoolOptions:
|
||||
return self.__max_idle_time_seconds
|
||||
|
||||
@property
|
||||
def connect_timeout(self):
|
||||
def connect_timeout(self) -> Optional[float]:
|
||||
"""How long a connection can take to be opened before timing out."""
|
||||
return self.__connect_timeout
|
||||
|
||||
@property
|
||||
def socket_timeout(self):
|
||||
def socket_timeout(self) -> Optional[float]:
|
||||
"""How long a send or receive on a socket can take before timing out."""
|
||||
return self.__socket_timeout
|
||||
|
||||
@property
|
||||
def wait_queue_timeout(self):
|
||||
def wait_queue_timeout(self) -> Optional[int]:
|
||||
"""How long a thread will wait for a socket from the pool if the pool
|
||||
has no free sockets.
|
||||
"""
|
||||
return self.__wait_queue_timeout
|
||||
|
||||
@property
|
||||
def _ssl_context(self):
|
||||
def _ssl_context(self) -> Optional[SSLContext]:
|
||||
"""An SSLContext instance or None."""
|
||||
return self.__ssl_context
|
||||
|
||||
@property
|
||||
def tls_allow_invalid_hostnames(self):
|
||||
def tls_allow_invalid_hostnames(self) -> bool:
|
||||
"""If True skip ssl.match_hostname."""
|
||||
return self.__tls_allow_invalid_hostnames
|
||||
|
||||
@property
|
||||
def _event_listeners(self):
|
||||
def _event_listeners(self) -> Optional[_EventListeners]:
|
||||
"""An instance of pymongo.monitoring._EventListeners."""
|
||||
return self.__event_listeners
|
||||
|
||||
@property
|
||||
def appname(self):
|
||||
def appname(self) -> Optional[str]:
|
||||
"""The application name, for sending with hello in server handshake."""
|
||||
return self.__appname
|
||||
|
||||
@property
|
||||
def driver(self):
|
||||
def driver(self) -> Optional[DriverInfo]:
|
||||
"""Driver name and version, for sending with hello in handshake."""
|
||||
return self.__driver
|
||||
|
||||
@property
|
||||
def _compression_settings(self):
|
||||
def _compression_settings(self) -> Optional[CompressionSettings]:
|
||||
return self.__compression_settings
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
def metadata(self) -> SON[str, Any]:
|
||||
"""A dict of metadata about the application, driver, os, and platform."""
|
||||
return self.__metadata.copy()
|
||||
|
||||
@property
|
||||
def server_api(self):
|
||||
def server_api(self) -> Optional[ServerApi]:
|
||||
"""A pymongo.server_api.ServerApi or None."""
|
||||
return self.__server_api
|
||||
|
||||
@property
|
||||
def load_balanced(self):
|
||||
def load_balanced(self) -> Optional[bool]:
|
||||
"""True if this Pool is configured in load balanced mode."""
|
||||
return self.__load_balanced
|
||||
|
||||
|
||||
class _CancellationContext:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this context."""
|
||||
self._cancelled = True
|
||||
|
||||
@property
|
||||
def cancelled(self):
|
||||
def cancelled(self) -> bool:
|
||||
"""Was cancel called?"""
|
||||
return self._cancelled
|
||||
|
||||
@ -624,47 +667,48 @@ class Connection:
|
||||
- `id`: the id of this socket in it's pool
|
||||
"""
|
||||
|
||||
def __init__(self, conn, pool, address, id):
|
||||
def __init__(
|
||||
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: Tuple[str, int], id: int
|
||||
):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.conn = conn
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.authed = set()
|
||||
self.closed = False
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
self.is_writable = False
|
||||
self.is_writable: bool = False
|
||||
self.max_wire_version = MAX_WIRE_VERSION
|
||||
self.max_bson_size = MAX_BSON_SIZE
|
||||
self.max_message_size = MAX_MESSAGE_SIZE
|
||||
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
|
||||
self.supports_sessions = False
|
||||
self.hello_ok = None
|
||||
self.hello_ok: bool = False
|
||||
self.is_mongos = False
|
||||
self.op_msg_enabled = False
|
||||
self.listeners = pool.opts._event_listeners
|
||||
self.enabled_for_cmap = pool.enabled_for_cmap
|
||||
self.compression_settings = pool.opts._compression_settings
|
||||
self.compression_context = None
|
||||
self.socket_checker = SocketChecker()
|
||||
self.oidc_token_gen_id = None
|
||||
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
|
||||
self.socket_checker: SocketChecker = SocketChecker()
|
||||
self.oidc_token_gen_id: Optional[int] = None
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
self.negotiated_mechs = None
|
||||
self.auth_ctx = None
|
||||
self.negotiated_mechs: Optional[List[str]] = None
|
||||
self.auth_ctx: Optional[_AuthContext] = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
self.cancel_context = None
|
||||
self.cancel_context: Optional[_CancellationContext] = None
|
||||
if not pool.handshake:
|
||||
# This is a Monitor connection.
|
||||
self.cancel_context = _CancellationContext()
|
||||
self.opts = pool.opts
|
||||
self.more_to_come = False
|
||||
self.more_to_come: bool = False
|
||||
# For load balancer support.
|
||||
self.service_id = None
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned_txn = False
|
||||
@ -673,14 +717,16 @@ class Connection:
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
|
||||
def set_conn_timeout(self, timeout):
|
||||
def set_conn_timeout(self, timeout: Optional[float]) -> None:
|
||||
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
|
||||
if timeout == self.last_timeout:
|
||||
return
|
||||
self.last_timeout = timeout
|
||||
self.conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(self, client, cmd):
|
||||
def apply_timeout(
|
||||
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
|
||||
) -> Optional[float]:
|
||||
# CSOT: use remaining timeout when set.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
@ -704,22 +750,22 @@ class Connection:
|
||||
self.set_conn_timeout(timeout)
|
||||
return timeout
|
||||
|
||||
def pin_txn(self):
|
||||
def pin_txn(self) -> None:
|
||||
self.pinned_txn = True
|
||||
assert not self.pinned_cursor
|
||||
|
||||
def pin_cursor(self):
|
||||
def pin_cursor(self) -> None:
|
||||
self.pinned_cursor = True
|
||||
assert not self.pinned_txn
|
||||
|
||||
def unpin(self):
|
||||
def unpin(self) -> None:
|
||||
pool = self.pool_ref()
|
||||
if pool:
|
||||
pool.checkin(self)
|
||||
else:
|
||||
self.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def hello_cmd(self):
|
||||
def hello_cmd(self) -> SON[str, Any]:
|
||||
# Handshake spec requires us to use OP_MSG+hello command for the
|
||||
# initial handshake in load balanced or stable API mode.
|
||||
if self.opts.server_api or self.hello_ok or self.opts.load_balanced:
|
||||
@ -728,10 +774,15 @@ class Connection:
|
||||
else:
|
||||
return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)])
|
||||
|
||||
def hello(self):
|
||||
def hello(self) -> Hello[Dict[str, Any]]:
|
||||
return self._hello(None, None, None)
|
||||
|
||||
def _hello(self, cluster_time, topology_version, heartbeat_frequency):
|
||||
def _hello(
|
||||
self,
|
||||
cluster_time: Optional[Mapping[str, Any]],
|
||||
topology_version: Optional[Any],
|
||||
heartbeat_frequency: Optional[int],
|
||||
) -> Hello[Dict[str, Any]]:
|
||||
cmd = self.hello_cmd()
|
||||
performing_handshake = not self.performed_handshake
|
||||
awaitable = False
|
||||
@ -744,6 +795,7 @@ class Connection:
|
||||
cmd["loadBalanced"] = True
|
||||
elif topology_version is not None:
|
||||
cmd["topologyVersion"] = topology_version
|
||||
assert heartbeat_frequency is not None
|
||||
cmd["maxAwaitTimeMS"] = int(heartbeat_frequency * 1000)
|
||||
awaitable = True
|
||||
# If connect_timeout is None there is no timeout.
|
||||
@ -808,7 +860,7 @@ class Connection:
|
||||
self.generation = self.pool_gen.get(self.service_id)
|
||||
return hello
|
||||
|
||||
def _next_reply(self):
|
||||
def _next_reply(self) -> Dict[str, Any]:
|
||||
reply = self.receive_message(None)
|
||||
self.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response()
|
||||
@ -819,23 +871,23 @@ class Connection:
|
||||
@_handle_reauth
|
||||
def command(
|
||||
self,
|
||||
dbname,
|
||||
spec,
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
codec_options=DEFAULT_CODEC_OPTIONS,
|
||||
check=True,
|
||||
allowable_errors=None,
|
||||
read_concern=None,
|
||||
write_concern=None,
|
||||
parse_write_concern_error=False,
|
||||
collation=None,
|
||||
session=None,
|
||||
client=None,
|
||||
retryable_write=False,
|
||||
publish_events=True,
|
||||
user_fields=None,
|
||||
exhaust_allowed=False,
|
||||
):
|
||||
dbname: str,
|
||||
spec: MutableMapping[str, Any],
|
||||
read_preference: _ServerMode = ReadPreference.PRIMARY,
|
||||
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
|
||||
check: bool = True,
|
||||
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
|
||||
read_concern: Optional[ReadConcern] = None,
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
parse_write_concern_error: bool = False,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
client: Optional[MongoClient] = None,
|
||||
retryable_write: bool = False,
|
||||
publish_events: bool = True,
|
||||
user_fields: Optional[Mapping[str, Any]] = None,
|
||||
exhaust_allowed: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a command or raise an error.
|
||||
|
||||
:Parameters:
|
||||
@ -873,7 +925,7 @@ class Connection:
|
||||
session._apply_to(spec, retryable_write, read_preference, self)
|
||||
self.send_cluster_time(spec, session, client)
|
||||
listeners = self.listeners if publish_events else None
|
||||
unacknowledged = write_concern and not write_concern.acknowledged
|
||||
unacknowledged = bool(write_concern and not write_concern.acknowledged)
|
||||
if self.op_msg_enabled:
|
||||
self._raise_if_not_writable(unacknowledged)
|
||||
try:
|
||||
@ -907,7 +959,7 @@ class Connection:
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
def send_message(self, message, max_doc_size):
|
||||
def send_message(self, message: bytes, max_doc_size: int) -> None:
|
||||
"""Send a raw BSON message or raise ConnectionFailure.
|
||||
|
||||
If a network exception is raised, the socket is closed.
|
||||
@ -923,7 +975,7 @@ class Connection:
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
def receive_message(self, request_id):
|
||||
def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise ConnectionFailure.
|
||||
|
||||
If any exception is raised, the socket is closed.
|
||||
@ -933,7 +985,7 @@ class Connection:
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
|
||||
def _raise_if_not_writable(self, unacknowledged):
|
||||
def _raise_if_not_writable(self, unacknowledged: bool) -> None:
|
||||
"""Raise NotPrimaryError on unacknowledged write if this socket is not
|
||||
writable.
|
||||
"""
|
||||
@ -941,7 +993,7 @@ class Connection:
|
||||
# Write won't succeed, bail as if we'd received a not primary error.
|
||||
raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107})
|
||||
|
||||
def unack_write(self, msg, max_doc_size):
|
||||
def unack_write(self, msg: bytes, max_doc_size: int) -> None:
|
||||
"""Send unack OP_MSG.
|
||||
|
||||
Can raise ConnectionFailure or InvalidDocument.
|
||||
@ -953,7 +1005,9 @@ class Connection:
|
||||
self._raise_if_not_writable(True)
|
||||
self.send_message(msg, max_doc_size)
|
||||
|
||||
def write_command(self, request_id, msg, codec_options):
|
||||
def write_command(
|
||||
self, request_id: int, msg: bytes, codec_options: CodecOptions
|
||||
) -> Dict[str, Any]:
|
||||
"""Send "insert" etc. command, returning response as a dict.
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
@ -970,7 +1024,7 @@ class Connection:
|
||||
helpers._check_command_response(result, self.max_wire_version)
|
||||
return result
|
||||
|
||||
def authenticate(self, reauthenticate=False):
|
||||
def authenticate(self, reauthenticate: bool = False) -> None:
|
||||
"""Authenticate to the server if needed.
|
||||
|
||||
Can raise ConnectionFailure or OperationFailure.
|
||||
@ -988,9 +1042,12 @@ class Connection:
|
||||
auth.authenticate(creds, self, reauthenticate=reauthenticate)
|
||||
self.ready = True
|
||||
if self.enabled_for_cmap:
|
||||
assert self.listeners is not None
|
||||
self.listeners.publish_connection_ready(self.address, self.id)
|
||||
|
||||
def validate_session(self, client, session):
|
||||
def validate_session(
|
||||
self, client: Optional[MongoClient], session: Optional[ClientSession]
|
||||
) -> None:
|
||||
"""Validate this session before use with client.
|
||||
|
||||
Raises error if the client is not the one that created the session.
|
||||
@ -999,15 +1056,16 @@ class Connection:
|
||||
if session._client is not client:
|
||||
raise InvalidOperation("Can only use session with the MongoClient that started it")
|
||||
|
||||
def close_conn(self, reason):
|
||||
def close_conn(self, reason: Optional[str]) -> None:
|
||||
"""Close this connection with a reason."""
|
||||
if self.closed:
|
||||
return
|
||||
self._close_conn()
|
||||
if reason and self.enabled_for_cmap:
|
||||
assert self.listeners is not None
|
||||
self.listeners.publish_connection_closed(self.address, self.id, reason)
|
||||
|
||||
def _close_conn(self):
|
||||
def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
@ -1021,31 +1079,36 @@ class Connection:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def conn_closed(self):
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
return self.socket_checker.socket_closed(self.conn)
|
||||
|
||||
def send_cluster_time(self, command, session, client):
|
||||
def send_cluster_time(
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
session: Optional[ClientSession],
|
||||
client: Optional[MongoClient],
|
||||
) -> None:
|
||||
"""Add $clusterTime."""
|
||||
if client:
|
||||
client._send_cluster_time(command, session)
|
||||
|
||||
def add_server_api(self, command):
|
||||
def add_server_api(self, command: MutableMapping[str, Any]) -> None:
|
||||
"""Add server_api parameters."""
|
||||
if self.opts.server_api:
|
||||
_add_to_command(command, self.opts.server_api)
|
||||
|
||||
def update_last_checkin_time(self):
|
||||
def update_last_checkin_time(self) -> None:
|
||||
self.last_checkin_time = time.monotonic()
|
||||
|
||||
def update_is_writable(self, is_writable):
|
||||
def update_is_writable(self, is_writable: bool) -> None:
|
||||
self.is_writable = is_writable
|
||||
|
||||
def idle_time_seconds(self):
|
||||
def idle_time_seconds(self) -> float:
|
||||
"""Seconds since this socket was last checked into its pool."""
|
||||
return time.monotonic() - self.last_checkin_time
|
||||
|
||||
def _raise_connection_failure(self, error):
|
||||
def _raise_connection_failure(self, error: BaseException) -> NoReturn:
|
||||
# Catch *all* exceptions from socket methods and close the socket. In
|
||||
# regular Python, socket operations only raise socket.error, even if
|
||||
# the underlying cause was a Ctrl-C: a signal raised during socket.recv
|
||||
@ -1072,16 +1135,16 @@ class Connection:
|
||||
else:
|
||||
raise
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return self.conn == other.conn
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.conn)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "Connection({}){} at {}".format(
|
||||
repr(self.conn),
|
||||
self.closed and " CLOSED" or "",
|
||||
@ -1089,7 +1152,7 @@ class Connection:
|
||||
)
|
||||
|
||||
|
||||
def _create_connection(address, options):
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
@ -1160,7 +1223,7 @@ def _create_connection(address, options):
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
def _configured_socket(address, options):
|
||||
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
@ -1170,39 +1233,42 @@ def _configured_socket(address, options):
|
||||
sock = _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is not None:
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
sock = ssl_context.wrap_socket(sock)
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc: # noqa: B014
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ")
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(sock.getpeercert(), hostname=host)
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
raise
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock)
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc: # noqa: B014
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ")
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
class _PoolClosedError(PyMongoError):
|
||||
@ -1212,23 +1278,23 @@ class _PoolClosedError(PyMongoError):
|
||||
|
||||
|
||||
class _PoolGeneration:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# Maps service_id to generation.
|
||||
self._generations = collections.defaultdict(int)
|
||||
self._generations: Dict[ObjectId, int] = collections.defaultdict(int)
|
||||
# Overall pool generation.
|
||||
self._generation = 0
|
||||
|
||||
def get(self, service_id):
|
||||
def get(self, service_id: Optional[ObjectId]) -> int:
|
||||
"""Get the generation for the given service_id."""
|
||||
if service_id is None:
|
||||
return self._generation
|
||||
return self._generations[service_id]
|
||||
|
||||
def get_overall(self):
|
||||
def get_overall(self) -> int:
|
||||
"""Get the Pool's overall generation."""
|
||||
return self._generation
|
||||
|
||||
def inc(self, service_id):
|
||||
def inc(self, service_id: Optional[ObjectId]) -> None:
|
||||
"""Increment the generation for the given service_id."""
|
||||
self._generation += 1
|
||||
if service_id is None:
|
||||
@ -1237,7 +1303,7 @@ class _PoolGeneration:
|
||||
else:
|
||||
self._generations[service_id] += 1
|
||||
|
||||
def stale(self, gen, service_id):
|
||||
def stale(self, gen: int, service_id: Optional[ObjectId]) -> bool:
|
||||
"""Return if the given generation for a given service_id is stale."""
|
||||
return gen != self.get(service_id)
|
||||
|
||||
@ -1251,7 +1317,7 @@ class PoolState:
|
||||
# Do *not* explicitly inherit from object or Jython won't call __del__
|
||||
# http://bugs.jython.org/issue1057
|
||||
class Pool:
|
||||
def __init__(self, address, options, handshake=True):
|
||||
def __init__(self, address: _Address, options: PoolOptions, handshake: bool = True):
|
||||
"""
|
||||
:Parameters:
|
||||
- `address`: a (hostname, port) tuple
|
||||
@ -1274,7 +1340,7 @@ class Pool:
|
||||
# Monotonically increasing connection ID required for CMAP Events.
|
||||
self.next_connection_id = 1
|
||||
# Track whether the sockets in this pool are writeable or not.
|
||||
self.is_writable = None
|
||||
self.is_writable: Optional[bool] = None
|
||||
|
||||
# Keep track of resets, so we notice sockets created before the most
|
||||
# recent reset and close them.
|
||||
@ -1306,6 +1372,7 @@ class Pool:
|
||||
self._max_connecting = self.opts.max_connecting
|
||||
self._pending = 0
|
||||
if self.enabled_for_cmap:
|
||||
assert self.opts._event_listeners is not None
|
||||
self.opts._event_listeners.publish_pool_created(
|
||||
self.address, self.opts.non_default_options
|
||||
)
|
||||
@ -1314,23 +1381,26 @@ class Pool:
|
||||
# Retain references to pinned connections to prevent the CPython GC
|
||||
# from thinking that a cursor's pinned connection can be GC'd when the
|
||||
# cursor is GC'd (see PYTHON-2751).
|
||||
self.__pinned_sockets = set()
|
||||
self.__pinned_sockets: Set[Connection] = set()
|
||||
self.ncursors = 0
|
||||
self.ntxns = 0
|
||||
|
||||
def ready(self):
|
||||
def ready(self) -> None:
|
||||
# Take the lock to avoid the race condition described in PYTHON-2699.
|
||||
with self.lock:
|
||||
if self.state != PoolState.READY:
|
||||
self.state = PoolState.READY
|
||||
if self.enabled_for_cmap:
|
||||
assert self.opts._event_listeners is not None
|
||||
self.opts._event_listeners.publish_pool_ready(self.address)
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
def closed(self) -> bool:
|
||||
return self.state == PoolState.CLOSED
|
||||
|
||||
def _reset(self, close, pause=True, service_id=None):
|
||||
def _reset(
|
||||
self, close: bool, pause: bool = True, service_id: Optional[ObjectId] = None
|
||||
) -> None:
|
||||
old_state = self.state
|
||||
with self.size_cond:
|
||||
if self.closed:
|
||||
@ -1370,14 +1440,16 @@ class Pool:
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_pool_closed(self.address)
|
||||
else:
|
||||
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_pool_cleared(self.address, service_id=service_id)
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def update_is_writable(self, is_writable):
|
||||
def update_is_writable(self, is_writable: Optional[bool]) -> None:
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
Pool.
|
||||
"""
|
||||
@ -1386,19 +1458,19 @@ class Pool:
|
||||
for _socket in self.conns:
|
||||
_socket.update_is_writable(self.is_writable)
|
||||
|
||||
def reset(self, service_id=None):
|
||||
def reset(self, service_id: Optional[ObjectId] = None) -> None:
|
||||
self._reset(close=False, service_id=service_id)
|
||||
|
||||
def reset_without_pause(self):
|
||||
def reset_without_pause(self) -> None:
|
||||
self._reset(close=False, pause=False)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._reset(close=True)
|
||||
|
||||
def stale_generation(self, gen, service_id):
|
||||
def stale_generation(self, gen: int, service_id: Optional[ObjectId]) -> bool:
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def remove_stale_sockets(self, reference_generation):
|
||||
def remove_stale_sockets(self, reference_generation: int) -> None:
|
||||
"""Removes stale sockets then adds new ones if pool is too small and
|
||||
has not been reset. The `reference_generation` argument specifies the
|
||||
`generation` at the point in time this operation was requested on the
|
||||
@ -1454,7 +1526,7 @@ class Pool:
|
||||
self.requests -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def connect(self, handler=None):
|
||||
def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
|
||||
"""Connect to Mongo and return a new Connection.
|
||||
|
||||
Can raise ConnectionFailure.
|
||||
@ -1468,12 +1540,14 @@ class Pool:
|
||||
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_created(self.address, conn_id)
|
||||
|
||||
try:
|
||||
sock = _configured_socket(self.address, self.opts)
|
||||
except BaseException as error:
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_closed(
|
||||
self.address, conn_id, ConnectionClosedReason.ERROR
|
||||
)
|
||||
@ -1483,7 +1557,7 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
conn = Connection(sock, self, self.address, conn_id)
|
||||
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
try:
|
||||
if self.handshake:
|
||||
conn.hello()
|
||||
@ -1499,7 +1573,7 @@ class Pool:
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkout(self, handler=None):
|
||||
def checkout(self, handler: Optional[_MongoClientErrorHandler] = None) -> Iterator[Connection]:
|
||||
"""Get a connection from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`Connection` object wrapping a connected
|
||||
@ -1518,11 +1592,13 @@ class Pool:
|
||||
"""
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
|
||||
conn = self._get_conn(handler=handler)
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_checked_out(self.address, conn.id)
|
||||
try:
|
||||
yield conn
|
||||
@ -1551,15 +1627,16 @@ class Pool:
|
||||
elif conn.active:
|
||||
self.checkin(conn)
|
||||
|
||||
def _raise_if_not_ready(self, emit_event):
|
||||
def _raise_if_not_ready(self, emit_event: bool) -> None:
|
||||
if self.state != PoolState.READY:
|
||||
if self.enabled_for_cmap and emit_event:
|
||||
assert self.opts._event_listeners is not None
|
||||
self.opts._event_listeners.publish_connection_check_out_failed(
|
||||
self.address, ConnectionCheckOutFailedReason.CONN_ERROR
|
||||
)
|
||||
_raise_connection_failure(self.address, AutoReconnect("connection pool paused"))
|
||||
|
||||
def _get_conn(self, handler=None):
|
||||
def _get_conn(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection:
|
||||
"""Get or create a Connection. Can raise ConnectionFailure."""
|
||||
# We use the pid here to avoid issues with fork / multiprocessing.
|
||||
# See test.test_client:TestClient.test_fork for an example of
|
||||
@ -1569,6 +1646,7 @@ class Pool:
|
||||
|
||||
if self.closed:
|
||||
if self.enabled_for_cmap:
|
||||
assert self.opts._event_listeners is not None
|
||||
self.opts._event_listeners.publish_connection_check_out_failed(
|
||||
self.address, ConnectionCheckOutFailedReason.POOL_CLOSED
|
||||
)
|
||||
@ -1649,6 +1727,7 @@ class Pool:
|
||||
self.size_cond.notify()
|
||||
|
||||
if self.enabled_for_cmap and not emitted_event:
|
||||
assert self.opts._event_listeners is not None
|
||||
self.opts._event_listeners.publish_connection_check_out_failed(
|
||||
self.address, ConnectionCheckOutFailedReason.CONN_ERROR
|
||||
)
|
||||
@ -1657,7 +1736,7 @@ class Pool:
|
||||
conn.active = True
|
||||
return conn
|
||||
|
||||
def checkin(self, conn):
|
||||
def checkin(self, conn: Connection) -> None:
|
||||
"""Return the connection to the pool, or if it's closed discard it.
|
||||
|
||||
:Parameters:
|
||||
@ -1671,6 +1750,7 @@ class Pool:
|
||||
self.__pinned_sockets.discard(conn)
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_checked_in(self.address, conn.id)
|
||||
if self.pid != os.getpid():
|
||||
self.reset_without_pause()
|
||||
@ -1680,6 +1760,7 @@ class Pool:
|
||||
elif conn.closed:
|
||||
# CMAP requires the closed event be emitted after the check in.
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_closed(
|
||||
self.address, conn.id, ConnectionClosedReason.ERROR
|
||||
)
|
||||
@ -1691,7 +1772,7 @@ class Pool:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
else:
|
||||
conn.update_last_checkin_time()
|
||||
conn.update_is_writable(self.is_writable)
|
||||
conn.update_is_writable(bool(self.is_writable))
|
||||
self.conns.appendleft(conn)
|
||||
# Notify any threads waiting to create a connection.
|
||||
self._max_connecting_cond.notify()
|
||||
@ -1706,7 +1787,7 @@ class Pool:
|
||||
self.operation_count -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def _perished(self, conn):
|
||||
def _perished(self, conn: Connection) -> bool:
|
||||
"""Return True and close the connection if it is "perished".
|
||||
|
||||
This side-effecty function checks if this socket has been idle for
|
||||
@ -1745,6 +1826,7 @@ class Pool:
|
||||
def _raise_wait_queue_timeout(self) -> NoReturn:
|
||||
listeners = self.opts._event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_connection_check_out_failed(
|
||||
self.address, ConnectionCheckOutFailedReason.TIMEOUT
|
||||
)
|
||||
@ -1768,7 +1850,7 @@ class Pool:
|
||||
"maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout)
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
# Avoid ResourceWarnings in Python 3
|
||||
# Close all sockets without calling reset() or close() because it is
|
||||
# not safe to acquire a lock in __del__.
|
||||
|
||||
@ -109,7 +109,7 @@ class Server:
|
||||
conn: Connection,
|
||||
operation: Union[_Query, _GetMore],
|
||||
read_preference: _ServerMode,
|
||||
listeners: _EventListeners,
|
||||
listeners: Optional[_EventListeners],
|
||||
unpack_res: Callable[..., List[_DocumentOut]],
|
||||
) -> Response:
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
@ -126,6 +126,7 @@ class Server:
|
||||
- `unpack_res`: A callable that decodes the wire protocol response.
|
||||
"""
|
||||
duration = None
|
||||
assert listeners is not None
|
||||
publish = listeners.enabled_for_commands
|
||||
if publish:
|
||||
start = datetime.now()
|
||||
@ -140,6 +141,7 @@ class Server:
|
||||
|
||||
if publish:
|
||||
cmd, dbn = operation.as_command(conn)
|
||||
assert listeners is not None
|
||||
listeners.publish_command_start(
|
||||
cmd, dbn, request_id, conn.address, service_id=conn.service_id
|
||||
)
|
||||
@ -177,6 +179,7 @@ class Server:
|
||||
failure: _DocumentOut = exc.details # type: ignore[assignment]
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
assert listeners is not None
|
||||
listeners.publish_command_failure(
|
||||
duration,
|
||||
failure,
|
||||
@ -196,11 +199,12 @@ class Server:
|
||||
elif operation.name == "explain":
|
||||
res = docs[0] if docs else {}
|
||||
else:
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1}
|
||||
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
|
||||
if operation.name == "find":
|
||||
res["cursor"]["firstBatch"] = docs
|
||||
else:
|
||||
res["cursor"]["nextBatch"] = docs
|
||||
assert listeners is not None
|
||||
listeners.publish_command_success(
|
||||
duration,
|
||||
res,
|
||||
|
||||
@ -1084,10 +1084,10 @@ class TestCommandMonitoring(IntegrationTest):
|
||||
|
||||
self.listener.reset()
|
||||
cmd = SON([("getnonce", 1)])
|
||||
listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address)
|
||||
listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address) # type: ignore[arg-type]
|
||||
delta = datetime.timedelta(milliseconds=100)
|
||||
listeners.publish_command_success(
|
||||
delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address
|
||||
delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address # type: ignore[arg-type]
|
||||
)
|
||||
started = self.listener.started_events[0]
|
||||
succeeded = self.listener.succeeded_events[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user