PYTHON-5215 Add an asyncio.Protocol implementation for KMS (#2460)

This commit is contained in:
Steven Silvester 2025-08-19 08:45:24 -05:00 committed by GitHub
parent 37d327fbd8
commit e4b7eb52e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 436 additions and 611 deletions

View File

@ -64,6 +64,7 @@ from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncBaseConnection
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@ -76,11 +77,11 @@ from pymongo.errors import (
ServerSelectionTimeoutError,
)
from pymongo.helpers_shared import _get_timeout_details
from pymongo.network_layer import async_socket_sendall
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_async_configured_socket,
_configured_protocol_interface,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
@ -93,10 +94,8 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
_IS_SYNC = False
_HTTPS_PORT = 443
@ -111,9 +110,10 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
try:
return await _async_configured_socket(address, opts)
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
return AsyncBaseConnection(interface, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@ -198,18 +198,11 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
try:
conn = await _connect_kms(address, opts)
try:
await async_socket_sendall(conn, message)
await async_sendall(conn.conn.get_conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
async_receive_data_socket,
)
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = await async_receive_kms(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
@ -228,7 +221,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
)
finally:
conn.close()
await conn.close_conn(None)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:

View File

@ -123,74 +123,19 @@ except ImportError:
_IS_SYNC = False
class AsyncConnection:
"""Store a connection with some metadata.
class AsyncBaseConnection:
"""A base connection object for server and kms connections."""
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
self.pool_ref = weakref.ref(pool)
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
self.conn = conn
self.address = address
self.id = id
self.is_sdam = is_sdam
self.closed = False
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size = MAX_BSON_SIZE
self.max_message_size = MAX_MESSAGE_SIZE
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.socket_checker: SocketChecker = SocketChecker()
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
self.cancel_context: _CancellationContext = _CancellationContext()
self.opts = pool.opts
self.more_to_come: bool = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
self.is_sdam = False
self.closed = False
self.last_timeout: float | None = None
self.more_to_come = False
self.opts = opts
self.max_wire_version = -1
def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
@ -219,17 +164,111 @@ class AsyncConnection:
formatted = format_timeout_details(timeout_details)
# CSOT: raise an error without running the command since we know it will time out.
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
if self.max_wire_version != -1:
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
else:
raise TimeoutError(errmsg)
if cmd is not None:
cmd["maxTimeMS"] = int(max_time_ms * 1000)
self.set_conn_timeout(timeout)
return timeout
async def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
await self._close_conn()
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
class AsyncConnection(AsyncBaseConnection):
"""Store a connection with some metadata.
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
super().__init__(conn, pool.opts)
self.pool_ref = weakref.ref(pool)
self.address: tuple[str, int] = address
self.id: int = id
self.is_sdam = is_sdam
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size: int = MAX_BSON_SIZE
self.max_message_size: int = MAX_MESSAGE_SIZE
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos: bool = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
def pin_txn(self) -> None:
self.pinned_txn = True
assert not self.pinned_cursor
@ -573,26 +612,6 @@ class AsyncConnection:
error=reason,
)
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
command: MutableMapping[str, Any],

View File

@ -22,10 +22,11 @@ import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
from asyncio import BaseTransport, BufferedProtocol, Future, Protocol, Transport
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)
@ -38,208 +39,30 @@ from pymongo.errors import ProtocolError, _OperationCancelled
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.socket_checker import _errno_from_exception
try:
from ssl import SSLError, SSLSocket
_HAVE_SSL = True
except ImportError:
_HAVE_SSL = False
try:
from pymongo.pyopenssl_context import _sslConn
_HAVE_PYOPENSSL = True
except ImportError:
_HAVE_PYOPENSSL = False
_sslConn = SSLSocket # type: ignore[assignment, misc]
from pymongo.ssl_support import (
BLOCKING_IO_LOOKUP_ERROR,
BLOCKING_IO_READ_ERROR,
BLOCKING_IO_WRITE_ERROR,
)
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.synchronous.pool import Connection
from pymongo.asynchronous.pool import AsyncBaseConnection, AsyncConnection
from pymongo.pyopenssl_context import _sslConn
from pymongo.synchronous.pool import BaseConnection, Connection
_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
_PYPY = "PyPy" in sys.version
_WINDOWS = sys.platform == "win32"
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
# the MongoDB wire protocol
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
timeout = sock.gettimeout()
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
else:
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as exc:
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
raise socket.timeout("timed out") from exc
finally:
sock.settimeout(timeout)
if sys.platform != "win32":
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
) -> None:
view = memoryview(buf)
sent = 0
def _is_ready(fut: Future[Any]) -> None:
if fut.done():
return
fut.set_result(None)
while sent < len(buf):
try:
sent += sock.send(view[sent:])
except BLOCKING_IO_ERRORS as exc:
fd = sock.fileno()
# Check for closed socket.
if fd == -1:
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
await fut
finally:
loop.remove_reader(fd)
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
try:
await fut
finally:
loop.remove_writer(fd)
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
loop.add_writer(fd, _is_ready, fut)
await fut
finally:
loop.remove_reader(fd)
loop.remove_writer(fd)
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
def _is_ready(fut: Future[Any]) -> None:
if fut.done():
return
fut.set_result(None)
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
# KMS responses update their expected size after the first batch, stop reading after one loop
if once:
return mv[:read]
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
if fd == -1:
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
await fut
finally:
loop.remove_reader(fd)
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
try:
await fut
finally:
loop.remove_writer(fd)
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
try:
loop.add_writer(fd, _is_ready, fut)
await fut
finally:
loop.remove_reader(fd)
loop.remove_writer(fd)
return mv
else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
) -> None:
view = memoryview(buf)
total_length = len(buf)
total_sent = 0
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
# down to 1ms.
backoff = 0.001
while total_sent < total_length:
try:
sent = sock.send(view[total_sent:])
except BLOCKING_IO_ERRORS:
await asyncio.sleep(backoff)
sent = 0
if sent > 0:
backoff = max(backoff / 2, 0.001)
else:
backoff = min(backoff * 2, 0.512)
total_sent += sent
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
# down to 1ms.
backoff = 0.001
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
# KMS responses update their expected size after the first batch, stop reading after one loop
if once:
return mv[:read]
except BLOCKING_IO_ERRORS:
await asyncio.sleep(backoff)
read = 0
if read > 0:
backoff = max(backoff / 2, 0.001)
else:
backoff = min(backoff * 2, 0.512)
total_read += read
return mv
BLOCKING_IO_ERRORS = (
BlockingIOError,
*ssl_support.BLOCKING_IO_LOOKUP_ERROR,
*ssl_support.BLOCKING_IO_ERRORS,
)
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)
async def _poll_cancellation(conn: AsyncConnection) -> None:
async def _poll_cancellation(conn: AsyncBaseConnection) -> None:
while True:
if conn.cancel_context.cancelled:
return
@ -247,49 +70,7 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
await asyncio.sleep(_POLL_TIMEOUT)
async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
sock_timeout = sock.gettimeout()
timeout = sock_timeout
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
timeout=timeout,
)
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)
async def _async_socket_receive(
conn: socket.socket, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv
_PYPY = "PyPy" in sys.version
_WINDOWS = sys.platform == "win32"
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn.sock
timed_out = False
@ -322,7 +103,7 @@ def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
raise socket.timeout("timed out")
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
def receive_data(conn: BaseConnection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
@ -412,7 +193,7 @@ class NetworkingInterfaceBase:
class AsyncNetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
def __init__(self, conn: tuple[Transport, PyMongoBaseProtocol]):
super().__init__(conn)
@property
@ -430,7 +211,7 @@ class AsyncNetworkingInterface(NetworkingInterfaceBase):
return self.conn[0].is_closing()
@property
def get_conn(self) -> PyMongoProtocol:
def get_conn(self) -> PyMongoBaseProtocol:
return self.conn[1]
@property
@ -469,9 +250,51 @@ class NetworkingInterface(NetworkingInterfaceBase):
return self.conn.recv_into(buffer)
class PyMongoProtocol(BufferedProtocol):
class PyMongoBaseProtocol(Protocol):
def __init__(self, timeout: Optional[float] = None):
self.transport: Transport = None # type: ignore[assignment]
self._timeout = timeout
self._closed = asyncio.get_running_loop().create_future()
self._connection_lost = False
def settimeout(self, timeout: float | None) -> None:
self._timeout = timeout
@property
def gettimeout(self) -> float | None:
"""The configured timeout for the socket that underlies our protocol pair."""
return self._timeout
def close(self, exc: Optional[Exception] = None) -> None:
self.transport.abort()
self._resolve_pending(exc)
self._connection_lost = True
def connection_lost(self, exc: Optional[Exception] = None) -> None:
self._resolve_pending(exc)
if not self._closed.done():
self._closed.set_result(None)
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
pass
async def wait_closed(self) -> None:
await self._closed
async def write(self, message: bytes) -> None:
"""Write a message to this connection's transport."""
if self.transport.is_closing():
raise OSError("Connection is closed")
self.transport.write(message)
self.transport.resume_reading()
async def read(self, *args: Any) -> Any:
raise NotImplementedError
class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
def __init__(self, timeout: Optional[float] = None):
super().__init__(timeout)
# Each message is reader in 2-3 parts: header, compression header, and message body
# The message buffer is allocated after the header is read.
self._header = memoryview(bytearray(16))
@ -485,25 +308,14 @@ class PyMongoProtocol(BufferedProtocol):
self._expecting_compression = False
self._message_size = 0
self._op_code = 0
self._connection_lost = False
self._read_waiter: Optional[Future[Any]] = None
self._timeout = timeout
self._is_compressed = False
self._compressor_id: Optional[int] = None
self._max_message_size = MAX_MESSAGE_SIZE
self._response_to: Optional[int] = None
self._closed = asyncio.get_running_loop().create_future()
self._pending_messages: collections.deque[Future[Any]] = collections.deque()
self._done_messages: collections.deque[Future[Any]] = collections.deque()
def settimeout(self, timeout: float | None) -> None:
self._timeout = timeout
@property
def gettimeout(self) -> float | None:
"""The configured timeout for the socket that underlies our protocol pair."""
return self._timeout
def connection_made(self, transport: BaseTransport) -> None:
"""Called exactly once when a connection is made.
The transport argument is the transport representing the write side of the connection.
@ -511,13 +323,6 @@ class PyMongoProtocol(BufferedProtocol):
self.transport = transport # type: ignore[assignment]
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
async def write(self, message: bytes) -> None:
"""Write a message to this connection's transport."""
if self.transport.is_closing():
raise OSError("Connection is closed")
self.transport.write(message)
self.transport.resume_reading()
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
"""Read a single MongoDB Wire Protocol message from this connection."""
if self.transport:
@ -660,7 +465,7 @@ class PyMongoProtocol(BufferedProtocol):
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
return op_code, compressor_id
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
pending = list(self._pending_messages)
for msg in pending:
if not msg.done():
@ -670,21 +475,92 @@ class PyMongoProtocol(BufferedProtocol):
msg.set_exception(exc)
self._done_messages.append(msg)
def close(self, exc: Optional[Exception] = None) -> None:
self.transport.abort()
self._resolve_pending_messages(exc)
self._connection_lost = True
def connection_lost(self, exc: Optional[Exception] = None) -> None:
self._resolve_pending_messages(exc)
if not self._closed.done():
self._closed.set_result(None)
class PyMongoKMSProtocol(PyMongoBaseProtocol):
def __init__(self, timeout: Optional[float] = None):
super().__init__(timeout)
self._buffers: collections.deque[memoryview[bytes]] = collections.deque()
self._bytes_ready = 0
self._pending_reads: collections.deque[int] = collections.deque()
self._pending_listeners: collections.deque[Future[Any]] = collections.deque()
async def wait_closed(self) -> None:
await self._closed
def connection_made(self, transport: BaseTransport) -> None:
"""Called exactly once when a connection is made.
The transport argument is the transport representing the write side of the connection.
"""
self.transport = transport # type: ignore[assignment]
def data_received(self, data: bytes) -> None:
if self._connection_lost:
return
self._bytes_ready += len(data)
self._buffers.append(memoryview(data))
if not len(self._pending_reads):
return
bytes_needed = self._pending_reads.popleft()
data = self._read(bytes_needed)
waiter = self._pending_listeners.popleft()
waiter.set_result(data)
async def read(self, bytes_needed: int) -> bytes:
"""Read up to the requested bytes from this connection."""
# Note: all reads are "up-to" bytes_needed because we don't know if the kms_context
# has processed a Content-Length header and is requesting a response or not.
# Wait for other listeners first.
if len(self._pending_listeners):
await asyncio.gather(*self._pending_listeners)
# If there are bytes ready, then there is no need to wait further.
if self._bytes_ready > 0:
return self._read(bytes_needed)
if self.transport:
try:
self.transport.resume_reading()
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
except AttributeError:
raise OSError("connection is already closed") from None
if self.transport and self.transport.is_closing():
raise OSError("connection is already closed")
self._pending_reads.append(bytes_needed)
read_waiter = asyncio.get_running_loop().create_future()
self._pending_listeners.append(read_waiter)
return await read_waiter
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
while self._pending_listeners:
fut = self._pending_listeners.popleft()
fut.set_result(b"")
def _read(self, bytes_needed: int) -> memoryview:
"""Read bytes."""
# Send the bytes to the listener.
if self._bytes_ready < bytes_needed:
bytes_needed = self._bytes_ready
self._bytes_ready -= bytes_needed
output_buf = bytearray(bytes_needed)
n_remaining = bytes_needed
out_index = 0
while n_remaining > 0:
buffer = self._buffers.popleft()
buf_size = len(buffer)
# if we didn't exhaust the buffer, read the partial data and return the buffer.
if buf_size > n_remaining:
output_buf[out_index : n_remaining + out_index] = buffer[:n_remaining]
buffer = buffer[n_remaining:]
n_remaining = 0
self._buffers.appendleft(buffer)
# otherwise exhaust the buffer.
else:
output_buf[out_index : out_index + buf_size] = buffer[:]
out_index += buf_size
n_remaining -= buf_size
return memoryview(output_buf)
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
try:
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
except asyncio.TimeoutError as exc:
@ -692,12 +568,18 @@ async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
raise socket.timeout("timed out") from exc
async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes:
"""Receive raw bytes from the kms connection."""
def callback(result: Any) -> bytes:
return result
return await _async_receive_data(conn, callback, bytes_needed)
async def _async_receive_data(
conn: AsyncBaseConnection, callback: Callable[..., Any], *args: Any
) -> Any:
timeout: Optional[Union[float, int]]
timeout = conn.conn.gettimeout
if _csot.get_timeout():
@ -713,8 +595,8 @@ async def async_receive_message(
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
read_task = create_task(conn.conn.get_conn.read(*args))
cancellation_task = create_task(_poll_cancellation(conn))
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
tasks = [read_task, cancellation_task]
try:
done, pending = await asyncio.wait(
@ -727,14 +609,7 @@ async def async_receive_message(
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
data, op_code = read_task.result()
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
return callback(read_task.result())
raise _OperationCancelled("operation cancelled")
except asyncio.CancelledError:
for task in tasks:
@ -743,6 +618,31 @@ async def async_receive_message(
raise
async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
def callback(result: Any) -> _OpMsg | _OpReply:
data, op_code = result
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
return await _async_receive_data(conn, callback, request_id, max_message_size)
def receive_kms(conn: BaseConnection, bytes_needed: int) -> bytes:
"""Receive raw bytes from the kms connection."""
return conn.conn.sock.recv(bytes_needed)
def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:

View File

@ -16,7 +16,6 @@
from __future__ import annotations
import asyncio
import functools
import socket
import ssl
import sys
@ -25,7 +24,6 @@ from typing import (
Any,
NoReturn,
Optional,
Union,
)
from pymongo import _csot
@ -37,13 +35,17 @@ from pymongo.errors import ( # type:ignore[attr-defined]
_CertificateError,
)
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
from pymongo.network_layer import (
AsyncNetworkingInterface,
NetworkingInterface,
PyMongoBaseProtocol,
PyMongoProtocol,
)
from pymongo.pool_options import PoolOptions
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
SSLErrors = (PYSSLError, SSLError)
if TYPE_CHECKING:
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
try:
@ -244,64 +246,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
raise OSError("getaddrinfo failed")
async def _async_configured_socket(
address: _Address, options: PoolOptions
) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if _has_sni(False):
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
)
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
async def _configured_protocol_interface(
address: _Address, options: PoolOptions
address: _Address,
options: PoolOptions,
protocol_kls: type[PyMongoBaseProtocol] = PyMongoProtocol,
) -> AsyncNetworkingInterface:
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
@ -316,7 +264,7 @@ async def _configured_protocol_interface(
if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
lambda: protocol_kls(timeout=timeout), sock=sock
)
)
@ -325,7 +273,7 @@ async def _configured_protocol_interface(
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
lambda: protocol_kls(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
@ -425,56 +373,9 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
raise OSError("getaddrinfo failed")
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if _has_sni(True):
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
def _configured_socket_interface(
address: _Address, options: PoolOptions, *args: Any
) -> NetworkingInterface:
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.

View File

@ -71,11 +71,11 @@ from pymongo.errors import (
ServerSelectionTimeoutError,
)
from pymongo.helpers_shared import _get_timeout_details
from pymongo.network_layer import sendall
from pymongo.network_layer import PyMongoKMSProtocol, receive_kms, sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_configured_socket,
_configured_socket_interface,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
@ -85,6 +85,7 @@ from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import BaseConnection
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
from pymongo.write_concern import WriteConcern
@ -92,10 +93,8 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
_IS_SYNC = True
_HTTPS_PORT = 443
@ -110,9 +109,10 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
def _connect_kms(address: _Address, opts: PoolOptions) -> BaseConnection:
try:
return _configured_socket(address, opts)
interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol)
return BaseConnection(interface, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@ -197,18 +197,11 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
try:
conn = _connect_kms(address, opts)
try:
sendall(conn, message)
sendall(conn.conn.get_conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
receive_data_socket,
)
data = receive_data_socket(conn, kms_context.bytes_needed)
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = receive_kms(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
@ -227,7 +220,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
)
finally:
conn.close()
conn.close_conn(None)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:

View File

@ -123,74 +123,19 @@ except ImportError:
_IS_SYNC = True
class Connection:
"""Store a connection with some metadata.
class BaseConnection:
"""A base connection object for server and kms connections."""
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: NetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
self.pool_ref = weakref.ref(pool)
def __init__(self, conn: NetworkingInterface, opts: PoolOptions):
self.conn = conn
self.address = address
self.id = id
self.is_sdam = is_sdam
self.closed = False
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size = MAX_BSON_SIZE
self.max_message_size = MAX_MESSAGE_SIZE
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.socket_checker: SocketChecker = SocketChecker()
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
self.cancel_context: _CancellationContext = _CancellationContext()
self.opts = pool.opts
self.more_to_come: bool = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
self.is_sdam = False
self.closed = False
self.last_timeout: float | None = None
self.more_to_come = False
self.opts = opts
self.max_wire_version = -1
def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
@ -219,17 +164,111 @@ class Connection:
formatted = format_timeout_details(timeout_details)
# CSOT: raise an error without running the command since we know it will time out.
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
if self.max_wire_version != -1:
raise ExecutionTimeout(
errmsg,
50,
{"ok": 0, "errmsg": errmsg, "code": 50},
self.max_wire_version,
)
else:
raise TimeoutError(errmsg)
if cmd is not None:
cmd["maxTimeMS"] = int(max_time_ms * 1000)
self.set_conn_timeout(timeout)
return timeout
def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
self._close_conn()
def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
class Connection(BaseConnection):
"""Store a connection with some metadata.
:param conn: a raw connection object
:param pool: a Pool instance
:param address: the server's (host, port)
:param id: the id of this socket in it's pool
:param is_sdam: SDAM connections do not call hello on creation
"""
def __init__(
self,
conn: NetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
is_sdam: bool,
):
super().__init__(conn, pool.opts)
self.pool_ref = weakref.ref(pool)
self.address: tuple[str, int] = address
self.id: int = id
self.is_sdam = is_sdam
self.last_checkin_time = time.monotonic()
self.performed_handshake = False
self.is_writable: bool = False
self.max_wire_version = MAX_WIRE_VERSION
self.max_bson_size: int = MAX_BSON_SIZE
self.max_message_size: int = MAX_MESSAGE_SIZE
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos: bool = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
self.compression_settings = pool.opts._compression_settings
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
self.oidc_token_gen_id: Optional[int] = None
# Support for mechanism negotiation on the initial handshake.
self.negotiated_mechs: Optional[list[str]] = None
self.auth_ctx: Optional[_AuthContext] = None
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
# For load balancer support.
self.service_id: Optional[ObjectId] = None
self.server_connection_id: Optional[int] = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned_txn = False
self.pinned_cursor = False
self.active = False
self.last_timeout = self.opts.socket_timeout
self.connect_rtt = 0.0
self._client_id = pool._client_id
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
def pin_txn(self) -> None:
self.pinned_txn = True
assert not self.pinned_cursor
@ -571,26 +610,6 @@ class Connection:
error=reason,
)
def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
self.closed = True
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
command: MutableMapping[str, Any],

View File

@ -120,9 +120,9 @@ replacements = {
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"async_receive_kms": "receive_kms",
"AsyncNetworkingInterface": "NetworkingInterface",
"_configured_protocol_interface": "_configured_socket_interface",
"_async_configured_socket": "_configured_socket",
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",