Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
a216788aff
@ -1,14 +1,6 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
Changes in Version 4.15.1 (XXXX/XX/XX)
|
||||
--------------------------------------
|
||||
|
||||
Version 4.15.1 is a bug fix release.
|
||||
|
||||
- Fixed a bug in ``AsyncMongoClient`` that caused a
|
||||
``ServerSelectionTimeoutError`` when used with ``uvicorn``, ``FastAPI``, or ``uvloop``.
|
||||
|
||||
Changes in Version 4.15.0 (2025/09/10)
|
||||
--------------------------------------
|
||||
|
||||
|
||||
@ -64,7 +64,6 @@ from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import AsyncBaseConnection
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts
|
||||
@ -77,11 +76,11 @@ from pymongo.errors import (
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_shared import _get_timeout_details
|
||||
from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall
|
||||
from pymongo.network_layer import async_socket_sendall
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
_configured_protocol_interface,
|
||||
_async_configured_socket,
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
@ -94,8 +93,10 @@ from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from pymongocrypt.mongocrypt import MongoCryptKmsContext
|
||||
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_HTTPS_PORT = 443
|
||||
@ -110,10 +111,9 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
|
||||
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
|
||||
|
||||
|
||||
async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection:
|
||||
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
try:
|
||||
interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol)
|
||||
return AsyncBaseConnection(interface, opts)
|
||||
return await _async_configured_socket(address, opts)
|
||||
except Exception as exc:
|
||||
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
|
||||
|
||||
@ -198,11 +198,19 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
try:
|
||||
conn = await _connect_kms(address, opts)
|
||||
try:
|
||||
await async_sendall(conn.conn.get_conn, message)
|
||||
await async_socket_sendall(conn, message)
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = await async_receive_kms(conn, kms_context.bytes_needed)
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data: memoryview | bytes
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
async_receive_data_socket,
|
||||
)
|
||||
|
||||
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
@ -221,7 +229,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
|
||||
)
|
||||
finally:
|
||||
await conn.close_conn(None)
|
||||
conn.close()
|
||||
except MongoCryptError:
|
||||
raise # Propagate MongoCryptError errors directly.
|
||||
except Exception as exc:
|
||||
|
||||
@ -123,19 +123,74 @@ except ImportError:
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class AsyncBaseConnection:
|
||||
"""A base connection object for server and kms connections."""
|
||||
class AsyncConnection:
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions):
|
||||
:param conn: a raw connection object
|
||||
:param pool: a Pool instance
|
||||
:param address: the server's (host, port)
|
||||
:param id: the id of this socket in it's pool
|
||||
:param is_sdam: SDAM connections do not call hello on creation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: AsyncNetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
is_sdam: bool,
|
||||
):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.conn = conn
|
||||
self.socket_checker: SocketChecker = SocketChecker()
|
||||
self.cancel_context: _CancellationContext = _CancellationContext()
|
||||
self.is_sdam = False
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.is_sdam = is_sdam
|
||||
self.closed = False
|
||||
self.last_timeout: float | None = None
|
||||
self.more_to_come = False
|
||||
self.opts = opts
|
||||
self.max_wire_version = -1
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
self.is_writable: bool = False
|
||||
self.max_wire_version = MAX_WIRE_VERSION
|
||||
self.max_bson_size = MAX_BSON_SIZE
|
||||
self.max_message_size = MAX_MESSAGE_SIZE
|
||||
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
|
||||
self.supports_sessions = False
|
||||
self.hello_ok: bool = False
|
||||
self.is_mongos = False
|
||||
self.op_msg_enabled = False
|
||||
self.listeners = pool.opts._event_listeners
|
||||
self.enabled_for_cmap = pool.enabled_for_cmap
|
||||
self.enabled_for_logging = pool.enabled_for_logging
|
||||
self.compression_settings = pool.opts._compression_settings
|
||||
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
|
||||
self.socket_checker: SocketChecker = SocketChecker()
|
||||
self.oidc_token_gen_id: Optional[int] = None
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
self.negotiated_mechs: Optional[list[str]] = None
|
||||
self.auth_ctx: Optional[_AuthContext] = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
self.cancel_context: _CancellationContext = _CancellationContext()
|
||||
self.opts = pool.opts
|
||||
self.more_to_come: bool = False
|
||||
# For load balancer support.
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.server_connection_id: Optional[int] = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned_txn = False
|
||||
self.pinned_cursor = False
|
||||
self.active = False
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
self._client_id = pool._client_id
|
||||
self.creation_time = time.monotonic()
|
||||
# For gossiping $clusterTime from the connection handshake to the client.
|
||||
self._cluster_time = None
|
||||
|
||||
def set_conn_timeout(self, timeout: Optional[float]) -> None:
|
||||
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
|
||||
@ -164,111 +219,17 @@ class AsyncBaseConnection:
|
||||
formatted = format_timeout_details(timeout_details)
|
||||
# CSOT: raise an error without running the command since we know it will time out.
|
||||
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
|
||||
if self.max_wire_version != -1:
|
||||
raise ExecutionTimeout(
|
||||
errmsg,
|
||||
50,
|
||||
{"ok": 0, "errmsg": errmsg, "code": 50},
|
||||
self.max_wire_version,
|
||||
)
|
||||
else:
|
||||
raise TimeoutError(errmsg)
|
||||
raise ExecutionTimeout(
|
||||
errmsg,
|
||||
50,
|
||||
{"ok": 0, "errmsg": errmsg, "code": 50},
|
||||
self.max_wire_version,
|
||||
)
|
||||
if cmd is not None:
|
||||
cmd["maxTimeMS"] = int(max_time_ms * 1000)
|
||||
self.set_conn_timeout(timeout)
|
||||
return timeout
|
||||
|
||||
async def close_conn(self, reason: Optional[str]) -> None:
|
||||
"""Close this connection with a reason."""
|
||||
if self.closed:
|
||||
return
|
||||
await self._close_conn()
|
||||
|
||||
async def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
self.closed = True
|
||||
self.cancel_context.cancel()
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
await self.conn.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
|
||||
class AsyncConnection(AsyncBaseConnection):
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
:param conn: a raw connection object
|
||||
:param pool: a Pool instance
|
||||
:param address: the server's (host, port)
|
||||
:param id: the id of this socket in it's pool
|
||||
:param is_sdam: SDAM connections do not call hello on creation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: AsyncNetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
is_sdam: bool,
|
||||
):
|
||||
super().__init__(conn, pool.opts)
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.address: tuple[str, int] = address
|
||||
self.id: int = id
|
||||
self.is_sdam = is_sdam
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
self.is_writable: bool = False
|
||||
self.max_wire_version = MAX_WIRE_VERSION
|
||||
self.max_bson_size: int = MAX_BSON_SIZE
|
||||
self.max_message_size: int = MAX_MESSAGE_SIZE
|
||||
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
|
||||
self.supports_sessions = False
|
||||
self.hello_ok: bool = False
|
||||
self.is_mongos: bool = False
|
||||
self.op_msg_enabled = False
|
||||
self.listeners = pool.opts._event_listeners
|
||||
self.enabled_for_cmap = pool.enabled_for_cmap
|
||||
self.enabled_for_logging = pool.enabled_for_logging
|
||||
self.compression_settings = pool.opts._compression_settings
|
||||
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
|
||||
self.oidc_token_gen_id: Optional[int] = None
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
self.negotiated_mechs: Optional[list[str]] = None
|
||||
self.auth_ctx: Optional[_AuthContext] = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
# For load balancer support.
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.server_connection_id: Optional[int] = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned_txn = False
|
||||
self.pinned_cursor = False
|
||||
self.active = False
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
self._client_id = pool._client_id
|
||||
self.creation_time = time.monotonic()
|
||||
# For gossiping $clusterTime from the connection handshake to the client.
|
||||
self._cluster_time = None
|
||||
|
||||
def pin_txn(self) -> None:
|
||||
self.pinned_txn = True
|
||||
assert not self.pinned_cursor
|
||||
@ -612,6 +573,26 @@ class AsyncConnection(AsyncBaseConnection):
|
||||
error=reason,
|
||||
)
|
||||
|
||||
async def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
self.closed = True
|
||||
self.cancel_context.cancel()
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
await self.conn.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
def send_cluster_time(
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
|
||||
@ -22,11 +22,10 @@ import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from asyncio import BaseProtocol, BaseTransport, BufferedProtocol, Future, Transport
|
||||
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
@ -39,30 +38,208 @@ from pymongo.errors import ProtocolError, _OperationCancelled
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncBaseConnection, AsyncConnection
|
||||
try:
|
||||
from ssl import SSLError, SSLSocket
|
||||
|
||||
_HAVE_SSL = True
|
||||
except ImportError:
|
||||
_HAVE_SSL = False
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.synchronous.pool import BaseConnection, Connection
|
||||
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_HAVE_PYOPENSSL = False
|
||||
_sslConn = SSLSocket # type: ignore[assignment, misc]
|
||||
|
||||
from pymongo.ssl_support import (
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
_UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
_PYPY = "PyPy" in sys.version
|
||||
_WINDOWS = sys.platform == "win32"
|
||||
|
||||
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (
|
||||
BlockingIOError,
|
||||
*ssl_support.BLOCKING_IO_LOOKUP_ERROR,
|
||||
*ssl_support.BLOCKING_IO_ERRORS,
|
||||
)
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
|
||||
# the MongoDB wire protocol
|
||||
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = sock.gettimeout()
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
|
||||
else:
|
||||
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
except asyncio.TimeoutError as exc:
|
||||
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
|
||||
raise socket.timeout("timed out") from exc
|
||||
finally:
|
||||
sock.settimeout(timeout)
|
||||
|
||||
|
||||
if sys.platform != "win32":
|
||||
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
sent = 0
|
||||
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while sent < len(buf):
|
||||
try:
|
||||
sent += sock.send(view[sent:]) # type:ignore[arg-type]
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = sock.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
|
||||
def _is_ready(fut: Future[Any]) -> None:
|
||||
if fut.done():
|
||||
return
|
||||
fut.set_result(None)
|
||||
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
total_read += read
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = conn.fileno()
|
||||
# Check for closed socket.
|
||||
if fd == -1:
|
||||
raise SSLError("Underlying socket has been closed") from None
|
||||
if isinstance(exc, BLOCKING_IO_READ_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
try:
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_writer(fd)
|
||||
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
|
||||
fut = loop.create_future()
|
||||
loop.add_reader(fd, _is_ready, fut)
|
||||
try:
|
||||
loop.add_writer(fd, _is_ready, fut)
|
||||
await fut
|
||||
finally:
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
return mv
|
||||
|
||||
else:
|
||||
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
|
||||
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
|
||||
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
total_length = len(buf)
|
||||
total_sent = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_sent < total_length:
|
||||
try:
|
||||
sent = sock.send(view[total_sent:])
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(backoff)
|
||||
sent = 0
|
||||
if sent > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_sent += sent
|
||||
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
# Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success
|
||||
# down to 1ms.
|
||||
backoff = 0.001
|
||||
while total_read < length:
|
||||
try:
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
except BLOCKING_IO_ERRORS:
|
||||
await asyncio.sleep(backoff)
|
||||
read = 0
|
||||
if read > 0:
|
||||
backoff = max(backoff / 2, 0.001)
|
||||
else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_read += read
|
||||
return mv
|
||||
|
||||
|
||||
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
sock.sendall(buf)
|
||||
|
||||
|
||||
async def _poll_cancellation(conn: AsyncBaseConnection) -> None:
|
||||
async def _poll_cancellation(conn: AsyncConnection) -> None:
|
||||
while True:
|
||||
if conn.cancel_context.cancelled:
|
||||
return
|
||||
@ -70,7 +247,49 @@ async def _poll_cancellation(conn: AsyncBaseConnection) -> None:
|
||||
await asyncio.sleep(_POLL_TIMEOUT)
|
||||
|
||||
|
||||
def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None:
|
||||
async def async_receive_data_socket(
|
||||
sock: Union[socket.socket, _sslConn], length: int
|
||||
) -> memoryview:
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
return await asyncio.wait_for(
|
||||
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(
|
||||
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise socket.timeout("timed out") from err
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_socket_receive(
|
||||
conn: socket.socket, length: int, loop: AbstractEventLoop
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
|
||||
if chunk_length == 0:
|
||||
raise OSError("connection closed")
|
||||
bytes_read += chunk_length
|
||||
return mv
|
||||
|
||||
|
||||
_PYPY = "PyPy" in sys.version
|
||||
_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn.sock
|
||||
timed_out = False
|
||||
@ -103,7 +322,7 @@ def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None:
|
||||
raise socket.timeout("timed out")
|
||||
|
||||
|
||||
def receive_data(conn: BaseConnection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
|
||||
buf = bytearray(length)
|
||||
mv = memoryview(buf)
|
||||
bytes_read = 0
|
||||
@ -193,7 +412,7 @@ class NetworkingInterfaceBase:
|
||||
|
||||
|
||||
class AsyncNetworkingInterface(NetworkingInterfaceBase):
|
||||
def __init__(self, conn: tuple[Transport, PyMongoBaseProtocol]):
|
||||
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
|
||||
super().__init__(conn)
|
||||
|
||||
@property
|
||||
@ -211,7 +430,7 @@ class AsyncNetworkingInterface(NetworkingInterfaceBase):
|
||||
return self.conn[0].is_closing()
|
||||
|
||||
@property
|
||||
def get_conn(self) -> PyMongoBaseProtocol:
|
||||
def get_conn(self) -> PyMongoProtocol:
|
||||
return self.conn[1]
|
||||
|
||||
@property
|
||||
@ -250,51 +469,9 @@ class NetworkingInterface(NetworkingInterfaceBase):
|
||||
return self.conn.recv_into(buffer)
|
||||
|
||||
|
||||
class PyMongoBaseProtocol(BaseProtocol):
|
||||
class PyMongoProtocol(BufferedProtocol):
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
self.transport: Transport = None # type: ignore[assignment]
|
||||
self._timeout = timeout
|
||||
self._closed = asyncio.get_running_loop().create_future()
|
||||
self._connection_lost = False
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self._timeout = timeout
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
"""The configured timeout for the socket that underlies our protocol pair."""
|
||||
return self._timeout
|
||||
|
||||
def close(self, exc: Optional[Exception] = None) -> None:
|
||||
self.transport.abort()
|
||||
self._resolve_pending(exc)
|
||||
self._connection_lost = True
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||
self._resolve_pending(exc)
|
||||
if not self._closed.done():
|
||||
self._closed.set_result(None)
|
||||
|
||||
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
|
||||
pass
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed
|
||||
|
||||
async def write(self, message: bytes) -> None:
|
||||
"""Write a message to this connection's transport."""
|
||||
if self.transport.is_closing():
|
||||
raise OSError("Connection is closed")
|
||||
self.transport.write(message)
|
||||
self.transport.resume_reading()
|
||||
|
||||
async def read(self, *args: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
super().__init__(timeout)
|
||||
# Each message is reader in 2-3 parts: header, compression header, and message body
|
||||
# The message buffer is allocated after the header is read.
|
||||
self._header = memoryview(bytearray(16))
|
||||
@ -308,14 +485,25 @@ class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
|
||||
self._expecting_compression = False
|
||||
self._message_size = 0
|
||||
self._op_code = 0
|
||||
self._connection_lost = False
|
||||
self._read_waiter: Optional[Future[Any]] = None
|
||||
self._timeout = timeout
|
||||
self._is_compressed = False
|
||||
self._compressor_id: Optional[int] = None
|
||||
self._max_message_size = MAX_MESSAGE_SIZE
|
||||
self._response_to: Optional[int] = None
|
||||
self._closed = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
self._done_messages: collections.deque[Future[Any]] = collections.deque()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self._timeout = timeout
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
"""The configured timeout for the socket that underlies our protocol pair."""
|
||||
return self._timeout
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called exactly once when a connection is made.
|
||||
The transport argument is the transport representing the write side of the connection.
|
||||
@ -323,6 +511,13 @@ class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
|
||||
|
||||
async def write(self, message: bytes) -> None:
|
||||
"""Write a message to this connection's transport."""
|
||||
if self.transport.is_closing():
|
||||
raise OSError("Connection is closed")
|
||||
self.transport.write(message)
|
||||
self.transport.resume_reading()
|
||||
|
||||
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
|
||||
"""Read a single MongoDB Wire Protocol message from this connection."""
|
||||
if self.transport:
|
||||
@ -465,7 +660,7 @@ class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
|
||||
return op_code, compressor_id
|
||||
|
||||
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
|
||||
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
|
||||
pending = list(self._pending_messages)
|
||||
for msg in pending:
|
||||
if not msg.done():
|
||||
@ -475,92 +670,21 @@ class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol):
|
||||
msg.set_exception(exc)
|
||||
self._done_messages.append(msg)
|
||||
|
||||
def close(self, exc: Optional[Exception] = None) -> None:
|
||||
self.transport.abort()
|
||||
self._resolve_pending_messages(exc)
|
||||
self._connection_lost = True
|
||||
|
||||
class PyMongoKMSProtocol(PyMongoBaseProtocol):
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
super().__init__(timeout)
|
||||
self._buffers: collections.deque[memoryview[bytes]] = collections.deque()
|
||||
self._bytes_ready = 0
|
||||
self._pending_reads: collections.deque[int] = collections.deque()
|
||||
self._pending_listeners: collections.deque[Future[Any]] = collections.deque()
|
||||
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||
self._resolve_pending_messages(exc)
|
||||
if not self._closed.done():
|
||||
self._closed.set_result(None)
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called exactly once when a connection is made.
|
||||
The transport argument is the transport representing the write side of the connection.
|
||||
"""
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
if self._connection_lost:
|
||||
return
|
||||
|
||||
self._bytes_ready += len(data)
|
||||
self._buffers.append(memoryview(data))
|
||||
|
||||
if not len(self._pending_reads):
|
||||
return
|
||||
|
||||
bytes_needed = self._pending_reads.popleft()
|
||||
data = self._read(bytes_needed)
|
||||
waiter = self._pending_listeners.popleft()
|
||||
waiter.set_result(data)
|
||||
|
||||
async def read(self, bytes_needed: int) -> bytes:
|
||||
"""Read up to the requested bytes from this connection."""
|
||||
# Note: all reads are "up-to" bytes_needed because we don't know if the kms_context
|
||||
# has processed a Content-Length header and is requesting a response or not.
|
||||
# Wait for other listeners first.
|
||||
if len(self._pending_listeners):
|
||||
await asyncio.gather(*self._pending_listeners)
|
||||
# If there are bytes ready, then there is no need to wait further.
|
||||
if self._bytes_ready > 0:
|
||||
return self._read(bytes_needed)
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.resume_reading()
|
||||
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
|
||||
except AttributeError:
|
||||
raise OSError("connection is already closed") from None
|
||||
if self.transport and self.transport.is_closing():
|
||||
raise OSError("connection is already closed")
|
||||
self._pending_reads.append(bytes_needed)
|
||||
read_waiter = asyncio.get_running_loop().create_future()
|
||||
self._pending_listeners.append(read_waiter)
|
||||
return await read_waiter
|
||||
|
||||
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
|
||||
while self._pending_listeners:
|
||||
fut = self._pending_listeners.popleft()
|
||||
fut.set_result(b"")
|
||||
|
||||
def _read(self, bytes_needed: int) -> bytes:
|
||||
"""Read bytes."""
|
||||
# Send the bytes to the listener.
|
||||
if self._bytes_ready < bytes_needed:
|
||||
bytes_needed = self._bytes_ready
|
||||
self._bytes_ready -= bytes_needed
|
||||
|
||||
output_buf = memoryview(bytearray(bytes_needed))
|
||||
n_remaining = bytes_needed
|
||||
out_index = 0
|
||||
while n_remaining > 0:
|
||||
buffer = self._buffers.popleft()
|
||||
buf_size = len(buffer)
|
||||
# if we didn't exhaust the buffer, read the partial data and return the buffer.
|
||||
if buf_size > n_remaining:
|
||||
output_buf[out_index : n_remaining + out_index] = buffer[:n_remaining]
|
||||
buffer = buffer[n_remaining:]
|
||||
n_remaining = 0
|
||||
self._buffers.appendleft(buffer)
|
||||
# otherwise exhaust the buffer.
|
||||
else:
|
||||
output_buf[out_index : out_index + buf_size] = buffer[:]
|
||||
out_index += buf_size
|
||||
n_remaining -= buf_size
|
||||
return bytes(output_buf)
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed
|
||||
|
||||
|
||||
async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
|
||||
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
|
||||
except asyncio.TimeoutError as exc:
|
||||
@ -568,18 +692,12 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
|
||||
raise socket.timeout("timed out") from exc
|
||||
|
||||
|
||||
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes:
|
||||
"""Receive raw bytes from the kms connection."""
|
||||
|
||||
def callback(result: Any) -> bytes:
|
||||
return result
|
||||
|
||||
return await _async_receive_data(conn, callback, bytes_needed)
|
||||
|
||||
|
||||
async def _async_receive_data(
|
||||
conn: AsyncBaseConnection, callback: Callable[..., Any], *args: Any
|
||||
) -> Any:
|
||||
async def async_receive_message(
|
||||
conn: AsyncConnection,
|
||||
request_id: Optional[int],
|
||||
max_message_size: int = MAX_MESSAGE_SIZE,
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
timeout: Optional[Union[float, int]]
|
||||
timeout = conn.conn.gettimeout
|
||||
if _csot.get_timeout():
|
||||
@ -595,8 +713,8 @@ async def _async_receive_data(
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
timeout = max(deadline - time.monotonic(), 0)
|
||||
|
||||
read_task = create_task(conn.conn.get_conn.read(*args))
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
|
||||
tasks = [read_task, cancellation_task]
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
@ -609,7 +727,14 @@ async def _async_receive_data(
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
return callback(read_task.result())
|
||||
data, op_code = read_task.result()
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
@ -618,31 +743,6 @@ async def _async_receive_data(
|
||||
raise
|
||||
|
||||
|
||||
async def async_receive_message(
|
||||
conn: AsyncConnection,
|
||||
request_id: Optional[int],
|
||||
max_message_size: int = MAX_MESSAGE_SIZE,
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
|
||||
def callback(result: Any) -> _OpMsg | _OpReply:
|
||||
data, op_code = result
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
return await _async_receive_data(conn, callback, request_id, max_message_size)
|
||||
|
||||
|
||||
def receive_kms(conn: BaseConnection, bytes_needed: int) -> bytes:
|
||||
"""Receive raw bytes from the kms connection."""
|
||||
return conn.conn.sock.recv(bytes_needed)
|
||||
|
||||
|
||||
def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
@ -670,7 +770,7 @@ def receive_message(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
data: bytes | memoryview
|
||||
data: memoryview | bytes
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
|
||||
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
@ -24,6 +25,7 @@ from typing import (
|
||||
Any,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import _csot
|
||||
@ -35,17 +37,13 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.helpers_shared import _get_timeout_details, format_timeout_details
|
||||
from pymongo.network_layer import (
|
||||
AsyncNetworkingInterface,
|
||||
NetworkingInterface,
|
||||
PyMongoBaseProtocol,
|
||||
PyMongoProtocol,
|
||||
)
|
||||
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
|
||||
|
||||
SSLErrors = (PYSSLError, SSLError)
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.typings import _Address
|
||||
|
||||
try:
|
||||
@ -246,10 +244,64 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
async def _async_configured_socket(
|
||||
address: _Address, options: PoolOptions
|
||||
) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = await _async_create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if _has_sni(False):
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
|
||||
)
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
async def _configured_protocol_interface(
|
||||
address: _Address,
|
||||
options: PoolOptions,
|
||||
protocol_kls: type[PyMongoBaseProtocol] = PyMongoProtocol,
|
||||
address: _Address, options: PoolOptions
|
||||
) -> AsyncNetworkingInterface:
|
||||
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
|
||||
|
||||
@ -264,7 +316,7 @@ async def _configured_protocol_interface(
|
||||
if ssl_context is None:
|
||||
return AsyncNetworkingInterface(
|
||||
await asyncio.get_running_loop().create_connection(
|
||||
lambda: protocol_kls(timeout=timeout), sock=sock
|
||||
lambda: PyMongoProtocol(timeout=timeout), sock=sock
|
||||
)
|
||||
)
|
||||
|
||||
@ -273,7 +325,7 @@ async def _configured_protocol_interface(
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: protocol_kls(timeout=timeout),
|
||||
lambda: PyMongoProtocol(timeout=timeout),
|
||||
sock=sock,
|
||||
server_hostname=host,
|
||||
ssl=ssl_context,
|
||||
@ -373,9 +425,56 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
def _configured_socket_interface(
|
||||
address: _Address, options: PoolOptions, *args: Any
|
||||
) -> NetworkingInterface:
|
||||
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if _has_sni(True):
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
|
||||
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
@ -71,11 +71,11 @@ from pymongo.errors import (
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.helpers_shared import _get_timeout_details
|
||||
from pymongo.network_layer import PyMongoKMSProtocol, receive_kms, sendall
|
||||
from pymongo.network_layer import sendall
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
_configured_socket_interface,
|
||||
_configured_socket,
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
@ -85,7 +85,6 @@ from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import BaseConnection
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
@ -93,8 +92,10 @@ from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from pymongocrypt.mongocrypt import MongoCryptKmsContext
|
||||
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.typings import _Address
|
||||
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_HTTPS_PORT = 443
|
||||
@ -109,10 +110,9 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
|
||||
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
|
||||
|
||||
|
||||
def _connect_kms(address: _Address, opts: PoolOptions) -> BaseConnection:
|
||||
def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
try:
|
||||
interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol)
|
||||
return BaseConnection(interface, opts)
|
||||
return _configured_socket(address, opts)
|
||||
except Exception as exc:
|
||||
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
|
||||
|
||||
@ -197,11 +197,19 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
try:
|
||||
conn = _connect_kms(address, opts)
|
||||
try:
|
||||
sendall(conn.conn.get_conn, message)
|
||||
sendall(conn, message)
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = receive_kms(conn, kms_context.bytes_needed)
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data: memoryview | bytes
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
receive_data_socket,
|
||||
)
|
||||
|
||||
data = receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
@ -220,7 +228,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts)
|
||||
)
|
||||
finally:
|
||||
conn.close_conn(None)
|
||||
conn.close()
|
||||
except MongoCryptError:
|
||||
raise # Propagate MongoCryptError errors directly.
|
||||
except Exception as exc:
|
||||
|
||||
@ -123,19 +123,74 @@ except ImportError:
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class BaseConnection:
|
||||
"""A base connection object for server and kms connections."""
|
||||
class Connection:
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
def __init__(self, conn: NetworkingInterface, opts: PoolOptions):
|
||||
:param conn: a raw connection object
|
||||
:param pool: a Pool instance
|
||||
:param address: the server's (host, port)
|
||||
:param id: the id of this socket in it's pool
|
||||
:param is_sdam: SDAM connections do not call hello on creation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: NetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
is_sdam: bool,
|
||||
):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.conn = conn
|
||||
self.socket_checker: SocketChecker = SocketChecker()
|
||||
self.cancel_context: _CancellationContext = _CancellationContext()
|
||||
self.is_sdam = False
|
||||
self.address = address
|
||||
self.id = id
|
||||
self.is_sdam = is_sdam
|
||||
self.closed = False
|
||||
self.last_timeout: float | None = None
|
||||
self.more_to_come = False
|
||||
self.opts = opts
|
||||
self.max_wire_version = -1
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
self.is_writable: bool = False
|
||||
self.max_wire_version = MAX_WIRE_VERSION
|
||||
self.max_bson_size = MAX_BSON_SIZE
|
||||
self.max_message_size = MAX_MESSAGE_SIZE
|
||||
self.max_write_batch_size = MAX_WRITE_BATCH_SIZE
|
||||
self.supports_sessions = False
|
||||
self.hello_ok: bool = False
|
||||
self.is_mongos = False
|
||||
self.op_msg_enabled = False
|
||||
self.listeners = pool.opts._event_listeners
|
||||
self.enabled_for_cmap = pool.enabled_for_cmap
|
||||
self.enabled_for_logging = pool.enabled_for_logging
|
||||
self.compression_settings = pool.opts._compression_settings
|
||||
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
|
||||
self.socket_checker: SocketChecker = SocketChecker()
|
||||
self.oidc_token_gen_id: Optional[int] = None
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
self.negotiated_mechs: Optional[list[str]] = None
|
||||
self.auth_ctx: Optional[_AuthContext] = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
self.cancel_context: _CancellationContext = _CancellationContext()
|
||||
self.opts = pool.opts
|
||||
self.more_to_come: bool = False
|
||||
# For load balancer support.
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.server_connection_id: Optional[int] = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned_txn = False
|
||||
self.pinned_cursor = False
|
||||
self.active = False
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
self._client_id = pool._client_id
|
||||
self.creation_time = time.monotonic()
|
||||
# For gossiping $clusterTime from the connection handshake to the client.
|
||||
self._cluster_time = None
|
||||
|
||||
def set_conn_timeout(self, timeout: Optional[float]) -> None:
|
||||
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
|
||||
@ -164,111 +219,17 @@ class BaseConnection:
|
||||
formatted = format_timeout_details(timeout_details)
|
||||
# CSOT: raise an error without running the command since we know it will time out.
|
||||
errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}"
|
||||
if self.max_wire_version != -1:
|
||||
raise ExecutionTimeout(
|
||||
errmsg,
|
||||
50,
|
||||
{"ok": 0, "errmsg": errmsg, "code": 50},
|
||||
self.max_wire_version,
|
||||
)
|
||||
else:
|
||||
raise TimeoutError(errmsg)
|
||||
raise ExecutionTimeout(
|
||||
errmsg,
|
||||
50,
|
||||
{"ok": 0, "errmsg": errmsg, "code": 50},
|
||||
self.max_wire_version,
|
||||
)
|
||||
if cmd is not None:
|
||||
cmd["maxTimeMS"] = int(max_time_ms * 1000)
|
||||
self.set_conn_timeout(timeout)
|
||||
return timeout
|
||||
|
||||
def close_conn(self, reason: Optional[str]) -> None:
|
||||
"""Close this connection with a reason."""
|
||||
if self.closed:
|
||||
return
|
||||
self._close_conn()
|
||||
|
||||
def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
self.closed = True
|
||||
self.cancel_context.cancel()
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
|
||||
class Connection(BaseConnection):
|
||||
"""Store a connection with some metadata.
|
||||
|
||||
:param conn: a raw connection object
|
||||
:param pool: a Pool instance
|
||||
:param address: the server's (host, port)
|
||||
:param id: the id of this socket in it's pool
|
||||
:param is_sdam: SDAM connections do not call hello on creation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: NetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
is_sdam: bool,
|
||||
):
|
||||
super().__init__(conn, pool.opts)
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.address: tuple[str, int] = address
|
||||
self.id: int = id
|
||||
self.is_sdam = is_sdam
|
||||
self.last_checkin_time = time.monotonic()
|
||||
self.performed_handshake = False
|
||||
self.is_writable: bool = False
|
||||
self.max_wire_version = MAX_WIRE_VERSION
|
||||
self.max_bson_size: int = MAX_BSON_SIZE
|
||||
self.max_message_size: int = MAX_MESSAGE_SIZE
|
||||
self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE
|
||||
self.supports_sessions = False
|
||||
self.hello_ok: bool = False
|
||||
self.is_mongos: bool = False
|
||||
self.op_msg_enabled = False
|
||||
self.listeners = pool.opts._event_listeners
|
||||
self.enabled_for_cmap = pool.enabled_for_cmap
|
||||
self.enabled_for_logging = pool.enabled_for_logging
|
||||
self.compression_settings = pool.opts._compression_settings
|
||||
self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None
|
||||
self.oidc_token_gen_id: Optional[int] = None
|
||||
# Support for mechanism negotiation on the initial handshake.
|
||||
self.negotiated_mechs: Optional[list[str]] = None
|
||||
self.auth_ctx: Optional[_AuthContext] = None
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
# For load balancer support.
|
||||
self.service_id: Optional[ObjectId] = None
|
||||
self.server_connection_id: Optional[int] = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned_txn = False
|
||||
self.pinned_cursor = False
|
||||
self.active = False
|
||||
self.last_timeout = self.opts.socket_timeout
|
||||
self.connect_rtt = 0.0
|
||||
self._client_id = pool._client_id
|
||||
self.creation_time = time.monotonic()
|
||||
# For gossiping $clusterTime from the connection handshake to the client.
|
||||
self._cluster_time = None
|
||||
|
||||
def pin_txn(self) -> None:
|
||||
self.pinned_txn = True
|
||||
assert not self.pinned_cursor
|
||||
@ -610,6 +571,26 @@ class Connection(BaseConnection):
|
||||
error=reason,
|
||||
)
|
||||
|
||||
def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
self.closed = True
|
||||
self.cancel_context.cancel()
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
def send_cluster_time(
|
||||
self,
|
||||
command: MutableMapping[str, Any],
|
||||
|
||||
@ -120,9 +120,9 @@ replacements = {
|
||||
"_async_create_lock": "_create_lock",
|
||||
"_async_create_condition": "_create_condition",
|
||||
"_async_cond_wait": "_cond_wait",
|
||||
"async_receive_kms": "receive_kms",
|
||||
"AsyncNetworkingInterface": "NetworkingInterface",
|
||||
"_configured_protocol_interface": "_configured_socket_interface",
|
||||
"_async_configured_socket": "_configured_socket",
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user