PYTHON-4493 - Use asyncio protocols instead of sockets for network IO (#2151)

Co-authored-by: Shane Harvey <shnhrv@gmail.com>
This commit is contained in:
Noah Stapp 2025-03-28 15:02:40 -04:00 committed by GitHub
parent f3ca1e0372
commit e51ad27d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1129 additions and 803 deletions

View File

@ -64,11 +64,6 @@ from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@ -80,12 +75,17 @@ from pymongo.errors import (
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
from pymongo.network_layer import async_socket_sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_async_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern
@ -113,7 +113,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return await _configured_socket(address, opts)
return await _async_configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@ -196,7 +196,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
try:
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn, message)
await async_socket_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))

View File

@ -2079,7 +2079,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
assert conn_mgr.conn is not None
conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
await conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
else:
await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:

View File

@ -36,7 +36,11 @@ from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
from pymongo.asynchronous.pool import ( # type: ignore[attr-defined]
AsyncConnection,
Pool,
_CancellationContext,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import datetime
import logging
import time
from typing import (
TYPE_CHECKING,
Any,
@ -31,20 +30,16 @@ from typing import (
from bson import _decode_all_selective
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.compression_support import _NO_COMPRESSION
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.message import _OpMsg
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
async_receive_data,
async_receive_message,
async_sendall,
)
@ -194,13 +189,13 @@ async def command(
)
try:
await async_sendall(conn.conn, msg)
await async_sendall(conn.conn.get_conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await receive_message(conn, request_id)
reply = await async_receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
@ -301,47 +296,3 @@ async def command(
)
return response_doc # type: ignore[return-value]
async def receive_message(
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await async_receive_data(conn, 9, deadline)
)
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
else:
data = await async_receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

View File

@ -14,14 +14,10 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import functools
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -40,8 +36,8 @@ from typing import (
from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.asynchronous.network import command, receive_message
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command
from pymongo.common import (
MAX_BSON_SIZE,
MAX_MESSAGE_SIZE,
@ -52,16 +48,13 @@ from pymongo.common import (
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConfigurationError,
ConnectionFailure,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import (
@ -79,13 +72,20 @@ from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
)
from pymongo.network_layer import async_sendall
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_CancellationContext,
_configured_protocol_interface,
_get_timeout_details,
_raise_connection_failure,
format_timeout_details,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.ssl_support import SSLError
if TYPE_CHECKING:
from bson import CodecOptions
@ -99,7 +99,6 @@ if TYPE_CHECKING:
ZstdContext,
)
from pymongo.message import _OpMsg, _OpReply
from pymongo.pyopenssl_context import _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _CollationIn
@ -123,133 +122,6 @@ except ImportError:
_IS_SYNC = False
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
class AsyncConnection:
"""Store a connection with some metadata.
@ -261,7 +133,11 @@ class AsyncConnection:
"""
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
@ -318,7 +194,7 @@ class AsyncConnection:
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.conn.settimeout(timeout)
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
@ -364,7 +240,7 @@ class AsyncConnection:
if pool:
await pool.checkin(self)
else:
self.close_conn(ConnectionClosedReason.STALE)
await self.close_conn(ConnectionClosedReason.STALE)
def hello_cmd(self) -> dict[str, Any]:
# Handshake spec requires us to use OP_MSG+hello command for the
@ -559,7 +435,7 @@ class AsyncConnection:
raise
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
async def send_message(self, message: bytes, max_doc_size: int) -> None:
"""Send a raw BSON message or raise ConnectionFailure.
@ -573,10 +449,10 @@ class AsyncConnection:
)
try:
await async_sendall(self.conn, message)
await async_sendall(self.conn.get_conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.
@ -584,10 +460,10 @@ class AsyncConnection:
If any exception is raised, the socket is closed.
"""
try:
return await receive_message(self, request_id, self.max_message_size)
return await async_receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
def _raise_if_not_writable(self, unacknowledged: bool) -> None:
"""Raise NotPrimaryError on unacknowledged write if this socket is not
@ -673,11 +549,11 @@ class AsyncConnection:
"Can only use session with the AsyncMongoClient that started it"
)
def close_conn(self, reason: Optional[str]) -> None:
async def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
self._close_conn()
await self._close_conn()
if reason:
if self.enabled_for_cmap:
assert self.listeners is not None
@ -694,7 +570,7 @@ class AsyncConnection:
error=reason,
)
def _close_conn(self) -> None:
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
@ -703,13 +579,16 @@ class AsyncConnection:
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.conn.close()
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.conn)
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
@ -736,7 +615,7 @@ class AsyncConnection:
"""Seconds since this socket was last checked into its pool."""
return time.monotonic() - self.last_checkin_time
def _raise_connection_failure(self, error: BaseException) -> NoReturn:
async def _raise_connection_failure(self, error: BaseException) -> NoReturn:
# Catch *all* exceptions from socket methods and close the socket. In
# regular Python, socket operations only raise socket.error, even if
# the underlying cause was a Ctrl-C: a signal raised during socket.recv
@ -756,7 +635,7 @@ class AsyncConnection:
reason = None
else:
reason = ConnectionClosedReason.ERROR
self.close_conn(reason)
await self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
details = _get_timeout_details(self.opts)
@ -781,145 +660,6 @@ class AsyncConnection:
)
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
async def _configured_socket(
address: _Address, options: PoolOptions
) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = await _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
class _PoolClosedError(PyMongoError):
"""Internal error raised when a thread tries to get a connection from a
closed pool.
@ -1121,7 +861,7 @@ class Pool:
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -1152,7 +892,7 @@ class Pool:
serviceId=service_id,
)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -1197,7 +937,7 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
conn = self.conns.pop()
conn.close_conn(ConnectionClosedReason.IDLE)
await conn.close_conn(ConnectionClosedReason.IDLE)
while True:
async with self.size_cond:
@ -1221,7 +961,7 @@ class Pool:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
return
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
@ -1266,7 +1006,7 @@ class Pool:
)
try:
sock = await _configured_socket(self.address, self.opts)
networking_interface = await _configured_protocol_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
@ -1293,7 +1033,7 @@ class Pool:
raise
conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type]
conn = AsyncConnection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
@ -1311,7 +1051,7 @@ class Pool:
except BaseException:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
if handler:
@ -1509,7 +1249,7 @@ class Pool:
except IndexError:
self._pending += 1
if conn: # We got a socket from the pool
if self._perished(conn):
if await self._perished(conn):
conn = None
continue
else: # We need to create a new connection
@ -1523,7 +1263,7 @@ class Pool:
except BaseException:
if conn:
# We checked out a socket but authentication failed.
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
async with self.size_cond:
self.requests -= 1
if incremented:
@ -1583,7 +1323,7 @@ class Pool:
await self.reset_without_pause()
else:
if self.closed:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
elif conn.closed:
# CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap:
@ -1607,7 +1347,7 @@ class Pool:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
else:
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
@ -1625,7 +1365,7 @@ class Pool:
self.operation_count -= 1
self.size_cond.notify()
def _perished(self, conn: AsyncConnection) -> bool:
async def _perished(self, conn: AsyncConnection) -> bool:
"""Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for
@ -1645,18 +1385,18 @@ class Pool:
self.opts.max_idle_time_seconds is not None
and idle_time_seconds > self.opts.max_idle_time_seconds
):
conn.close_conn(ConnectionClosedReason.IDLE)
await conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
):
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
return True
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
return True
return False
@ -1704,5 +1444,6 @@ class Pool:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.
for conn in self.conns:
conn.close_conn(None)
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)

View File

@ -16,21 +16,26 @@
from __future__ import annotations
import asyncio
import collections
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, Future
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
)
from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress
from pymongo.errors import ProtocolError, _OperationCancelled
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.socket_checker import _errno_from_exception
try:
@ -69,13 +74,15 @@ _POLL_TIMEOUT = 0.5
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
# the MongoDB wire protocol
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
timeout = sock.gettimeout()
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
else:
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as exc:
@ -87,7 +94,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non
if sys.platform != "win32":
async def _async_sendall_ssl(
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
) -> None:
view = memoryview(buf)
@ -130,7 +137,7 @@ if sys.platform != "win32":
loop.remove_reader(fd)
loop.remove_writer(fd)
async def _async_receive_ssl(
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
@ -184,7 +191,7 @@ else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
async def _async_sendall_ssl(
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
) -> None:
view = memoryview(buf)
@ -205,7 +212,7 @@ else:
backoff = min(backoff * 2, 0.512)
total_sent += sent
async def _async_receive_ssl(
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
@ -244,52 +251,6 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
await asyncio.sleep(_POLL_TIMEOUT)
async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
timeout: Optional[Union[float, int]]
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
else:
timeout = sock_timeout
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
cancellation_task = create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
try:
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
return read_task.result()
raise _OperationCancelled("operation cancelled")
except asyncio.CancelledError:
for task in tasks:
task.cancel()
await asyncio.wait(tasks)
raise
finally:
sock.settimeout(sock_timeout)
async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
@ -301,18 +262,23 @@ async def async_receive_data_socket(
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
return await asyncio.wait_for(
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
timeout=timeout,
)
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
async def _async_socket_receive(
conn: socket.socket, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
@ -328,7 +294,7 @@ _PYPY = "PyPy" in sys.version
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
sock = conn.conn.sock
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
@ -413,3 +379,403 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
conn.set_conn_timeout(orig_timeout)
return mv
class NetworkingInterfaceBase:
def __init__(self, conn: Any):
self.conn = conn
@property
def gettimeout(self) -> Any:
raise NotImplementedError
def settimeout(self, timeout: float | None) -> None:
raise NotImplementedError
def close(self) -> Any:
raise NotImplementedError
def is_closing(self) -> bool:
raise NotImplementedError
@property
def get_conn(self) -> Any:
raise NotImplementedError
@property
def sock(self) -> Any:
raise NotImplementedError
class AsyncNetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
super().__init__(conn)
@property
def gettimeout(self) -> float | None:
return self.conn[1].gettimeout
def settimeout(self, timeout: float | None) -> None:
self.conn[1].settimeout(timeout)
async def close(self) -> None:
self.conn[1].close()
await self.conn[1].wait_closed()
def is_closing(self) -> bool:
return self.conn[0].is_closing()
@property
def get_conn(self) -> PyMongoProtocol:
return self.conn[1]
@property
def sock(self) -> socket.socket:
return self.conn[0].get_extra_info("socket")
class NetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: Union[socket.socket, _sslConn]):
super().__init__(conn)
def gettimeout(self) -> float | None:
return self.conn.gettimeout()
def settimeout(self, timeout: float | None) -> None:
self.conn.settimeout(timeout)
def close(self) -> None:
self.conn.close()
def is_closing(self) -> bool:
return self.conn.is_closing()
@property
def get_conn(self) -> Union[socket.socket, _sslConn]:
return self.conn
@property
def sock(self) -> Union[socket.socket, _sslConn]:
return self.conn
def fileno(self) -> int:
return self.conn.fileno()
def recv_into(self, buffer: bytes) -> int:
return self.conn.recv_into(buffer)
class PyMongoProtocol(BufferedProtocol):
def __init__(self, timeout: Optional[float] = None):
self.transport: Transport = None # type: ignore[assignment]
# Each message is reader in 2-3 parts: header, compression header, and message body
# The message buffer is allocated after the header is read.
self._header = memoryview(bytearray(16))
self._header_index = 0
self._compression_header = memoryview(bytearray(9))
self._compression_index = 0
self._message: Optional[memoryview] = None
self._message_index = 0
# State. TODO: replace booleans with an enum?
self._expecting_header = True
self._expecting_compression = False
self._message_size = 0
self._op_code = 0
self._connection_lost = False
self._read_waiter: Optional[Future] = None
self._timeout = timeout
self._is_compressed = False
self._compressor_id: Optional[int] = None
self._max_message_size = MAX_MESSAGE_SIZE
self._response_to: Optional[int] = None
self._closed = asyncio.get_running_loop().create_future()
self._pending_messages: collections.deque[Future] = collections.deque()
self._done_messages: collections.deque[Future] = collections.deque()
def settimeout(self, timeout: float | None) -> None:
self._timeout = timeout
@property
def gettimeout(self) -> float | None:
"""The configured timeout for the socket that underlies our protocol pair."""
return self._timeout
def connection_made(self, transport: BaseTransport) -> None:
"""Called exactly once when a connection is made.
The transport argument is the transport representing the write side of the connection.
"""
self.transport = transport # type: ignore[assignment]
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
async def write(self, message: bytes) -> None:
"""Write a message to this connection's transport."""
if self.transport.is_closing():
raise OSError("Connection is closed")
self.transport.write(message)
self.transport.resume_reading()
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
"""Read a single MongoDB Wire Protocol message from this connection."""
if self.transport:
try:
self.transport.resume_reading()
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
except AttributeError:
raise OSError("connection is already closed") from None
self._max_message_size = max_message_size
if self._done_messages:
message = await self._done_messages.popleft()
else:
if self.transport and self.transport.is_closing():
raise OSError("connection is already closed")
read_waiter = asyncio.get_running_loop().create_future()
self._pending_messages.append(read_waiter)
try:
message = await read_waiter
finally:
if read_waiter in self._done_messages:
self._done_messages.remove(read_waiter)
if message:
op_code, compressor_id, response_to, data = message
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(
f"Got response id {response_to!r} but expected {request_id!r}"
)
if compressor_id is not None:
data = decompress(data, compressor_id)
return data, op_code
raise OSError("connection closed")
def get_buffer(self, sizehint: int) -> memoryview:
"""Called to allocate a new receive buffer.
The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
If any data does not fit into the returned buffer, this method will be called again until
either no data remains or an empty buffer is returned.
"""
# Due to a bug, Python <=3.11 will call get_buffer() even after we raise
# ProtocolError in buffer_updated() and call connection_lost(). We allocate
# a temp buffer to drain the waiting data.
if self._connection_lost:
if not self._message:
self._message = memoryview(bytearray(2**14))
return self._message
# TODO: optimize this by caching pointers to the buffers.
# return self._buffer[self._index:]
if self._expecting_header:
return self._header[self._header_index :]
if self._expecting_compression:
return self._compression_header[self._compression_index :]
return self._message[self._message_index :] # type: ignore[index]
def buffer_updated(self, nbytes: int) -> None:
"""Called when the buffer was updated with the received data"""
# Wrote 0 bytes into a non-empty buffer, signal connection closed
if nbytes == 0:
self.close(OSError("connection closed"))
return
if self._connection_lost:
return
if self._expecting_header:
self._header_index += nbytes
if self._header_index >= 16:
self._expecting_header = False
try:
(
self._message_size,
self._op_code,
self._response_to,
self._expecting_compression,
) = self.process_header()
except ProtocolError as exc:
self.close(exc)
return
self._message = memoryview(bytearray(self._message_size))
return
if self._expecting_compression:
self._compression_index += nbytes
if self._compression_index >= 9:
self._expecting_compression = False
self._op_code, self._compressor_id = self.process_compression_header()
return
self._message_index += nbytes
if self._message_index >= self._message_size:
self._expecting_header = True
# Pause reading to avoid storing an arbitrary number of messages in memory.
self.transport.pause_reading()
if self._pending_messages:
result = self._pending_messages.popleft()
else:
result = asyncio.get_running_loop().create_future()
# Future has been cancelled, close this connection
if result.done():
self.close(None)
return
# Necessary values to reconstruct and verify message
result.set_result(
(self._op_code, self._compressor_id, self._response_to, self._message)
)
self._done_messages.append(result)
# Reset internal state to expect a new message
self._header_index = 0
self._compression_index = 0
self._message_index = 0
self._message_size = 0
self._message = None
self._op_code = 0
self._compressor_id = None
self._response_to = None
def process_header(self) -> tuple[int, int, int, bool]:
"""Unpack a MongoDB Wire Protocol header."""
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
expecting_compression = False
if op_code == 2012: # OP_COMPRESSED
if length <= 25:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
)
expecting_compression = True
length -= 9
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > self._max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({self._max_message_size!r})"
)
return length - 16, op_code, response_to, expecting_compression
def process_compression_header(self) -> tuple[int, int]:
"""Unpack a MongoDB Wire Protocol compression header."""
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
return op_code, compressor_id
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
pending = list(self._pending_messages)
for msg in pending:
if not msg.done():
if exc is None:
msg.set_result(None)
else:
msg.set_exception(exc)
self._done_messages.append(msg)
def close(self, exc: Optional[Exception] = None) -> None:
self.transport.abort()
self._resolve_pending_messages(exc)
self._connection_lost = True
def connection_lost(self, exc: Optional[Exception] = None) -> None:
self._resolve_pending_messages(exc)
if not self._closed.done():
self._closed.set_result(None)
async def wait_closed(self) -> None:
await self._closed
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
try:
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
except asyncio.TimeoutError as exc:
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
raise socket.timeout("timed out") from exc
async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
timeout: Optional[Union[float, int]]
timeout = conn.conn.gettimeout
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
cancellation_task = create_task(_poll_cancellation(conn))
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
tasks = [read_task, cancellation_task]
try:
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
data, op_code = read_task.result()
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
raise _OperationCancelled("operation cancelled")
except asyncio.CancelledError:
for task in tasks:
task.cancel()
await asyncio.wait(tasks)
raise
def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

546
pymongo/pool_shared.py Normal file
View File

@ -0,0 +1,546 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pool utilities and shared helper methods."""
from __future__ import annotations
import asyncio
import functools
import socket
import ssl
import sys
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
Optional,
Union,
)
from pymongo import _csot
from pymongo.asynchronous.helpers import _getaddrinfo
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConnectionFailure,
NetworkTimeout,
_CertificateError,
)
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
from pymongo.pool_options import PoolOptions
from pymongo.ssl_support import HAS_SNI, SSLError
if TYPE_CHECKING:
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
def _set_non_inheritable_non_atomic(fd: int) -> None:
"""Set the close-on-exec flag on the given file descriptor."""
flags = fcntl(fd, F_GETFD)
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ..), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
"""Dummy function for platforms that don't provide fcntl."""
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM):
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
async def _async_configured_socket(
address: _Address, options: PoolOptions
) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
async def _configured_protocol_interface(
address: _Address, options: PoolOptions
) -> AsyncNetworkingInterface:
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets protocol's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
timeout = options.socket_timeout
if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
except _CertificateError:
transport.abort()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
transport.abort()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise
return AsyncNetworkingInterface((transport, protocol))
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return NetworkingInterface(sock)
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
ssl_sock = ssl_context.wrap_socket(sock)
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return NetworkingInterface(ssl_sock)

View File

@ -70,21 +70,21 @@ from pymongo.errors import (
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
from pymongo.network_layer import sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import (
from pymongo.pool_shared import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern

View File

@ -36,7 +36,11 @@ from pymongo.server_description import ServerDescription
from pymongo.synchronous.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
from pymongo.synchronous.pool import ( # type: ignore[attr-defined]
Connection,
Pool,
_CancellationContext,
)
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import datetime
import logging
import time
from typing import (
TYPE_CHECKING,
Any,
@ -31,20 +30,16 @@ from typing import (
from bson import _decode_all_selective
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.compression_support import _NO_COMPRESSION
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.message import _OpMsg
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
receive_data,
receive_message,
sendall,
)
@ -194,7 +189,7 @@ def command(
)
try:
sendall(conn.conn, msg)
sendall(conn.conn.get_conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
@ -301,45 +296,3 @@ def command(
)
return response_doc # type: ignore[return-value]
def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

View File

@ -14,14 +14,10 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import functools
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -49,16 +45,13 @@ from pymongo.common import (
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConfigurationError,
ConnectionFailure,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import (
@ -76,16 +69,23 @@ from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
)
from pymongo.network_layer import sendall
from pymongo.network_layer import NetworkingInterface, receive_message, sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_CancellationContext,
_configured_socket_interface,
_get_timeout_details,
_raise_connection_failure,
format_timeout_details,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.ssl_support import SSLError
from pymongo.synchronous.client_session import _validate_session_write_concern
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.synchronous.network import command, receive_message
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.synchronous.network import command
if TYPE_CHECKING:
from bson import CodecOptions
@ -96,7 +96,6 @@ if TYPE_CHECKING:
ZstdContext,
)
from pymongo.message import _OpMsg, _OpReply
from pymongo.pyopenssl_context import _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.auth import _AuthContext
@ -123,133 +122,6 @@ except ImportError:
_IS_SYNC = True
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
class Connection:
"""Store a connection with some metadata.
@ -261,7 +133,11 @@ class Connection:
"""
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
self,
conn: NetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
@ -318,7 +194,7 @@ class Connection:
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.conn.settimeout(timeout)
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
@ -573,7 +449,7 @@ class Connection:
)
try:
sendall(self.conn, message)
sendall(self.conn.get_conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -707,7 +583,10 @@ class Connection:
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.conn)
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
@ -779,143 +658,6 @@ class Connection:
)
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
class _PoolClosedError(PyMongoError):
"""Internal error raised when a thread tries to get a connection from a
closed pool.
@ -1260,7 +1002,7 @@ class Pool:
)
try:
sock = _configured_socket(self.address, self.opts)
networking_interface = _configured_socket_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
with self.lock:
@ -1287,7 +1029,7 @@ class Pool:
raise
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
conn = Connection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
@ -1698,5 +1440,6 @@ class Pool:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.
for conn in self.conns:
conn.close_conn(None)
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)

View File

@ -116,6 +116,7 @@ filterwarnings = [
"module:unclosed <socket.socket:ResourceWarning",
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# https://github.com/eventlet/eventlet/issues/818
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
# https://github.com/dateutil/dateutil/issues/1314

View File

@ -22,6 +22,8 @@ import sys
import warnings
from test.asynchronous import AsyncPyMongoTestCase
import pytest
sys.path[0:0] = [""]
from test import unittest
@ -30,6 +32,8 @@ from test.asynchronous.unified_format import generate_test_classes
from pymongo import AsyncMongoClient
from pymongo.asynchronous.auth_oidc import OIDCCallback
pytestmark = pytest.mark.auth
_IS_SYNC = False
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")

View File

@ -301,7 +301,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_bulk_max_message_size(self):
await self.coll.delete_many({})
self.addCleanup(self.coll.delete_many, {})
self.addAsyncCleanup(self.coll.delete_many, {})
_16_MB = 16 * 1000 * 1000
# Generate a list of documents such that the first batched OP_MSG is
# as close as possible to the 48MB limit.
@ -505,7 +505,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_single_error_ordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -547,7 +547,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_multiple_error_ordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -616,7 +616,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_single_error_unordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -659,7 +659,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_multiple_error_unordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True),
@ -1002,7 +1002,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
await self.coll.delete_many({})
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
# Fail due to write concern support as well
# as duplicate key error on ordered batch.
@ -1077,7 +1077,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
await self.coll.delete_many({})
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
# Fail due to write concern support as well
# as duplicate key error on unordered batch.

View File

@ -746,7 +746,7 @@ class TestClient(AsyncIntegrationTest):
# Assert that if a socket is closed, a new one takes its place
async with server._pool.checkout() as conn:
conn.close_conn(None)
await conn.close_conn(None)
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"a closed socket gets replaced from the pool",
@ -1270,7 +1270,6 @@ class TestClient(AsyncIntegrationTest):
no_timeout = self.client
timeout_sec = 1
timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
self.addAsyncCleanup(timeout.close)
await no_timeout.pymongo_test.drop_collection("test")
await no_timeout.pymongo_test.test.insert_one({"x": 1})
@ -1337,7 +1336,7 @@ class TestClient(AsyncIntegrationTest):
async def test_socketKeepAlive(self):
pool = await async_get_pool(self.client)
async with pool.checkout() as conn:
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check
@ -1537,7 +1536,7 @@ class TestClient(AsyncIntegrationTest):
# Cause a network error.
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure):
await anext(cursor)
@ -1562,7 +1561,7 @@ class TestClient(AsyncIntegrationTest):
# Cause a network error on the actual socket.
pool = await async_get_pool(c)
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
# AsyncConnection.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect.
@ -2254,7 +2253,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
# Cause a network error.
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure):
@ -2282,7 +2281,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
# Cause a network error.
conn = cursor._sock_mgr.conn
conn.conn.close()
await conn.conn.close()
# A getmore fails.
with self.assertRaises(ConnectionFailure):

View File

@ -651,7 +651,6 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
_OVERHEAD = 500
internal_client = await self.async_rs_or_single_client(timeoutMS=None)
self.addAsyncCleanup(internal_client.close)
collection = internal_client.db["coll"]
self.addAsyncCleanup(collection.drop)

View File

@ -273,7 +273,7 @@ class AsyncTestCMAP(AsyncIntegrationTest):
for t in self.targets.values():
await t.join(5)
for conn in self.labels.values():
conn.close_conn(None)
await conn.close_conn(None)
self.addAsyncCleanup(cleanup)

View File

@ -1401,7 +1401,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_to_list_length(self):
coll = self.db.test
await coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
self.addAsyncCleanup(coll.drop)
c = coll.find()
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)
@ -1812,6 +1812,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
@async_client_context.require_version_min(5, 0, -1)
@async_client_context.require_no_mongos
@async_client_context.require_sync
async def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener])
@ -1821,7 +1822,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
listener.reset()
result = await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
result = list(await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
self.assertEqual(len(result), 3)

View File

@ -219,7 +219,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Use Connection's API to close the socket.
conn.close_conn(None)
await conn.close_conn(None)
self.assertEqual(0, len(cx_pool.conns))
@ -232,7 +232,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
await conn.conn.close()
self.assertTrue(conn.conn_closed())
async with cx_pool.checkout() as new_connection:
@ -306,7 +306,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
await conn.conn.close()
# Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)

View File

@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
self.deprecation_filter = DeprecationFilter()
async def asyncTearDown(self) -> None:
await super().asyncTearDown()
self.deprecation_filter.stop()
@ -196,6 +197,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
await super().asyncTearDown()
async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()

View File

@ -159,6 +159,7 @@ class AsyncMockConnection:
self.cancel_context = _CancellationContext()
self.more_to_come = False
self.id = random.randint(0, 100)
self.server_connection_id = random.randint(0, 100)
def close_conn(self, reason):
pass

View File

@ -22,6 +22,8 @@ import sys
import warnings
from test import PyMongoTestCase
import pytest
sys.path[0:0] = [""]
from test import unittest
@ -30,6 +32,8 @@ from test.unified_format import generate_test_classes
from pymongo import MongoClient
from pymongo.synchronous.auth_oidc import OIDCCallback
pytestmark = pytest.mark.auth
_IS_SYNC = True
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")

View File

@ -1233,7 +1233,6 @@ class TestClient(IntegrationTest):
no_timeout = self.client
timeout_sec = 1
timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
self.addCleanup(timeout.close)
no_timeout.pymongo_test.drop_collection("test")
no_timeout.pymongo_test.test.insert_one({"x": 1})
@ -1296,7 +1295,7 @@ class TestClient(IntegrationTest):
def test_socketKeepAlive(self):
pool = get_pool(self.client)
with pool.checkout() as conn:
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check

View File

@ -647,7 +647,6 @@ class TestClientBulkWriteCSOT(IntegrationTest):
_OVERHEAD = 500
internal_client = self.rs_or_single_client(timeoutMS=None)
self.addCleanup(internal_client.close)
collection = internal_client.db["coll"]
self.addCleanup(collection.drop)

View File

@ -1801,6 +1801,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
@client_context.require_version_min(5, 0, -1)
@client_context.require_no_mongos
@client_context.require_sync
def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(event_listeners=[listener])
@ -1810,7 +1811,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
listener.reset()
result = c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
self.assertEqual(len(result), 3)

View File

@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(IntegrationTest):
self.deprecation_filter = DeprecationFilter()
def tearDown(self) -> None:
super().tearDown()
self.deprecation_filter.stop()
@ -194,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
super().tearDown()
def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()

View File

@ -157,6 +157,7 @@ class MockConnection:
self.cancel_context = _CancellationContext()
self.more_to_come = False
self.id = random.randint(0, 100)
self.server_connection_id = random.randint(0, 100)
def close_conn(self, reason):
pass

View File

@ -47,6 +47,7 @@ replacements = {
"async_receive_message": "receive_message",
"async_receive_data": "receive_data",
"async_sendall": "sendall",
"async_socket_sendall": "sendall",
"asynchronous": "synchronous",
"Asynchronous": "Synchronous",
"AsyncBulkTestBase": "BulkTestBase",
@ -119,6 +120,9 @@ replacements = {
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"AsyncNetworkingInterface": "NetworkingInterface",
"_configured_protocol_interface": "_configured_socket_interface",
"_async_configured_socket": "_configured_socket",
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
@ -127,6 +131,7 @@ replacements = {
"async_create_barrier": "create_barrier",
"async_barrier_wait": "barrier_wait",
"async_joinall": "joinall",
"_async_create_connection": "_create_connection",
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
}