Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
51292decab
@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including:
|
||||
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
|
||||
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
|
||||
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
|
||||
- Added :meth:`gridfs.asynchronous.grid_file.AsyncGridFSBucket.delete_by_name` and :meth:`gridfs.grid_file.GridFSBucket.delete_by_name`
|
||||
for more performant deletion of a file with multiple revisions.
|
||||
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
|
||||
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
|
||||
- Added index hinting support to the
|
||||
|
||||
@ -834,6 +834,35 @@ class AsyncGridFSBucket:
|
||||
if not res.deleted_count:
|
||||
raise NoFile("no file could be deleted because none matched %s" % file_id)
|
||||
|
||||
@_csot.apply
|
||||
async def delete_by_name(
|
||||
self, filename: str, session: Optional[AsyncClientSession] = None
|
||||
) -> None:
|
||||
"""Given a filename, delete this stored file's files collection document(s)
|
||||
and associated chunks from a GridFS bucket.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = AsyncMongoClient().test
|
||||
fs = AsyncGridFSBucket(my_db)
|
||||
await fs.upload_from_stream("test_file", "data I want to store!")
|
||||
await fs.delete_by_name("test_file")
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists.
|
||||
|
||||
:param filename: The name of the file to be deleted.
|
||||
:param session: a :class:`~pymongo.client_session.AsyncClientSession`
|
||||
|
||||
.. versionadded:: 4.12
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
files = self._files.find({"filename": filename}, {"_id": 1}, session=session)
|
||||
file_ids = [file["_id"] async for file in files]
|
||||
res = await self._files.delete_many({"_id": {"$in": file_ids}}, session=session)
|
||||
await self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session)
|
||||
if not res.deleted_count:
|
||||
raise NoFile(f"no file could be deleted because none matched filename {filename!r}")
|
||||
|
||||
def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor:
|
||||
"""Find and return the files collection documents that match ``filter``
|
||||
|
||||
|
||||
@ -830,6 +830,33 @@ class GridFSBucket:
|
||||
if not res.deleted_count:
|
||||
raise NoFile("no file could be deleted because none matched %s" % file_id)
|
||||
|
||||
@_csot.apply
|
||||
def delete_by_name(self, filename: str, session: Optional[ClientSession] = None) -> None:
|
||||
"""Given a filename, delete this stored file's files collection document(s)
|
||||
and associated chunks from a GridFS bucket.
|
||||
|
||||
For example::
|
||||
|
||||
my_db = MongoClient().test
|
||||
fs = GridFSBucket(my_db)
|
||||
fs.upload_from_stream("test_file", "data I want to store!")
|
||||
fs.delete_by_name("test_file")
|
||||
|
||||
Raises :exc:`~gridfs.errors.NoFile` if no file with the given filename exists.
|
||||
|
||||
:param filename: The name of the file to be deleted.
|
||||
:param session: a :class:`~pymongo.client_session.ClientSession`
|
||||
|
||||
.. versionadded:: 4.12
|
||||
"""
|
||||
_disallow_transactions(session)
|
||||
files = self._files.find({"filename": filename}, {"_id": 1}, session=session)
|
||||
file_ids = [file["_id"] for file in files]
|
||||
res = self._files.delete_many({"_id": {"$in": file_ids}}, session=session)
|
||||
self._chunks.delete_many({"files_id": {"$in": file_ids}}, session=session)
|
||||
if not res.deleted_count:
|
||||
raise NoFile(f"no file could be deleted because none matched filename {filename!r}")
|
||||
|
||||
def find(self, *args: Any, **kwargs: Any) -> GridOutCursor:
|
||||
"""Find and return the files collection documents that match ``filter``
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class _AsyncBulk:
|
||||
self,
|
||||
collection: AsyncCollection[_DocumentType],
|
||||
ordered: bool,
|
||||
bypass_document_validation: bool,
|
||||
bypass_document_validation: Optional[bool],
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
@ -516,8 +516,8 @@ class _AsyncBulk:
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val:
|
||||
cmd["bypassDocumentValidation"] = True
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
|
||||
cmd["let"] = self.let
|
||||
if session:
|
||||
|
||||
@ -701,7 +701,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
requests: Sequence[_WriteOp[_DocumentType]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
@ -800,7 +800,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
ordered: bool,
|
||||
write_concern: WriteConcern,
|
||||
op_id: Optional[int],
|
||||
bypass_doc_val: bool,
|
||||
bypass_doc_val: Optional[bool],
|
||||
session: Optional[AsyncClientSession],
|
||||
comment: Optional[Any] = None,
|
||||
) -> Any:
|
||||
@ -814,8 +814,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def _insert_command(
|
||||
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
|
||||
) -> None:
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
if bypass_doc_val is not None:
|
||||
command["bypassDocumentValidation"] = bypass_doc_val
|
||||
|
||||
result = await conn.command(
|
||||
self._database.name,
|
||||
@ -840,7 +840,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
async def insert_one(
|
||||
self,
|
||||
document: Union[_DocumentType, RawBSONDocument],
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertOneResult:
|
||||
@ -906,7 +906,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertManyResult:
|
||||
@ -986,7 +986,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
op_id: Optional[int] = None,
|
||||
ordered: bool = True,
|
||||
bypass_doc_val: Optional[bool] = False,
|
||||
bypass_doc_val: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -1041,8 +1041,8 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
if comment is not None:
|
||||
command["comment"] = comment
|
||||
# Update command.
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
if bypass_doc_val is not None:
|
||||
command["bypassDocumentValidation"] = bypass_doc_val
|
||||
|
||||
# The command result has to be published for APM unmodified
|
||||
# so we make a shallow copy here before adding updatedExisting.
|
||||
@ -1082,7 +1082,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
op_id: Optional[int] = None,
|
||||
ordered: bool = True,
|
||||
bypass_doc_val: Optional[bool] = False,
|
||||
bypass_doc_val: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -1128,7 +1128,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[AsyncClientSession] = None,
|
||||
@ -1237,7 +1237,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -2948,6 +2948,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
returning aggregate results using a cursor.
|
||||
- `collation` (optional): An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
- `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation.
|
||||
|
||||
|
||||
:return: A :class:`~pymongo.asynchronous.command_cursor.AsyncCommandCursor` over the result
|
||||
|
||||
@ -64,11 +64,6 @@ from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.asynchronous.pool import (
|
||||
_configured_socket,
|
||||
_get_timeout_details,
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
|
||||
@ -80,12 +75,17 @@ from pymongo.errors import (
|
||||
NetworkTimeout,
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
|
||||
from pymongo.network_layer import async_socket_sendall
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
_async_configured_socket,
|
||||
_get_timeout_details,
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
@ -113,7 +113,7 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
|
||||
|
||||
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
try:
|
||||
return await _configured_socket(address, opts)
|
||||
return await _async_configured_socket(address, opts)
|
||||
except Exception as exc:
|
||||
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
|
||||
|
||||
@ -196,7 +196,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
try:
|
||||
conn = await _connect_kms(address, opts)
|
||||
try:
|
||||
await async_sendall(conn, message)
|
||||
await async_socket_sendall(conn, message)
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
|
||||
@ -2079,7 +2079,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# exhausted the result set we *must* close the socket
|
||||
# to stop the server from sending more data.
|
||||
assert conn_mgr.conn is not None
|
||||
conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
await conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
else:
|
||||
await self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr)
|
||||
if conn_mgr:
|
||||
|
||||
@ -36,7 +36,11 @@ from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
||||
from pymongo.asynchronous.pool import ( # type: ignore[attr-defined]
|
||||
AsyncConnection,
|
||||
Pool,
|
||||
_CancellationContext,
|
||||
)
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -31,20 +30,16 @@ from typing import (
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot, helpers_shared, message
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.compression_support import _NO_COMPRESSION
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.message import _OpMsg
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
async_receive_data,
|
||||
async_receive_message,
|
||||
async_sendall,
|
||||
)
|
||||
|
||||
@ -194,13 +189,13 @@ async def command(
|
||||
)
|
||||
|
||||
try:
|
||||
await async_sendall(conn.conn, msg)
|
||||
await async_sendall(conn.conn.get_conn, msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
response_doc: _DocumentOut = {"ok": 1}
|
||||
else:
|
||||
reply = await receive_message(conn, request_id)
|
||||
reply = await async_receive_message(conn, request_id)
|
||||
conn.more_to_come = reply.more_to_come
|
||||
unpacked_docs = reply.unpack_response(
|
||||
codec_options=codec_options, user_fields=user_fields
|
||||
@ -301,47 +296,3 @@ async def command(
|
||||
)
|
||||
|
||||
return response_doc # type: ignore[return-value]
|
||||
|
||||
|
||||
async def receive_message(
|
||||
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
||||
await async_receive_data(conn, 9, deadline)
|
||||
)
|
||||
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = await async_receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
@ -14,14 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
@ -40,8 +36,8 @@ from typing import (
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
from pymongo import _csot, helpers_shared
|
||||
from pymongo.asynchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
|
||||
from pymongo.asynchronous.network import command, receive_message
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.asynchronous.network import command
|
||||
from pymongo.common import (
|
||||
MAX_BSON_SIZE,
|
||||
MAX_MESSAGE_SIZE,
|
||||
@ -52,16 +48,13 @@ from pymongo.common import (
|
||||
from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
AutoReconnect,
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
DocumentTooLarge,
|
||||
ExecutionTimeout,
|
||||
InvalidOperation,
|
||||
NetworkTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
WaitQueueTimeoutError,
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import (
|
||||
@ -79,13 +72,20 @@ from pymongo.monitoring import (
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason,
|
||||
)
|
||||
from pymongo.network_layer import async_sendall
|
||||
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
_CancellationContext,
|
||||
_configured_protocol_interface,
|
||||
_get_timeout_details,
|
||||
_raise_connection_failure,
|
||||
format_timeout_details,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
from pymongo.ssl_support import SSLError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
@ -99,7 +99,6 @@ if TYPE_CHECKING:
|
||||
ZstdContext,
|
||||
)
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.typings import _Address, _CollationIn
|
||||
@ -123,133 +122,6 @@ except ImportError:
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_MAX_TCP_KEEPIDLE = 120
|
||||
_MAX_TCP_KEEPINTVL = 10
|
||||
_MAX_TCP_KEEPCNT = 9
|
||||
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
import _winreg as winreg
|
||||
except ImportError:
|
||||
import winreg
|
||||
|
||||
def _query(key, name, default):
|
||||
try:
|
||||
value, _ = winreg.QueryValueEx(key, name)
|
||||
# Ensure the value is a number or raise ValueError.
|
||||
return int(value)
|
||||
except (OSError, ValueError):
|
||||
# QueryValueEx raises OSError when the key does not exist (i.e.
|
||||
# the system is using the Windows default value).
|
||||
return default
|
||||
|
||||
try:
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
|
||||
) as key:
|
||||
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
|
||||
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
|
||||
except OSError:
|
||||
# We could not check the default values because winreg.OpenKey failed.
|
||||
# Assume the system is using the default values.
|
||||
_WINDOWS_TCP_IDLE_MS = 7200000
|
||||
_WINDOWS_TCP_INTERVAL_MS = 1000
|
||||
|
||||
def _set_keepalive_times(sock):
|
||||
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
|
||||
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
|
||||
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
|
||||
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
|
||||
|
||||
else:
|
||||
|
||||
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
|
||||
if hasattr(socket, tcp_option):
|
||||
sockopt = getattr(socket, tcp_option)
|
||||
try:
|
||||
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
|
||||
# TCP_KEEPIDLE and friends. Don't attempt to set the
|
||||
# values there.
|
||||
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
|
||||
if default > max_value:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _set_keepalive_times(sock: socket.socket) -> None:
|
||||
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
|
||||
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
|
||||
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
||||
|
||||
|
||||
def _raise_connection_failure(
|
||||
address: Any,
|
||||
error: Exception,
|
||||
msg_prefix: Optional[str] = None,
|
||||
timeout_details: Optional[dict[str, float]] = None,
|
||||
) -> NoReturn:
|
||||
"""Convert a socket.error to ConnectionFailure and raise it."""
|
||||
host, port = address
|
||||
# If connecting to a Unix socket, port will be None.
|
||||
if port is not None:
|
||||
msg = "%s:%d: %s" % (host, port, error)
|
||||
else:
|
||||
msg = f"{host}: {error}"
|
||||
if msg_prefix:
|
||||
msg = msg_prefix + msg
|
||||
if "configured timeouts" not in msg:
|
||||
msg += format_timeout_details(timeout_details)
|
||||
if isinstance(error, socket.timeout):
|
||||
raise NetworkTimeout(msg) from error
|
||||
elif isinstance(error, SSLError) and "timed out" in str(error):
|
||||
# Eventlet does not distinguish TLS network timeouts from other
|
||||
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
|
||||
# Luckily, we can work around this limitation because the phrase
|
||||
# 'timed out' appears in all the timeout related SSLErrors raised.
|
||||
raise NetworkTimeout(msg) from error
|
||||
else:
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
socket_timeout = options.socket_timeout
|
||||
connect_timeout = options.connect_timeout
|
||||
if timeout:
|
||||
details["timeoutMS"] = timeout * 1000
|
||||
if socket_timeout and not timeout:
|
||||
details["socketTimeoutMS"] = socket_timeout * 1000
|
||||
if connect_timeout:
|
||||
details["connectTimeoutMS"] = connect_timeout * 1000
|
||||
return details
|
||||
|
||||
|
||||
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
||||
result = ""
|
||||
if details:
|
||||
result += " (configured timeouts:"
|
||||
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
|
||||
if timeout in details:
|
||||
result += f" {timeout}: {details[timeout]}ms,"
|
||||
result = result[:-1]
|
||||
result += ")"
|
||||
return result
|
||||
|
||||
|
||||
class _CancellationContext:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this context."""
|
||||
self._cancelled = True
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Was cancel called?"""
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class AsyncConnection:
|
||||
"""Store a connection with some metadata.
|
||||
@ -261,7 +133,11 @@ class AsyncConnection:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
|
||||
self,
|
||||
conn: AsyncNetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.conn = conn
|
||||
@ -318,7 +194,7 @@ class AsyncConnection:
|
||||
if timeout == self.last_timeout:
|
||||
return
|
||||
self.last_timeout = timeout
|
||||
self.conn.settimeout(timeout)
|
||||
self.conn.get_conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(
|
||||
self, client: AsyncMongoClient, cmd: Optional[MutableMapping[str, Any]]
|
||||
@ -364,7 +240,7 @@ class AsyncConnection:
|
||||
if pool:
|
||||
await pool.checkin(self)
|
||||
else:
|
||||
self.close_conn(ConnectionClosedReason.STALE)
|
||||
await self.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def hello_cmd(self) -> dict[str, Any]:
|
||||
# Handshake spec requires us to use OP_MSG+hello command for the
|
||||
@ -559,7 +435,7 @@ class AsyncConnection:
|
||||
raise
|
||||
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
await self._raise_connection_failure(error)
|
||||
|
||||
async def send_message(self, message: bytes, max_doc_size: int) -> None:
|
||||
"""Send a raw BSON message or raise ConnectionFailure.
|
||||
@ -573,10 +449,10 @@ class AsyncConnection:
|
||||
)
|
||||
|
||||
try:
|
||||
await async_sendall(self.conn, message)
|
||||
await async_sendall(self.conn.get_conn, message)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
await self._raise_connection_failure(error)
|
||||
|
||||
async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise ConnectionFailure.
|
||||
@ -584,10 +460,10 @@ class AsyncConnection:
|
||||
If any exception is raised, the socket is closed.
|
||||
"""
|
||||
try:
|
||||
return await receive_message(self, request_id, self.max_message_size)
|
||||
return await async_receive_message(self, request_id, self.max_message_size)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
await self._raise_connection_failure(error)
|
||||
|
||||
def _raise_if_not_writable(self, unacknowledged: bool) -> None:
|
||||
"""Raise NotPrimaryError on unacknowledged write if this socket is not
|
||||
@ -673,11 +549,11 @@ class AsyncConnection:
|
||||
"Can only use session with the AsyncMongoClient that started it"
|
||||
)
|
||||
|
||||
def close_conn(self, reason: Optional[str]) -> None:
|
||||
async def close_conn(self, reason: Optional[str]) -> None:
|
||||
"""Close this connection with a reason."""
|
||||
if self.closed:
|
||||
return
|
||||
self._close_conn()
|
||||
await self._close_conn()
|
||||
if reason:
|
||||
if self.enabled_for_cmap:
|
||||
assert self.listeners is not None
|
||||
@ -694,7 +570,7 @@ class AsyncConnection:
|
||||
error=reason,
|
||||
)
|
||||
|
||||
def _close_conn(self) -> None:
|
||||
async def _close_conn(self) -> None:
|
||||
"""Close this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
@ -703,13 +579,16 @@ class AsyncConnection:
|
||||
# Note: We catch exceptions to avoid spurious errors on interpreter
|
||||
# shutdown.
|
||||
try:
|
||||
self.conn.close()
|
||||
await self.conn.close()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
return self.socket_checker.socket_closed(self.conn)
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
def send_cluster_time(
|
||||
self,
|
||||
@ -736,7 +615,7 @@ class AsyncConnection:
|
||||
"""Seconds since this socket was last checked into its pool."""
|
||||
return time.monotonic() - self.last_checkin_time
|
||||
|
||||
def _raise_connection_failure(self, error: BaseException) -> NoReturn:
|
||||
async def _raise_connection_failure(self, error: BaseException) -> NoReturn:
|
||||
# Catch *all* exceptions from socket methods and close the socket. In
|
||||
# regular Python, socket operations only raise socket.error, even if
|
||||
# the underlying cause was a Ctrl-C: a signal raised during socket.recv
|
||||
@ -756,7 +635,7 @@ class AsyncConnection:
|
||||
reason = None
|
||||
else:
|
||||
reason = ConnectionClosedReason.ERROR
|
||||
self.close_conn(reason)
|
||||
await self.close_conn(reason)
|
||||
# SSLError from PyOpenSSL inherits directly from Exception.
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
details = _get_timeout_details(self.opts)
|
||||
@ -781,145 +660,6 @@ class AsyncConnection:
|
||||
)
|
||||
|
||||
|
||||
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
|
||||
This is a modified version of create_connection from CPython >= 2.7.
|
||||
"""
|
||||
host, port = address
|
||||
|
||||
# Check if dealing with a unix domain socket
|
||||
if host.endswith(".sock"):
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
# SOCK_CLOEXEC not supported for Unix sockets.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.connect(host)
|
||||
return sock
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Don't try IPv6 if we don't support it. Also skip it if host
|
||||
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
||||
# like PYTHON-356.
|
||||
family = socket.AF_INET
|
||||
if socket.has_ipv6 and host != "localhost":
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
# all file descriptors are created non-inheritable. See PEP 446.
|
||||
try:
|
||||
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
||||
except OSError:
|
||||
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
||||
# it?
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
# Fallback when SOCK_CLOEXEC isn't available.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# CSOT: apply timeout to socket connect.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
timeout = options.connect_timeout
|
||||
elif timeout <= 0:
|
||||
raise socket.timeout("timed out")
|
||||
sock.settimeout(timeout)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
||||
_set_keepalive_times(sock)
|
||||
sock.connect(sa)
|
||||
return sock
|
||||
except OSError as e:
|
||||
err = e
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
else:
|
||||
# This likely means we tried to connect to an IPv6 only
|
||||
# host with an OS/kernel or Python interpreter that doesn't
|
||||
# support IPv6. The test case is Jython2.5.1 which doesn't
|
||||
# support IPv6 at all.
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
async def _configured_socket(
|
||||
address: _Address, options: PoolOptions
|
||||
) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = await _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if _IS_SYNC:
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
|
||||
)
|
||||
else:
|
||||
if _IS_SYNC:
|
||||
ssl_sock = ssl_context.wrap_socket(sock)
|
||||
else:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
class _PoolClosedError(PyMongoError):
|
||||
"""Internal error raised when a thread tries to get a connection from a
|
||||
closed pool.
|
||||
@ -1121,7 +861,7 @@ class Pool:
|
||||
# publishing the PoolClearedEvent.
|
||||
if close:
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_pool_closed(self.address)
|
||||
@ -1152,7 +892,7 @@ class Pool:
|
||||
serviceId=service_id,
|
||||
)
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
@ -1197,7 +937,7 @@ class Pool:
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
conn = self.conns.pop()
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
async with self.size_cond:
|
||||
@ -1221,7 +961,7 @@ class Pool:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.conns.appendleft(conn)
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
@ -1266,7 +1006,7 @@ class Pool:
|
||||
)
|
||||
|
||||
try:
|
||||
sock = await _configured_socket(self.address, self.opts)
|
||||
networking_interface = await _configured_protocol_interface(self.address, self.opts)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
async with self.lock:
|
||||
@ -1293,7 +1033,7 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
conn = AsyncConnection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
async with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
@ -1311,7 +1051,7 @@ class Pool:
|
||||
except BaseException:
|
||||
async with self.lock:
|
||||
self.active_contexts.discard(conn.cancel_context)
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
raise
|
||||
|
||||
if handler:
|
||||
@ -1509,7 +1249,7 @@ class Pool:
|
||||
except IndexError:
|
||||
self._pending += 1
|
||||
if conn: # We got a socket from the pool
|
||||
if self._perished(conn):
|
||||
if await self._perished(conn):
|
||||
conn = None
|
||||
continue
|
||||
else: # We need to create a new connection
|
||||
@ -1523,7 +1263,7 @@ class Pool:
|
||||
except BaseException:
|
||||
if conn:
|
||||
# We checked out a socket but authentication failed.
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
async with self.size_cond:
|
||||
self.requests -= 1
|
||||
if incremented:
|
||||
@ -1583,7 +1323,7 @@ class Pool:
|
||||
await self.reset_without_pause()
|
||||
else:
|
||||
if self.closed:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif conn.closed:
|
||||
# CMAP requires the closed event be emitted after the check in.
|
||||
if self.enabled_for_cmap:
|
||||
@ -1607,7 +1347,7 @@ class Pool:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
else:
|
||||
conn.update_last_checkin_time()
|
||||
conn.update_is_writable(bool(self.is_writable))
|
||||
@ -1625,7 +1365,7 @@ class Pool:
|
||||
self.operation_count -= 1
|
||||
self.size_cond.notify()
|
||||
|
||||
def _perished(self, conn: AsyncConnection) -> bool:
|
||||
async def _perished(self, conn: AsyncConnection) -> bool:
|
||||
"""Return True and close the connection if it is "perished".
|
||||
|
||||
This side-effecty function checks if this socket has been idle for
|
||||
@ -1645,18 +1385,18 @@ class Pool:
|
||||
self.opts.max_idle_time_seconds is not None
|
||||
and idle_time_seconds > self.opts.max_idle_time_seconds
|
||||
):
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
return True
|
||||
|
||||
if self._check_interval_seconds is not None and (
|
||||
self._check_interval_seconds == 0 or idle_time_seconds > self._check_interval_seconds
|
||||
):
|
||||
if conn.conn_closed():
|
||||
conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
await conn.close_conn(ConnectionClosedReason.ERROR)
|
||||
return True
|
||||
|
||||
if self.stale_generation(conn.generation, conn.service_id):
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -1704,5 +1444,6 @@ class Pool:
|
||||
# Avoid ResourceWarnings in Python 3
|
||||
# Close all sockets without calling reset() or close() because it is
|
||||
# not safe to acquire a lock in __del__.
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
if _IS_SYNC:
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
|
||||
@ -16,21 +16,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import errno
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import _csot, ssl_support
|
||||
from pymongo._asyncio_task import create_task
|
||||
from pymongo.errors import _OperationCancelled
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import decompress
|
||||
from pymongo.errors import ProtocolError, _OperationCancelled
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
|
||||
try:
|
||||
@ -69,13 +74,15 @@ _POLL_TIMEOUT = 0.5
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
|
||||
# the MongoDB wire protocol
|
||||
async def async_socket_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
|
||||
timeout = sock.gettimeout()
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout)
|
||||
await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout)
|
||||
else:
|
||||
await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type]
|
||||
except asyncio.TimeoutError as exc:
|
||||
@ -87,7 +94,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non
|
||||
|
||||
if sys.platform != "win32":
|
||||
|
||||
async def _async_sendall_ssl(
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
@ -130,7 +137,7 @@ if sys.platform != "win32":
|
||||
loop.remove_reader(fd)
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_receive_ssl(
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
@ -184,7 +191,7 @@ else:
|
||||
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
|
||||
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
|
||||
# Note: In PYTHON-4493 we plan to replace this code with asyncio streams.
|
||||
async def _async_sendall_ssl(
|
||||
async def _async_socket_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
@ -205,7 +212,7 @@ else:
|
||||
backoff = min(backoff * 2, 0.512)
|
||||
total_sent += sent
|
||||
|
||||
async def _async_receive_ssl(
|
||||
async def _async_socket_receive_ssl(
|
||||
conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
@ -244,52 +251,6 @@ async def _poll_cancellation(conn: AsyncConnection) -> None:
|
||||
await asyncio.sleep(_POLL_TIMEOUT)
|
||||
|
||||
|
||||
async def async_receive_data(
|
||||
conn: AsyncConnection, length: int, deadline: Optional[float]
|
||||
) -> memoryview:
|
||||
sock = conn.conn
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout: Optional[Union[float, int]]
|
||||
if deadline:
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
timeout = max(deadline - time.monotonic(), 0)
|
||||
else:
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_running_loop()
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
|
||||
else:
|
||||
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
|
||||
tasks = [read_task, cancellation_task]
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
return read_task.result()
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.wait(tasks)
|
||||
raise
|
||||
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def async_receive_data_socket(
|
||||
sock: Union[socket.socket, _sslConn], length: int
|
||||
) -> memoryview:
|
||||
@ -301,18 +262,23 @@ async def async_receive_data_socket(
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
return await asyncio.wait_for(
|
||||
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
_async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
|
||||
return await asyncio.wait_for(
|
||||
_async_socket_receive(sock, length, loop), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise socket.timeout("timed out") from err
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
|
||||
async def _async_socket_receive(
|
||||
conn: socket.socket, length: int, loop: AbstractEventLoop
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
while bytes_read < length:
|
||||
@ -328,7 +294,7 @@ _PYPY = "PyPy" in sys.version
|
||||
|
||||
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
|
||||
"""Block until at least one byte is read, or a timeout, or a cancel."""
|
||||
sock = conn.conn
|
||||
sock = conn.conn.sock
|
||||
timed_out = False
|
||||
# Check if the connection's socket has been manually closed
|
||||
if sock.fileno() == -1:
|
||||
@ -413,3 +379,403 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
|
||||
conn.set_conn_timeout(orig_timeout)
|
||||
|
||||
return mv
|
||||
|
||||
|
||||
class NetworkingInterfaceBase:
|
||||
def __init__(self, conn: Any):
|
||||
self.conn = conn
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def get_conn(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def sock(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncNetworkingInterface(NetworkingInterfaceBase):
|
||||
def __init__(self, conn: tuple[Transport, PyMongoProtocol]):
|
||||
super().__init__(conn)
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.conn[1].gettimeout
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self.conn[1].settimeout(timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.conn[1].close()
|
||||
await self.conn[1].wait_closed()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self.conn[0].is_closing()
|
||||
|
||||
@property
|
||||
def get_conn(self) -> PyMongoProtocol:
|
||||
return self.conn[1]
|
||||
|
||||
@property
|
||||
def sock(self) -> socket.socket:
|
||||
return self.conn[0].get_extra_info("socket")
|
||||
|
||||
|
||||
class NetworkingInterface(NetworkingInterfaceBase):
|
||||
def __init__(self, conn: Union[socket.socket, _sslConn]):
|
||||
super().__init__(conn)
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self.conn.gettimeout()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self.conn.settimeout(timeout)
|
||||
|
||||
def close(self) -> None:
|
||||
self.conn.close()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self.conn.is_closing()
|
||||
|
||||
@property
|
||||
def get_conn(self) -> Union[socket.socket, _sslConn]:
|
||||
return self.conn
|
||||
|
||||
@property
|
||||
def sock(self) -> Union[socket.socket, _sslConn]:
|
||||
return self.conn
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.conn.fileno()
|
||||
|
||||
def recv_into(self, buffer: bytes) -> int:
|
||||
return self.conn.recv_into(buffer)
|
||||
|
||||
|
||||
class PyMongoProtocol(BufferedProtocol):
|
||||
def __init__(self, timeout: Optional[float] = None):
|
||||
self.transport: Transport = None # type: ignore[assignment]
|
||||
# Each message is reader in 2-3 parts: header, compression header, and message body
|
||||
# The message buffer is allocated after the header is read.
|
||||
self._header = memoryview(bytearray(16))
|
||||
self._header_index = 0
|
||||
self._compression_header = memoryview(bytearray(9))
|
||||
self._compression_index = 0
|
||||
self._message: Optional[memoryview] = None
|
||||
self._message_index = 0
|
||||
# State. TODO: replace booleans with an enum?
|
||||
self._expecting_header = True
|
||||
self._expecting_compression = False
|
||||
self._message_size = 0
|
||||
self._op_code = 0
|
||||
self._connection_lost = False
|
||||
self._read_waiter: Optional[Future] = None
|
||||
self._timeout = timeout
|
||||
self._is_compressed = False
|
||||
self._compressor_id: Optional[int] = None
|
||||
self._max_message_size = MAX_MESSAGE_SIZE
|
||||
self._response_to: Optional[int] = None
|
||||
self._closed = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages: collections.deque[Future] = collections.deque()
|
||||
self._done_messages: collections.deque[Future] = collections.deque()
|
||||
|
||||
def settimeout(self, timeout: float | None) -> None:
|
||||
self._timeout = timeout
|
||||
|
||||
@property
|
||||
def gettimeout(self) -> float | None:
|
||||
"""The configured timeout for the socket that underlies our protocol pair."""
|
||||
return self._timeout
|
||||
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called exactly once when a connection is made.
|
||||
The transport argument is the transport representing the write side of the connection.
|
||||
"""
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE)
|
||||
|
||||
async def write(self, message: bytes) -> None:
|
||||
"""Write a message to this connection's transport."""
|
||||
if self.transport.is_closing():
|
||||
raise OSError("Connection is closed")
|
||||
self.transport.write(message)
|
||||
self.transport.resume_reading()
|
||||
|
||||
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
|
||||
"""Read a single MongoDB Wire Protocol message from this connection."""
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.resume_reading()
|
||||
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
|
||||
except AttributeError:
|
||||
raise OSError("connection is already closed") from None
|
||||
self._max_message_size = max_message_size
|
||||
if self._done_messages:
|
||||
message = await self._done_messages.popleft()
|
||||
else:
|
||||
if self.transport and self.transport.is_closing():
|
||||
raise OSError("connection is already closed")
|
||||
read_waiter = asyncio.get_running_loop().create_future()
|
||||
self._pending_messages.append(read_waiter)
|
||||
try:
|
||||
message = await read_waiter
|
||||
finally:
|
||||
if read_waiter in self._done_messages:
|
||||
self._done_messages.remove(read_waiter)
|
||||
if message:
|
||||
op_code, compressor_id, response_to, data = message
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(
|
||||
f"Got response id {response_to!r} but expected {request_id!r}"
|
||||
)
|
||||
if compressor_id is not None:
|
||||
data = decompress(data, compressor_id)
|
||||
return data, op_code
|
||||
raise OSError("connection closed")
|
||||
|
||||
def get_buffer(self, sizehint: int) -> memoryview:
|
||||
"""Called to allocate a new receive buffer.
|
||||
The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
|
||||
If any data does not fit into the returned buffer, this method will be called again until
|
||||
either no data remains or an empty buffer is returned.
|
||||
"""
|
||||
# Due to a bug, Python <=3.11 will call get_buffer() even after we raise
|
||||
# ProtocolError in buffer_updated() and call connection_lost(). We allocate
|
||||
# a temp buffer to drain the waiting data.
|
||||
if self._connection_lost:
|
||||
if not self._message:
|
||||
self._message = memoryview(bytearray(2**14))
|
||||
return self._message
|
||||
# TODO: optimize this by caching pointers to the buffers.
|
||||
# return self._buffer[self._index:]
|
||||
if self._expecting_header:
|
||||
return self._header[self._header_index :]
|
||||
if self._expecting_compression:
|
||||
return self._compression_header[self._compression_index :]
|
||||
return self._message[self._message_index :] # type: ignore[index]
|
||||
|
||||
def buffer_updated(self, nbytes: int) -> None:
|
||||
"""Called when the buffer was updated with the received data"""
|
||||
# Wrote 0 bytes into a non-empty buffer, signal connection closed
|
||||
if nbytes == 0:
|
||||
self.close(OSError("connection closed"))
|
||||
return
|
||||
if self._connection_lost:
|
||||
return
|
||||
if self._expecting_header:
|
||||
self._header_index += nbytes
|
||||
if self._header_index >= 16:
|
||||
self._expecting_header = False
|
||||
try:
|
||||
(
|
||||
self._message_size,
|
||||
self._op_code,
|
||||
self._response_to,
|
||||
self._expecting_compression,
|
||||
) = self.process_header()
|
||||
except ProtocolError as exc:
|
||||
self.close(exc)
|
||||
return
|
||||
self._message = memoryview(bytearray(self._message_size))
|
||||
return
|
||||
if self._expecting_compression:
|
||||
self._compression_index += nbytes
|
||||
if self._compression_index >= 9:
|
||||
self._expecting_compression = False
|
||||
self._op_code, self._compressor_id = self.process_compression_header()
|
||||
return
|
||||
|
||||
self._message_index += nbytes
|
||||
if self._message_index >= self._message_size:
|
||||
self._expecting_header = True
|
||||
# Pause reading to avoid storing an arbitrary number of messages in memory.
|
||||
self.transport.pause_reading()
|
||||
if self._pending_messages:
|
||||
result = self._pending_messages.popleft()
|
||||
else:
|
||||
result = asyncio.get_running_loop().create_future()
|
||||
# Future has been cancelled, close this connection
|
||||
if result.done():
|
||||
self.close(None)
|
||||
return
|
||||
# Necessary values to reconstruct and verify message
|
||||
result.set_result(
|
||||
(self._op_code, self._compressor_id, self._response_to, self._message)
|
||||
)
|
||||
self._done_messages.append(result)
|
||||
# Reset internal state to expect a new message
|
||||
self._header_index = 0
|
||||
self._compression_index = 0
|
||||
self._message_index = 0
|
||||
self._message_size = 0
|
||||
self._message = None
|
||||
self._op_code = 0
|
||||
self._compressor_id = None
|
||||
self._response_to = None
|
||||
|
||||
def process_header(self) -> tuple[int, int, int, bool]:
|
||||
"""Unpack a MongoDB Wire Protocol header."""
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
|
||||
expecting_compression = False
|
||||
if op_code == 2012: # OP_COMPRESSED
|
||||
if length <= 25:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
|
||||
)
|
||||
expecting_compression = True
|
||||
length -= 9
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > self._max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({self._max_message_size!r})"
|
||||
)
|
||||
|
||||
return length - 16, op_code, response_to, expecting_compression
|
||||
|
||||
def process_compression_header(self) -> tuple[int, int]:
|
||||
"""Unpack a MongoDB Wire Protocol compression header."""
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
|
||||
return op_code, compressor_id
|
||||
|
||||
def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None:
|
||||
pending = list(self._pending_messages)
|
||||
for msg in pending:
|
||||
if not msg.done():
|
||||
if exc is None:
|
||||
msg.set_result(None)
|
||||
else:
|
||||
msg.set_exception(exc)
|
||||
self._done_messages.append(msg)
|
||||
|
||||
def close(self, exc: Optional[Exception] = None) -> None:
|
||||
self.transport.abort()
|
||||
self._resolve_pending_messages(exc)
|
||||
self._connection_lost = True
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception] = None) -> None:
|
||||
self._resolve_pending_messages(exc)
|
||||
if not self._closed.done():
|
||||
self._closed.set_result(None)
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed
|
||||
|
||||
|
||||
async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout)
|
||||
except asyncio.TimeoutError as exc:
|
||||
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
|
||||
raise socket.timeout("timed out") from exc
|
||||
|
||||
|
||||
async def async_receive_message(
|
||||
conn: AsyncConnection,
|
||||
request_id: Optional[int],
|
||||
max_message_size: int = MAX_MESSAGE_SIZE,
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
timeout: Optional[Union[float, int]]
|
||||
timeout = conn.conn.gettimeout
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
if deadline:
|
||||
# When the timeout has expired perform one final check to
|
||||
# see if the socket is readable. This helps avoid spurious
|
||||
# timeouts on AWS Lambda and other FaaS environments.
|
||||
timeout = max(deadline - time.monotonic(), 0)
|
||||
|
||||
cancellation_task = create_task(_poll_cancellation(conn))
|
||||
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
|
||||
tasks = [read_task, cancellation_task]
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
await asyncio.wait(pending)
|
||||
if len(done) == 0:
|
||||
raise socket.timeout("timed out")
|
||||
if read_task in done:
|
||||
data, op_code = read_task.result()
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
raise _OperationCancelled("operation cancelled")
|
||||
except asyncio.CancelledError:
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.wait(tasks)
|
||||
raise
|
||||
|
||||
|
||||
def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
|
||||
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
546
pymongo/pool_shared.py
Normal file
546
pymongo/pool_shared.py
Normal file
@ -0,0 +1,546 @@
|
||||
# Copyright 2025-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pool utilities and shared helper methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pymongo import _csot
|
||||
from pymongo.asynchronous.helpers import _getaddrinfo
|
||||
from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
AutoReconnect,
|
||||
ConnectionFailure,
|
||||
NetworkTimeout,
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.typings import _Address
|
||||
|
||||
try:
|
||||
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
|
||||
|
||||
def _set_non_inheritable_non_atomic(fd: int) -> None:
|
||||
"""Set the close-on-exec flag on the given file descriptor."""
|
||||
flags = fcntl(fd, F_GETFD)
|
||||
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
|
||||
|
||||
except ImportError:
|
||||
# Windows, various platforms we don't claim to support
|
||||
# (Jython, IronPython, ..), systems that don't provide
|
||||
# everything we need from fcntl, etc.
|
||||
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
|
||||
"""Dummy function for platforms that don't provide fcntl."""
|
||||
|
||||
|
||||
_MAX_TCP_KEEPIDLE = 120
|
||||
_MAX_TCP_KEEPINTVL = 10
|
||||
_MAX_TCP_KEEPCNT = 9
|
||||
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
import _winreg as winreg
|
||||
except ImportError:
|
||||
import winreg
|
||||
|
||||
def _query(key, name, default):
|
||||
try:
|
||||
value, _ = winreg.QueryValueEx(key, name)
|
||||
# Ensure the value is a number or raise ValueError.
|
||||
return int(value)
|
||||
except (OSError, ValueError):
|
||||
# QueryValueEx raises OSError when the key does not exist (i.e.
|
||||
# the system is using the Windows default value).
|
||||
return default
|
||||
|
||||
try:
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
|
||||
) as key:
|
||||
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
|
||||
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
|
||||
except OSError:
|
||||
# We could not check the default values because winreg.OpenKey failed.
|
||||
# Assume the system is using the default values.
|
||||
_WINDOWS_TCP_IDLE_MS = 7200000
|
||||
_WINDOWS_TCP_INTERVAL_MS = 1000
|
||||
|
||||
def _set_keepalive_times(sock):
|
||||
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
|
||||
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
|
||||
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
|
||||
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
|
||||
|
||||
else:
|
||||
|
||||
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
|
||||
if hasattr(socket, tcp_option):
|
||||
sockopt = getattr(socket, tcp_option)
|
||||
try:
|
||||
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
|
||||
# TCP_KEEPIDLE and friends. Don't attempt to set the
|
||||
# values there.
|
||||
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
|
||||
if default > max_value:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _set_keepalive_times(sock: socket.socket) -> None:
|
||||
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
|
||||
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
|
||||
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
||||
|
||||
|
||||
def _raise_connection_failure(
|
||||
address: Any,
|
||||
error: Exception,
|
||||
msg_prefix: Optional[str] = None,
|
||||
timeout_details: Optional[dict[str, float]] = None,
|
||||
) -> NoReturn:
|
||||
"""Convert a socket.error to ConnectionFailure and raise it."""
|
||||
host, port = address
|
||||
# If connecting to a Unix socket, port will be None.
|
||||
if port is not None:
|
||||
msg = "%s:%d: %s" % (host, port, error)
|
||||
else:
|
||||
msg = f"{host}: {error}"
|
||||
if msg_prefix:
|
||||
msg = msg_prefix + msg
|
||||
if "configured timeouts" not in msg:
|
||||
msg += format_timeout_details(timeout_details)
|
||||
if isinstance(error, socket.timeout):
|
||||
raise NetworkTimeout(msg) from error
|
||||
elif isinstance(error, SSLError) and "timed out" in str(error):
|
||||
# Eventlet does not distinguish TLS network timeouts from other
|
||||
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
|
||||
# Luckily, we can work around this limitation because the phrase
|
||||
# 'timed out' appears in all the timeout related SSLErrors raised.
|
||||
raise NetworkTimeout(msg) from error
|
||||
else:
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
socket_timeout = options.socket_timeout
|
||||
connect_timeout = options.connect_timeout
|
||||
if timeout:
|
||||
details["timeoutMS"] = timeout * 1000
|
||||
if socket_timeout and not timeout:
|
||||
details["socketTimeoutMS"] = socket_timeout * 1000
|
||||
if connect_timeout:
|
||||
details["connectTimeoutMS"] = connect_timeout * 1000
|
||||
return details
|
||||
|
||||
|
||||
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
||||
result = ""
|
||||
if details:
|
||||
result += " (configured timeouts:"
|
||||
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
|
||||
if timeout in details:
|
||||
result += f" {timeout}: {details[timeout]}ms,"
|
||||
result = result[:-1]
|
||||
result += ")"
|
||||
return result
|
||||
|
||||
|
||||
class _CancellationContext:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this context."""
|
||||
self._cancelled = True
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Was cancel called?"""
|
||||
return self._cancelled
|
||||
|
||||
|
||||
async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
|
||||
This is a modified version of create_connection from CPython >= 2.7.
|
||||
"""
|
||||
host, port = address
|
||||
|
||||
# Check if dealing with a unix domain socket
|
||||
if host.endswith(".sock"):
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
# SOCK_CLOEXEC not supported for Unix sockets.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.connect(host)
|
||||
return sock
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Don't try IPv6 if we don't support it. Also skip it if host
|
||||
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
||||
# like PYTHON-356.
|
||||
family = socket.AF_INET
|
||||
if socket.has_ipv6 and host != "localhost":
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM):
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
# all file descriptors are created non-inheritable. See PEP 446.
|
||||
try:
|
||||
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
||||
except OSError:
|
||||
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
||||
# it?
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
# Fallback when SOCK_CLOEXEC isn't available.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# CSOT: apply timeout to socket connect.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
timeout = options.connect_timeout
|
||||
elif timeout <= 0:
|
||||
raise socket.timeout("timed out")
|
||||
sock.settimeout(timeout)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
||||
_set_keepalive_times(sock)
|
||||
sock.connect(sa)
|
||||
return sock
|
||||
except OSError as e:
|
||||
err = e
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
else:
|
||||
# This likely means we tried to connect to an IPv6 only
|
||||
# host with an OS/kernel or Python interpreter that doesn't
|
||||
# support IPv6. The test case is Jython2.5.1 which doesn't
|
||||
# support IPv6 at all.
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
async def _async_configured_socket(
|
||||
address: _Address, options: PoolOptions
|
||||
) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = await _async_create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = await ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
|
||||
)
|
||||
else:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = await ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
async def _configured_protocol_interface(
|
||||
address: _Address, options: PoolOptions
|
||||
) -> AsyncNetworkingInterface:
|
||||
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets protocol's SSL and timeout options.
|
||||
"""
|
||||
sock = await _async_create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
timeout = options.socket_timeout
|
||||
|
||||
if ssl_context is None:
|
||||
return AsyncNetworkingInterface(
|
||||
await asyncio.get_running_loop().create_connection(
|
||||
lambda: PyMongoProtocol(timeout=timeout), sock=sock
|
||||
)
|
||||
)
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
||||
lambda: PyMongoProtocol(timeout=timeout),
|
||||
sock=sock,
|
||||
server_hostname=host,
|
||||
ssl=ssl_context,
|
||||
)
|
||||
except _CertificateError:
|
||||
transport.abort()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
transport.abort()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
||||
except _CertificateError:
|
||||
transport.abort()
|
||||
raise
|
||||
|
||||
return AsyncNetworkingInterface((transport, protocol))
|
||||
|
||||
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
|
||||
This is a modified version of create_connection from CPython >= 2.7.
|
||||
"""
|
||||
host, port = address
|
||||
|
||||
# Check if dealing with a unix domain socket
|
||||
if host.endswith(".sock"):
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
# SOCK_CLOEXEC not supported for Unix sockets.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.connect(host)
|
||||
return sock
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Don't try IPv6 if we don't support it. Also skip it if host
|
||||
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
||||
# like PYTHON-356.
|
||||
family = socket.AF_INET
|
||||
if socket.has_ipv6 and host != "localhost":
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore]
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
# all file descriptors are created non-inheritable. See PEP 446.
|
||||
try:
|
||||
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
||||
except OSError:
|
||||
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
||||
# it?
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
# Fallback when SOCK_CLOEXEC isn't available.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# CSOT: apply timeout to socket connect.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
timeout = options.connect_timeout
|
||||
elif timeout <= 0:
|
||||
raise socket.timeout("timed out")
|
||||
sock.settimeout(timeout)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
||||
_set_keepalive_times(sock)
|
||||
sock.connect(sa)
|
||||
return sock
|
||||
except OSError as e:
|
||||
err = e
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
else:
|
||||
# This likely means we tried to connect to an IPv6 only
|
||||
# host with an OS/kernel or Python interpreter that doesn't
|
||||
# support IPv6. The test case is Jython2.5.1 which doesn't
|
||||
# support IPv6 at all.
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
|
||||
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return NetworkingInterface(sock)
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock)
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return NetworkingInterface(ssl_sock)
|
||||
@ -87,7 +87,7 @@ class _Bulk:
|
||||
self,
|
||||
collection: Collection[_DocumentType],
|
||||
ordered: bool,
|
||||
bypass_document_validation: bool,
|
||||
bypass_document_validation: Optional[bool],
|
||||
comment: Optional[str] = None,
|
||||
let: Optional[Any] = None,
|
||||
) -> None:
|
||||
@ -516,8 +516,8 @@ class _Bulk:
|
||||
if self.comment:
|
||||
cmd["comment"] = self.comment
|
||||
_csot.apply_write_concern(cmd, write_concern)
|
||||
if self.bypass_doc_val:
|
||||
cmd["bypassDocumentValidation"] = True
|
||||
if self.bypass_doc_val is not None:
|
||||
cmd["bypassDocumentValidation"] = self.bypass_doc_val
|
||||
if self.let is not None and run.op_type in (_DELETE, _UPDATE):
|
||||
cmd["let"] = self.let
|
||||
if session:
|
||||
|
||||
@ -700,7 +700,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
requests: Sequence[_WriteOp[_DocumentType]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
let: Optional[Mapping] = None,
|
||||
@ -799,7 +799,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
ordered: bool,
|
||||
write_concern: WriteConcern,
|
||||
op_id: Optional[int],
|
||||
bypass_doc_val: bool,
|
||||
bypass_doc_val: Optional[bool],
|
||||
session: Optional[ClientSession],
|
||||
comment: Optional[Any] = None,
|
||||
) -> Any:
|
||||
@ -813,8 +813,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def _insert_command(
|
||||
session: Optional[ClientSession], conn: Connection, retryable_write: bool
|
||||
) -> None:
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
if bypass_doc_val is not None:
|
||||
command["bypassDocumentValidation"] = bypass_doc_val
|
||||
|
||||
result = conn.command(
|
||||
self._database.name,
|
||||
@ -839,7 +839,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
def insert_one(
|
||||
self,
|
||||
document: Union[_DocumentType, RawBSONDocument],
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertOneResult:
|
||||
@ -905,7 +905,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
self,
|
||||
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
|
||||
ordered: bool = True,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
comment: Optional[Any] = None,
|
||||
) -> InsertManyResult:
|
||||
@ -985,7 +985,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
op_id: Optional[int] = None,
|
||||
ordered: bool = True,
|
||||
bypass_doc_val: Optional[bool] = False,
|
||||
bypass_doc_val: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -1040,8 +1040,8 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
if comment is not None:
|
||||
command["comment"] = comment
|
||||
# Update command.
|
||||
if bypass_doc_val:
|
||||
command["bypassDocumentValidation"] = True
|
||||
if bypass_doc_val is not None:
|
||||
command["bypassDocumentValidation"] = bypass_doc_val
|
||||
|
||||
# The command result has to be published for APM unmodified
|
||||
# so we make a shallow copy here before adding updatedExisting.
|
||||
@ -1081,7 +1081,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
write_concern: Optional[WriteConcern] = None,
|
||||
op_id: Optional[int] = None,
|
||||
ordered: bool = True,
|
||||
bypass_doc_val: Optional[bool] = False,
|
||||
bypass_doc_val: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -1127,7 +1127,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
replacement: Mapping[str, Any],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
@ -1236,7 +1236,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
filter: Mapping[str, Any],
|
||||
update: Union[Mapping[str, Any], _Pipeline],
|
||||
upsert: bool = False,
|
||||
bypass_document_validation: bool = False,
|
||||
bypass_document_validation: Optional[bool] = None,
|
||||
collation: Optional[_CollationIn] = None,
|
||||
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
|
||||
hint: Optional[_IndexKeyHint] = None,
|
||||
@ -2941,6 +2941,7 @@ class Collection(common.BaseObject, Generic[_DocumentType]):
|
||||
returning aggregate results using a cursor.
|
||||
- `collation` (optional): An instance of
|
||||
:class:`~pymongo.collation.Collation`.
|
||||
- `bypassDocumentValidation` (bool): If ``True``, allows the write to opt-out of document level validation.
|
||||
|
||||
|
||||
:return: A :class:`~pymongo.command_cursor.CommandCursor` over the result
|
||||
|
||||
@ -70,21 +70,21 @@ from pymongo.errors import (
|
||||
NetworkTimeout,
|
||||
ServerSelectionTimeoutError,
|
||||
)
|
||||
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
|
||||
from pymongo.network_layer import sendall
|
||||
from pymongo.operations import UpdateOne
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.synchronous.pool import (
|
||||
from pymongo.pool_shared import (
|
||||
_configured_socket,
|
||||
_get_timeout_details,
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
@ -36,7 +36,11 @@ from pymongo.server_description import ServerDescription
|
||||
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||
from pymongo.synchronous.pool import ( # type: ignore[attr-defined]
|
||||
Connection,
|
||||
Pool,
|
||||
_CancellationContext,
|
||||
)
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -31,20 +30,16 @@ from typing import (
|
||||
|
||||
from bson import _decode_all_selective
|
||||
from pymongo import _csot, helpers_shared, message
|
||||
from pymongo.common import MAX_MESSAGE_SIZE
|
||||
from pymongo.compression_support import _NO_COMPRESSION, decompress
|
||||
from pymongo.compression_support import _NO_COMPRESSION
|
||||
from pymongo.errors import (
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
ProtocolError,
|
||||
)
|
||||
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
|
||||
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
|
||||
from pymongo.message import _OpMsg
|
||||
from pymongo.monitoring import _is_speculative_authenticate
|
||||
from pymongo.network_layer import (
|
||||
_UNPACK_COMPRESSION_HEADER,
|
||||
_UNPACK_HEADER,
|
||||
receive_data,
|
||||
receive_message,
|
||||
sendall,
|
||||
)
|
||||
|
||||
@ -194,7 +189,7 @@ def command(
|
||||
)
|
||||
|
||||
try:
|
||||
sendall(conn.conn, msg)
|
||||
sendall(conn.conn.get_conn, msg)
|
||||
if use_op_msg and unacknowledged:
|
||||
# Unacknowledged, fake a successful command response.
|
||||
reply = None
|
||||
@ -301,45 +296,3 @@ def command(
|
||||
)
|
||||
|
||||
return response_doc # type: ignore[return-value]
|
||||
|
||||
|
||||
def receive_message(
|
||||
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
||||
) -> Union[_OpReply, _OpMsg]:
|
||||
"""Receive a raw BSON message or raise socket.error."""
|
||||
if _csot.get_timeout():
|
||||
deadline = _csot.get_deadline()
|
||||
else:
|
||||
timeout = conn.conn.gettimeout()
|
||||
if timeout:
|
||||
deadline = time.monotonic() + timeout
|
||||
else:
|
||||
deadline = None
|
||||
# Ignore the response's request id.
|
||||
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
|
||||
# No request_id for exhaust cursor "getMore".
|
||||
if request_id is not None:
|
||||
if request_id != response_to:
|
||||
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
|
||||
if length <= 16:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) not longer than standard message header size (16)"
|
||||
)
|
||||
if length > max_message_size:
|
||||
raise ProtocolError(
|
||||
f"Message length ({length!r}) is larger than server max "
|
||||
f"message size ({max_message_size!r})"
|
||||
)
|
||||
if op_code == 2012:
|
||||
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
|
||||
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
|
||||
else:
|
||||
data = receive_data(conn, length - 16, deadline)
|
||||
|
||||
try:
|
||||
unpack_reply = _UNPACK_REPLY[op_code]
|
||||
except KeyError:
|
||||
raise ProtocolError(
|
||||
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
||||
) from None
|
||||
return unpack_reply(data)
|
||||
|
||||
@ -14,14 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
@ -49,16 +45,13 @@ from pymongo.common import (
|
||||
from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
AutoReconnect,
|
||||
ConfigurationError,
|
||||
ConnectionFailure,
|
||||
DocumentTooLarge,
|
||||
ExecutionTimeout,
|
||||
InvalidOperation,
|
||||
NetworkTimeout,
|
||||
NotPrimaryError,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
WaitQueueTimeoutError,
|
||||
_CertificateError,
|
||||
)
|
||||
from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.lock import (
|
||||
@ -76,16 +69,23 @@ from pymongo.monitoring import (
|
||||
ConnectionCheckOutFailedReason,
|
||||
ConnectionClosedReason,
|
||||
)
|
||||
from pymongo.network_layer import sendall
|
||||
from pymongo.network_layer import NetworkingInterface, receive_message, sendall
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
_CancellationContext,
|
||||
_configured_socket_interface,
|
||||
_get_timeout_details,
|
||||
_raise_connection_failure,
|
||||
format_timeout_details,
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
from pymongo.ssl_support import SSLError
|
||||
from pymongo.synchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
|
||||
from pymongo.synchronous.network import command, receive_message
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
from pymongo.synchronous.network import command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
@ -96,7 +96,6 @@ if TYPE_CHECKING:
|
||||
ZstdContext,
|
||||
)
|
||||
from pymongo.message import _OpMsg, _OpReply
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import _ServerMode
|
||||
from pymongo.synchronous.auth import _AuthContext
|
||||
@ -123,133 +122,6 @@ except ImportError:
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_MAX_TCP_KEEPIDLE = 120
|
||||
_MAX_TCP_KEEPINTVL = 10
|
||||
_MAX_TCP_KEEPCNT = 9
|
||||
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
import _winreg as winreg
|
||||
except ImportError:
|
||||
import winreg
|
||||
|
||||
def _query(key, name, default):
|
||||
try:
|
||||
value, _ = winreg.QueryValueEx(key, name)
|
||||
# Ensure the value is a number or raise ValueError.
|
||||
return int(value)
|
||||
except (OSError, ValueError):
|
||||
# QueryValueEx raises OSError when the key does not exist (i.e.
|
||||
# the system is using the Windows default value).
|
||||
return default
|
||||
|
||||
try:
|
||||
with winreg.OpenKey(
|
||||
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
|
||||
) as key:
|
||||
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
|
||||
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
|
||||
except OSError:
|
||||
# We could not check the default values because winreg.OpenKey failed.
|
||||
# Assume the system is using the default values.
|
||||
_WINDOWS_TCP_IDLE_MS = 7200000
|
||||
_WINDOWS_TCP_INTERVAL_MS = 1000
|
||||
|
||||
def _set_keepalive_times(sock):
|
||||
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
|
||||
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
|
||||
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
|
||||
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
|
||||
|
||||
else:
|
||||
|
||||
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
|
||||
if hasattr(socket, tcp_option):
|
||||
sockopt = getattr(socket, tcp_option)
|
||||
try:
|
||||
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
|
||||
# TCP_KEEPIDLE and friends. Don't attempt to set the
|
||||
# values there.
|
||||
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
|
||||
if default > max_value:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _set_keepalive_times(sock: socket.socket) -> None:
|
||||
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
|
||||
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
|
||||
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
||||
|
||||
|
||||
def _raise_connection_failure(
|
||||
address: Any,
|
||||
error: Exception,
|
||||
msg_prefix: Optional[str] = None,
|
||||
timeout_details: Optional[dict[str, float]] = None,
|
||||
) -> NoReturn:
|
||||
"""Convert a socket.error to ConnectionFailure and raise it."""
|
||||
host, port = address
|
||||
# If connecting to a Unix socket, port will be None.
|
||||
if port is not None:
|
||||
msg = "%s:%d: %s" % (host, port, error)
|
||||
else:
|
||||
msg = f"{host}: {error}"
|
||||
if msg_prefix:
|
||||
msg = msg_prefix + msg
|
||||
if "configured timeouts" not in msg:
|
||||
msg += format_timeout_details(timeout_details)
|
||||
if isinstance(error, socket.timeout):
|
||||
raise NetworkTimeout(msg) from error
|
||||
elif isinstance(error, SSLError) and "timed out" in str(error):
|
||||
# Eventlet does not distinguish TLS network timeouts from other
|
||||
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
|
||||
# Luckily, we can work around this limitation because the phrase
|
||||
# 'timed out' appears in all the timeout related SSLErrors raised.
|
||||
raise NetworkTimeout(msg) from error
|
||||
else:
|
||||
raise AutoReconnect(msg) from error
|
||||
|
||||
|
||||
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
||||
details = {}
|
||||
timeout = _csot.get_timeout()
|
||||
socket_timeout = options.socket_timeout
|
||||
connect_timeout = options.connect_timeout
|
||||
if timeout:
|
||||
details["timeoutMS"] = timeout * 1000
|
||||
if socket_timeout and not timeout:
|
||||
details["socketTimeoutMS"] = socket_timeout * 1000
|
||||
if connect_timeout:
|
||||
details["connectTimeoutMS"] = connect_timeout * 1000
|
||||
return details
|
||||
|
||||
|
||||
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
||||
result = ""
|
||||
if details:
|
||||
result += " (configured timeouts:"
|
||||
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
|
||||
if timeout in details:
|
||||
result += f" {timeout}: {details[timeout]}ms,"
|
||||
result = result[:-1]
|
||||
result += ")"
|
||||
return result
|
||||
|
||||
|
||||
class _CancellationContext:
|
||||
def __init__(self) -> None:
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this context."""
|
||||
self._cancelled = True
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Was cancel called?"""
|
||||
return self._cancelled
|
||||
|
||||
|
||||
class Connection:
|
||||
"""Store a connection with some metadata.
|
||||
@ -261,7 +133,11 @@ class Connection:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conn: Union[socket.socket, _sslConn], pool: Pool, address: tuple[str, int], id: int
|
||||
self,
|
||||
conn: NetworkingInterface,
|
||||
pool: Pool,
|
||||
address: tuple[str, int],
|
||||
id: int,
|
||||
):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.conn = conn
|
||||
@ -318,7 +194,7 @@ class Connection:
|
||||
if timeout == self.last_timeout:
|
||||
return
|
||||
self.last_timeout = timeout
|
||||
self.conn.settimeout(timeout)
|
||||
self.conn.get_conn.settimeout(timeout)
|
||||
|
||||
def apply_timeout(
|
||||
self, client: MongoClient, cmd: Optional[MutableMapping[str, Any]]
|
||||
@ -573,7 +449,7 @@ class Connection:
|
||||
)
|
||||
|
||||
try:
|
||||
sendall(self.conn, message)
|
||||
sendall(self.conn.get_conn, message)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
self._raise_connection_failure(error)
|
||||
@ -707,7 +583,10 @@ class Connection:
|
||||
|
||||
def conn_closed(self) -> bool:
|
||||
"""Return True if we know socket has been closed, False otherwise."""
|
||||
return self.socket_checker.socket_closed(self.conn)
|
||||
if _IS_SYNC:
|
||||
return self.socket_checker.socket_closed(self.conn.get_conn)
|
||||
else:
|
||||
return self.conn.is_closing()
|
||||
|
||||
def send_cluster_time(
|
||||
self,
|
||||
@ -779,143 +658,6 @@ class Connection:
|
||||
)
|
||||
|
||||
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
|
||||
This is a modified version of create_connection from CPython >= 2.7.
|
||||
"""
|
||||
host, port = address
|
||||
|
||||
# Check if dealing with a unix domain socket
|
||||
if host.endswith(".sock"):
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
# SOCK_CLOEXEC not supported for Unix sockets.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.connect(host)
|
||||
return sock
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Don't try IPv6 if we don't support it. Also skip it if host
|
||||
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
||||
# like PYTHON-356.
|
||||
family = socket.AF_INET
|
||||
if socket.has_ipv6 and host != "localhost":
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
# all file descriptors are created non-inheritable. See PEP 446.
|
||||
try:
|
||||
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
||||
except OSError:
|
||||
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
||||
# it?
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
# Fallback when SOCK_CLOEXEC isn't available.
|
||||
_set_non_inheritable_non_atomic(sock.fileno())
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# CSOT: apply timeout to socket connect.
|
||||
timeout = _csot.remaining()
|
||||
if timeout is None:
|
||||
timeout = options.connect_timeout
|
||||
elif timeout <= 0:
|
||||
raise socket.timeout("timed out")
|
||||
sock.settimeout(timeout)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
||||
_set_keepalive_times(sock)
|
||||
sock.connect(sa)
|
||||
return sock
|
||||
except OSError as e:
|
||||
err = e
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
else:
|
||||
# This likely means we tried to connect to an IPv6 only
|
||||
# host with an OS/kernel or Python interpreter that doesn't
|
||||
# support IPv6. The test case is Jython2.5.1 which doesn't
|
||||
# support IPv6 at all.
|
||||
raise OSError("getaddrinfo failed")
|
||||
|
||||
|
||||
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
|
||||
"""Given (host, port) and PoolOptions, return a configured socket.
|
||||
|
||||
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
sock.settimeout(options.socket_timeout)
|
||||
return sock
|
||||
|
||||
host = address[0]
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if _IS_SYNC:
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = ssl_context.a_wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc]
|
||||
)
|
||||
else:
|
||||
if _IS_SYNC:
|
||||
ssl_sock = ssl_context.wrap_socket(sock)
|
||||
else:
|
||||
if hasattr(ssl_context, "a_wrap_socket"):
|
||||
ssl_sock = ssl_context.a_wrap_socket(sock) # type: ignore[assignment, misc]
|
||||
else:
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc]
|
||||
except _CertificateError:
|
||||
sock.close()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
details = _get_timeout_details(options)
|
||||
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
||||
if (
|
||||
ssl_context.verify_mode
|
||||
and not ssl_context.check_hostname
|
||||
and not options.tls_allow_invalid_hostnames
|
||||
):
|
||||
try:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
|
||||
except _CertificateError:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
|
||||
ssl_sock.settimeout(options.socket_timeout)
|
||||
return ssl_sock
|
||||
|
||||
|
||||
class _PoolClosedError(PyMongoError):
|
||||
"""Internal error raised when a thread tries to get a connection from a
|
||||
closed pool.
|
||||
@ -1260,7 +1002,7 @@ class Pool:
|
||||
)
|
||||
|
||||
try:
|
||||
sock = _configured_socket(self.address, self.opts)
|
||||
networking_interface = _configured_socket_interface(self.address, self.opts)
|
||||
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
|
||||
except BaseException as error:
|
||||
with self.lock:
|
||||
@ -1287,7 +1029,7 @@ class Pool:
|
||||
|
||||
raise
|
||||
|
||||
conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
conn = Connection(networking_interface, self, self.address, conn_id) # type: ignore[arg-type]
|
||||
with self.lock:
|
||||
self.active_contexts.add(conn.cancel_context)
|
||||
self.active_contexts.discard(tmp_context)
|
||||
@ -1698,5 +1440,6 @@ class Pool:
|
||||
# Avoid ResourceWarnings in Python 3
|
||||
# Close all sockets without calling reset() or close() because it is
|
||||
# not safe to acquire a lock in __del__.
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
if _IS_SYNC:
|
||||
for conn in self.conns:
|
||||
conn.close_conn(None)
|
||||
|
||||
@ -116,6 +116,7 @@ filterwarnings = [
|
||||
"module:unclosed <socket.socket:ResourceWarning",
|
||||
"module:unclosed <ssl.SSLSocket:ResourceWarning",
|
||||
"module:unclosed <socket object:ResourceWarning",
|
||||
"module:unclosed transport:ResourceWarning",
|
||||
# https://github.com/eventlet/eventlet/issues/818
|
||||
"module:please use dns.resolver.Resolver.resolve:DeprecationWarning",
|
||||
# https://github.com/dateutil/dateutil/issues/1314
|
||||
|
||||
@ -22,6 +22,8 @@ import sys
|
||||
import warnings
|
||||
from test.asynchronous import AsyncPyMongoTestCase
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
@ -30,6 +32,8 @@ from test.asynchronous.unified_format import generate_test_classes
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.auth_oidc import OIDCCallback
|
||||
|
||||
pytestmark = pytest.mark.auth
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
|
||||
|
||||
@ -301,7 +301,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
async def test_bulk_max_message_size(self):
|
||||
await self.coll.delete_many({})
|
||||
self.addCleanup(self.coll.delete_many, {})
|
||||
self.addAsyncCleanup(self.coll.delete_many, {})
|
||||
_16_MB = 16 * 1000 * 1000
|
||||
# Generate a list of documents such that the first batched OP_MSG is
|
||||
# as close as possible to the 48MB limit.
|
||||
@ -505,7 +505,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
async def test_single_error_ordered_batch(self):
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
requests: list = [
|
||||
InsertOne({"b": 1, "a": 1}),
|
||||
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
|
||||
@ -547,7 +547,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
async def test_multiple_error_ordered_batch(self):
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
requests: list = [
|
||||
InsertOne({"b": 1, "a": 1}),
|
||||
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
|
||||
@ -616,7 +616,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
async def test_single_error_unordered_batch(self):
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
requests: list = [
|
||||
InsertOne({"b": 1, "a": 1}),
|
||||
UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True),
|
||||
@ -659,7 +659,7 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
|
||||
async def test_multiple_error_unordered_batch(self):
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
requests: list = [
|
||||
InsertOne({"b": 1, "a": 1}),
|
||||
UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True),
|
||||
@ -1002,7 +1002,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
|
||||
await self.coll.delete_many({})
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
|
||||
# Fail due to write concern support as well
|
||||
# as duplicate key error on ordered batch.
|
||||
@ -1077,7 +1077,7 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
|
||||
|
||||
await self.coll.delete_many({})
|
||||
await self.coll.create_index("a", unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("a", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("a", 1)])
|
||||
|
||||
# Fail due to write concern support as well
|
||||
# as duplicate key error on unordered batch.
|
||||
|
||||
@ -746,7 +746,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
|
||||
# Assert that if a socket is closed, a new one takes its place
|
||||
async with server._pool.checkout() as conn:
|
||||
conn.close_conn(None)
|
||||
await conn.close_conn(None)
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"a closed socket gets replaced from the pool",
|
||||
@ -1270,7 +1270,6 @@ class TestClient(AsyncIntegrationTest):
|
||||
no_timeout = self.client
|
||||
timeout_sec = 1
|
||||
timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
|
||||
self.addAsyncCleanup(timeout.close)
|
||||
|
||||
await no_timeout.pymongo_test.drop_collection("test")
|
||||
await no_timeout.pymongo_test.test.insert_one({"x": 1})
|
||||
@ -1337,7 +1336,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_socketKeepAlive(self):
|
||||
pool = await async_get_pool(self.client)
|
||||
async with pool.checkout() as conn:
|
||||
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
|
||||
@no_type_check
|
||||
@ -1537,7 +1536,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
|
||||
# Cause a network error.
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await anext(cursor)
|
||||
@ -1562,7 +1561,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# Cause a network error on the actual socket.
|
||||
pool = await async_get_pool(c)
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
|
||||
# AsyncConnection.authenticate logs, but gets a socket.error. Should be
|
||||
# reraised as AutoReconnect.
|
||||
@ -2254,7 +2253,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
|
||||
|
||||
# Cause a network error.
|
||||
conn = one(pool.conns)
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
|
||||
cursor = collection.find(cursor_type=CursorType.EXHAUST)
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
@ -2282,7 +2281,7 @@ class TestExhaustCursor(AsyncIntegrationTest):
|
||||
|
||||
# Cause a network error.
|
||||
conn = cursor._sock_mgr.conn
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
|
||||
# A getmore fails.
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
|
||||
@ -651,7 +651,6 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
_OVERHEAD = 500
|
||||
|
||||
internal_client = await self.async_rs_or_single_client(timeoutMS=None)
|
||||
self.addAsyncCleanup(internal_client.close)
|
||||
|
||||
collection = internal_client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
|
||||
@ -273,7 +273,7 @@ class AsyncTestCMAP(AsyncIntegrationTest):
|
||||
for t in self.targets.values():
|
||||
await t.join(5)
|
||||
for conn in self.labels.values():
|
||||
conn.close_conn(None)
|
||||
await conn.close_conn(None)
|
||||
|
||||
self.addAsyncCleanup(cleanup)
|
||||
|
||||
|
||||
@ -1401,7 +1401,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_to_list_length(self):
|
||||
coll = self.db.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
self.addCleanup(coll.drop)
|
||||
self.addAsyncCleanup(coll.drop)
|
||||
c = coll.find()
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
@ -1812,6 +1812,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
|
||||
@async_client_context.require_version_min(5, 0, -1)
|
||||
@async_client_context.require_no_mongos
|
||||
@async_client_context.require_sync
|
||||
async def test_exhaust_cursor_db_set(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await self.async_rs_or_single_client(event_listeners=[listener])
|
||||
@ -1821,7 +1822,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
|
||||
listener.reset()
|
||||
|
||||
result = await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
|
||||
result = list(await c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
|
||||
|
||||
@ -115,6 +115,17 @@ class TestGridfs(AsyncIntegrationTest):
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
async def test_delete_by_name(self):
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
gfs = gridfs.AsyncGridFSBucket(self.db)
|
||||
await gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
|
||||
self.assertEqual(1, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(5, await self.db.fs.chunks.count_documents({}))
|
||||
await gfs.delete_by_name("test_filename")
|
||||
self.assertEqual(0, await self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, await self.db.fs.chunks.count_documents({}))
|
||||
|
||||
async def test_empty_file(self):
|
||||
oid = await self.fs.upload_from_stream("test_filename", b"")
|
||||
self.assertEqual(b"", await (await self.fs.open_download_stream(oid)).read())
|
||||
|
||||
@ -219,7 +219,7 @@ class TestPooling(_TestPoolingBase):
|
||||
|
||||
async with cx_pool.checkout() as conn:
|
||||
# Use Connection's API to close the socket.
|
||||
conn.close_conn(None)
|
||||
await conn.close_conn(None)
|
||||
|
||||
self.assertEqual(0, len(cx_pool.conns))
|
||||
|
||||
@ -232,7 +232,7 @@ class TestPooling(_TestPoolingBase):
|
||||
async with cx_pool.checkout() as conn:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
self.assertTrue(conn.conn_closed())
|
||||
|
||||
async with cx_pool.checkout() as new_connection:
|
||||
@ -306,7 +306,7 @@ class TestPooling(_TestPoolingBase):
|
||||
async with cx_pool.checkout() as conn:
|
||||
# Simulate a closed socket without telling the Connection it's
|
||||
# closed.
|
||||
conn.conn.close()
|
||||
await conn.conn.close()
|
||||
|
||||
# Swap pool's address with a bad one.
|
||||
address, cx_pool.address = cx_pool.address, ("foo.com", 1234)
|
||||
|
||||
@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest):
|
||||
self.deprecation_filter = DeprecationFilter()
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
await super().asyncTearDown()
|
||||
self.deprecation_filter.stop()
|
||||
|
||||
|
||||
@ -196,6 +197,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
self.knobs.disable()
|
||||
await super().asyncTearDown()
|
||||
|
||||
async def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
|
||||
@ -45,7 +45,7 @@ from test.utils_shared import (
|
||||
|
||||
from bson import DBRef
|
||||
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
|
||||
from pymongo import ASCENDING, AsyncMongoClient, monitoring
|
||||
from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
@ -543,7 +543,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
(bucket.rename, [1, "f2"], {}),
|
||||
# Delete both files so _test_ops can run these operations twice.
|
||||
(bucket.delete, [1], {}),
|
||||
(bucket.delete, [2], {}),
|
||||
(bucket.delete_by_name, ["f"], {}),
|
||||
)
|
||||
|
||||
async def test_gridfsbucket_cursor(self):
|
||||
|
||||
@ -32,7 +32,7 @@ from typing import List
|
||||
|
||||
from bson import encode
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import WriteConcern
|
||||
from pymongo import WriteConcern, _csot
|
||||
from pymongo.asynchronous import client_session
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
@ -295,6 +295,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
"new-name",
|
||||
),
|
||||
),
|
||||
(bucket.delete_by_name, ("new-name",)),
|
||||
]
|
||||
|
||||
async with client.start_session() as s, await s.start_transaction():
|
||||
|
||||
@ -66,7 +66,7 @@ import pymongo
|
||||
from bson import SON, json_util
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs import AsyncGridFSBucket, GridOut
|
||||
from gridfs import AsyncGridFSBucket, GridOut, NoFile
|
||||
from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot
|
||||
from pymongo.asynchronous.change_stream import AsyncChangeStream
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState
|
||||
@ -632,7 +632,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
|
||||
# Connection errors are considered client errors.
|
||||
if isinstance(error, ConnectionFailure):
|
||||
self.assertNotIsInstance(error, NotPrimaryError)
|
||||
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)):
|
||||
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)):
|
||||
pass
|
||||
else:
|
||||
self.assertNotIsInstance(error, PyMongoError)
|
||||
|
||||
@ -159,6 +159,7 @@ class AsyncMockConnection:
|
||||
self.cancel_context = _CancellationContext()
|
||||
self.more_to_come = False
|
||||
self.id = random.randint(0, 100)
|
||||
self.server_connection_id = random.randint(0, 100)
|
||||
|
||||
def close_conn(self, reason):
|
||||
pass
|
||||
|
||||
493
test/crud/unified/bypassDocumentValidation.json
Normal file
493
test/crud/unified/bypassDocumentValidation.json
Normal file
@ -0,0 +1,493 @@
|
||||
{
|
||||
"description": "bypassDocumentValidation",
|
||||
"schemaVersion": "1.4",
|
||||
"runOnRequirements": [
|
||||
{
|
||||
"minServerVersion": "3.2",
|
||||
"serverless": "forbid"
|
||||
}
|
||||
],
|
||||
"createEntities": [
|
||||
{
|
||||
"client": {
|
||||
"id": "client0",
|
||||
"observeEvents": [
|
||||
"commandStartedEvent"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": {
|
||||
"id": "database0",
|
||||
"client": "client0",
|
||||
"databaseName": "crud"
|
||||
}
|
||||
},
|
||||
{
|
||||
"collection": {
|
||||
"id": "collection0",
|
||||
"database": "database0",
|
||||
"collectionName": "coll"
|
||||
}
|
||||
}
|
||||
],
|
||||
"initialData": [
|
||||
{
|
||||
"collectionName": "coll",
|
||||
"databaseName": "crud",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 1,
|
||||
"x": 11
|
||||
},
|
||||
{
|
||||
"_id": 2,
|
||||
"x": 22
|
||||
},
|
||||
{
|
||||
"_id": 3,
|
||||
"x": 33
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tests": [
|
||||
{
|
||||
"description": "Aggregate with $out passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "aggregate",
|
||||
"arguments": {
|
||||
"pipeline": [
|
||||
{
|
||||
"$sort": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$out": "other_test_collection"
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"aggregate": "coll",
|
||||
"pipeline": [
|
||||
{
|
||||
"$sort": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$out": "other_test_collection"
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
},
|
||||
"commandName": "aggregate",
|
||||
"databaseName": "crud"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "BulkWrite passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "bulkWrite",
|
||||
"arguments": {
|
||||
"requests": [
|
||||
{
|
||||
"insertOne": {
|
||||
"document": {
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "coll",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "FindOneAndReplace passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "findOneAndReplace",
|
||||
"arguments": {
|
||||
"filter": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"replacement": {
|
||||
"x": 32
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"findAndModify": "coll",
|
||||
"query": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
"x": 32
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "FindOneAndUpdate passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "findOneAndUpdate",
|
||||
"arguments": {
|
||||
"filter": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"findAndModify": "coll",
|
||||
"query": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "InsertMany passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "insertMany",
|
||||
"arguments": {
|
||||
"documents": [
|
||||
{
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "coll",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "InsertOne passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "insertOne",
|
||||
"arguments": {
|
||||
"document": {
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"insert": "coll",
|
||||
"documents": [
|
||||
{
|
||||
"_id": 4,
|
||||
"x": 44
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "ReplaceOne passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "replaceOne",
|
||||
"arguments": {
|
||||
"filter": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"replacement": {
|
||||
"x": 32
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"update": "coll",
|
||||
"updates": [
|
||||
{
|
||||
"q": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"u": {
|
||||
"x": 32
|
||||
},
|
||||
"multi": {
|
||||
"$$unsetOrMatches": false
|
||||
},
|
||||
"upsert": {
|
||||
"$$unsetOrMatches": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "UpdateMany passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "updateMany",
|
||||
"arguments": {
|
||||
"filter": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"update": "coll",
|
||||
"updates": [
|
||||
{
|
||||
"q": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"u": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"multi": true,
|
||||
"upsert": {
|
||||
"$$unsetOrMatches": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "UpdateOne passes bypassDocumentValidation: false",
|
||||
"operations": [
|
||||
{
|
||||
"object": "collection0",
|
||||
"name": "updateOne",
|
||||
"arguments": {
|
||||
"filter": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"update": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"events": [
|
||||
{
|
||||
"commandStartedEvent": {
|
||||
"command": {
|
||||
"update": "coll",
|
||||
"updates": [
|
||||
{
|
||||
"q": {
|
||||
"_id": {
|
||||
"$gt": 1
|
||||
}
|
||||
},
|
||||
"u": {
|
||||
"$inc": {
|
||||
"x": 1
|
||||
}
|
||||
},
|
||||
"multi": {
|
||||
"$$unsetOrMatches": false
|
||||
},
|
||||
"upsert": {
|
||||
"$$unsetOrMatches": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"bypassDocumentValidation": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
230
test/gridfs/deleteByName.json
Normal file
230
test/gridfs/deleteByName.json
Normal file
@ -0,0 +1,230 @@
|
||||
{
|
||||
"description": "gridfs-deleteByName",
|
||||
"schemaVersion": "1.0",
|
||||
"createEntities": [
|
||||
{
|
||||
"client": {
|
||||
"id": "client0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"database": {
|
||||
"id": "database0",
|
||||
"client": "client0",
|
||||
"databaseName": "gridfs-tests"
|
||||
}
|
||||
},
|
||||
{
|
||||
"bucket": {
|
||||
"id": "bucket0",
|
||||
"database": "database0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"collection": {
|
||||
"id": "bucket0_files_collection",
|
||||
"database": "database0",
|
||||
"collectionName": "fs.files"
|
||||
}
|
||||
},
|
||||
{
|
||||
"collection": {
|
||||
"id": "bucket0_chunks_collection",
|
||||
"database": "database0",
|
||||
"collectionName": "fs.chunks"
|
||||
}
|
||||
}
|
||||
],
|
||||
"initialData": [
|
||||
{
|
||||
"collectionName": "fs.files",
|
||||
"databaseName": "gridfs-tests",
|
||||
"documents": [
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000001"
|
||||
},
|
||||
"length": 0,
|
||||
"chunkSize": 4,
|
||||
"uploadDate": {
|
||||
"$date": "1970-01-01T00:00:00.000Z"
|
||||
},
|
||||
"filename": "filename",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000002"
|
||||
},
|
||||
"length": 0,
|
||||
"chunkSize": 4,
|
||||
"uploadDate": {
|
||||
"$date": "1970-01-01T00:00:00.000Z"
|
||||
},
|
||||
"filename": "filename",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000003"
|
||||
},
|
||||
"length": 2,
|
||||
"chunkSize": 4,
|
||||
"uploadDate": {
|
||||
"$date": "1970-01-01T00:00:00.000Z"
|
||||
},
|
||||
"filename": "filename",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"length": 8,
|
||||
"chunkSize": 4,
|
||||
"uploadDate": {
|
||||
"$date": "1970-01-01T00:00:00.000Z"
|
||||
},
|
||||
"filename": "otherfilename",
|
||||
"metadata": {}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"collectionName": "fs.chunks",
|
||||
"databaseName": "gridfs-tests",
|
||||
"documents": [
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000001"
|
||||
},
|
||||
"files_id": {
|
||||
"$oid": "000000000000000000000002"
|
||||
},
|
||||
"n": 0,
|
||||
"data": {
|
||||
"$binary": {
|
||||
"base64": "",
|
||||
"subType": "00"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000002"
|
||||
},
|
||||
"files_id": {
|
||||
"$oid": "000000000000000000000003"
|
||||
},
|
||||
"n": 0,
|
||||
"data": {
|
||||
"$binary": {
|
||||
"base64": "",
|
||||
"subType": "00"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000003"
|
||||
},
|
||||
"files_id": {
|
||||
"$oid": "000000000000000000000003"
|
||||
},
|
||||
"n": 0,
|
||||
"data": {
|
||||
"$binary": {
|
||||
"base64": "",
|
||||
"subType": "00"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"files_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"n": 0,
|
||||
"data": {
|
||||
"$binary": {
|
||||
"base64": "",
|
||||
"subType": "00"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tests": [
|
||||
{
|
||||
"description": "delete when multiple revisions of the file exist",
|
||||
"operations": [
|
||||
{
|
||||
"name": "deleteByName",
|
||||
"object": "bucket0",
|
||||
"arguments": {
|
||||
"filename": "filename"
|
||||
}
|
||||
}
|
||||
],
|
||||
"outcome": [
|
||||
{
|
||||
"collectionName": "fs.files",
|
||||
"databaseName": "gridfs-tests",
|
||||
"documents": [
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"length": 8,
|
||||
"chunkSize": 4,
|
||||
"uploadDate": {
|
||||
"$date": "1970-01-01T00:00:00.000Z"
|
||||
},
|
||||
"filename": "otherfilename",
|
||||
"metadata": {}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"collectionName": "fs.chunks",
|
||||
"databaseName": "gridfs-tests",
|
||||
"documents": [
|
||||
{
|
||||
"_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"files_id": {
|
||||
"$oid": "000000000000000000000004"
|
||||
},
|
||||
"n": 0,
|
||||
"data": {
|
||||
"$binary": {
|
||||
"base64": "",
|
||||
"subType": "00"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "delete when file name does not exist",
|
||||
"operations": [
|
||||
{
|
||||
"name": "deleteByName",
|
||||
"object": "bucket0",
|
||||
"arguments": {
|
||||
"filename": "missing-file"
|
||||
},
|
||||
"expectError": {
|
||||
"isClientError": true
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -1616,6 +1616,50 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "pinned connection is released when session ended",
|
||||
"operations": [
|
||||
{
|
||||
"name": "startTransaction",
|
||||
"object": "session0"
|
||||
},
|
||||
{
|
||||
"name": "insertOne",
|
||||
"object": "collection0",
|
||||
"arguments": {
|
||||
"document": {
|
||||
"x": 1
|
||||
},
|
||||
"session": "session0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "commitTransaction",
|
||||
"object": "session0"
|
||||
},
|
||||
{
|
||||
"name": "endSession",
|
||||
"object": "session0"
|
||||
}
|
||||
],
|
||||
"expectEvents": [
|
||||
{
|
||||
"client": "client0",
|
||||
"eventType": "cmap",
|
||||
"events": [
|
||||
{
|
||||
"connectionReadyEvent": {}
|
||||
},
|
||||
{
|
||||
"connectionCheckedOutEvent": {}
|
||||
},
|
||||
{
|
||||
"connectionCheckedInEvent": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -22,6 +22,8 @@ import sys
|
||||
import warnings
|
||||
from test import PyMongoTestCase
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
@ -30,6 +32,8 @@ from test.unified_format import generate_test_classes
|
||||
from pymongo import MongoClient
|
||||
from pymongo.synchronous.auth_oidc import OIDCCallback
|
||||
|
||||
pytestmark = pytest.mark.auth
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth")
|
||||
|
||||
@ -1233,7 +1233,6 @@ class TestClient(IntegrationTest):
|
||||
no_timeout = self.client
|
||||
timeout_sec = 1
|
||||
timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
|
||||
self.addCleanup(timeout.close)
|
||||
|
||||
no_timeout.pymongo_test.drop_collection("test")
|
||||
no_timeout.pymongo_test.test.insert_one({"x": 1})
|
||||
@ -1296,7 +1295,7 @@ class TestClient(IntegrationTest):
|
||||
def test_socketKeepAlive(self):
|
||||
pool = get_pool(self.client)
|
||||
with pool.checkout() as conn:
|
||||
keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
keepalive = conn.conn.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
|
||||
self.assertTrue(keepalive)
|
||||
|
||||
@no_type_check
|
||||
|
||||
@ -647,7 +647,6 @@ class TestClientBulkWriteCSOT(IntegrationTest):
|
||||
_OVERHEAD = 500
|
||||
|
||||
internal_client = self.rs_or_single_client(timeoutMS=None)
|
||||
self.addCleanup(internal_client.close)
|
||||
|
||||
collection = internal_client.db["coll"]
|
||||
self.addCleanup(collection.drop)
|
||||
|
||||
@ -1801,6 +1801,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
|
||||
@client_context.require_version_min(5, 0, -1)
|
||||
@client_context.require_no_mongos
|
||||
@client_context.require_sync
|
||||
def test_exhaust_cursor_db_set(self):
|
||||
listener = OvertCommandListener()
|
||||
client = self.rs_or_single_client(event_listeners=[listener])
|
||||
@ -1810,7 +1811,7 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
|
||||
listener.reset()
|
||||
|
||||
result = c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1).to_list()
|
||||
result = list(c.find({}, cursor_type=pymongo.CursorType.EXHAUST, batch_size=1))
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
|
||||
|
||||
@ -115,6 +115,17 @@ class TestGridfs(IntegrationTest):
|
||||
self.assertEqual(0, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
def test_delete_by_name(self):
|
||||
self.assertEqual(0, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
gfs = gridfs.GridFSBucket(self.db)
|
||||
gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1)
|
||||
self.assertEqual(1, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(5, self.db.fs.chunks.count_documents({}))
|
||||
gfs.delete_by_name("test_filename")
|
||||
self.assertEqual(0, self.db.fs.files.count_documents({}))
|
||||
self.assertEqual(0, self.db.fs.chunks.count_documents({}))
|
||||
|
||||
def test_empty_file(self):
|
||||
oid = self.fs.upload_from_stream("test_filename", b"")
|
||||
self.assertEqual(b"", (self.fs.open_download_stream(oid)).read())
|
||||
|
||||
@ -137,6 +137,7 @@ class IgnoreDeprecationsTest(IntegrationTest):
|
||||
self.deprecation_filter = DeprecationFilter()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
self.deprecation_filter.stop()
|
||||
|
||||
|
||||
@ -194,6 +195,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")])
|
||||
)
|
||||
self.knobs.disable()
|
||||
super().tearDown()
|
||||
|
||||
def test_supported_single_statement_no_retry(self):
|
||||
listener = OvertCommandListener()
|
||||
|
||||
@ -45,7 +45,7 @@ from test.utils_shared import (
|
||||
|
||||
from bson import DBRef
|
||||
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
|
||||
from pymongo import ASCENDING, MongoClient, monitoring
|
||||
from pymongo import ASCENDING, MongoClient, _csot, monitoring
|
||||
from pymongo.common import _MAX_END_SESSIONS
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import IndexModel, InsertOne, UpdateOne
|
||||
@ -543,7 +543,7 @@ class TestSession(IntegrationTest):
|
||||
(bucket.rename, [1, "f2"], {}),
|
||||
# Delete both files so _test_ops can run these operations twice.
|
||||
(bucket.delete, [1], {}),
|
||||
(bucket.delete, [2], {}),
|
||||
(bucket.delete_by_name, ["f"], {}),
|
||||
)
|
||||
|
||||
def test_gridfsbucket_cursor(self):
|
||||
|
||||
@ -32,7 +32,7 @@ from typing import List
|
||||
|
||||
from bson import encode
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from pymongo import WriteConcern
|
||||
from pymongo import WriteConcern, _csot
|
||||
from pymongo.errors import (
|
||||
CollectionInvalid,
|
||||
ConfigurationError,
|
||||
@ -287,6 +287,7 @@ class TestTransactions(TransactionsBase):
|
||||
"new-name",
|
||||
),
|
||||
),
|
||||
(bucket.delete_by_name, ("new-name",)),
|
||||
]
|
||||
|
||||
with client.start_session() as s, s.start_transaction():
|
||||
|
||||
@ -65,7 +65,7 @@ import pymongo
|
||||
from bson import SON, json_util
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
from bson.objectid import ObjectId
|
||||
from gridfs import GridFSBucket, GridOut
|
||||
from gridfs import GridFSBucket, GridOut, NoFile
|
||||
from pymongo import ASCENDING, CursorType, MongoClient, _csot
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
|
||||
from pymongo.errors import (
|
||||
@ -631,7 +631,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
# Connection errors are considered client errors.
|
||||
if isinstance(error, ConnectionFailure):
|
||||
self.assertNotIsInstance(error, NotPrimaryError)
|
||||
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)):
|
||||
elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError, NoFile)):
|
||||
pass
|
||||
else:
|
||||
self.assertNotIsInstance(error, PyMongoError)
|
||||
|
||||
@ -157,6 +157,7 @@ class MockConnection:
|
||||
self.cancel_context = _CancellationContext()
|
||||
self.more_to_come = False
|
||||
self.id = random.randint(0, 100)
|
||||
self.server_connection_id = random.randint(0, 100)
|
||||
|
||||
def close_conn(self, reason):
|
||||
pass
|
||||
|
||||
@ -615,6 +615,10 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
|
||||
# Aggregate uses "batchSize", while find uses batch_size.
|
||||
elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate":
|
||||
continue
|
||||
elif arg_name == "bypassDocumentValidation" and (
|
||||
opname == "aggregate" or "find_one_and" in opname
|
||||
):
|
||||
continue
|
||||
elif arg_name == "timeoutMode":
|
||||
raise unittest.SkipTest("PyMongo does not support timeoutMode")
|
||||
# Requires boolean returnDocument.
|
||||
|
||||
@ -47,6 +47,7 @@ replacements = {
|
||||
"async_receive_message": "receive_message",
|
||||
"async_receive_data": "receive_data",
|
||||
"async_sendall": "sendall",
|
||||
"async_socket_sendall": "sendall",
|
||||
"asynchronous": "synchronous",
|
||||
"Asynchronous": "Synchronous",
|
||||
"AsyncBulkTestBase": "BulkTestBase",
|
||||
@ -119,6 +120,9 @@ replacements = {
|
||||
"_async_create_lock": "_create_lock",
|
||||
"_async_create_condition": "_create_condition",
|
||||
"_async_cond_wait": "_cond_wait",
|
||||
"AsyncNetworkingInterface": "NetworkingInterface",
|
||||
"_configured_protocol_interface": "_configured_socket_interface",
|
||||
"_async_configured_socket": "_configured_socket",
|
||||
"SpecRunnerTask": "SpecRunnerThread",
|
||||
"AsyncMockConnection": "MockConnection",
|
||||
"AsyncMockPool": "MockPool",
|
||||
@ -127,6 +131,7 @@ replacements = {
|
||||
"async_create_barrier": "create_barrier",
|
||||
"async_barrier_wait": "barrier_wait",
|
||||
"async_joinall": "joinall",
|
||||
"_async_create_connection": "_create_connection",
|
||||
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user