Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-03-31 12:59:54 -05:00
commit 51292decab
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
47 changed files with 2024 additions and 843 deletions

View File

@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including:
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
- Added :meth:`gridfs.asynchronous.grid_file.AsyncGridFSBucket.delete_by_name` and :meth:`gridfs.grid_file.GridFSBucket.delete_by_name`
for more performant deletion of a file with multiple revisions.
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
- Added index hinting support to the

View File

@ -834,6 +834,35 @@ class AsyncGridFSBucket:
if not res.deleted_count:
raise NoFile("no file could be deleted because none matched %s" % file_id)
@_csot.apply
async def delete_by_name(
self, filename: str, session: Optional[AsyncClientSession] = None
) -> None:
"""Given a filename, delete this stored file's files collection document(s)
and associated chunks from a GridFS bucket.
For example::
my_db = AsyncMongoClient().test
fs = AsyncGridFSBucket(my_db)
await fs.upload_from_stream("test_file", "data I want to store!")
await fs.delete_by_name("test_file")
Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists.
:param filename: The name of the file to be deleted.
:param session: a :class:`~pymongo.client_session.AsyncClientSession`
.. versionadded:: 4.12
"""
_disallow_transactions(session)
files = self._files.find({"filename": filename}, {"_id": 1}, session=session)
file_ids = [file["_id"] async for file in files]
res = await self._files.delete_many({"_id": {"$in": file_ids}}, session=session)
await self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session)
if not res.deleted_count:
raise NoFile(f"no file could be deleted because none matched filename {filename!r}")
def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor:
"""Find and return the files collection documents that match ``filter``

View File

@ -830,6 +830,33 @@ class GridFSBucket:
if not res.deleted_count:
raise NoFile("no file could be deleted because none matched %s" % file_id)
@_csot.apply
def delete_by_name(self, filename: str, session: Optional[ClientSession] = None) -> None:
"""Given a filename, delete this stored file's files collection document(s)
and associated chunks from a GridFS bucket.
For example::
my_db = MongoClient().test
fs = GridFSBucket(my_db)
fs.upload_from_stream("test_file", "data I want to store!")
fs.delete_by_name("test_file")
Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists.
:param filename: The name of the file to be deleted.
:param session: a :class:`~pymongo.client_session.ClientSession`
.. versionadded:: 4.12
"""
_disallow_transactions(session)
files = self._files.find({"filename": filename}, {"_id": 1}, session=session)
file_ids = [file["_id"] for file in files]
res = self._files.delete_many({"_id": {"$in": file_ids}}, session=session)
self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session)
if not res.deleted_count:
raise NoFile(f"no file could be deleted because none matched filename {filename!r}")
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
"""Find and return the files collection documents that match ``filter``

View File

@ -87,7 +87,7 @@ class _AsyncBulk:
self,
collection: AsyncCollection[_DocumentType],
ordered: bool,
bypass_document_validation: bool,
bypass_document_validation: Optional[bool],
comment: Optional[str] = None,
let: Optional[Any] = None,
) -> None:
@ -516,8 +516,8 @@ class _AsyncBulk:
if self.comment:
cmd["comment"] = self.comment
_csot.apply_write_concern(cmd, write_concern)
if self.bypass_doc_val:
cmd["bypassDocumentValidation"] = True
if self.bypass_doc_val is not None:
cmd["bypassDocumentValidation"] = self.bypass_doc_val
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
cmd["let"] = self.let
if session:

View File

@ -701,7 +701,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
requests: Sequence[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
@ -800,7 +800,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
ordered: bool,
write_concern: WriteConcern,
op_id: Optional[int],
bypass_doc_val: bool,
bypass_doc_val: Optional[bool],
session: Optional[AsyncClientSession],
comment: Optional[Any] = None,
) -> Any:
@ -814,8 +814,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def _insert_command(
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
) -> None:
if bypass_doc_val:
command["bypassDocumentValidation"] = True
if bypass_doc_val is not None:
command["bypassDocumentValidation"] = bypass_doc_val
result = await conn.command(
self._database.name,
@ -840,7 +840,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
async def insert_one(
self,
document: Union[_DocumentType, RawBSONDocument],
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> InsertOneResult:
@ -906,7 +906,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
self,
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
ordered: bool = True,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
) -> InsertManyResult:
@ -986,7 +986,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern: Optional[WriteConcern] = None,
op_id: Optional[int] = None,
ordered: bool = True,
bypass_doc_val: Optional[bool] = False,
bypass_doc_val: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -1041,8 +1041,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
# Update command.
if bypass_doc_val:
command["bypassDocumentValidation"] = True
if bypass_doc_val is not None:
command["bypassDocumentValidation"] = bypass_doc_val
# The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting.
@ -1082,7 +1082,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
write_concern: Optional[WriteConcern] = None,
op_id: Optional[int] = None,
ordered: bool = True,
bypass_doc_val: Optional[bool] = False,
bypass_doc_val: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -1128,7 +1128,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: bool = False,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[AsyncClientSession] = None,
@ -1237,7 +1237,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -2948,6 +2948,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
returning aggregate results using a cursor.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`.
- `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation.
:return: A :class:`~pymongo.asynchronous.command_cursor.AsyncCommandCursor` over the result

View File

@ -64,11 +64,6 @@ from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@ -80,12 +75,17 @@ from pymongo.errors import (
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
from pymongo.network_layer import async_socket_sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_async_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern
@ -113,7 +113,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return await _configured_socket(address, opts)
return await _async_configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@ -196,7 +196,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
try:
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn, message)
await async_socket_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))

View File

@ -2079,7 +2079,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
assert conn_mgr.conn is not None
conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
await conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
else:
await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
if conn_mgr:

View File

@ -36,7 +36,11 @@ from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
from pymongo.asynchronous.pool import ( # type: ignore[attr-defined]
AsyncConnection,
Pool,
_CancellationContext,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import datetime
import logging
import time
from typing import (
TYPE_CHECKING,
Any,
@ -31,20 +30,16 @@ from typing import (
from bson import _decode_all_selective
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.compression_support import _NO_COMPRESSION
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.message import _OpMsg
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
async_receive_data,
async_receive_message,
async_sendall,
)
@ -194,13 +189,13 @@ async def command(
)
try:
await async_sendall(conn.conn, msg)
await async_sendall(conn.conn.get_conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await receive_message(conn, request_id)
reply = await async_receive_message(conn, request_id)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
@ -301,47 +296,3 @@ async def command(
)
return response_doc # type: ignore[return-value]
async def receive_message(
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await async_receive_data(conn, 9, deadline)
)
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
else:
data = await async_receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

View File

@ -14,14 +14,10 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import functools
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -40,8 +36,8 @@ from typing import (
from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.asynchronous.network import command, receive_message
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command
from pymongo.common import (
MAX_BSON_SIZE,
MAX_MESSAGE_SIZE,
@ -52,16 +48,13 @@ from pymongo.common import (
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConfigurationError,
ConnectionFailure,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import (
@ -79,13 +72,20 @@ from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
)
from pymongo.network_layer import async_sendall
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_CancellationContext,
_configured_protocol_interface,
_get_timeout_details,
_raise_connection_failure,
format_timeout_details,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.ssl_support import SSLError
if TYPE_CHECKING:
from bson import CodecOptions
@ -99,7 +99,6 @@ if TYPE_CHECKING:
ZstdContext,
)
from pymongo.message import _OpMsg, _OpReply
from pymongo.pyopenssl_context import _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _CollationIn
@ -123,133 +122,6 @@ except ImportError:
_IS_SYNC = False
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
class AsyncConnection:
"""Store a connection with some metadata.
@ -261,7 +133,11 @@ class AsyncConnection:
"""
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
self,
conn: AsyncNetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
@ -318,7 +194,7 @@ class AsyncConnection:
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.conn.settimeout(timeout)
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
@ -364,7 +240,7 @@ class AsyncConnection:
if pool:
await pool.checkin(self)
else:
self.close_conn(ConnectionClosedReason.STALE)
await self.close_conn(ConnectionClosedReason.STALE)
def hello_cmd(self) -> dict[str, Any]:
# Handshake spec requires us to use OP_MSG+hello command for the
@ -559,7 +435,7 @@ class AsyncConnection:
raise
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
async def send_message(self, message: bytes, max_doc_size: int) -> None:
"""Send a raw BSON message or raise ConnectionFailure.
@ -573,10 +449,10 @@ class AsyncConnection:
)
try:
await async_sendall(self.conn, message)
await async_sendall(self.conn.get_conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.
@ -584,10 +460,10 @@ class AsyncConnection:
If any exception is raised, the socket is closed.
"""
try:
return await receive_message(self, request_id, self.max_message_size)
return await async_receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
await self._raise_connection_failure(error)
def _raise_if_not_writable(self, unacknowledged: bool) -> None:
"""Raise NotPrimaryError on unacknowledged write if this socket is not
@ -673,11 +549,11 @@ class AsyncConnection:
"Can only use session with the AsyncMongoClient that started it"
)
def close_conn(self, reason: Optional[str]) -> None:
async def close_conn(self, reason: Optional[str]) -> None:
"""Close this connection with a reason."""
if self.closed:
return
self._close_conn()
await self._close_conn()
if reason:
if self.enabled_for_cmap:
assert self.listeners is not None
@ -694,7 +570,7 @@ class AsyncConnection:
error=reason,
)
def _close_conn(self) -> None:
async def _close_conn(self) -> None:
"""Close this connection."""
if self.closed:
return
@ -703,13 +579,16 @@ class AsyncConnection:
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.conn.close()
await self.conn.close()
except Exception: # noqa: S110
pass
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.conn)
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
@ -736,7 +615,7 @@ class AsyncConnection:
"""Seconds since this socket was last checked into its pool."""
return time.monotonic() - self.last_checkin_time
def _raise_connection_failure(self, error: BaseException) -> NoReturn:
async def _raise_connection_failure(self, error: BaseException) -> NoReturn:
# Catch *all* exceptions from socket methods and close the socket. In
# regular Python, socket operations only raise socket.error, even if
# the underlying cause was a Ctrl-C: a signal raised during socket.recv
@ -756,7 +635,7 @@ class AsyncConnection:
reason = None
else:
reason = ConnectionClosedReason.ERROR
self.close_conn(reason)
await self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
details = _get_timeout_details(self.opts)
@ -781,145 +660,6 @@ class AsyncConnection:
)
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
async def _configured_socket(
address: _Address, options: PoolOptions
) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = await _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
class _PoolClosedError(PyMongoError):
"""Internal error raised when a thread tries to get a connection from a
closed pool.
@ -1121,7 +861,7 @@ class Pool:
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -1152,7 +892,7 @@ class Pool:
serviceId=service_id,
)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -1197,7 +937,7 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
conn = self.conns.pop()
conn.close_conn(ConnectionClosedReason.IDLE)
await conn.close_conn(ConnectionClosedReason.IDLE)
while True:
async with self.size_cond:
@ -1221,7 +961,7 @@ class Pool:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
return
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
@ -1266,7 +1006,7 @@ class Pool:
)
try:
sock = await _configured_socket(self.address, self.opts)
networking_interface = await _configured_protocol_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
@ -1293,7 +1033,7 @@ class Pool:
raise
conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type]
conn = AsyncConnection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
@ -1311,7 +1051,7 @@ class Pool:
except BaseException:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
if handler:
@ -1509,7 +1249,7 @@ class Pool:
except IndexError:
self._pending += 1
if conn: # We got a socket from the pool
if self._perished(conn):
if await self._perished(conn):
conn = None
continue
else: # We need to create a new connection
@ -1523,7 +1263,7 @@ class Pool:
except BaseException:
if conn:
# We checked out a socket but authentication failed.
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
async with self.size_cond:
self.requests -= 1
if incremented:
@ -1583,7 +1323,7 @@ class Pool:
await self.reset_without_pause()
else:
if self.closed:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
elif conn.closed:
# CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap:
@ -1607,7 +1347,7 @@ class Pool:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
else:
conn.update_last_checkin_time()
conn.update_is_writable(bool(self.is_writable))
@ -1625,7 +1365,7 @@ class Pool:
self.operation_count -= 1
self.size_cond.notify()
def _perished(self, conn: AsyncConnection) -> bool:
async def _perished(self, conn: AsyncConnection) -> bool:
"""Return True and close the connection if it is "perished".
This side-effecty function checks if this socket has been idle for
@ -1645,18 +1385,18 @@ class Pool:
self.opts.max_idle_time_seconds is not None
and idle_time_seconds > self.opts.max_idle_time_seconds
):
conn.close_conn(ConnectionClosedReason.IDLE)
await conn.close_conn(ConnectionClosedReason.IDLE)
return True
if self._check_interval_seconds is not None and (
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
):
if conn.conn_closed():
conn.close_conn(ConnectionClosedReason.ERROR)
await conn.close_conn(ConnectionClosedReason.ERROR)
return True
if self.stale_generation(conn.generation, conn.service_id):
conn.close_conn(ConnectionClosedReason.STALE)
await conn.close_conn(ConnectionClosedReason.STALE)
return True
return False
@ -1704,5 +1444,6 @@ class Pool:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.
for conn in self.conns:
conn.close_conn(None)
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)

View File

@ -16,21 +16,26 @@
from __future__ import annotations
import asyncio
import collections
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, Future
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
)
from pymongo import _csot, ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import decompress
from pymongo.errors import ProtocolError, _OperationCancelled
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.socket_checker import _errno_from_exception
try:
@ -69,13 +74,15 @@ _POLL_TIMEOUT = 0.5
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
# the MongoDB wire protocol
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
timeout = sock.gettimeout()
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
else:
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as exc:
@ -87,7 +94,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non
if sys.platform != "win32":
async def _async_sendall_ssl(
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
) -> None:
view = memoryview(buf)
@ -130,7 +137,7 @@ if sys.platform != "win32":
loop.remove_reader(fd)
loop.remove_writer(fd)
async def _async_receive_ssl(
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
@ -184,7 +191,7 @@ else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
async def _async_sendall_ssl(
async def _async_socket_sendall_ssl(
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
) -> None:
view = memoryview(buf)
@ -205,7 +212,7 @@ else:
backoff = min(backoff * 2, 0.512)
total_sent += sent
async def _async_receive_ssl(
async def _async_socket_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
@ -244,52 +251,6 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
await asyncio.sleep(_POLL_TIMEOUT)
async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
timeout: Optional[Union[float, int]]
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
else:
timeout = sock_timeout
sock.settimeout(0.0)
loop = asyncio.get_running_loop()
cancellation_task = create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
try:
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
return read_task.result()
raise _OperationCancelled("operation cancelled")
except asyncio.CancelledError:
for task in tasks:
task.cancel()
await asyncio.wait(tasks)
raise
finally:
sock.settimeout(sock_timeout)
async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
@ -301,18 +262,23 @@ async def async_receive_data_socket(
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
return await asyncio.wait_for(
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
timeout=timeout,
)
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
async def _async_socket_receive(
conn: socket.socket, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
@ -328,7 +294,7 @@ _PYPY = "PyPy" in sys.version
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
sock = conn.conn.sock
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
@ -413,3 +379,403 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
conn.set_conn_timeout(orig_timeout)
return mv
class NetworkingInterfaceBase:
def __init__(self, conn: Any):
self.conn = conn
@property
def gettimeout(self) -> Any:
raise NotImplementedError
def settimeout(self, timeout: float | None) -> None:
raise NotImplementedError
def close(self) -> Any:
raise NotImplementedError
def is_closing(self) -> bool:
raise NotImplementedError
@property
def get_conn(self) -> Any:
raise NotImplementedError
@property
def sock(self) -> Any:
raise NotImplementedError
class AsyncNetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
super().__init__(conn)
@property
def gettimeout(self) -> float | None:
return self.conn[1].gettimeout
def settimeout(self, timeout: float | None) -> None:
self.conn[1].settimeout(timeout)
async def close(self) -> None:
self.conn[1].close()
await self.conn[1].wait_closed()
def is_closing(self) -> bool:
return self.conn[0].is_closing()
@property
def get_conn(self) -> PyMongoProtocol:
return self.conn[1]
@property
def sock(self) -> socket.socket:
return self.conn[0].get_extra_info("socket")
class NetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: Union[socket.socket, _sslConn]):
super().__init__(conn)
def gettimeout(self) -> float | None:
return self.conn.gettimeout()
def settimeout(self, timeout: float | None) -> None:
self.conn.settimeout(timeout)
def close(self) -> None:
self.conn.close()
def is_closing(self) -> bool:
return self.conn.is_closing()
@property
def get_conn(self) -> Union[socket.socket, _sslConn]:
return self.conn
@property
def sock(self) -> Union[socket.socket, _sslConn]:
return self.conn
def fileno(self) -> int:
return self.conn.fileno()
def recv_into(self, buffer: bytes) -> int:
return self.conn.recv_into(buffer)
class PyMongoProtocol(BufferedProtocol):
def __init__(self, timeout: Optional[float] = None):
self.transport: Transport = None # type: ignore[assignment]
# Each message is reader in 2-3 parts: header, compression header, and message body
# The message buffer is allocated after the header is read.
self._header = memoryview(bytearray(16))
self._header_index = 0
self._compression_header = memoryview(bytearray(9))
self._compression_index = 0
self._message: Optional[memoryview] = None
self._message_index = 0
# State. TODO: replace booleans with an enum?
self._expecting_header = True
self._expecting_compression = False
self._message_size = 0
self._op_code = 0
self._connection_lost = False
self._read_waiter: Optional[Future] = None
self._timeout = timeout
self._is_compressed = False
self._compressor_id: Optional[int] = None
self._max_message_size = MAX_MESSAGE_SIZE
self._response_to: Optional[int] = None
self._closed = asyncio.get_running_loop().create_future()
self._pending_messages: collections.deque[Future] = collections.deque()
self._done_messages: collections.deque[Future] = collections.deque()
def settimeout(self, timeout: float | None) -> None:
self._timeout = timeout
@property
def gettimeout(self) -> float | None:
"""The configured timeout for the socket that underlies our protocol pair."""
return self._timeout
def connection_made(self, transport: BaseTransport) -> None:
"""Called exactly once when a connection is made.
The transport argument is the transport representing the write side of the connection.
"""
self.transport = transport # type: ignore[assignment]
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
async def write(self, message: bytes) -> None:
"""Write a message to this connection's transport."""
if self.transport.is_closing():
raise OSError("Connection is closed")
self.transport.write(message)
self.transport.resume_reading()
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
"""Read a single MongoDB Wire Protocol message from this connection."""
if self.transport:
try:
self.transport.resume_reading()
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
except AttributeError:
raise OSError("connection is already closed") from None
self._max_message_size = max_message_size
if self._done_messages:
message = await self._done_messages.popleft()
else:
if self.transport and self.transport.is_closing():
raise OSError("connection is already closed")
read_waiter = asyncio.get_running_loop().create_future()
self._pending_messages.append(read_waiter)
try:
message = await read_waiter
finally:
if read_waiter in self._done_messages:
self._done_messages.remove(read_waiter)
if message:
op_code, compressor_id, response_to, data = message
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(
f"Got response id {response_to!r} but expected {request_id!r}"
)
if compressor_id is not None:
data = decompress(data, compressor_id)
return data, op_code
raise OSError("connection closed")
def get_buffer(self, sizehint: int) -> memoryview:
"""Called to allocate a new receive buffer.
The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
If any data does not fit into the returned buffer, this method will be called again until
either no data remains or an empty buffer is returned.
"""
# Due to a bug, Python <=3.11 will call get_buffer() even after we raise
# ProtocolError in buffer_updated() and call connection_lost(). We allocate
# a temp buffer to drain the waiting data.
if self._connection_lost:
if not self._message:
self._message = memoryview(bytearray(2**14))
return self._message
# TODO: optimize this by caching pointers to the buffers.
# return self._buffer[self._index:]
if self._expecting_header:
return self._header[self._header_index :]
if self._expecting_compression:
return self._compression_header[self._compression_index :]
return self._message[self._message_index :] # type: ignore[index]
def buffer_updated(self, nbytes: int) -> None:
"""Called when the buffer was updated with the received data"""
# Wrote 0 bytes into a non-empty buffer, signal connection closed
if nbytes == 0:
self.close(OSError("connection closed"))
return
if self._connection_lost:
return
if self._expecting_header:
self._header_index += nbytes
if self._header_index >= 16:
self._expecting_header = False
try:
(
self._message_size,
self._op_code,
self._response_to,
self._expecting_compression,
) = self.process_header()
except ProtocolError as exc:
self.close(exc)
return
self._message = memoryview(bytearray(self._message_size))
return
if self._expecting_compression:
self._compression_index += nbytes
if self._compression_index >= 9:
self._expecting_compression = False
self._op_code, self._compressor_id = self.process_compression_header()
return
self._message_index += nbytes
if self._message_index >= self._message_size:
self._expecting_header = True
# Pause reading to avoid storing an arbitrary number of messages in memory.
self.transport.pause_reading()
if self._pending_messages:
result = self._pending_messages.popleft()
else:
result = asyncio.get_running_loop().create_future()
# Future has been cancelled, close this connection
if result.done():
self.close(None)
return
# Necessary values to reconstruct and verify message
result.set_result(
(self._op_code, self._compressor_id, self._response_to, self._message)
)
self._done_messages.append(result)
# Reset internal state to expect a new message
self._header_index = 0
self._compression_index = 0
self._message_index = 0
self._message_size = 0
self._message = None
self._op_code = 0
self._compressor_id = None
self._response_to = None
def process_header(self) -> tuple[int, int, int, bool]:
"""Unpack a MongoDB Wire Protocol header."""
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
expecting_compression = False
if op_code == 2012: # OP_COMPRESSED
if length <= 25:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
)
expecting_compression = True
length -= 9
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > self._max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({self._max_message_size!r})"
)
return length - 16, op_code, response_to, expecting_compression
def process_compression_header(self) -> tuple[int, int]:
"""Unpack a MongoDB Wire Protocol compression header."""
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
return op_code, compressor_id
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
pending = list(self._pending_messages)
for msg in pending:
if not msg.done():
if exc is None:
msg.set_result(None)
else:
msg.set_exception(exc)
self._done_messages.append(msg)
def close(self, exc: Optional[Exception] = None) -> None:
self.transport.abort()
self._resolve_pending_messages(exc)
self._connection_lost = True
def connection_lost(self, exc: Optional[Exception] = None) -> None:
self._resolve_pending_messages(exc)
if not self._closed.done():
self._closed.set_result(None)
async def wait_closed(self) -> None:
await self._closed
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
try:
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
except asyncio.TimeoutError as exc:
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
raise socket.timeout("timed out") from exc
async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
timeout: Optional[Union[float, int]]
timeout = conn.conn.gettimeout
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
cancellation_task = create_task(_poll_cancellation(conn))
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
tasks = [read_task, cancellation_task]
try:
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
data, op_code = read_task.result()
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)
raise _OperationCancelled("operation cancelled")
except asyncio.CancelledError:
for task in tasks:
task.cancel()
await asyncio.wait(tasks)
raise
def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

546
pymongo/pool_shared.py Normal file
View File

@ -0,0 +1,546 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pool utilities and shared helper methods."""
from __future__ import annotations
import asyncio
import functools
import socket
import ssl
import sys
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
Optional,
Union,
)
from pymongo import _csot
from pymongo.asynchronous.helpers import _getaddrinfo
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConnectionFailure,
NetworkTimeout,
_CertificateError,
)
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
from pymongo.pool_options import PoolOptions
from pymongo.ssl_support import HAS_SNI, SSLError
if TYPE_CHECKING:
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
try:
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
def _set_non_inheritable_non_atomic(fd: int) -> None:
"""Set the close-on-exec flag on the given file descriptor."""
flags = fcntl(fd, F_GETFD)
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
except ImportError:
# Windows, various platforms we don't claim to support
# (Jython, IronPython, ..), systems that don't provide
# everything we need from fcntl, etc.
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
"""Dummy function for platforms that don't provide fcntl."""
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM):
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
async def _async_configured_socket(
address: _Address, options: PoolOptions
) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
else:
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
async def _configured_protocol_interface(
address: _Address, options: PoolOptions
) -> AsyncNetworkingInterface:
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets protocol's SSL and timeout options.
"""
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
timeout = options.socket_timeout
if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
except _CertificateError:
transport.abort()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
transport.abort()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise
return AsyncNetworkingInterface((transport, protocol))
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a raw configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return NetworkingInterface(sock)
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
ssl_sock = ssl_context.wrap_socket(sock)
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return NetworkingInterface(ssl_sock)

View File

@ -87,7 +87,7 @@ class _Bulk:
self,
collection: Collection[_DocumentType],
ordered: bool,
bypass_document_validation: bool,
bypass_document_validation: Optional[bool],
comment: Optional[str] = None,
let: Optional[Any] = None,
) -> None:
@ -516,8 +516,8 @@ class _Bulk:
if self.comment:
cmd["comment"] = self.comment
_csot.apply_write_concern(cmd, write_concern)
if self.bypass_doc_val:
cmd["bypassDocumentValidation"] = True
if self.bypass_doc_val is not None:
cmd["bypassDocumentValidation"] = self.bypass_doc_val
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
cmd["let"] = self.let
if session:

View File

@ -700,7 +700,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self,
requests: Sequence[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
@ -799,7 +799,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
ordered: bool,
write_concern: WriteConcern,
op_id: Optional[int],
bypass_doc_val: bool,
bypass_doc_val: Optional[bool],
session: Optional[ClientSession],
comment: Optional[Any] = None,
) -> Any:
@ -813,8 +813,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def _insert_command(
session: Optional[ClientSession], conn: Connection, retryable_write: bool
) -> None:
if bypass_doc_val:
command["bypassDocumentValidation"] = True
if bypass_doc_val is not None:
command["bypassDocumentValidation"] = bypass_doc_val
result = conn.command(
self._database.name,
@ -839,7 +839,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
def insert_one(
self,
document: Union[_DocumentType, RawBSONDocument],
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
) -> InsertOneResult:
@ -905,7 +905,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
self,
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
ordered: bool = True,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
) -> InsertManyResult:
@ -985,7 +985,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern: Optional[WriteConcern] = None,
op_id: Optional[int] = None,
ordered: bool = True,
bypass_doc_val: Optional[bool] = False,
bypass_doc_val: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -1040,8 +1040,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
if comment is not None:
command["comment"] = comment
# Update command.
if bypass_doc_val:
command["bypassDocumentValidation"] = True
if bypass_doc_val is not None:
command["bypassDocumentValidation"] = bypass_doc_val
# The command result has to be published for APM unmodified
# so we make a shallow copy here before adding updatedExisting.
@ -1081,7 +1081,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
write_concern: Optional[WriteConcern] = None,
op_id: Optional[int] = None,
ordered: bool = True,
bypass_doc_val: Optional[bool] = False,
bypass_doc_val: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -1127,7 +1127,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
replacement: Mapping[str, Any],
upsert: bool = False,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
session: Optional[ClientSession] = None,
@ -1236,7 +1236,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
filter: Mapping[str, Any],
update: Union[Mapping[str, Any], _Pipeline],
upsert: bool = False,
bypass_document_validation: bool = False,
bypass_document_validation: Optional[bool] = None,
collation: Optional[_CollationIn] = None,
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
hint: Optional[_IndexKeyHint] = None,
@ -2941,6 +2941,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
returning aggregate results using a cursor.
- `collation` (optional): An instance of
:class:`~pymongo.collation.Collation`.
- `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation.
:return: A :class:`~pymongo.command_cursor.CommandCursor` over the result

View File

@ -70,21 +70,21 @@ from pymongo.errors import (
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
from pymongo.network_layer import sendall
from pymongo.operations import UpdateOne
from pymongo.pool_options import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import (
from pymongo.pool_shared import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern

View File

@ -36,7 +36,11 @@ from pymongo.server_description import ServerDescription
from pymongo.synchronous.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
from pymongo.synchronous.pool import ( # type: ignore[attr-defined]
Connection,
Pool,
_CancellationContext,
)
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology

View File

@ -17,7 +17,6 @@ from __future__ import annotations
import datetime
import logging
import time
from typing import (
TYPE_CHECKING,
Any,
@ -31,20 +30,16 @@ from typing import (
from bson import _decode_all_selective
from pymongo import _csot, helpers_shared, message
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.compression_support import _NO_COMPRESSION, decompress
from pymongo.compression_support import _NO_COMPRESSION
from pymongo.errors import (
NotPrimaryError,
OperationFailure,
ProtocolError,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.message import _OpMsg
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
receive_data,
receive_message,
sendall,
)
@ -194,7 +189,7 @@ def command(
)
try:
sendall(conn.conn, msg)
sendall(conn.conn.get_conn, msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
@ -301,45 +296,3 @@ def command(
)
return response_doc # type: ignore[return-value]
def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
if length <= 16:
raise ProtocolError(
f"Message length ({length!r}) not longer than standard message header size (16)"
)
if length > max_message_size:
raise ProtocolError(
f"Message length ({length!r}) is larger than server max "
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
except KeyError:
raise ProtocolError(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)

View File

@ -14,14 +14,10 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import functools
import logging
import os
import socket
import ssl
import sys
import time
import weakref
@ -49,16 +45,13 @@ from pymongo.common import (
from pymongo.errors import ( # type:ignore[attr-defined]
AutoReconnect,
ConfigurationError,
ConnectionFailure,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NetworkTimeout,
NotPrimaryError,
OperationFailure,
PyMongoError,
WaitQueueTimeoutError,
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import (
@ -76,16 +69,23 @@ from pymongo.monitoring import (
ConnectionCheckOutFailedReason,
ConnectionClosedReason,
)
from pymongo.network_layer import sendall
from pymongo.network_layer import NetworkingInterface, receive_message, sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
_CancellationContext,
_configured_socket_interface,
_get_timeout_details,
_raise_connection_failure,
format_timeout_details,
)
from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.ssl_support import SSLError
from pymongo.synchronous.client_session import _validate_session_write_concern
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
from pymongo.synchronous.network import command, receive_message
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.synchronous.network import command
if TYPE_CHECKING:
from bson import CodecOptions
@ -96,7 +96,6 @@ if TYPE_CHECKING:
ZstdContext,
)
from pymongo.message import _OpMsg, _OpReply
from pymongo.pyopenssl_context import _sslConn
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.auth import _AuthContext
@ -123,133 +122,6 @@ except ImportError:
_IS_SYNC = True
_MAX_TCP_KEEPIDLE = 120
_MAX_TCP_KEEPINTVL = 10
_MAX_TCP_KEEPCNT = 9
if sys.platform == "win32":
try:
import _winreg as winreg
except ImportError:
import winreg
def _query(key, name, default):
try:
value, _ = winreg.QueryValueEx(key, name)
# Ensure the value is a number or raise ValueError.
return int(value)
except (OSError, ValueError):
# QueryValueEx raises OSError when the key does not exist (i.e.
# the system is using the Windows default value).
return default
try:
with winreg.OpenKey(
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
) as key:
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
except OSError:
# We could not check the default values because winreg.OpenKey failed.
# Assume the system is using the default values.
_WINDOWS_TCP_IDLE_MS = 7200000
_WINDOWS_TCP_INTERVAL_MS = 1000
def _set_keepalive_times(sock):
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
else:
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
if hasattr(socket, tcp_option):
sockopt = getattr(socket, tcp_option)
try:
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
# TCP_KEEPIDLE and friends. Don't attempt to set the
# values there.
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
if default > max_value:
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
except OSError:
pass
def _set_keepalive_times(sock: socket.socket) -> None:
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
def _raise_connection_failure(
address: Any,
error: Exception,
msg_prefix: Optional[str] = None,
timeout_details: Optional[dict[str, float]] = None,
) -> NoReturn:
"""Convert a socket.error to ConnectionFailure and raise it."""
host, port = address
# If connecting to a Unix socket, port will be None.
if port is not None:
msg = "%s:%d: %s" % (host, port, error)
else:
msg = f"{host}: {error}"
if msg_prefix:
msg = msg_prefix + msg
if "configured timeouts" not in msg:
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
# 'timed out' appears in all the timeout related SSLErrors raised.
raise NetworkTimeout(msg) from error
else:
raise AutoReconnect(msg) from error
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
socket_timeout = options.socket_timeout
connect_timeout = options.connect_timeout
if timeout:
details["timeoutMS"] = timeout * 1000
if socket_timeout and not timeout:
details["socketTimeoutMS"] = socket_timeout * 1000
if connect_timeout:
details["connectTimeoutMS"] = connect_timeout * 1000
return details
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
result = ""
if details:
result += " (configured timeouts:"
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
if timeout in details:
result += f" {timeout}: {details[timeout]}ms,"
result = result[:-1]
result += ")"
return result
class _CancellationContext:
def __init__(self) -> None:
self._cancelled = False
def cancel(self) -> None:
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self) -> bool:
"""Was cancel called?"""
return self._cancelled
class Connection:
"""Store a connection with some metadata.
@ -261,7 +133,11 @@ class Connection:
"""
def __init__(
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
self,
conn: NetworkingInterface,
pool: Pool,
address: tuple[str, int],
id: int,
):
self.pool_ref = weakref.ref(pool)
self.conn = conn
@ -318,7 +194,7 @@ class Connection:
if timeout == self.last_timeout:
return
self.last_timeout = timeout
self.conn.settimeout(timeout)
self.conn.get_conn.settimeout(timeout)
def apply_timeout(
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
@ -573,7 +449,7 @@ class Connection:
)
try:
sendall(self.conn, message)
sendall(self.conn.get_conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@ -707,7 +583,10 @@ class Connection:
def conn_closed(self) -> bool:
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.conn)
if _IS_SYNC:
return self.socket_checker.socket_closed(self.conn.get_conn)
else:
return self.conn.is_closing()
def send_cluster_time(
self,
@ -779,143 +658,6 @@ class Connection:
)
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
"""Given (host, port) and PoolOptions, connect and return a socket object.
Can raise socket.error.
This is a modified version of create_connection from CPython >= 2.7.
"""
host, port = address
# Check if dealing with a unix domain socket
if host.endswith(".sock"):
if not hasattr(socket, "AF_UNIX"):
raise ConnectionFailure("UNIX-sockets are not supported on this system")
sock = socket.socket(socket.AF_UNIX)
# SOCK_CLOEXEC not supported for Unix sockets.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.connect(host)
return sock
except OSError:
sock.close()
raise
# Don't try IPv6 if we don't support it. Also skip it if host
# is 'localhost' (::1 is fine). Avoids slow connect issues
# like PYTHON-356.
family = socket.AF_INET
if socket.has_ipv6 and host != "localhost":
family = socket.AF_UNSPEC
err = None
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
af, socktype, proto, dummy, sa = res
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
# all file descriptors are created non-inheritable. See PEP 446.
try:
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
except OSError:
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
# it?
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
timeout = _csot.remaining()
if timeout is None:
timeout = options.connect_timeout
elif timeout <= 0:
raise socket.timeout("timed out")
sock.settimeout(timeout)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
_set_keepalive_times(sock)
sock.connect(sa)
return sock
except OSError as e:
err = e
sock.close()
if err is not None:
raise err
else:
# This likely means we tried to connect to an IPv6 only
# host with an OS/kernel or Python interpreter that doesn't
# support IPv6. The test case is Jython2.5.1 which doesn't
# support IPv6 at all.
raise OSError("getaddrinfo failed")
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
"""Given (host, port) and PoolOptions, return a configured socket.
Can raise socket.error, ConnectionFailure, or _CertificateError.
Sets socket's SSL and timeout options.
"""
sock = _create_connection(address, options)
ssl_context = options._ssl_context
if ssl_context is None:
sock.settimeout(options.socket_timeout)
return sock
host = address[0]
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(
None,
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
)
else:
if _IS_SYNC:
ssl_sock = ssl_context.wrap_socket(sock)
else:
if hasattr(ssl_context, "a_wrap_socket"):
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
else:
loop = asyncio.get_running_loop()
ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
except _CertificateError:
sock.close()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
class _PoolClosedError(PyMongoError):
"""Internal error raised when a thread tries to get a connection from a
closed pool.
@ -1260,7 +1002,7 @@ class Pool:
)
try:
sock = _configured_socket(self.address, self.opts)
networking_interface = _configured_socket_interface(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
with self.lock:
@ -1287,7 +1029,7 @@ class Pool:
raise
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
conn = Connection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
@ -1698,5 +1440,6 @@ class Pool:
# Avoid ResourceWarnings in Python 3
# Close all sockets without calling reset() or close() because it is
# not safe to acquire a lock in __del__.
for conn in self.conns:
conn.close_conn(None)
if _IS_SYNC:
for conn in self.conns:
conn.close_conn(None)

View File

@ -116,6 +116,7 @@ filterwarnings = [
"module:unclosed <socket.socket:ResourceWarning",
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# https://github.com/eventlet/eventlet/issues/818
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
# https://github.com/dateutil/dateutil/issues/1314

View File

@ -22,6 +22,8 @@ import sys
import warnings
from test.asynchronous import AsyncPyMongoTestCase
import pytest
sys.path[0:0] = [""]
from test import unittest
@ -30,6 +32,8 @@ from test.asynchronous.unified_format import generate_test_classes
from pymongo import AsyncMongoClient
from pymongo.asynchronous.auth_oidc import OIDCCallback
pytestmark = pytest.mark.auth
_IS_SYNC = False
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")

View File

@ -301,7 +301,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_bulk_max_message_size(self):
await self.coll.delete_many({})
self.addCleanup(self.coll.delete_many, {})
self.addAsyncCleanup(self.coll.delete_many, {})
_16_MB = 16 * 1000 * 1000
# Generate a list of documents such that the first batched OP_MSG is
# as close as possible to the 48MB limit.
@ -505,7 +505,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_single_error_ordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -547,7 +547,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_multiple_error_ordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -616,7 +616,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_single_error_unordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
@ -659,7 +659,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
async def test_multiple_error_unordered_batch(self):
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
requests: list = [
InsertOne({"b": 1, "a": 1}),
UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True),
@ -1002,7 +1002,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
await self.coll.delete_many({})
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
# Fail due to write concern support as well
# as duplicate key error on ordered batch.
@ -1077,7 +1077,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
await self.coll.delete_many({})
await self.coll.create_index("a", unique=True)
self.addCleanup(self.coll.drop_index, [("a", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
# Fail due to write concern support as well
# as duplicate key error on unordered batch.

View File

@ -746,7 +746,7 @@ class TestClient(AsyncIntegrationTest):
# Assert that if a socket is closed, a new one takes its place
async with server._pool.checkout() as conn:
conn.close_conn(None)
await conn.close_conn(None)
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"a closed socket gets replaced from the pool",
@ -1270,7 +1270,6 @@ class TestClient(AsyncIntegrationTest):
no_timeout = self.client
timeout_sec = 1
timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
self.addAsyncCleanup(timeout.close)
await no_timeout.pymongo_test.drop_collection("test")
await no_timeout.pymongo_test.test.insert_one({"x": 1})
@ -1337,7 +1336,7 @@ class TestClient(AsyncIntegrationTest):
async def test_socketKeepAlive(self):
pool = await async_get_pool(self.client)
async with pool.checkout() as conn:
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check
@ -1537,7 +1536,7 @@ class TestClient(AsyncIntegrationTest):
# Cause a network error.
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure):
await anext(cursor)
@ -1562,7 +1561,7 @@ class TestClient(AsyncIntegrationTest):
# Cause a network error on the actual socket.
pool = await async_get_pool(c)
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
# AsyncConnection.authenticate logs, but gets a socket.error. Should be
# reraised as AutoReconnect.
@ -2254,7 +2253,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
# Cause a network error.
conn = one(pool.conns)
conn.conn.close()
await conn.conn.close()
cursor = collection.find(cursor_type=CursorType.EXHAUST)
with self.assertRaises(ConnectionFailure):
@ -2282,7 +2281,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
# Cause a network error.
conn = cursor._sock_mgr.conn
conn.conn.close()
await conn.conn.close()
# A getmore fails.
with self.assertRaises(ConnectionFailure):

View File

@ -651,7 +651,6 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
_OVERHEAD = 500
internal_client = await self.async_rs_or_single_client(timeoutMS=None)
self.addAsyncCleanup(internal_client.close)
collection = internal_client.db["coll"]
self.addAsyncCleanup(collection.drop)

View File

@ -273,7 +273,7 @@ class AsyncTestCMAP(AsyncIntegrationTest):
for t in self.targets.values():
await t.join(5)
for conn in self.labels.values():
conn.close_conn(None)
await conn.close_conn(None)
self.addAsyncCleanup(cleanup)

View File

@ -1401,7 +1401,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_to_list_length(self):
coll = self.db.test
await coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
self.addAsyncCleanup(coll.drop)
c = coll.find()
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)
@ -1812,6 +1812,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
@async_client_context.require_version_min(5, 0, -1)
@async_client_context.require_no_mongos
@async_client_context.require_sync
async def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener])
@ -1821,7 +1822,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
listener.reset()
result = await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
result = list(await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
self.assertEqual(len(result), 3)

View File

@ -115,6 +115,17 @@ class TestGridfs(AsyncIntegrationTest):
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
async def test_delete_by_name(self):
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
gfs = gridfs.AsyncGridFSBucket(self.db)
await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
self.assertEqual(1, await self.db.fs.files.count_documents({}))
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
await gfs.delete_by_name("test_filename")
self.assertEqual(0, await self.db.fs.files.count_documents({}))
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
async def test_empty_file(self):
oid = await self.fs.upload_from_stream("test_filename", b"")
self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read())

View File

@ -219,7 +219,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Use Connection's API to close the socket.
conn.close_conn(None)
await conn.close_conn(None)
self.assertEqual(0, len(cx_pool.conns))
@ -232,7 +232,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
await conn.conn.close()
self.assertTrue(conn.conn_closed())
async with cx_pool.checkout() as new_connection:
@ -306,7 +306,7 @@ class TestPooling(_TestPoolingBase):
async with cx_pool.checkout() as conn:
# Simulate a closed socket without telling the Connection it's
# closed.
conn.conn.close()
await conn.conn.close()
# Swap pool's address with a bad one.
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)

View File

@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
self.deprecation_filter = DeprecationFilter()
async def asyncTearDown(self) -> None:
await super().asyncTearDown()
self.deprecation_filter.stop()
@ -196,6 +197,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
await super().asyncTearDown()
async def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()

View File

@ -45,7 +45,7 @@ from test.utils_shared import (
from bson import DBRef
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo import ASCENDING, AsyncMongoClient, monitoring
from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
@ -543,7 +543,7 @@ class TestSession(AsyncIntegrationTest):
(bucket.rename, [1, "f2"], {}),
# Delete both files so _test_ops can run these operations twice.
(bucket.delete, [1], {}),
(bucket.delete, [2], {}),
(bucket.delete_by_name, ["f"], {}),
)
async def test_gridfsbucket_cursor(self):

View File

@ -32,7 +32,7 @@ from typing import List
from bson import encode
from bson.raw_bson import RawBSONDocument
from pymongo import WriteConcern
from pymongo import WriteConcern, _csot
from pymongo.asynchronous import client_session
from pymongo.asynchronous.client_session import TransactionOptions
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
@ -295,6 +295,7 @@ class TestTransactions(AsyncTransactionsBase):
"new-name",
),
),
(bucket.delete_by_name, ("new-name",)),
]
async with client.start_session() as s, await s.start_transaction():

View File

@ -66,7 +66,7 @@ import pymongo
from bson import SON, json_util
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.objectid import ObjectId
from gridfs import AsyncGridFSBucket, GridOut
from gridfs import AsyncGridFSBucket, GridOut, NoFile
from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot
from pymongo.asynchronous.change_stream import AsyncChangeStream
from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState
@ -632,7 +632,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
# Connection errors are considered client errors.
if isinstance(error, ConnectionFailure):
self.assertNotIsInstance(error, NotPrimaryError)
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)):
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)):
pass
else:
self.assertNotIsInstance(error, PyMongoError)

View File

@ -159,6 +159,7 @@ class AsyncMockConnection:
self.cancel_context = _CancellationContext()
self.more_to_come = False
self.id = random.randint(0, 100)
self.server_connection_id = random.randint(0, 100)
def close_conn(self, reason):
pass

View File

@ -0,0 +1,493 @@
{
"description": "bypassDocumentValidation",
"schemaVersion": "1.4",
"runOnRequirements": [
{
"minServerVersion": "3.2",
"serverless": "forbid"
}
],
"createEntities": [
{
"client": {
"id": "client0",
"observeEvents": [
"commandStartedEvent"
]
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "crud"
}
},
{
"collection": {
"id": "collection0",
"database": "database0",
"collectionName": "coll"
}
}
],
"initialData": [
{
"collectionName": "coll",
"databaseName": "crud",
"documents": [
{
"_id": 1,
"x": 11
},
{
"_id": 2,
"x": 22
},
{
"_id": 3,
"x": 33
}
]
}
],
"tests": [
{
"description": "Aggregate with $out passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "aggregate",
"arguments": {
"pipeline": [
{
"$sort": {
"x": 1
}
},
{
"$match": {
"_id": {
"$gt": 1
}
}
},
{
"$out": "other_test_collection"
}
],
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"aggregate": "coll",
"pipeline": [
{
"$sort": {
"x": 1
}
},
{
"$match": {
"_id": {
"$gt": 1
}
}
},
{
"$out": "other_test_collection"
}
],
"bypassDocumentValidation": false
},
"commandName": "aggregate",
"databaseName": "crud"
}
}
]
}
]
},
{
"description": "BulkWrite passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "bulkWrite",
"arguments": {
"requests": [
{
"insertOne": {
"document": {
"_id": 4,
"x": 44
}
}
}
],
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "coll",
"documents": [
{
"_id": 4,
"x": 44
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "FindOneAndReplace passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "findOneAndReplace",
"arguments": {
"filter": {
"_id": {
"$gt": 1
}
},
"replacement": {
"x": 32
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"findAndModify": "coll",
"query": {
"_id": {
"$gt": 1
}
},
"update": {
"x": 32
},
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "FindOneAndUpdate passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "findOneAndUpdate",
"arguments": {
"filter": {
"_id": {
"$gt": 1
}
},
"update": {
"$inc": {
"x": 1
}
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"findAndModify": "coll",
"query": {
"_id": {
"$gt": 1
}
},
"update": {
"$inc": {
"x": 1
}
},
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "InsertMany passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "insertMany",
"arguments": {
"documents": [
{
"_id": 4,
"x": 44
}
],
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "coll",
"documents": [
{
"_id": 4,
"x": 44
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "InsertOne passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "insertOne",
"arguments": {
"document": {
"_id": 4,
"x": 44
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"insert": "coll",
"documents": [
{
"_id": 4,
"x": 44
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "ReplaceOne passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "replaceOne",
"arguments": {
"filter": {
"_id": {
"$gt": 1
}
},
"replacement": {
"x": 32
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"update": "coll",
"updates": [
{
"q": {
"_id": {
"$gt": 1
}
},
"u": {
"x": 32
},
"multi": {
"$$unsetOrMatches": false
},
"upsert": {
"$$unsetOrMatches": false
}
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "UpdateMany passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "updateMany",
"arguments": {
"filter": {
"_id": {
"$gt": 1
}
},
"update": {
"$inc": {
"x": 1
}
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"update": "coll",
"updates": [
{
"q": {
"_id": {
"$gt": 1
}
},
"u": {
"$inc": {
"x": 1
}
},
"multi": true,
"upsert": {
"$$unsetOrMatches": false
}
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
},
{
"description": "UpdateOne passes bypassDocumentValidation: false",
"operations": [
{
"object": "collection0",
"name": "updateOne",
"arguments": {
"filter": {
"_id": {
"$gt": 1
}
},
"update": {
"$inc": {
"x": 1
}
},
"bypassDocumentValidation": false
}
}
],
"expectEvents": [
{
"client": "client0",
"events": [
{
"commandStartedEvent": {
"command": {
"update": "coll",
"updates": [
{
"q": {
"_id": {
"$gt": 1
}
},
"u": {
"$inc": {
"x": 1
}
},
"multi": {
"$$unsetOrMatches": false
},
"upsert": {
"$$unsetOrMatches": false
}
}
],
"bypassDocumentValidation": false
}
}
}
]
}
]
}
]
}

View File

@ -0,0 +1,230 @@
{
"description": "gridfs-deleteByName",
"schemaVersion": "1.0",
"createEntities": [
{
"client": {
"id": "client0"
}
},
{
"database": {
"id": "database0",
"client": "client0",
"databaseName": "gridfs-tests"
}
},
{
"bucket": {
"id": "bucket0",
"database": "database0"
}
},
{
"collection": {
"id": "bucket0_files_collection",
"database": "database0",
"collectionName": "fs.files"
}
},
{
"collection": {
"id": "bucket0_chunks_collection",
"database": "database0",
"collectionName": "fs.chunks"
}
}
],
"initialData": [
{
"collectionName": "fs.files",
"databaseName": "gridfs-tests",
"documents": [
{
"_id": {
"$oid": "000000000000000000000001"
},
"length": 0,
"chunkSize": 4,
"uploadDate": {
"$date": "1970-01-01T00:00:00.000Z"
},
"filename": "filename",
"metadata": {}
},
{
"_id": {
"$oid": "000000000000000000000002"
},
"length": 0,
"chunkSize": 4,
"uploadDate": {
"$date": "1970-01-01T00:00:00.000Z"
},
"filename": "filename",
"metadata": {}
},
{
"_id": {
"$oid": "000000000000000000000003"
},
"length": 2,
"chunkSize": 4,
"uploadDate": {
"$date": "1970-01-01T00:00:00.000Z"
},
"filename": "filename",
"metadata": {}
},
{
"_id": {
"$oid": "000000000000000000000004"
},
"length": 8,
"chunkSize": 4,
"uploadDate": {
"$date": "1970-01-01T00:00:00.000Z"
},
"filename": "otherfilename",
"metadata": {}
}
]
},
{
"collectionName": "fs.chunks",
"databaseName": "gridfs-tests",
"documents": [
{
"_id": {
"$oid": "000000000000000000000001"
},
"files_id": {
"$oid": "000000000000000000000002"
},
"n": 0,
"data": {
"$binary": {
"base64": "",
"subType": "00"
}
}
},
{
"_id": {
"$oid": "000000000000000000000002"
},
"files_id": {
"$oid": "000000000000000000000003"
},
"n": 0,
"data": {
"$binary": {
"base64": "",
"subType": "00"
}
}
},
{
"_id": {
"$oid": "000000000000000000000003"
},
"files_id": {
"$oid": "000000000000000000000003"
},
"n": 0,
"data": {
"$binary": {
"base64": "",
"subType": "00"
}
}
},
{
"_id": {
"$oid": "000000000000000000000004"
},
"files_id": {
"$oid": "000000000000000000000004"
},
"n": 0,
"data": {
"$binary": {
"base64": "",
"subType": "00"
}
}
}
]
}
],
"tests": [
{
"description": "delete when multiple revisions of the file exist",
"operations": [
{
"name": "deleteByName",
"object": "bucket0",
"arguments": {
"filename": "filename"
}
}
],
"outcome": [
{
"collectionName": "fs.files",
"databaseName": "gridfs-tests",
"documents": [
{
"_id": {
"$oid": "000000000000000000000004"
},
"length": 8,
"chunkSize": 4,
"uploadDate": {
"$date": "1970-01-01T00:00:00.000Z"
},
"filename": "otherfilename",
"metadata": {}
}
]
},
{
"collectionName": "fs.chunks",
"databaseName": "gridfs-tests",
"documents": [
{
"_id": {
"$oid": "000000000000000000000004"
},
"files_id": {
"$oid": "000000000000000000000004"
},
"n": 0,
"data": {
"$binary": {
"base64": "",
"subType": "00"
}
}
}
]
}
]
},
{
"description": "delete when file name does not exist",
"operations": [
{
"name": "deleteByName",
"object": "bucket0",
"arguments": {
"filename": "missing-file"
},
"expectError": {
"isClientError": true
}
}
]
}
]
}

View File

@ -1616,6 +1616,50 @@
]
}
]
},
{
"description": "pinned connection is released when session ended",
"operations": [
{
"name": "startTransaction",
"object": "session0"
},
{
"name": "insertOne",
"object": "collection0",
"arguments": {
"document": {
"x": 1
},
"session": "session0"
}
},
{
"name": "commitTransaction",
"object": "session0"
},
{
"name": "endSession",
"object": "session0"
}
],
"expectEvents": [
{
"client": "client0",
"eventType": "cmap",
"events": [
{
"connectionReadyEvent": {}
},
{
"connectionCheckedOutEvent": {}
},
{
"connectionCheckedInEvent": {}
}
]
}
]
}
]
}

View File

@ -22,6 +22,8 @@ import sys
import warnings
from test import PyMongoTestCase
import pytest
sys.path[0:0] = [""]
from test import unittest
@ -30,6 +32,8 @@ from test.unified_format import generate_test_classes
from pymongo import MongoClient
from pymongo.synchronous.auth_oidc import OIDCCallback
pytestmark = pytest.mark.auth
_IS_SYNC = True
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")

View File

@ -1233,7 +1233,6 @@ class TestClient(IntegrationTest):
no_timeout = self.client
timeout_sec = 1
timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
self.addCleanup(timeout.close)
no_timeout.pymongo_test.drop_collection("test")
no_timeout.pymongo_test.test.insert_one({"x": 1})
@ -1296,7 +1295,7 @@ class TestClient(IntegrationTest):
def test_socketKeepAlive(self):
pool = get_pool(self.client)
with pool.checkout() as conn:
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
self.assertTrue(keepalive)
@no_type_check

View File

@ -647,7 +647,6 @@ class TestClientBulkWriteCSOT(IntegrationTest):
_OVERHEAD = 500
internal_client = self.rs_or_single_client(timeoutMS=None)
self.addCleanup(internal_client.close)
collection = internal_client.db["coll"]
self.addCleanup(collection.drop)

View File

@ -1801,6 +1801,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
@client_context.require_version_min(5, 0, -1)
@client_context.require_no_mongos
@client_context.require_sync
def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(event_listeners=[listener])
@ -1810,7 +1811,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
listener.reset()
result = c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
self.assertEqual(len(result), 3)

View File

@ -115,6 +115,17 @@ class TestGridfs(IntegrationTest):
self.assertEqual(0, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
def test_delete_by_name(self):
self.assertEqual(0, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
gfs = gridfs.GridFSBucket(self.db)
gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
self.assertEqual(1, self.db.fs.files.count_documents({}))
self.assertEqual(5, self.db.fs.chunks.count_documents({}))
gfs.delete_by_name("test_filename")
self.assertEqual(0, self.db.fs.files.count_documents({}))
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
def test_empty_file(self):
oid = self.fs.upload_from_stream("test_filename", b"")
self.assertEqual(b"", (self.fs.open_download_stream(oid)).read())

View File

@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(IntegrationTest):
self.deprecation_filter = DeprecationFilter()
def tearDown(self) -> None:
super().tearDown()
self.deprecation_filter.stop()
@ -194,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
)
self.knobs.disable()
super().tearDown()
def test_supported_single_statement_no_retry(self):
listener = OvertCommandListener()

View File

@ -45,7 +45,7 @@ from test.utils_shared import (
from bson import DBRef
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo import ASCENDING, MongoClient, monitoring
from pymongo import ASCENDING, MongoClient, _csot, monitoring
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
@ -543,7 +543,7 @@ class TestSession(IntegrationTest):
(bucket.rename, [1, "f2"], {}),
# Delete both files so _test_ops can run these operations twice.
(bucket.delete, [1], {}),
(bucket.delete, [2], {}),
(bucket.delete_by_name, ["f"], {}),
)
def test_gridfsbucket_cursor(self):

View File

@ -32,7 +32,7 @@ from typing import List
from bson import encode
from bson.raw_bson import RawBSONDocument
from pymongo import WriteConcern
from pymongo import WriteConcern, _csot
from pymongo.errors import (
CollectionInvalid,
ConfigurationError,
@ -287,6 +287,7 @@ class TestTransactions(TransactionsBase):
"new-name",
),
),
(bucket.delete_by_name, ("new-name",)),
]
with client.start_session() as s, s.start_transaction():

View File

@ -65,7 +65,7 @@ import pymongo
from bson import SON, json_util
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.objectid import ObjectId
from gridfs import GridFSBucket, GridOut
from gridfs import GridFSBucket, GridOut, NoFile
from pymongo import ASCENDING, CursorType, MongoClient, _csot
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
from pymongo.errors import (
@ -631,7 +631,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
# Connection errors are considered client errors.
if isinstance(error, ConnectionFailure):
self.assertNotIsInstance(error, NotPrimaryError)
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)):
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)):
pass
else:
self.assertNotIsInstance(error, PyMongoError)

View File

@ -157,6 +157,7 @@ class MockConnection:
self.cancel_context = _CancellationContext()
self.more_to_come = False
self.id = random.randint(0, 100)
self.server_connection_id = random.randint(0, 100)
def close_conn(self, reason):
pass

View File

@ -615,6 +615,10 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
# Aggregate uses "batchSize", while find uses batch_size.
elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate":
continue
elif arg_name == "bypassDocumentValidation" and (
opname == "aggregate" or "find_one_and" in opname
):
continue
elif arg_name == "timeoutMode":
raise unittest.SkipTest("PyMongo does not support timeoutMode")
# Requires boolean returnDocument.

View File

@ -47,6 +47,7 @@ replacements = {
"async_receive_message": "receive_message",
"async_receive_data": "receive_data",
"async_sendall": "sendall",
"async_socket_sendall": "sendall",
"asynchronous": "synchronous",
"Asynchronous": "Synchronous",
"AsyncBulkTestBase": "BulkTestBase",
@ -119,6 +120,9 @@ replacements = {
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"AsyncNetworkingInterface": "NetworkingInterface",
"_configured_protocol_interface": "_configured_socket_interface",
"_async_configured_socket": "_configured_socket",
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
@ -127,6 +131,7 @@ replacements = {
"async_create_barrier": "create_barrier",
"async_barrier_wait": "barrier_wait",
"async_joinall": "joinall",
"_async_create_connection": "_create_connection",
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
}