Compare commits

...

13 Commits

Author SHA1 Message Date
mongodb-dbx-release-bot[bot]
84d0d3db4d
BUMP 4.12.1
Signed-off-by: mongodb-dbx-release-bot[bot] <167856002+mongodb-dbx-release-bot[bot]@users.noreply.github.com>
2025-04-29 18:30:01 +00:00
Jeffrey A. Clark
c52a456fd2
PYTHON-5357 Update changelog for 4.12.1 release (#2321) (#2323) 2025-04-29 13:28:51 -04:00
Noah Stapp
dd99f80ce3
PYTHON-5309: [v4.12] AsyncMongoClient doesn't use PyOpenSSL (#2286) (#2319)
Co-authored-by: Iris <58442094+sleepyStick@users.noreply.github.com>
2025-04-28 13:49:21 -04:00
mongodb-drivers-pr-bot[bot]
fecd29c1f8
PYTHON-5336 Added VECTOR_SUBTYPE line to API docs (#2313) [v4.12] (#2314)
Co-authored-by: Casey Clements <caseyclements@users.noreply.github.com>
2025-04-25 12:30:46 -05:00
Noah Stapp
c11d0f4def
PYTHON-5306: [v4.12] - Fix use of public MongoClient attributes before connection (#2285) (#2311) 2025-04-24 15:08:37 -04:00
Noah Stapp
f5836b3f6f
PYTHON-5346: [v4.12] test_init_disconnected_with_srv cannot run against sharded Topologies (#2304) (#2309) 2025-04-24 13:05:31 -05:00
Steven Silvester
38bc13db9c
PYTHON-5212 [v4.12] Do not hold Topology lock while resetting pool (#2307)
Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
2025-04-24 11:19:24 -05:00
mongodb-drivers-pr-bot[bot]
c6671e239a
PYTHON-5348 Fix CodeQL Scanning for GitHub Actions (#2308) [v4.12] (#2310)
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
2025-04-24 10:12:36 -05:00
Shane Harvey
79cb34a7b7
PYTHON-5314 [v4.12] Fix default imports for modules that worked in v4.8 (#2300) (#2303) 2025-04-23 11:34:51 -07:00
Shane Harvey
c83784610e
PYTHON-5310 [v4.12] Fix uri_parser AttributeError when used directly (#2283) (#2302)
Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
2025-04-23 10:43:11 -07:00
Steven Silvester
3377510a63
PYTHON-5295 [v4.12] Update lockfile for compat with older versions of uv (#2278) 2025-04-10 14:29:46 -05:00
Steven Silvester
6b3e254332
PYTHON-5297 [v4.12] AsyncMongoClient connection error causes UnboundLocalError (#2277)
Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
2025-04-10 13:58:35 -05:00
Steven Silvester
de0e2336cb
PYTHON-5288: [v4.12] SRV hostname validation fails when resolver and resolved hostnames are identical with three domain levels (#2276)
Co-authored-by: Jeffrey A. Clark <aclark@aclark.net>
2025-04-10 12:52:56 -05:00
42 changed files with 737 additions and 191 deletions

View File

@ -735,18 +735,20 @@ buildvariants:
- macos-14
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3
- name: pyopenssl-rhel8-python3.10
tasks:
- name: .replica_set .auth .ssl .sync
- name: .7.0 .auth .ssl .sync
- name: .replica_set .auth .ssl .sync_async
- name: .7.0 .auth .ssl .sync_async
display_name: PyOpenSSL RHEL8 Python3.10
run_on:
- rhel87-small
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: /opt/python/3.10/bin/python3
- name: pyopenssl-rhel8-python3.11
tasks:
@ -757,7 +759,8 @@ buildvariants:
- rhel87-small
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: /opt/python/3.11/bin/python3
- name: pyopenssl-rhel8-python3.12
tasks:
@ -768,18 +771,20 @@ buildvariants:
- rhel87-small
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: /opt/python/3.12/bin/python3
- name: pyopenssl-win64-python3.13
tasks:
- name: .replica_set .auth .ssl .sync
- name: .7.0 .auth .ssl .sync
- name: .replica_set .auth .ssl .sync_async
- name: .7.0 .auth .ssl .sync_async
display_name: PyOpenSSL Win64 Python3.13
run_on:
- windows-64-vsMulti-small
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: C:/python/Python313/python.exe
- name: pyopenssl-rhel8-pypy3.10
tasks:
@ -790,7 +795,8 @@ buildvariants:
- rhel87-small
batchtime: 10080
expansions:
TEST_NAME: pyopenssl
TEST_NAME: default
SUB_TEST_NAME: pyopenssl
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
# Search index tests

View File

@ -491,7 +491,7 @@ def create_enterprise_auth_variants():
def create_pyopenssl_variants():
base_name = "PyOpenSSL"
batchtime = BATCHTIME_WEEK
expansions = dict(TEST_NAME="pyopenssl")
expansions = dict(TEST_NAME="default", SUB_TEST_NAME="pyopenssl")
variants = []
for python in ALL_PYTHONS:
@ -506,14 +506,25 @@ def create_pyopenssl_variants():
host = DEFAULT_HOST
display_name = get_variant_name(base_name, host, python=python)
variant = create_variant(
[f".replica_set .{auth} .{ssl} .sync", f".7.0 .{auth} .{ssl} .sync"],
display_name,
python=python,
host=host,
expansions=expansions,
batchtime=batchtime,
)
# only need to run some on async
if python in (CPYTHONS[1], CPYTHONS[-1]):
variant = create_variant(
[f".replica_set .{auth} .{ssl} .sync_async", f".7.0 .{auth} .{ssl} .sync_async"],
display_name,
python=python,
host=host,
expansions=expansions,
batchtime=batchtime,
)
else:
variant = create_variant(
[f".replica_set .{auth} .{ssl} .sync", f".7.0 .{auth} .{ssl} .sync"],
display_name,
python=python,
host=host,
expansions=expansions,
batchtime=batchtime,
)
variants.append(variant)
return variants

View File

@ -54,7 +54,6 @@ jobs:
queries: security-extended
config: |
paths-ignore:
- '.github/**'
- 'doc/**'
- 'tools/**'
- 'test/**'

View File

@ -16,6 +16,7 @@
.. autodata:: MD5_SUBTYPE
.. autodata:: COLUMN_SUBTYPE
.. autodata:: SENSITIVE_SUBTYPE
.. autodata:: VECTOR_SUBTYPE
.. autodata:: USER_DEFINED_SUBTYPE
.. autoclass:: UuidRepresentation

View File

@ -1,6 +1,31 @@
Changelog
=========
Changes in Version 4.12.1 (2025/04/29)
--------------------------------------
Version 4.12.1 is a bug fix release.
- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL.
- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels.
- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``.
- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string
could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description,
nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type
errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()".
- Fixed a bug where MongoDB cluster topology changes could cause asynchronous operations to take much longer to complete
due to holding the Topology lock while closing stale connections.
- Fixed a bug that would cause AsyncMongoClient to attempt to use PyOpenSSL when available, resulting in errors such as
"pymongo.errors.ServerSelectionTimeoutError: 'SSLContext' object has no attribute 'wrap_bio'".
Issues Resolved
...............
See the `PyMongo 4.12.1 release notes in JIRA`_ for the list of resolved issues
in this release.
.. _PyMongo 4.12.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=43094
Changes in Version 4.12.0 (2025/04/08)
--------------------------------------

View File

@ -105,6 +105,16 @@ from pymongo.synchronous.collection import ReturnDocument
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.write_concern import WriteConcern
# Public module compatibility imports
# isort: off
from pymongo import uri_parser # noqa: F401
from pymongo import change_stream # noqa: F401
from pymongo import client_session # noqa: F401
from pymongo import collection # noqa: F401
from pymongo import command_cursor # noqa: F401
from pymongo import database # noqa: F401
# isort: on
version = __version__
"""Current version of PyMongo."""

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import re
from typing import List, Tuple, Union
__version__ = "4.12.0"
__version__ = "4.12.1"
def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]:

View File

@ -87,7 +87,7 @@ from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
@ -157,6 +157,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
self.mongocryptd_client = mongocryptd_client
self.opts = opts
self._spawned = False
self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC)
async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.
@ -168,7 +169,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
endpoint = kms_context.endpoint
message = kms_context.message
provider = kms_context.kms_provider
ctx = self.opts._kms_ssl_contexts.get(provider)
ctx = self._kms_ssl_contexts.get(provider)
if ctx is None:
# Enable strict certificate verification, OCSP, match hostname, and
# SNI using the system default CA certificates.
@ -180,6 +181,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
_IS_SYNC,
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
@ -396,6 +398,8 @@ class _Encrypter:
encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS)
self._bypass_auto_encryption = opts._bypass_auto_encryption
self._internal_client = None
# parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created
opts._kms_ssl_contexts(_IS_SYNC)
def _get_internal_client(
encrypter: _Encrypter, mongo_client: AsyncMongoClient[_DocumentTypeArg]
@ -675,6 +679,7 @@ class AsyncClientEncryption(Generic[_DocumentType]):
kms_tls_options=kms_tls_options,
key_expiration_ms=key_expiration_ms,
)
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
None, key_vault_coll, None, opts
)

View File

@ -109,6 +109,7 @@ from pymongo.operations import (
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@ -779,7 +780,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
@ -804,18 +805,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break
@ -838,7 +839,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
@ -857,7 +858,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
@ -873,8 +874,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
@ -975,6 +975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.asynchronous.encryption import _Encrypter
@ -1205,6 +1206,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description
@property
@ -1218,6 +1229,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)
@ -1576,6 +1589,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
"""
if self._topology is None:
await self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
@ -1598,6 +1613,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_primary() # type: ignore[return-value]
@property
@ -1611,6 +1628,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_secondaries()
@property
@ -1621,6 +1640,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_arbiters()
@property

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import logging
@ -75,6 +76,7 @@ from pymongo.monitoring import (
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
SSLErrors,
_CancellationContext,
_configured_protocol_interface,
_get_timeout_details,
@ -85,7 +87,6 @@ from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import SSLError
if TYPE_CHECKING:
from bson import CodecOptions
@ -637,7 +638,7 @@ class AsyncConnection:
reason = ConnectionClosedReason.ERROR
await self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
else:
@ -860,8 +861,14 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -891,8 +898,14 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.STALE)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.STALE)
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -938,8 +951,14 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
close_conns.append(self.conns.pop())
for conn in close_conns:
await conn.close_conn(ConnectionClosedReason.IDLE)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
return_exceptions=True,
)
else:
for conn in close_conns:
await conn.close_conn(ConnectionClosedReason.IDLE)
while True:
async with self.size_cond:
@ -1033,7 +1052,7 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
if isinstance(error, (IOError, OSError, SSLError)):
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)

View File

@ -51,6 +51,7 @@ class TopologySettings:
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.
@ -78,8 +79,10 @@ class TopologySettings:
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])

View File

@ -96,6 +96,7 @@ class _SrvResolver:
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
self.nparts = len(split_fqdn)
async def get_options(self) -> Optional[str]:
from dns import resolver
@ -137,12 +138,13 @@ class _SrvResolver:
# Validate hosts
for node in nodes:
if self.__fqdn == node[0].lower():
srv_host = node[0].lower()
if self.__fqdn == srv_host and self.nparts < 3:
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:

View File

@ -529,12 +529,6 @@ class Topology:
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)
# Wake anything waiting in select_servers().
self._condition.notify_all()
@ -557,6 +551,11 @@ class Topology:
# that didn't include this server.
if self._opened and self._description.has_server(server_description.address):
await self._process_change(server_description, reset_pool, interrupt_connections)
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)
async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.

View File

@ -84,7 +84,9 @@ def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
return ReadConcern(concern)
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
def _parse_ssl_options(
options: Mapping[str, Any], is_sync: bool
) -> tuple[Optional[SSLContext], bool]:
"""Parse ssl options."""
use_tls = options.get("tls")
if use_tls is not None:
@ -138,6 +140,7 @@ def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext]
allow_invalid_certificates,
allow_invalid_hostnames,
disable_ocsp_endpoint_check,
is_sync,
)
return ctx, allow_invalid_hostnames
return None, allow_invalid_hostnames
@ -167,7 +170,7 @@ def _parse_pool_options(
compression_settings = CompressionSettings(
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options, is_sync)
load_balanced = options.get("loadbalanced")
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
return PoolOptions(

View File

@ -20,6 +20,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional
from pymongo.uri_parser_shared import _parse_kms_tls_options
try:
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
@ -32,9 +34,9 @@ except ImportError:
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.errors import ConfigurationError
from pymongo.uri_parser_shared import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
@ -236,10 +238,22 @@ class AutoEncryptionOpts:
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
# Maps KMS provider name to a SSLContext.
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
self._kms_tls_options = kms_tls_options
self._sync_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None
self._async_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None
self._bypass_query_analysis = bypass_query_analysis
self._key_expiration_ms = key_expiration_ms
def _kms_ssl_contexts(self, is_sync: bool) -> dict[str, SSLContext]:
if is_sync:
if self._sync_kms_ssl_contexts is None:
self._sync_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, True)
return self._sync_kms_ssl_contexts
else:
if self._async_kms_ssl_contexts is None:
self._async_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, False)
return self._async_kms_ssl_contexts
class RangeOpts:
"""Options to configure encrypted queries using the range algorithm."""

View File

@ -46,22 +46,18 @@ except ImportError:
_HAVE_SSL = False
try:
from pymongo.pyopenssl_context import (
BLOCKING_IO_LOOKUP_ERROR,
BLOCKING_IO_READ_ERROR,
BLOCKING_IO_WRITE_ERROR,
_sslConn,
)
from pymongo.pyopenssl_context import _sslConn
_HAVE_PYOPENSSL = True
except ImportError:
_HAVE_PYOPENSSL = False
_sslConn = SSLSocket # type: ignore
from pymongo.ssl_support import ( # type: ignore[assignment]
BLOCKING_IO_LOOKUP_ERROR,
BLOCKING_IO_READ_ERROR,
BLOCKING_IO_WRITE_ERROR,
)
_sslConn = SSLSocket # type: ignore[assignment, misc]
from pymongo.ssl_support import (
BLOCKING_IO_LOOKUP_ERROR,
BLOCKING_IO_READ_ERROR,
BLOCKING_IO_WRITE_ERROR,
)
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
@ -71,7 +67,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
# These socket-based I/O methods are for KMS requests and any other network operations that do not use

View File

@ -38,8 +38,9 @@ from pymongo.errors import ( # type:ignore[attr-defined]
)
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
from pymongo.pool_options import PoolOptions
from pymongo.ssl_support import HAS_SNI, SSLError
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
SSLErrors = (PYSSLError, SSLError)
if TYPE_CHECKING:
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
@ -138,7 +139,7 @@ def _raise_connection_failure(
msg += format_timeout_details(timeout_details)
if isinstance(error, socket.timeout):
raise NetworkTimeout(msg) from error
elif isinstance(error, SSLError) and "timed out" in str(error):
elif isinstance(error, SSLErrors) and "timed out" in str(error):
# Eventlet does not distinguish TLS network timeouts from other
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
# Luckily, we can work around this limitation because the phrase
@ -279,7 +280,7 @@ async def _async_configured_socket(
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _has_sni(False):
loop = asyncio.get_running_loop()
ssl_sock = await loop.run_in_executor(
None,
@ -293,7 +294,7 @@ async def _async_configured_socket(
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
except (OSError, *SSLErrors) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
@ -346,12 +347,10 @@ async def _configured_protocol_interface(
ssl=ssl_context,
)
except _CertificateError:
transport.abort()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
transport.abort()
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
@ -460,7 +459,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _has_sni(True):
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
else:
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
@ -469,7 +468,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
except (OSError, *SSLErrors) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
@ -509,7 +508,7 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
if HAS_SNI:
if _has_sni(True):
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
else:
ssl_sock = ssl_context.wrap_socket(sock)
@ -518,7 +517,7 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, SSLError) as exc:
except (OSError, *SSLErrors) as exc:
sock.close()
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol

View File

@ -15,16 +15,19 @@
"""Support for SSL in PyMongo."""
from __future__ import annotations
import types
import warnings
from typing import Optional
from typing import Any, Optional, Union
from pymongo.errors import ConfigurationError
HAVE_SSL = True
HAVE_PYSSL = True
try:
import pymongo.pyopenssl_context as _ssl
import pymongo.pyopenssl_context as _pyssl
except (ImportError, AttributeError) as exc:
HAVE_PYSSL = False
if isinstance(exc, AttributeError):
warnings.warn(
"Failed to use the installed version of PyOpenSSL. "
@ -35,10 +38,10 @@ except (ImportError, AttributeError) as exc:
UserWarning,
stacklevel=2,
)
try:
import pymongo.ssl_context as _ssl # type: ignore[no-redef]
except ImportError:
HAVE_SSL = False
try:
import pymongo.ssl_context as _ssl
except ImportError:
HAVE_SSL = False
if HAVE_SSL:
@ -49,14 +52,29 @@ if HAVE_SSL:
import ssl as _stdlibssl # noqa: F401
from ssl import CERT_NONE, CERT_REQUIRED
HAS_SNI = _ssl.HAS_SNI
IPADDR_SAFE = True
if HAVE_PYSSL:
PYSSLError: Any = _pyssl.SSLError
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR)
BLOCKING_IO_WRITE_ERROR: tuple = (
_pyssl.BLOCKING_IO_WRITE_ERROR,
_ssl.BLOCKING_IO_WRITE_ERROR,
)
else:
PYSSLError = _ssl.SSLError
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,)
BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,)
SSLError = _ssl.SSLError
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR
BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
def _has_sni(is_sync: bool) -> bool:
if is_sync and HAVE_PYSSL:
return _pyssl.HAS_SNI
return _ssl.HAS_SNI
def get_ssl_context(
certfile: Optional[str],
passphrase: Optional[str],
@ -65,10 +83,15 @@ if HAVE_SSL:
allow_invalid_certificates: bool,
allow_invalid_hostnames: bool,
disable_ocsp_endpoint_check: bool,
) -> _ssl.SSLContext:
is_sync: bool,
) -> Union[_pyssl.SSLContext, _ssl.SSLContext]: # type: ignore[name-defined]
"""Create and return an SSLContext object."""
if is_sync and HAVE_PYSSL:
ssl: types.ModuleType = _pyssl
else:
ssl = _ssl
verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED
ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23)
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if verify_mode != CERT_NONE:
ctx.check_hostname = not allow_invalid_hostnames
else:
@ -80,22 +103,20 @@ if HAVE_SSL:
# up to date versions of MongoDB 2.4 and above already disable
# SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7
# and >= 3.3.4 and SSLv3 in >= 3.4.3.
ctx.options |= _ssl.OP_NO_SSLv2
ctx.options |= _ssl.OP_NO_SSLv3
ctx.options |= _ssl.OP_NO_COMPRESSION
ctx.options |= _ssl.OP_NO_RENEGOTIATION
ctx.options |= ssl.OP_NO_SSLv2
ctx.options |= ssl.OP_NO_SSLv3
ctx.options |= ssl.OP_NO_COMPRESSION
ctx.options |= ssl.OP_NO_RENEGOTIATION
if certfile is not None:
try:
ctx.load_cert_chain(certfile, None, passphrase)
except _ssl.SSLError as exc:
except ssl.SSLError as exc:
raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None
if crlfile is not None:
if _ssl.IS_PYOPENSSL:
if ssl.IS_PYOPENSSL:
raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL")
# Match the server's behavior.
ctx.verify_flags = getattr( # type:ignore[attr-defined]
_ssl, "VERIFY_CRL_CHECK_LEAF", 0
)
ctx.verify_flags = getattr(ssl, "VERIFY_CRL_CHECK_LEAF", 0)
ctx.load_verify_locations(crlfile)
if ca_certs is not None:
ctx.load_verify_locations(ca_certs)
@ -109,9 +130,11 @@ else:
class SSLError(Exception): # type: ignore
pass
HAS_SNI = False
IPADDR_SAFE = False
BLOCKING_IO_ERRORS = () # type:ignore[assignment]
BLOCKING_IO_ERRORS = ()
def _has_sni(is_sync: bool) -> bool: # noqa: ARG001
return False
def get_ssl_context(*dummy): # type: ignore
"""No ssl module, raise ConfigurationError."""

View File

@ -86,7 +86,7 @@ from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser_shared import parse_host
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
@ -156,6 +156,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
self.mongocryptd_client = mongocryptd_client
self.opts = opts
self._spawned = False
self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC)
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.
@ -167,7 +168,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
endpoint = kms_context.endpoint
message = kms_context.message
provider = kms_context.kms_provider
ctx = self.opts._kms_ssl_contexts.get(provider)
ctx = self._kms_ssl_contexts.get(provider)
if ctx is None:
# Enable strict certificate verification, OCSP, match hostname, and
# SNI using the system default CA certificates.
@ -179,6 +180,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
_IS_SYNC,
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
@ -393,6 +395,8 @@ class _Encrypter:
encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS)
self._bypass_auto_encryption = opts._bypass_auto_encryption
self._internal_client = None
# parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created
opts._kms_ssl_contexts(_IS_SYNC)
def _get_internal_client(
encrypter: _Encrypter, mongo_client: MongoClient[_DocumentTypeArg]
@ -668,6 +672,7 @@ class ClientEncryption(Generic[_DocumentType]):
kms_tls_options=kms_tls_options,
key_expiration_ms=key_expiration_ms,
)
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
None, key_vault_coll, None, opts
)

View File

@ -101,6 +101,7 @@ from pymongo.operations import (
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, uri_parser
@ -777,7 +778,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
@ -802,18 +803,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break
@ -836,7 +837,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
@ -855,7 +856,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
@ -871,8 +872,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
@ -973,6 +973,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.synchronous.encryption import _Encrypter
@ -1203,6 +1204,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description
@property
@ -1216,6 +1227,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)
@ -1570,6 +1583,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
"""
if self._topology is None:
self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
@ -1592,6 +1607,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_primary() # type: ignore[return-value]
@property
@ -1605,6 +1622,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_secondaries()
@property
@ -1615,6 +1634,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_arbiters()
@property

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import logging
@ -72,6 +73,7 @@ from pymongo.monitoring import (
from pymongo.network_layer import NetworkingInterface, receive_message, sendall
from pymongo.pool_options import PoolOptions
from pymongo.pool_shared import (
SSLErrors,
_CancellationContext,
_configured_socket_interface,
_get_timeout_details,
@ -82,7 +84,6 @@ from pymongo.read_preferences import ReadPreference
from pymongo.server_api import _add_to_command
from pymongo.server_type import SERVER_TYPE
from pymongo.socket_checker import SocketChecker
from pymongo.ssl_support import SSLError
from pymongo.synchronous.client_session import _validate_session_write_concern
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.synchronous.network import command
@ -635,7 +636,7 @@ class Connection:
reason = ConnectionClosedReason.ERROR
self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)
else:
@ -858,8 +859,14 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -889,8 +896,14 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -934,8 +947,14 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
close_conns.append(self.conns.pop())
for conn in close_conns:
conn.close_conn(ConnectionClosedReason.IDLE)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
return_exceptions=True,
)
else:
for conn in close_conns:
conn.close_conn(ConnectionClosedReason.IDLE)
while True:
with self.size_cond:
@ -1029,7 +1048,7 @@ class Pool:
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
error=ConnectionClosedReason.ERROR,
)
if isinstance(error, (IOError, OSError, SSLError)):
if isinstance(error, (IOError, OSError, *SSLErrors)):
details = _get_timeout_details(self.opts)
_raise_connection_failure(self.address, error, timeout_details=details)

View File

@ -51,6 +51,7 @@ class TopologySettings:
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.
@ -78,8 +79,10 @@ class TopologySettings:
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode
self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])

View File

@ -96,6 +96,7 @@ class _SrvResolver:
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
self.nparts = len(split_fqdn)
def get_options(self) -> Optional[str]:
from dns import resolver
@ -137,12 +138,13 @@ class _SrvResolver:
# Validate hosts
for node in nodes:
if self.__fqdn == node[0].lower():
srv_host = node[0].lower()
if self.__fqdn == srv_host and self.nparts < 3:
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:

View File

@ -529,12 +529,6 @@ class Topology:
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
# Wake anything waiting in select_servers().
self._condition.notify_all()
@ -557,6 +551,11 @@ class Topology:
# that didn't include this server.
if self._opened and self._description.has_server(server_description.address):
self._process_change(server_description, reset_pool, interrupt_connections)
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.

View File

@ -420,7 +420,10 @@ def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
def _parse_kms_tls_options(
kms_tls_options: Optional[Mapping[str, Any]],
is_sync: bool,
) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}
@ -435,7 +438,7 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts, is_sync)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
if allow_invalid_hostnames:

View File

@ -823,6 +823,14 @@ class ClientContext:
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def require_async(self, func):
"""Run a test only if using the asynchronous API.""" # unasync: off
return self._require(
lambda: not _IS_SYNC,
"This test only works with the asynchronous API", # unasync: off
func=func,
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)

View File

@ -825,6 +825,14 @@ class AsyncClientContext:
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def require_async(self, func):
"""Run a test only if using the asynchronous API.""" # unasync: off
return self._require(
lambda: not _IS_SYNC,
"This test only works with the asynchronous API", # unasync: off
func=func,
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)

View File

@ -849,6 +849,58 @@ class TestClient(AsyncIntegrationTest):
with self.assertRaises(ConnectionFailure):
await c.pymongo_test.test.find_one()
@async_client_context.require_replica_set
@async_client_context.require_no_load_balancer
@async_client_context.require_tls
async def test_init_disconnected_with_srv(self):
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# nodes returns an empty set if not connected
self.assertEqual(c.nodes, frozenset())
# topology_description returns the initial seed description if not connected
topology_description = c.topology_description
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
self.assertEqual(
{
("test1.test.build.10gen.cc", None): ServerDescription(
("test1.test.build.10gen.cc", None)
)
},
topology_description.server_descriptions(),
)
# address causes client to block until connected
self.assertIsNotNone(await c.address)
# Initial seed topology and connected topology have the same ID
self.assertEqual(
c._topology._topology_id, topology_description._topology_settings._topology_id
)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# primary causes client to block until connected
await c.primary
self.assertIsNotNone(c._topology)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# secondaries causes client to block until connected
await c.secondaries
self.assertIsNotNone(c._topology)
await c.close()
c = await self.async_rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# arbiters causes client to block until connected
await c.arbiters
self.assertIsNotNone(c._topology)
async def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = await self.async_rs_or_single_client(seed, connect=False)

View File

@ -20,10 +20,15 @@ import os
import socketserver
import sys
import threading
import time
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
sys.path[0:0] = [""]
from test.asynchronous import (
@ -370,6 +375,74 @@ class TestPoolManagement(AsyncIntegrationTest):
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
@async_client_context.require_failCommand_appName
@async_client_context.require_test_commands
@async_client_context.require_async
async def test_connection_close_does_not_block_other_operations(self):
listener = CMAPHeartbeatListener()
client = await self.async_single_client(
appName="SDAMConnectionCloseTest",
event_listeners=[listener],
heartbeatFrequencyMS=500,
minPoolSize=10,
)
server = await (await client._get_topology()).select_server(
writable_server_selector, _Op.TEST
)
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
await client.db.test.insert_one({"x": 1})
close_delay = 0.1
latencies = []
should_exit = []
async def run_task():
while True:
start_time = time.monotonic()
await client.db.test.find_one({})
elapsed = time.monotonic() - start_time
latencies.append(elapsed)
if should_exit:
break
await asyncio.sleep(0.001)
task = ConcurrentRunner(target=run_task)
await task.start()
original_close = AsyncConnection.close_conn
try:
# Artificially delay the close operation to simulate a slow close
async def mock_close(self, reason):
await asyncio.sleep(close_delay)
await original_close(self, reason)
AsyncConnection.close_conn = mock_close
fail_hello = {
"mode": {"times": 4},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 91,
"appName": "SDAMConnectionCloseTest",
},
}
async with self.fail_point(fail_hello):
# Wait for server heartbeat to fail
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
# Wait until all idle connections are closed to simulate real-world conditions
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
# Wait for one more find to complete after the pool has been reset, then shutdown the task
n = len(latencies)
await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find")
should_exit.append(True)
await task.join()
# No operation latency should not significantly exceed close_delay
self.assertLessEqual(max(latencies), close_delay * 5.0)
finally:
AsyncConnection.close_conn = original_close
class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_serverless

View File

@ -220,12 +220,15 @@ class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
mock_resolver.side_effect = mock_resolve
domain = case["query"].split("._tcp.")[1]
connection_string = f"mongodb+srv://{domain}"
try:
if "expected_error" not in case:
await parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
try:
await parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
with patch("dns.asyncresolver.resolve"):
@ -289,6 +292,17 @@ class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
]
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
async def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical(
self
):
test_cases = [
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "blogs.mongodb.com",
},
]
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
if __name__ == "__main__":
unittest.main()

View File

@ -41,6 +41,7 @@ import pytest
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.helpers import anext
from pymongo.daemon import _spawn_daemon
from pymongo.uri_parser_shared import _parse_kms_tls_options
try:
from pymongo.pyopenssl_context import IS_PYOPENSSL
@ -141,7 +142,7 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
self.assertEqual(opts._kms_ssl_contexts, {})
self.assertEqual(opts._kms_tls_options, None)
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def test_init_spawn_args(self):
@ -165,30 +166,38 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
)
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def test_init_kms_tls_options(self):
async def test_init_kms_tls_options(self):
# Error cases:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
AsyncMongoClient(auto_encryption_opts=opts)
tls_opts: Any
for tls_opts in [
{"kmip": {"tls": True, "tlsInsecure": True}},
{"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}},
{"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}},
]:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"):
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
AsyncMongoClient(auto_encryption_opts=opts)
opts = AutoEncryptionOpts(
{}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}
)
with self.assertRaises(FileNotFoundError):
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}})
AsyncMongoClient(auto_encryption_opts=opts)
# Success cases:
tls_opts: Any
for tls_opts in [None, {}]:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
self.assertEqual(opts._kms_ssl_contexts, {})
kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
self.assertEqual(kms_tls_contexts, {})
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
ctx = opts._kms_ssl_contexts["kmip"]
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
ctx = _kms_ssl_contexts["kmip"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
ctx = opts._kms_ssl_contexts["aws"]
ctx = _kms_ssl_contexts["aws"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
opts = AutoEncryptionOpts(
@ -196,7 +205,8 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
"k.d",
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
)
ctx = opts._kms_ssl_contexts["kmip"]
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
ctx = _kms_ssl_contexts["kmip"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
@ -2225,7 +2235,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
)
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
if not hasattr(ctx, "check_ocsp_endpoint"):
raise self.skipTest("OCSP not enabled")
self.assertFalse(ctx.check_ocsp_endpoint)

View File

@ -43,7 +43,7 @@ from urllib.parse import quote_plus
from pymongo import AsyncMongoClient, ssl_support
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context
from pymongo.write_concern import WriteConcern
_HAVE_PYOPENSSL = False
@ -134,7 +134,7 @@ class TestClientSSL(AsyncPyMongoTestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_use_pyopenssl_when_available(self):
self.assertTrue(_ssl.IS_PYOPENSSL)
self.assertTrue(HAVE_PYSSL)
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
def test_load_trusted_ca_certs(self):
@ -177,7 +177,7 @@ class TestSSL(AsyncIntegrationTest):
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
@ -309,13 +309,13 @@ class TestSSL(AsyncIntegrationTest):
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
ctx = get_ssl_context(None, None, None, None, True, True, False)
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, True, False, False)
ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, True, False)
ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
self.assertTrue(ctx.check_hostname)
response = await self.client.admin.command(HelloCompat.LEGACY_CMD)
@ -376,9 +376,11 @@ class TestSSL(AsyncIntegrationTest):
)
@async_client_context.require_tlsCertificateKeyFile
@async_client_context.require_sync
@async_client_context.require_no_api_version
@ignore_deprecations
async def test_tlsCRLFile_support(self):
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
@ -469,7 +471,7 @@ class TestSSL(AsyncIntegrationTest):
)
def test_system_certs_config_error(self):
ctx = get_ssl_context(None, None, None, None, True, True, False)
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
ctx, "load_default_certs"
):
@ -500,11 +502,11 @@ class TestSSL(AsyncIntegrationTest):
# Force the test on Windows, regardless of environment.
ssl_support.HAVE_WINCERTSTORE = False
try:
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
finally:
@ -521,11 +523,11 @@ class TestSSL(AsyncIntegrationTest):
if not ssl_support.HAVE_WINCERTSTORE:
raise SkipTest("Need wincertstore to test wincertstore.")
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
@ -657,6 +659,16 @@ class TestSSL(AsyncIntegrationTest):
) as client:
self.assertTrue(await client.admin.command("ping"))
@async_client_context.require_async
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
async def test_pyopenssl_ignored_in_async(self):
client = AsyncMongoClient(
"mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true"
)
await client.admin.command("ping") # command doesn't matter, just needs it to connect
await client.close()
if __name__ == "__main__":
unittest.main()

View File

@ -26,7 +26,7 @@ import pytest
sys.path[0:0] = [""]
import pymongo
from pymongo.ssl_support import HAS_SNI
from pymongo.ssl_support import _has_sni
pytestmark = pytest.mark.atlas_connect
@ -57,7 +57,7 @@ class TestAtlasConnect(PyMongoTestCase):
# No auth error
client.test.test.count_documents({})
@unittest.skipUnless(HAS_SNI, "Free tier requires SNI support")
@unittest.skipUnless(_has_sni(True), "Free tier requires SNI support")
def test_free_tier(self):
self.connect(URIS["ATLAS_FREE"])
@ -80,7 +80,7 @@ class TestAtlasConnect(PyMongoTestCase):
self.connect(uri)
self.assertIn("mongodb+srv://", uri)
@unittest.skipUnless(HAS_SNI, "Free tier requires SNI support")
@unittest.skipUnless(_has_sni(True), "Free tier requires SNI support")
def test_srv_free_tier(self):
self.connect_srv(URIS["ATLAS_SRV_FREE"])

View File

@ -824,6 +824,58 @@ class TestClient(IntegrationTest):
with self.assertRaises(ConnectionFailure):
c.pymongo_test.test.find_one()
@client_context.require_replica_set
@client_context.require_no_load_balancer
@client_context.require_tls
def test_init_disconnected_with_srv(self):
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# nodes returns an empty set if not connected
self.assertEqual(c.nodes, frozenset())
# topology_description returns the initial seed description if not connected
topology_description = c.topology_description
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
self.assertEqual(
{
("test1.test.build.10gen.cc", None): ServerDescription(
("test1.test.build.10gen.cc", None)
)
},
topology_description.server_descriptions(),
)
# address causes client to block until connected
self.assertIsNotNone(c.address)
# Initial seed topology and connected topology have the same ID
self.assertEqual(
c._topology._topology_id, topology_description._topology_settings._topology_id
)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# primary causes client to block until connected
c.primary
self.assertIsNotNone(c._topology)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# secondaries causes client to block until connected
c.secondaries
self.assertIsNotNone(c._topology)
c.close()
c = self.rs_or_single_client(
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
)
# arbiters causes client to block until connected
c.arbiters
self.assertIsNotNone(c._topology)
def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = self.rs_or_single_client(seed, connect=False)

View File

@ -209,6 +209,19 @@ class TestDefaultExports(unittest.TestCase):
)
from pymongo.write_concern import WriteConcern, validate_boolean
def test_pymongo_submodule_attributes(self):
import pymongo
self.assertTrue(hasattr(pymongo, "uri_parser"))
self.assertTrue(pymongo.uri_parser)
self.assertTrue(pymongo.uri_parser.parse_uri)
self.assertTrue(pymongo.change_stream)
self.assertTrue(pymongo.client_session)
self.assertTrue(pymongo.collection)
self.assertTrue(pymongo.cursor)
self.assertTrue(pymongo.command_cursor)
self.assertTrue(pymongo.database)
def test_gridfs_imports(self):
import gridfs
from gridfs.errors import CorruptGridFile, FileExists, GridFSError, NoFile

View File

@ -20,10 +20,15 @@ import os
import socketserver
import sys
import threading
import time
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.helpers import ConcurrentRunner
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import Connection
sys.path[0:0] = [""]
from test import (
@ -370,6 +375,72 @@ class TestPoolManagement(IntegrationTest):
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
@client_context.require_failCommand_appName
@client_context.require_test_commands
@client_context.require_async
def test_connection_close_does_not_block_other_operations(self):
listener = CMAPHeartbeatListener()
client = self.single_client(
appName="SDAMConnectionCloseTest",
event_listeners=[listener],
heartbeatFrequencyMS=500,
minPoolSize=10,
)
server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST)
wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
client.db.test.insert_one({"x": 1})
close_delay = 0.1
latencies = []
should_exit = []
def run_task():
while True:
start_time = time.monotonic()
client.db.test.find_one({})
elapsed = time.monotonic() - start_time
latencies.append(elapsed)
if should_exit:
break
time.sleep(0.001)
task = ConcurrentRunner(target=run_task)
task.start()
original_close = Connection.close_conn
try:
# Artificially delay the close operation to simulate a slow close
def mock_close(self, reason):
time.sleep(close_delay)
original_close(self, reason)
Connection.close_conn = mock_close
fail_hello = {
"mode": {"times": 4},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 91,
"appName": "SDAMConnectionCloseTest",
},
}
with self.fail_point(fail_hello):
# Wait for server heartbeat to fail
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
# Wait until all idle connections are closed to simulate real-world conditions
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
# Wait for one more find to complete after the pool has been reset, then shutdown the task
n = len(latencies)
wait_until(lambda: len(latencies) >= n + 1, "run one more find")
should_exit.append(True)
task.join()
# No operation latency should not significantly exceed close_delay
self.assertLessEqual(max(latencies), close_delay * 5.0)
finally:
Connection.close_conn = original_close
class TestServerMonitoringMode(IntegrationTest):
@client_context.require_no_serverless

View File

@ -218,12 +218,15 @@ class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
mock_resolver.side_effect = mock_resolve
domain = case["query"].split("._tcp.")[1]
connection_string = f"mongodb+srv://{domain}"
try:
if "expected_error" not in case:
parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
try:
parse_uri(connection_string)
except ConfigurationError as e:
self.assertIn(case["expected_error"], str(e))
else:
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
with patch("dns.resolver.resolve"):
@ -287,6 +290,17 @@ class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
]
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical(
self
):
test_cases = [
{
"query": "_mongodb._tcp.blogs.mongodb.com",
"mock_target": "blogs.mongodb.com",
},
]
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
if __name__ == "__main__":
unittest.main()

View File

@ -41,6 +41,7 @@ import pytest
from pymongo.daemon import _spawn_daemon
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.helpers import next
from pymongo.uri_parser_shared import _parse_kms_tls_options
try:
from pymongo.pyopenssl_context import IS_PYOPENSSL
@ -141,7 +142,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
self.assertEqual(opts._kms_ssl_contexts, {})
self.assertEqual(opts._kms_tls_options, None)
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def test_init_spawn_args(self):
@ -167,28 +168,36 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def test_init_kms_tls_options(self):
# Error cases:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
MongoClient(auto_encryption_opts=opts)
tls_opts: Any
for tls_opts in [
{"kmip": {"tls": True, "tlsInsecure": True}},
{"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}},
{"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}},
]:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"):
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
MongoClient(auto_encryption_opts=opts)
opts = AutoEncryptionOpts(
{}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}
)
with self.assertRaises(FileNotFoundError):
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}})
MongoClient(auto_encryption_opts=opts)
# Success cases:
tls_opts: Any
for tls_opts in [None, {}]:
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
self.assertEqual(opts._kms_ssl_contexts, {})
kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
self.assertEqual(kms_tls_contexts, {})
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
ctx = opts._kms_ssl_contexts["kmip"]
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
ctx = _kms_ssl_contexts["kmip"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
ctx = opts._kms_ssl_contexts["aws"]
ctx = _kms_ssl_contexts["aws"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
opts = AutoEncryptionOpts(
@ -196,7 +205,8 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
"k.d",
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
)
ctx = opts._kms_ssl_contexts["kmip"]
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
ctx = _kms_ssl_contexts["kmip"]
self.assertEqual(ctx.check_hostname, True)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
@ -2217,7 +2227,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
)
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
if not hasattr(ctx, "check_ocsp_endpoint"):
raise self.skipTest("OCSP not enabled")
self.assertFalse(ctx.check_ocsp_endpoint)

View File

@ -43,7 +43,7 @@ from urllib.parse import quote_plus
from pymongo import MongoClient, ssl_support
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
from pymongo.hello import HelloCompat
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context
from pymongo.write_concern import WriteConcern
_HAVE_PYOPENSSL = False
@ -134,7 +134,7 @@ class TestClientSSL(PyMongoTestCase):
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
def test_use_pyopenssl_when_available(self):
self.assertTrue(_ssl.IS_PYOPENSSL)
self.assertTrue(HAVE_PYSSL)
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
def test_load_trusted_ca_certs(self):
@ -177,7 +177,7 @@ class TestSSL(IntegrationTest):
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
@ -309,13 +309,13 @@ class TestSSL(IntegrationTest):
#
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
ctx = get_ssl_context(None, None, None, None, True, True, False)
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, True, False, False)
ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, True, False)
ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC)
self.assertFalse(ctx.check_hostname)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
self.assertTrue(ctx.check_hostname)
response = self.client.admin.command(HelloCompat.LEGACY_CMD)
@ -376,9 +376,11 @@ class TestSSL(IntegrationTest):
)
@client_context.require_tlsCertificateKeyFile
@client_context.require_sync
@client_context.require_no_api_version
@ignore_deprecations
def test_tlsCRLFile_support(self):
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL:
self.assertRaises(
ConfigurationError,
self.simple_client,
@ -469,7 +471,7 @@ class TestSSL(IntegrationTest):
)
def test_system_certs_config_error(self):
ctx = get_ssl_context(None, None, None, None, True, True, False)
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
ctx, "load_default_certs"
):
@ -500,11 +502,11 @@ class TestSSL(IntegrationTest):
# Force the test on Windows, regardless of environment.
ssl_support.HAVE_WINCERTSTORE = False
try:
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
finally:
@ -521,11 +523,11 @@ class TestSSL(IntegrationTest):
if not ssl_support.HAVE_WINCERTSTORE:
raise SkipTest("Need wincertstore to test wincertstore.")
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
ctx = get_ssl_context(None, None, None, None, False, False, False)
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
ssl_sock = ctx.wrap_socket(socket.socket())
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
@ -657,6 +659,14 @@ class TestSSL(IntegrationTest):
) as client:
self.assertTrue(client.admin.command("ping"))
@client_context.require_async
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
def test_pyopenssl_ignored_in_async(self):
client = MongoClient("mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true")
client.admin.command("ping") # command doesn't matter, just needs it to connect
client.close()
if __name__ == "__main__":
unittest.main()

View File

@ -35,6 +35,7 @@ def check_ocsp(host: str, port: int, capath: str) -> None:
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
True, # is sync
) # disable_ocsp_endpoint_check
# Ensure we're using pyOpenSSL.

View File

@ -288,7 +288,8 @@ def process_files(
if file in docstring_translate_files:
lines = translate_docstrings(lines)
if file in sync_test_files:
translate_imports(lines)
lines = translate_imports(lines)
lines = process_ignores(lines)
f.seek(0)
f.writelines(lines)
f.truncate()
@ -390,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]:
return [line for line in lines if line != "DOCSTRING_REMOVED"]
def process_ignores(lines: list[str]) -> list[str]:
for i in range(len(lines)):
for k, v in replacements.items():
if "unasync: off" in lines[i] and v in lines[i]:
lines[i] = lines[i].replace(v, k)
return lines
def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None:
unasync_files(
files,

1
uv.lock generated
View File

@ -998,6 +998,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5
[[package]]
name = "pymongo"
version = "4.12.0"
source = { editable = "." }
dependencies = [
{ name = "dnspython" },