Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d8a478f4c | ||
|
|
1d913c9ee8 | ||
|
|
84d0d3db4d | ||
|
|
c52a456fd2 | ||
|
|
dd99f80ce3 | ||
|
|
fecd29c1f8 | ||
|
|
c11d0f4def | ||
|
|
f5836b3f6f | ||
|
|
38bc13db9c | ||
|
|
c6671e239a | ||
|
|
79cb34a7b7 | ||
|
|
c83784610e | ||
|
|
3377510a63 | ||
|
|
6b3e254332 | ||
|
|
de0e2336cb |
@ -735,18 +735,20 @@ buildvariants:
|
||||
- macos-14
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3
|
||||
- name: pyopenssl-rhel8-python3.10
|
||||
tasks:
|
||||
- name: .replica_set .auth .ssl .sync
|
||||
- name: .7.0 .auth .ssl .sync
|
||||
- name: .replica_set .auth .ssl .sync_async
|
||||
- name: .7.0 .auth .ssl .sync_async
|
||||
display_name: PyOpenSSL RHEL8 Python3.10
|
||||
run_on:
|
||||
- rhel87-small
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: /opt/python/3.10/bin/python3
|
||||
- name: pyopenssl-rhel8-python3.11
|
||||
tasks:
|
||||
@ -757,7 +759,8 @@ buildvariants:
|
||||
- rhel87-small
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: /opt/python/3.11/bin/python3
|
||||
- name: pyopenssl-rhel8-python3.12
|
||||
tasks:
|
||||
@ -768,18 +771,20 @@ buildvariants:
|
||||
- rhel87-small
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: /opt/python/3.12/bin/python3
|
||||
- name: pyopenssl-win64-python3.13
|
||||
tasks:
|
||||
- name: .replica_set .auth .ssl .sync
|
||||
- name: .7.0 .auth .ssl .sync
|
||||
- name: .replica_set .auth .ssl .sync_async
|
||||
- name: .7.0 .auth .ssl .sync_async
|
||||
display_name: PyOpenSSL Win64 Python3.13
|
||||
run_on:
|
||||
- windows-64-vsMulti-small
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: C:/python/Python313/python.exe
|
||||
- name: pyopenssl-rhel8-pypy3.10
|
||||
tasks:
|
||||
@ -790,7 +795,8 @@ buildvariants:
|
||||
- rhel87-small
|
||||
batchtime: 10080
|
||||
expansions:
|
||||
TEST_NAME: pyopenssl
|
||||
TEST_NAME: default
|
||||
SUB_TEST_NAME: pyopenssl
|
||||
PYTHON_BINARY: /opt/python/pypy3.10/bin/python3
|
||||
|
||||
# Search index tests
|
||||
|
||||
@ -491,7 +491,7 @@ def create_enterprise_auth_variants():
|
||||
def create_pyopenssl_variants():
|
||||
base_name = "PyOpenSSL"
|
||||
batchtime = BATCHTIME_WEEK
|
||||
expansions = dict(TEST_NAME="pyopenssl")
|
||||
expansions = dict(TEST_NAME="default", SUB_TEST_NAME="pyopenssl")
|
||||
variants = []
|
||||
|
||||
for python in ALL_PYTHONS:
|
||||
@ -506,14 +506,25 @@ def create_pyopenssl_variants():
|
||||
host = DEFAULT_HOST
|
||||
|
||||
display_name = get_variant_name(base_name, host, python=python)
|
||||
variant = create_variant(
|
||||
[f".replica_set .{auth} .{ssl} .sync", f".7.0 .{auth} .{ssl} .sync"],
|
||||
display_name,
|
||||
python=python,
|
||||
host=host,
|
||||
expansions=expansions,
|
||||
batchtime=batchtime,
|
||||
)
|
||||
# only need to run some on async
|
||||
if python in (CPYTHONS[1], CPYTHONS[-1]):
|
||||
variant = create_variant(
|
||||
[f".replica_set .{auth} .{ssl} .sync_async", f".7.0 .{auth} .{ssl} .sync_async"],
|
||||
display_name,
|
||||
python=python,
|
||||
host=host,
|
||||
expansions=expansions,
|
||||
batchtime=batchtime,
|
||||
)
|
||||
else:
|
||||
variant = create_variant(
|
||||
[f".replica_set .{auth} .{ssl} .sync", f".7.0 .{auth} .{ssl} .sync"],
|
||||
display_name,
|
||||
python=python,
|
||||
host=host,
|
||||
expansions=expansions,
|
||||
batchtime=batchtime,
|
||||
)
|
||||
variants.append(variant)
|
||||
|
||||
return variants
|
||||
|
||||
1
.github/workflows/codeql.yml
vendored
1
.github/workflows/codeql.yml
vendored
@ -54,7 +54,6 @@ jobs:
|
||||
queries: security-extended
|
||||
config: |
|
||||
paths-ignore:
|
||||
- '.github/**'
|
||||
- 'doc/**'
|
||||
- 'tools/**'
|
||||
- 'test/**'
|
||||
|
||||
2
.github/workflows/release-python.yml
vendored
2
.github/workflows/release-python.yml
vendored
@ -20,7 +20,7 @@ env:
|
||||
# Changes per repo
|
||||
PRODUCT_NAME: PyMongo
|
||||
# Changes per branch
|
||||
EVERGREEN_PROJECT: mongo-python-driver
|
||||
EVERGREEN_PROJECT: mongo-python-driver-prev-rel
|
||||
# Constant
|
||||
# inputs will be empty on a scheduled run. so, we only set dry_run
|
||||
# to 'false' when the input is set to 'false'.
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
.. autodata:: MD5_SUBTYPE
|
||||
.. autodata:: COLUMN_SUBTYPE
|
||||
.. autodata:: SENSITIVE_SUBTYPE
|
||||
.. autodata:: VECTOR_SUBTYPE
|
||||
.. autodata:: USER_DEFINED_SUBTYPE
|
||||
|
||||
.. autoclass:: UuidRepresentation
|
||||
|
||||
@ -1,6 +1,31 @@
|
||||
Changelog
|
||||
=========
|
||||
|
||||
Changes in Version 4.12.1 (2025/04/29)
|
||||
--------------------------------------
|
||||
|
||||
Version 4.12.1 is a bug fix release.
|
||||
|
||||
- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL.
|
||||
- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels.
|
||||
- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``.
|
||||
- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string
|
||||
could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description,
|
||||
nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type
|
||||
errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()".
|
||||
- Fixed a bug where MongoDB cluster topology changes could cause asynchronous operations to take much longer to complete
|
||||
due to holding the Topology lock while closing stale connections.
|
||||
- Fixed a bug that would cause AsyncMongoClient to attempt to use PyOpenSSL when available, resulting in errors such as
|
||||
"pymongo.errors.ServerSelectionTimeoutError: 'SSLContext' object has no attribute 'wrap_bio'".
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
See the `PyMongo 4.12.1 release notes in JIRA`_ for the list of resolved issues
|
||||
in this release.
|
||||
|
||||
.. _PyMongo 4.12.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=43094
|
||||
|
||||
Changes in Version 4.12.0 (2025/04/08)
|
||||
--------------------------------------
|
||||
|
||||
|
||||
@ -105,6 +105,16 @@ from pymongo.synchronous.collection import ReturnDocument
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
# Public module compatibility imports
|
||||
# isort: off
|
||||
from pymongo import uri_parser # noqa: F401
|
||||
from pymongo import change_stream # noqa: F401
|
||||
from pymongo import client_session # noqa: F401
|
||||
from pymongo import collection # noqa: F401
|
||||
from pymongo import command_cursor # noqa: F401
|
||||
from pymongo import database # noqa: F401
|
||||
# isort: on
|
||||
|
||||
version = __version__
|
||||
"""Current version of PyMongo."""
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
__version__ = "4.12.0"
|
||||
__version__ = "4.12.2.dev0"
|
||||
|
||||
|
||||
def get_version_tuple(version: str) -> Tuple[Union[int, str], ...]:
|
||||
|
||||
@ -87,7 +87,7 @@ from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -157,6 +157,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
self.mongocryptd_client = mongocryptd_client
|
||||
self.opts = opts
|
||||
self._spawned = False
|
||||
self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC)
|
||||
|
||||
async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
|
||||
"""Complete a KMS request.
|
||||
@ -168,7 +169,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
endpoint = kms_context.endpoint
|
||||
message = kms_context.message
|
||||
provider = kms_context.kms_provider
|
||||
ctx = self.opts._kms_ssl_contexts.get(provider)
|
||||
ctx = self._kms_ssl_contexts.get(provider)
|
||||
if ctx is None:
|
||||
# Enable strict certificate verification, OCSP, match hostname, and
|
||||
# SNI using the system default CA certificates.
|
||||
@ -180,6 +181,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
False, # allow_invalid_certificates
|
||||
False, # allow_invalid_hostnames
|
||||
False, # disable_ocsp_endpoint_check
|
||||
_IS_SYNC,
|
||||
)
|
||||
# CSOT: set timeout for socket creation.
|
||||
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
|
||||
@ -396,6 +398,8 @@ class _Encrypter:
|
||||
encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS)
|
||||
self._bypass_auto_encryption = opts._bypass_auto_encryption
|
||||
self._internal_client = None
|
||||
# parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created
|
||||
opts._kms_ssl_contexts(_IS_SYNC)
|
||||
|
||||
def _get_internal_client(
|
||||
encrypter: _Encrypter, mongo_client: AsyncMongoClient[_DocumentTypeArg]
|
||||
@ -675,6 +679,7 @@ class AsyncClientEncryption(Generic[_DocumentType]):
|
||||
kms_tls_options=kms_tls_options,
|
||||
key_expiration_ms=key_expiration_ms,
|
||||
)
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
|
||||
None, key_vault_coll, None, opts
|
||||
)
|
||||
|
||||
@ -109,6 +109,7 @@ from pymongo.operations import (
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
|
||||
@ -779,7 +780,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
self._seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
@ -804,18 +805,18 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
self._seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
dbase = res["database"] or dbase
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
self._seeds.update(split_hosts(entity, self._port))
|
||||
if not self._seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
for hostname in [node[0] for node in self._seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
@ -838,7 +839,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
opts = self._normalize_and_validate_options(opts, self._seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
@ -857,7 +858,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"seeds": self._seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
@ -873,8 +874,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
@ -975,6 +975,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.asynchronous.encryption import _Encrypter
|
||||
@ -1205,6 +1206,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Unknown,
|
||||
servers,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self._topology_settings,
|
||||
)
|
||||
return self._topology.description
|
||||
|
||||
@property
|
||||
@ -1218,6 +1229,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
to any servers, or a network partition causes it to lose connection
|
||||
to all servers.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return frozenset()
|
||||
description = self._topology.description
|
||||
return frozenset(s.address for s in description.known_servers)
|
||||
|
||||
@ -1576,6 +1589,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
topology_type = self._topology._description.topology_type
|
||||
if (
|
||||
topology_type == TOPOLOGY_TYPE.Sharded
|
||||
@ -1598,6 +1613,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
AsyncMongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_primary() # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
@ -1611,6 +1628,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
AsyncMongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
@ -1621,6 +1640,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
connected to a replica set, there are no arbiters, or this client was
|
||||
created without the `replicaSet` option.
|
||||
"""
|
||||
if self._topology is None:
|
||||
await self._get_topology()
|
||||
return await self._topology.get_arbiters()
|
||||
|
||||
@property
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import logging
|
||||
@ -75,6 +76,7 @@ from pymongo.monitoring import (
|
||||
from pymongo.network_layer import AsyncNetworkingInterface, async_receive_message, async_sendall
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
SSLErrors,
|
||||
_CancellationContext,
|
||||
_configured_protocol_interface,
|
||||
_get_timeout_details,
|
||||
@ -85,7 +87,6 @@ from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import SSLError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bson import CodecOptions
|
||||
@ -637,7 +638,7 @@ class AsyncConnection:
|
||||
reason = ConnectionClosedReason.ERROR
|
||||
await self.close_conn(reason)
|
||||
# SSLError from PyOpenSSL inherits directly from Exception.
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
if isinstance(error, (IOError, OSError, *SSLErrors)):
|
||||
details = _get_timeout_details(self.opts)
|
||||
_raise_connection_failure(self.address, error, timeout_details=details)
|
||||
else:
|
||||
@ -860,8 +861,14 @@ class Pool:
|
||||
# PoolClosedEvent but that reset() SHOULD close sockets *after*
|
||||
# publishing the PoolClearedEvent.
|
||||
if close:
|
||||
for conn in sockets:
|
||||
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in sockets:
|
||||
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_pool_closed(self.address)
|
||||
@ -891,8 +898,14 @@ class Pool:
|
||||
serverPort=self.address[1],
|
||||
serviceId=service_id,
|
||||
)
|
||||
for conn in sockets:
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in sockets:
|
||||
await conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
@ -938,8 +951,14 @@ class Pool:
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
close_conns.append(self.conns.pop())
|
||||
for conn in close_conns:
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in close_conns:
|
||||
await conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
async with self.size_cond:
|
||||
@ -1033,7 +1052,7 @@ class Pool:
|
||||
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
|
||||
error=ConnectionClosedReason.ERROR,
|
||||
)
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
if isinstance(error, (IOError, OSError, *SSLErrors)):
|
||||
details = _get_timeout_details(self.opts)
|
||||
_raise_connection_failure(self.address, error, timeout_details=details)
|
||||
|
||||
|
||||
@ -51,6 +51,7 @@ class TopologySettings:
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
@ -78,8 +79,10 @@ class TopologySettings:
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
if topology_id is not None:
|
||||
self._topology_id = topology_id
|
||||
else:
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@ -96,6 +96,7 @@ class _SrvResolver:
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
self.nparts = len(split_fqdn)
|
||||
|
||||
async def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
@ -137,12 +138,13 @@ class _SrvResolver:
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
if self.__fqdn == node[0].lower():
|
||||
srv_host = node[0].lower()
|
||||
if self.__fqdn == srv_host and self.nparts < 3:
|
||||
raise ConfigurationError(
|
||||
"Invalid SRV host: return address is identical to SRV hostname"
|
||||
)
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
nlist = srv_host.split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||
if self.__plist != nlist:
|
||||
|
||||
@ -529,12 +529,6 @@ class Topology:
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
server = self._servers.get(server_description.address)
|
||||
if server:
|
||||
await server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
# Wake anything waiting in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
@ -557,6 +551,11 @@ class Topology:
|
||||
# that didn't include this server.
|
||||
if self._opened and self._description.has_server(server_description.address):
|
||||
await self._process_change(server_description, reset_pool, interrupt_connections)
|
||||
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
|
||||
if reset_pool:
|
||||
server = self._servers.get(server_description.address)
|
||||
if server:
|
||||
await server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
|
||||
"""Process a new seedlist on an opened topology.
|
||||
|
||||
@ -84,7 +84,9 @@ def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern:
|
||||
return ReadConcern(concern)
|
||||
|
||||
|
||||
def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]:
|
||||
def _parse_ssl_options(
|
||||
options: Mapping[str, Any], is_sync: bool
|
||||
) -> tuple[Optional[SSLContext], bool]:
|
||||
"""Parse ssl options."""
|
||||
use_tls = options.get("tls")
|
||||
if use_tls is not None:
|
||||
@ -138,6 +140,7 @@ def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext]
|
||||
allow_invalid_certificates,
|
||||
allow_invalid_hostnames,
|
||||
disable_ocsp_endpoint_check,
|
||||
is_sync,
|
||||
)
|
||||
return ctx, allow_invalid_hostnames
|
||||
return None, allow_invalid_hostnames
|
||||
@ -167,7 +170,7 @@ def _parse_pool_options(
|
||||
compression_settings = CompressionSettings(
|
||||
options.get("compressors", []), options.get("zlibcompressionlevel", -1)
|
||||
)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options)
|
||||
ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options, is_sync)
|
||||
load_balanced = options.get("loadbalanced")
|
||||
max_connecting = options.get("maxconnecting", common.MAX_CONNECTING)
|
||||
return PoolOptions(
|
||||
|
||||
@ -20,6 +20,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional
|
||||
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||
|
||||
try:
|
||||
import pymongocrypt # type:ignore[import-untyped] # noqa: F401
|
||||
|
||||
@ -32,9 +34,9 @@ except ImportError:
|
||||
from bson import int64
|
||||
from pymongo.common import validate_is_mapping
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
||||
|
||||
|
||||
@ -236,10 +238,22 @@ class AutoEncryptionOpts:
|
||||
if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args):
|
||||
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
|
||||
# Maps KMS provider name to a SSLContext.
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options)
|
||||
self._kms_tls_options = kms_tls_options
|
||||
self._sync_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None
|
||||
self._async_kms_ssl_contexts: Optional[dict[str, SSLContext]] = None
|
||||
self._bypass_query_analysis = bypass_query_analysis
|
||||
self._key_expiration_ms = key_expiration_ms
|
||||
|
||||
def _kms_ssl_contexts(self, is_sync: bool) -> dict[str, SSLContext]:
|
||||
if is_sync:
|
||||
if self._sync_kms_ssl_contexts is None:
|
||||
self._sync_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, True)
|
||||
return self._sync_kms_ssl_contexts
|
||||
else:
|
||||
if self._async_kms_ssl_contexts is None:
|
||||
self._async_kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, False)
|
||||
return self._async_kms_ssl_contexts
|
||||
|
||||
|
||||
class RangeOpts:
|
||||
"""Options to configure encrypted queries using the range algorithm."""
|
||||
|
||||
@ -46,22 +46,18 @@ except ImportError:
|
||||
_HAVE_SSL = False
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import (
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
_sslConn,
|
||||
)
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
|
||||
_HAVE_PYOPENSSL = True
|
||||
except ImportError:
|
||||
_HAVE_PYOPENSSL = False
|
||||
_sslConn = SSLSocket # type: ignore
|
||||
from pymongo.ssl_support import ( # type: ignore[assignment]
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
_sslConn = SSLSocket # type: ignore[assignment, misc]
|
||||
|
||||
from pymongo.ssl_support import (
|
||||
BLOCKING_IO_LOOKUP_ERROR,
|
||||
BLOCKING_IO_READ_ERROR,
|
||||
BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
@ -71,7 +67,7 @@ _UNPACK_HEADER = struct.Struct("<iiii").unpack
|
||||
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
|
||||
_POLL_TIMEOUT = 0.5
|
||||
# Errors raised by sockets (and TLS sockets) when in non-blocking mode.
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
BLOCKING_IO_ERRORS = (BlockingIOError, *BLOCKING_IO_LOOKUP_ERROR, *ssl_support.BLOCKING_IO_ERRORS)
|
||||
|
||||
|
||||
# These socket-based I/O methods are for KMS requests and any other network operations that do not use
|
||||
|
||||
@ -38,8 +38,9 @@ from pymongo.errors import ( # type:ignore[attr-defined]
|
||||
)
|
||||
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
|
||||
|
||||
SSLErrors = (PYSSLError, SSLError)
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import _sslConn
|
||||
from pymongo.typings import _Address
|
||||
@ -138,7 +139,7 @@ def _raise_connection_failure(
|
||||
msg += format_timeout_details(timeout_details)
|
||||
if isinstance(error, socket.timeout):
|
||||
raise NetworkTimeout(msg) from error
|
||||
elif isinstance(error, SSLError) and "timed out" in str(error):
|
||||
elif isinstance(error, SSLErrors) and "timed out" in str(error):
|
||||
# Eventlet does not distinguish TLS network timeouts from other
|
||||
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
|
||||
# Luckily, we can work around this limitation because the phrase
|
||||
@ -279,7 +280,7 @@ async def _async_configured_socket(
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if _has_sni(False):
|
||||
loop = asyncio.get_running_loop()
|
||||
ssl_sock = await loop.run_in_executor(
|
||||
None,
|
||||
@ -293,7 +294,7 @@ async def _async_configured_socket(
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
@ -346,12 +347,10 @@ async def _configured_protocol_interface(
|
||||
ssl=ssl_context,
|
||||
)
|
||||
except _CertificateError:
|
||||
transport.abort()
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
transport.abort()
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
||||
@ -460,7 +459,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if _has_sni(True):
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
|
||||
@ -469,7 +468,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
@ -509,7 +508,7 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
|
||||
try:
|
||||
# We have to pass hostname / ip address to wrap_socket
|
||||
# to use SSLContext.check_hostname.
|
||||
if HAS_SNI:
|
||||
if _has_sni(True):
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
||||
else:
|
||||
ssl_sock = ssl_context.wrap_socket(sock)
|
||||
@ -518,7 +517,7 @@ def _configured_socket_interface(address: _Address, options: PoolOptions) -> Net
|
||||
# Raise _CertificateError directly like we do after match_hostname
|
||||
# below.
|
||||
raise
|
||||
except (OSError, SSLError) as exc:
|
||||
except (OSError, *SSLErrors) as exc:
|
||||
sock.close()
|
||||
# We raise AutoReconnect for transient and permanent SSL handshake
|
||||
# failures alike. Permanent handshake failures, like protocol
|
||||
|
||||
@ -15,16 +15,19 @@
|
||||
"""Support for SSL in PyMongo."""
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
HAVE_SSL = True
|
||||
HAVE_PYSSL = True
|
||||
|
||||
try:
|
||||
import pymongo.pyopenssl_context as _ssl
|
||||
import pymongo.pyopenssl_context as _pyssl
|
||||
except (ImportError, AttributeError) as exc:
|
||||
HAVE_PYSSL = False
|
||||
if isinstance(exc, AttributeError):
|
||||
warnings.warn(
|
||||
"Failed to use the installed version of PyOpenSSL. "
|
||||
@ -35,10 +38,10 @@ except (ImportError, AttributeError) as exc:
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
import pymongo.ssl_context as _ssl # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
HAVE_SSL = False
|
||||
try:
|
||||
import pymongo.ssl_context as _ssl
|
||||
except ImportError:
|
||||
HAVE_SSL = False
|
||||
|
||||
|
||||
if HAVE_SSL:
|
||||
@ -49,14 +52,29 @@ if HAVE_SSL:
|
||||
import ssl as _stdlibssl # noqa: F401
|
||||
from ssl import CERT_NONE, CERT_REQUIRED
|
||||
|
||||
HAS_SNI = _ssl.HAS_SNI
|
||||
IPADDR_SAFE = True
|
||||
|
||||
if HAVE_PYSSL:
|
||||
PYSSLError: Any = _pyssl.SSLError
|
||||
BLOCKING_IO_ERRORS: tuple = _ssl.BLOCKING_IO_ERRORS + _pyssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR: tuple = (_pyssl.BLOCKING_IO_READ_ERROR, _ssl.BLOCKING_IO_READ_ERROR)
|
||||
BLOCKING_IO_WRITE_ERROR: tuple = (
|
||||
_pyssl.BLOCKING_IO_WRITE_ERROR,
|
||||
_ssl.BLOCKING_IO_WRITE_ERROR,
|
||||
)
|
||||
else:
|
||||
PYSSLError = _ssl.SSLError
|
||||
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR = (_ssl.BLOCKING_IO_READ_ERROR,)
|
||||
BLOCKING_IO_WRITE_ERROR = (_ssl.BLOCKING_IO_WRITE_ERROR,)
|
||||
SSLError = _ssl.SSLError
|
||||
BLOCKING_IO_ERRORS = _ssl.BLOCKING_IO_ERRORS
|
||||
BLOCKING_IO_READ_ERROR = _ssl.BLOCKING_IO_READ_ERROR
|
||||
BLOCKING_IO_WRITE_ERROR = _ssl.BLOCKING_IO_WRITE_ERROR
|
||||
BLOCKING_IO_LOOKUP_ERROR = BLOCKING_IO_READ_ERROR
|
||||
|
||||
def _has_sni(is_sync: bool) -> bool:
|
||||
if is_sync and HAVE_PYSSL:
|
||||
return _pyssl.HAS_SNI
|
||||
return _ssl.HAS_SNI
|
||||
|
||||
def get_ssl_context(
|
||||
certfile: Optional[str],
|
||||
passphrase: Optional[str],
|
||||
@ -65,10 +83,15 @@ if HAVE_SSL:
|
||||
allow_invalid_certificates: bool,
|
||||
allow_invalid_hostnames: bool,
|
||||
disable_ocsp_endpoint_check: bool,
|
||||
) -> _ssl.SSLContext:
|
||||
is_sync: bool,
|
||||
) -> Union[_pyssl.SSLContext, _ssl.SSLContext]: # type: ignore[name-defined]
|
||||
"""Create and return an SSLContext object."""
|
||||
if is_sync and HAVE_PYSSL:
|
||||
ssl: types.ModuleType = _pyssl
|
||||
else:
|
||||
ssl = _ssl
|
||||
verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED
|
||||
ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
if verify_mode != CERT_NONE:
|
||||
ctx.check_hostname = not allow_invalid_hostnames
|
||||
else:
|
||||
@ -80,22 +103,20 @@ if HAVE_SSL:
|
||||
# up to date versions of MongoDB 2.4 and above already disable
|
||||
# SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7
|
||||
# and >= 3.3.4 and SSLv3 in >= 3.4.3.
|
||||
ctx.options |= _ssl.OP_NO_SSLv2
|
||||
ctx.options |= _ssl.OP_NO_SSLv3
|
||||
ctx.options |= _ssl.OP_NO_COMPRESSION
|
||||
ctx.options |= _ssl.OP_NO_RENEGOTIATION
|
||||
ctx.options |= ssl.OP_NO_SSLv2
|
||||
ctx.options |= ssl.OP_NO_SSLv3
|
||||
ctx.options |= ssl.OP_NO_COMPRESSION
|
||||
ctx.options |= ssl.OP_NO_RENEGOTIATION
|
||||
if certfile is not None:
|
||||
try:
|
||||
ctx.load_cert_chain(certfile, None, passphrase)
|
||||
except _ssl.SSLError as exc:
|
||||
except ssl.SSLError as exc:
|
||||
raise ConfigurationError(f"Private key doesn't match certificate: {exc}") from None
|
||||
if crlfile is not None:
|
||||
if _ssl.IS_PYOPENSSL:
|
||||
if ssl.IS_PYOPENSSL:
|
||||
raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL")
|
||||
# Match the server's behavior.
|
||||
ctx.verify_flags = getattr( # type:ignore[attr-defined]
|
||||
_ssl, "VERIFY_CRL_CHECK_LEAF", 0
|
||||
)
|
||||
ctx.verify_flags = getattr(ssl, "VERIFY_CRL_CHECK_LEAF", 0)
|
||||
ctx.load_verify_locations(crlfile)
|
||||
if ca_certs is not None:
|
||||
ctx.load_verify_locations(ca_certs)
|
||||
@ -109,9 +130,11 @@ else:
|
||||
class SSLError(Exception): # type: ignore
|
||||
pass
|
||||
|
||||
HAS_SNI = False
|
||||
IPADDR_SAFE = False
|
||||
BLOCKING_IO_ERRORS = () # type:ignore[assignment]
|
||||
BLOCKING_IO_ERRORS = ()
|
||||
|
||||
def _has_sni(is_sync: bool) -> bool: # noqa: ARG001
|
||||
return False
|
||||
|
||||
def get_ssl_context(*dummy): # type: ignore
|
||||
"""No ssl module, raise ConfigurationError."""
|
||||
|
||||
@ -86,7 +86,7 @@ from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -156,6 +156,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
self.mongocryptd_client = mongocryptd_client
|
||||
self.opts = opts
|
||||
self._spawned = False
|
||||
self._kms_ssl_contexts = opts._kms_ssl_contexts(_IS_SYNC)
|
||||
|
||||
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
|
||||
"""Complete a KMS request.
|
||||
@ -167,7 +168,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
endpoint = kms_context.endpoint
|
||||
message = kms_context.message
|
||||
provider = kms_context.kms_provider
|
||||
ctx = self.opts._kms_ssl_contexts.get(provider)
|
||||
ctx = self._kms_ssl_contexts.get(provider)
|
||||
if ctx is None:
|
||||
# Enable strict certificate verification, OCSP, match hostname, and
|
||||
# SNI using the system default CA certificates.
|
||||
@ -179,6 +180,7 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
False, # allow_invalid_certificates
|
||||
False, # allow_invalid_hostnames
|
||||
False, # disable_ocsp_endpoint_check
|
||||
_IS_SYNC,
|
||||
)
|
||||
# CSOT: set timeout for socket creation.
|
||||
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
|
||||
@ -393,6 +395,8 @@ class _Encrypter:
|
||||
encrypted_fields_map = _dict_to_bson(opts._encrypted_fields_map, False, _DATA_KEY_OPTS)
|
||||
self._bypass_auto_encryption = opts._bypass_auto_encryption
|
||||
self._internal_client = None
|
||||
# parsing kms_ssl_contexts here so that parsing errors will be raised before internal clients are created
|
||||
opts._kms_ssl_contexts(_IS_SYNC)
|
||||
|
||||
def _get_internal_client(
|
||||
encrypter: _Encrypter, mongo_client: MongoClient[_DocumentTypeArg]
|
||||
@ -668,6 +672,7 @@ class ClientEncryption(Generic[_DocumentType]):
|
||||
kms_tls_options=kms_tls_options,
|
||||
key_expiration_ms=key_expiration_ms,
|
||||
)
|
||||
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
|
||||
None, key_vault_coll, None, opts
|
||||
)
|
||||
|
||||
@ -101,6 +101,7 @@ from pymongo.operations import (
|
||||
)
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database, uri_parser
|
||||
@ -777,7 +778,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
self._seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
@ -802,18 +803,18 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
self._seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
dbase = res["database"] or dbase
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
self._seeds.update(split_hosts(entity, self._port))
|
||||
if not self._seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
for hostname in [node[0] for node in self._seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
@ -836,7 +837,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
opts = self._normalize_and_validate_options(opts, self._seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
@ -855,7 +856,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"seeds": self._seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
@ -871,8 +872,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
@ -973,6 +973,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.synchronous.encryption import _Encrypter
|
||||
@ -1203,6 +1204,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 4.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
|
||||
return TopologyDescription(
|
||||
TOPOLOGY_TYPE.Unknown,
|
||||
servers,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self._topology_settings,
|
||||
)
|
||||
return self._topology.description
|
||||
|
||||
@property
|
||||
@ -1216,6 +1227,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
to any servers, or a network partition causes it to lose connection
|
||||
to all servers.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return frozenset()
|
||||
description = self._topology.description
|
||||
return frozenset(s.address for s in description.known_servers)
|
||||
|
||||
@ -1570,6 +1583,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
.. versionadded:: 3.0
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
topology_type = self._topology._description.topology_type
|
||||
if (
|
||||
topology_type == TOPOLOGY_TYPE.Sharded
|
||||
@ -1592,6 +1607,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
MongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_primary() # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
@ -1605,6 +1622,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionadded:: 3.0
|
||||
MongoClient gained this property in version 3.0.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_secondaries()
|
||||
|
||||
@property
|
||||
@ -1615,6 +1634,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
connected to a replica set, there are no arbiters, or this client was
|
||||
created without the `replicaSet` option.
|
||||
"""
|
||||
if self._topology is None:
|
||||
self._get_topology()
|
||||
return self._topology.get_arbiters()
|
||||
|
||||
@property
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import logging
|
||||
@ -72,6 +73,7 @@ from pymongo.monitoring import (
|
||||
from pymongo.network_layer import NetworkingInterface, receive_message, sendall
|
||||
from pymongo.pool_options import PoolOptions
|
||||
from pymongo.pool_shared import (
|
||||
SSLErrors,
|
||||
_CancellationContext,
|
||||
_configured_socket_interface,
|
||||
_get_timeout_details,
|
||||
@ -82,7 +84,6 @@ from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_api import _add_to_command
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import SSLError
|
||||
from pymongo.synchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
from pymongo.synchronous.network import command
|
||||
@ -635,7 +636,7 @@ class Connection:
|
||||
reason = ConnectionClosedReason.ERROR
|
||||
self.close_conn(reason)
|
||||
# SSLError from PyOpenSSL inherits directly from Exception.
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
if isinstance(error, (IOError, OSError, *SSLErrors)):
|
||||
details = _get_timeout_details(self.opts)
|
||||
_raise_connection_failure(self.address, error, timeout_details=details)
|
||||
else:
|
||||
@ -858,8 +859,14 @@ class Pool:
|
||||
# PoolClosedEvent but that reset() SHOULD close sockets *after*
|
||||
# publishing the PoolClearedEvent.
|
||||
if close:
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
|
||||
if self.enabled_for_cmap:
|
||||
assert listeners is not None
|
||||
listeners.publish_pool_closed(self.address)
|
||||
@ -889,8 +896,14 @@ class Pool:
|
||||
serverPort=self.address[1],
|
||||
serviceId=service_id,
|
||||
)
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in sockets:
|
||||
conn.close_conn(ConnectionClosedReason.STALE)
|
||||
|
||||
def update_is_writable(self, is_writable: Optional[bool]) -> None:
|
||||
"""Updates the is_writable attribute on all sockets currently in the
|
||||
@ -934,8 +947,14 @@ class Pool:
|
||||
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
|
||||
):
|
||||
close_conns.append(self.conns.pop())
|
||||
for conn in close_conns:
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
|
||||
return_exceptions=True,
|
||||
)
|
||||
else:
|
||||
for conn in close_conns:
|
||||
conn.close_conn(ConnectionClosedReason.IDLE)
|
||||
|
||||
while True:
|
||||
with self.size_cond:
|
||||
@ -1029,7 +1048,7 @@ class Pool:
|
||||
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
|
||||
error=ConnectionClosedReason.ERROR,
|
||||
)
|
||||
if isinstance(error, (IOError, OSError, SSLError)):
|
||||
if isinstance(error, (IOError, OSError, *SSLErrors)):
|
||||
details = _get_timeout_details(self.opts)
|
||||
_raise_connection_failure(self.address, error, timeout_details=details)
|
||||
|
||||
|
||||
@ -51,6 +51,7 @@ class TopologySettings:
|
||||
srv_service_name: str = common.SRV_SERVICE_NAME,
|
||||
srv_max_hosts: int = 0,
|
||||
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
|
||||
topology_id: Optional[ObjectId] = None,
|
||||
):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
@ -78,8 +79,10 @@ class TopologySettings:
|
||||
self._srv_service_name = srv_service_name
|
||||
self._srv_max_hosts = srv_max_hosts or 0
|
||||
self._server_monitoring_mode = server_monitoring_mode
|
||||
|
||||
self._topology_id = ObjectId()
|
||||
if topology_id is not None:
|
||||
self._topology_id = topology_id
|
||||
else:
|
||||
self._topology_id = ObjectId()
|
||||
# Store the allocation traceback to catch unclosed clients in the
|
||||
# test suite.
|
||||
self._stack = "".join(traceback.format_stack()[:-2])
|
||||
|
||||
@ -96,6 +96,7 @@ class _SrvResolver:
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
self.nparts = len(split_fqdn)
|
||||
|
||||
def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
@ -137,12 +138,13 @@ class _SrvResolver:
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
if self.__fqdn == node[0].lower():
|
||||
srv_host = node[0].lower()
|
||||
if self.__fqdn == srv_host and self.nparts < 3:
|
||||
raise ConfigurationError(
|
||||
"Invalid SRV host: return address is identical to SRV hostname"
|
||||
)
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
nlist = srv_host.split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||
if self.__plist != nlist:
|
||||
|
||||
@ -529,12 +529,6 @@ class Topology:
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
server = self._servers.get(server_description.address)
|
||||
if server:
|
||||
server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
# Wake anything waiting in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
@ -557,6 +551,11 @@ class Topology:
|
||||
# that didn't include this server.
|
||||
if self._opened and self._description.has_server(server_description.address):
|
||||
self._process_change(server_description, reset_pool, interrupt_connections)
|
||||
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
|
||||
if reset_pool:
|
||||
server = self._servers.get(server_description.address)
|
||||
if server:
|
||||
server.pool.reset(interrupt_connections=interrupt_connections)
|
||||
|
||||
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
|
||||
"""Process a new seedlist on an opened topology.
|
||||
|
||||
@ -420,7 +420,10 @@ def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||
|
||||
|
||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||
def _parse_kms_tls_options(
|
||||
kms_tls_options: Optional[Mapping[str, Any]],
|
||||
is_sync: bool,
|
||||
) -> dict[str, SSLContext]:
|
||||
"""Parse KMS TLS connection options."""
|
||||
if not kms_tls_options:
|
||||
return {}
|
||||
@ -435,7 +438,7 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict
|
||||
opts = _handle_security_options(opts)
|
||||
opts = _normalize_options(opts)
|
||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts, is_sync)
|
||||
if ssl_context is None:
|
||||
raise ConfigurationError("TLS is required for KMS providers")
|
||||
if allow_invalid_hostnames:
|
||||
|
||||
@ -823,6 +823,14 @@ class ClientContext:
|
||||
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||
)
|
||||
|
||||
def require_async(self, func):
|
||||
"""Run a test only if using the asynchronous API.""" # unasync: off
|
||||
return self._require(
|
||||
lambda: not _IS_SYNC,
|
||||
"This test only works with the asynchronous API", # unasync: off
|
||||
func=func,
|
||||
)
|
||||
|
||||
def mongos_seeds(self):
|
||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||
|
||||
|
||||
@ -825,6 +825,14 @@ class AsyncClientContext:
|
||||
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
|
||||
)
|
||||
|
||||
def require_async(self, func):
|
||||
"""Run a test only if using the asynchronous API.""" # unasync: off
|
||||
return self._require(
|
||||
lambda: not _IS_SYNC,
|
||||
"This test only works with the asynchronous API", # unasync: off
|
||||
func=func,
|
||||
)
|
||||
|
||||
def mongos_seeds(self):
|
||||
return ",".join("{}:{}".format(*address) for address in self.mongoses)
|
||||
|
||||
|
||||
@ -849,6 +849,58 @@ class TestClient(AsyncIntegrationTest):
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
await c.pymongo_test.test.find_one()
|
||||
|
||||
@async_client_context.require_replica_set
|
||||
@async_client_context.require_no_load_balancer
|
||||
@async_client_context.require_tls
|
||||
async def test_init_disconnected_with_srv(self):
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# nodes returns an empty set if not connected
|
||||
self.assertEqual(c.nodes, frozenset())
|
||||
# topology_description returns the initial seed description if not connected
|
||||
topology_description = c.topology_description
|
||||
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
|
||||
self.assertEqual(
|
||||
{
|
||||
("test1.test.build.10gen.cc", None): ServerDescription(
|
||||
("test1.test.build.10gen.cc", None)
|
||||
)
|
||||
},
|
||||
topology_description.server_descriptions(),
|
||||
)
|
||||
|
||||
# address causes client to block until connected
|
||||
self.assertIsNotNone(await c.address)
|
||||
# Initial seed topology and connected topology have the same ID
|
||||
self.assertEqual(
|
||||
c._topology._topology_id, topology_description._topology_settings._topology_id
|
||||
)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# primary causes client to block until connected
|
||||
await c.primary
|
||||
self.assertIsNotNone(c._topology)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# secondaries causes client to block until connected
|
||||
await c.secondaries
|
||||
self.assertIsNotNone(c._topology)
|
||||
await c.close()
|
||||
|
||||
c = await self.async_rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# arbiters causes client to block until connected
|
||||
await c.arbiters
|
||||
self.assertIsNotNone(c._topology)
|
||||
|
||||
async def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = await self.async_rs_or_single_client(seed, connect=False)
|
||||
|
||||
@ -20,10 +20,15 @@ import os
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from test.asynchronous.helpers import ConcurrentRunner
|
||||
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import (
|
||||
@ -370,6 +375,74 @@ class TestPoolManagement(AsyncIntegrationTest):
|
||||
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
|
||||
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
|
||||
|
||||
@async_client_context.require_failCommand_appName
|
||||
@async_client_context.require_test_commands
|
||||
@async_client_context.require_async
|
||||
async def test_connection_close_does_not_block_other_operations(self):
|
||||
listener = CMAPHeartbeatListener()
|
||||
client = await self.async_single_client(
|
||||
appName="SDAMConnectionCloseTest",
|
||||
event_listeners=[listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
minPoolSize=10,
|
||||
)
|
||||
server = await (await client._get_topology()).select_server(
|
||||
writable_server_selector, _Op.TEST
|
||||
)
|
||||
await async_wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"pool initialized with 10 connections",
|
||||
)
|
||||
|
||||
await client.db.test.insert_one({"x": 1})
|
||||
close_delay = 0.1
|
||||
latencies = []
|
||||
should_exit = []
|
||||
|
||||
async def run_task():
|
||||
while True:
|
||||
start_time = time.monotonic()
|
||||
await client.db.test.find_one({})
|
||||
elapsed = time.monotonic() - start_time
|
||||
latencies.append(elapsed)
|
||||
if should_exit:
|
||||
break
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
task = ConcurrentRunner(target=run_task)
|
||||
await task.start()
|
||||
original_close = AsyncConnection.close_conn
|
||||
try:
|
||||
# Artificially delay the close operation to simulate a slow close
|
||||
async def mock_close(self, reason):
|
||||
await asyncio.sleep(close_delay)
|
||||
await original_close(self, reason)
|
||||
|
||||
AsyncConnection.close_conn = mock_close
|
||||
|
||||
fail_hello = {
|
||||
"mode": {"times": 4},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"errorCode": 91,
|
||||
"appName": "SDAMConnectionCloseTest",
|
||||
},
|
||||
}
|
||||
async with self.fail_point(fail_hello):
|
||||
# Wait for server heartbeat to fail
|
||||
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
|
||||
# Wait until all idle connections are closed to simulate real-world conditions
|
||||
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
|
||||
# Wait for one more find to complete after the pool has been reset, then shutdown the task
|
||||
n = len(latencies)
|
||||
await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find")
|
||||
should_exit.append(True)
|
||||
await task.join()
|
||||
# No operation latency should not significantly exceed close_delay
|
||||
self.assertLessEqual(max(latencies), close_delay * 5.0)
|
||||
finally:
|
||||
AsyncConnection.close_conn = original_close
|
||||
|
||||
|
||||
class TestServerMonitoringMode(AsyncIntegrationTest):
|
||||
@async_client_context.require_no_serverless
|
||||
|
||||
@ -220,12 +220,15 @@ class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
|
||||
mock_resolver.side_effect = mock_resolve
|
||||
domain = case["query"].split("._tcp.")[1]
|
||||
connection_string = f"mongodb+srv://{domain}"
|
||||
try:
|
||||
if "expected_error" not in case:
|
||||
await parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
try:
|
||||
await parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
|
||||
async def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
|
||||
with patch("dns.asyncresolver.resolve"):
|
||||
@ -289,6 +292,17 @@ class TestInitialDnsSeedlistDiscovery(AsyncPyMongoTestCase):
|
||||
]
|
||||
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
async def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical(
|
||||
self
|
||||
):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "blogs.mongodb.com",
|
||||
},
|
||||
]
|
||||
await self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -41,6 +41,7 @@ import pytest
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import IS_PYOPENSSL
|
||||
@ -141,7 +142,7 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
|
||||
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
|
||||
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
|
||||
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
|
||||
self.assertEqual(opts._kms_ssl_contexts, {})
|
||||
self.assertEqual(opts._kms_tls_options, None)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def test_init_spawn_args(self):
|
||||
@ -165,30 +166,38 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
|
||||
)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def test_init_kms_tls_options(self):
|
||||
async def test_init_kms_tls_options(self):
|
||||
# Error cases:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
|
||||
with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
|
||||
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
|
||||
AsyncMongoClient(auto_encryption_opts=opts)
|
||||
|
||||
tls_opts: Any
|
||||
for tls_opts in [
|
||||
{"kmip": {"tls": True, "tlsInsecure": True}},
|
||||
{"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}},
|
||||
{"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}},
|
||||
]:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"):
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
AsyncMongoClient(auto_encryption_opts=opts)
|
||||
opts = AutoEncryptionOpts(
|
||||
{}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}
|
||||
)
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}})
|
||||
AsyncMongoClient(auto_encryption_opts=opts)
|
||||
# Success cases:
|
||||
tls_opts: Any
|
||||
for tls_opts in [None, {}]:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
self.assertEqual(opts._kms_ssl_contexts, {})
|
||||
kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
self.assertEqual(kms_tls_contexts, {})
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
|
||||
ctx = opts._kms_ssl_contexts["kmip"]
|
||||
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
ctx = _kms_ssl_contexts["kmip"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
ctx = opts._kms_ssl_contexts["aws"]
|
||||
ctx = _kms_ssl_contexts["aws"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -196,7 +205,8 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase):
|
||||
"k.d",
|
||||
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
|
||||
)
|
||||
ctx = opts._kms_ssl_contexts["kmip"]
|
||||
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
ctx = _kms_ssl_contexts["kmip"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
|
||||
@ -2225,7 +2235,7 @@ class TestKmsTLSOptions(AsyncEncryptionIntegrationTest):
|
||||
encryption = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
|
||||
)
|
||||
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
|
||||
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
|
||||
if not hasattr(ctx, "check_ocsp_endpoint"):
|
||||
raise self.skipTest("OCSP not enabled")
|
||||
self.assertFalse(ctx.check_ocsp_endpoint)
|
||||
|
||||
@ -43,7 +43,7 @@ from urllib.parse import quote_plus
|
||||
from pymongo import AsyncMongoClient, ssl_support
|
||||
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
|
||||
from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_HAVE_PYOPENSSL = False
|
||||
@ -134,7 +134,7 @@ class TestClientSSL(AsyncPyMongoTestCase):
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
def test_use_pyopenssl_when_available(self):
|
||||
self.assertTrue(_ssl.IS_PYOPENSSL)
|
||||
self.assertTrue(HAVE_PYSSL)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
|
||||
def test_load_trusted_ca_certs(self):
|
||||
@ -177,7 +177,7 @@ class TestSSL(AsyncIntegrationTest):
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
|
||||
if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
@ -309,13 +309,13 @@ class TestSSL(AsyncIntegrationTest):
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, True, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
self.assertTrue(ctx.check_hostname)
|
||||
|
||||
response = await self.client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
@ -376,9 +376,11 @@ class TestSSL(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
@async_client_context.require_tlsCertificateKeyFile
|
||||
@async_client_context.require_sync
|
||||
@async_client_context.require_no_api_version
|
||||
@ignore_deprecations
|
||||
async def test_tlsCRLFile_support(self):
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
@ -469,7 +471,7 @@ class TestSSL(AsyncIntegrationTest):
|
||||
)
|
||||
|
||||
def test_system_certs_config_error(self):
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
|
||||
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
|
||||
ctx, "load_default_certs"
|
||||
):
|
||||
@ -500,11 +502,11 @@ class TestSSL(AsyncIntegrationTest):
|
||||
# Force the test on Windows, regardless of environment.
|
||||
ssl_support.HAVE_WINCERTSTORE = False
|
||||
try:
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
|
||||
finally:
|
||||
@ -521,11 +523,11 @@ class TestSSL(AsyncIntegrationTest):
|
||||
if not ssl_support.HAVE_WINCERTSTORE:
|
||||
raise SkipTest("Need wincertstore to test wincertstore.")
|
||||
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
|
||||
|
||||
@ -657,6 +659,16 @@ class TestSSL(AsyncIntegrationTest):
|
||||
) as client:
|
||||
self.assertTrue(await client.admin.command("ping"))
|
||||
|
||||
@async_client_context.require_async
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
async def test_pyopenssl_ignored_in_async(self):
|
||||
client = AsyncMongoClient(
|
||||
"mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true"
|
||||
)
|
||||
await client.admin.command("ping") # command doesn't matter, just needs it to connect
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -26,7 +26,7 @@ import pytest
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import pymongo
|
||||
from pymongo.ssl_support import HAS_SNI
|
||||
from pymongo.ssl_support import _has_sni
|
||||
|
||||
pytestmark = pytest.mark.atlas_connect
|
||||
|
||||
@ -57,7 +57,7 @@ class TestAtlasConnect(PyMongoTestCase):
|
||||
# No auth error
|
||||
client.test.test.count_documents({})
|
||||
|
||||
@unittest.skipUnless(HAS_SNI, "Free tier requires SNI support")
|
||||
@unittest.skipUnless(_has_sni(True), "Free tier requires SNI support")
|
||||
def test_free_tier(self):
|
||||
self.connect(URIS["ATLAS_FREE"])
|
||||
|
||||
@ -80,7 +80,7 @@ class TestAtlasConnect(PyMongoTestCase):
|
||||
self.connect(uri)
|
||||
self.assertIn("mongodb+srv://", uri)
|
||||
|
||||
@unittest.skipUnless(HAS_SNI, "Free tier requires SNI support")
|
||||
@unittest.skipUnless(_has_sni(True), "Free tier requires SNI support")
|
||||
def test_srv_free_tier(self):
|
||||
self.connect_srv(URIS["ATLAS_SRV_FREE"])
|
||||
|
||||
|
||||
@ -824,6 +824,58 @@ class TestClient(IntegrationTest):
|
||||
with self.assertRaises(ConnectionFailure):
|
||||
c.pymongo_test.test.find_one()
|
||||
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_no_load_balancer
|
||||
@client_context.require_tls
|
||||
def test_init_disconnected_with_srv(self):
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# nodes returns an empty set if not connected
|
||||
self.assertEqual(c.nodes, frozenset())
|
||||
# topology_description returns the initial seed description if not connected
|
||||
topology_description = c.topology_description
|
||||
self.assertEqual(topology_description.topology_type, TOPOLOGY_TYPE.Unknown)
|
||||
self.assertEqual(
|
||||
{
|
||||
("test1.test.build.10gen.cc", None): ServerDescription(
|
||||
("test1.test.build.10gen.cc", None)
|
||||
)
|
||||
},
|
||||
topology_description.server_descriptions(),
|
||||
)
|
||||
|
||||
# address causes client to block until connected
|
||||
self.assertIsNotNone(c.address)
|
||||
# Initial seed topology and connected topology have the same ID
|
||||
self.assertEqual(
|
||||
c._topology._topology_id, topology_description._topology_settings._topology_id
|
||||
)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# primary causes client to block until connected
|
||||
c.primary
|
||||
self.assertIsNotNone(c._topology)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# secondaries causes client to block until connected
|
||||
c.secondaries
|
||||
self.assertIsNotNone(c._topology)
|
||||
c.close()
|
||||
|
||||
c = self.rs_or_single_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc", connect=False, tlsInsecure=True
|
||||
)
|
||||
# arbiters causes client to block until connected
|
||||
c.arbiters
|
||||
self.assertIsNotNone(c._topology)
|
||||
|
||||
def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = self.rs_or_single_client(seed, connect=False)
|
||||
|
||||
@ -209,6 +209,19 @@ class TestDefaultExports(unittest.TestCase):
|
||||
)
|
||||
from pymongo.write_concern import WriteConcern, validate_boolean
|
||||
|
||||
def test_pymongo_submodule_attributes(self):
|
||||
import pymongo
|
||||
|
||||
self.assertTrue(hasattr(pymongo, "uri_parser"))
|
||||
self.assertTrue(pymongo.uri_parser)
|
||||
self.assertTrue(pymongo.uri_parser.parse_uri)
|
||||
self.assertTrue(pymongo.change_stream)
|
||||
self.assertTrue(pymongo.client_session)
|
||||
self.assertTrue(pymongo.collection)
|
||||
self.assertTrue(pymongo.cursor)
|
||||
self.assertTrue(pymongo.command_cursor)
|
||||
self.assertTrue(pymongo.database)
|
||||
|
||||
def test_gridfs_imports(self):
|
||||
import gridfs
|
||||
from gridfs.errors import CorruptGridFile, FileExists, GridFSError, NoFile
|
||||
|
||||
@ -20,10 +20,15 @@ import os
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from pathlib import Path
|
||||
from test.helpers import ConcurrentRunner
|
||||
|
||||
from pymongo.operations import _Op
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.synchronous.pool import Connection
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import (
|
||||
@ -370,6 +375,72 @@ class TestPoolManagement(IntegrationTest):
|
||||
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
|
||||
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
|
||||
|
||||
@client_context.require_failCommand_appName
|
||||
@client_context.require_test_commands
|
||||
@client_context.require_async
|
||||
def test_connection_close_does_not_block_other_operations(self):
|
||||
listener = CMAPHeartbeatListener()
|
||||
client = self.single_client(
|
||||
appName="SDAMConnectionCloseTest",
|
||||
event_listeners=[listener],
|
||||
heartbeatFrequencyMS=500,
|
||||
minPoolSize=10,
|
||||
)
|
||||
server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST)
|
||||
wait_until(
|
||||
lambda: len(server._pool.conns) == 10,
|
||||
"pool initialized with 10 connections",
|
||||
)
|
||||
|
||||
client.db.test.insert_one({"x": 1})
|
||||
close_delay = 0.1
|
||||
latencies = []
|
||||
should_exit = []
|
||||
|
||||
def run_task():
|
||||
while True:
|
||||
start_time = time.monotonic()
|
||||
client.db.test.find_one({})
|
||||
elapsed = time.monotonic() - start_time
|
||||
latencies.append(elapsed)
|
||||
if should_exit:
|
||||
break
|
||||
time.sleep(0.001)
|
||||
|
||||
task = ConcurrentRunner(target=run_task)
|
||||
task.start()
|
||||
original_close = Connection.close_conn
|
||||
try:
|
||||
# Artificially delay the close operation to simulate a slow close
|
||||
def mock_close(self, reason):
|
||||
time.sleep(close_delay)
|
||||
original_close(self, reason)
|
||||
|
||||
Connection.close_conn = mock_close
|
||||
|
||||
fail_hello = {
|
||||
"mode": {"times": 4},
|
||||
"data": {
|
||||
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
|
||||
"errorCode": 91,
|
||||
"appName": "SDAMConnectionCloseTest",
|
||||
},
|
||||
}
|
||||
with self.fail_point(fail_hello):
|
||||
# Wait for server heartbeat to fail
|
||||
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
|
||||
# Wait until all idle connections are closed to simulate real-world conditions
|
||||
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
|
||||
# Wait for one more find to complete after the pool has been reset, then shutdown the task
|
||||
n = len(latencies)
|
||||
wait_until(lambda: len(latencies) >= n + 1, "run one more find")
|
||||
should_exit.append(True)
|
||||
task.join()
|
||||
# No operation latency should not significantly exceed close_delay
|
||||
self.assertLessEqual(max(latencies), close_delay * 5.0)
|
||||
finally:
|
||||
Connection.close_conn = original_close
|
||||
|
||||
|
||||
class TestServerMonitoringMode(IntegrationTest):
|
||||
@client_context.require_no_serverless
|
||||
|
||||
@ -218,12 +218,15 @@ class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
|
||||
mock_resolver.side_effect = mock_resolve
|
||||
domain = case["query"].split("._tcp.")[1]
|
||||
connection_string = f"mongodb+srv://{domain}"
|
||||
try:
|
||||
if "expected_error" not in case:
|
||||
parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
try:
|
||||
parse_uri(connection_string)
|
||||
except ConfigurationError as e:
|
||||
self.assertIn(case["expected_error"], str(e))
|
||||
else:
|
||||
self.fail(f"ConfigurationError was not raised for query: {case['query']}")
|
||||
|
||||
def test_1_allow_srv_hosts_with_fewer_than_three_dot_separated_parts(self):
|
||||
with patch("dns.resolver.resolve"):
|
||||
@ -287,6 +290,17 @@ class TestInitialDnsSeedlistDiscovery(PyMongoTestCase):
|
||||
]
|
||||
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
def test_5_when_srv_hostname_has_two_dot_separated_parts_it_is_valid_for_the_returned_hostname_to_be_identical(
|
||||
self
|
||||
):
|
||||
test_cases = [
|
||||
{
|
||||
"query": "_mongodb._tcp.blogs.mongodb.com",
|
||||
"mock_target": "blogs.mongodb.com",
|
||||
},
|
||||
]
|
||||
self.run_initial_dns_seedlist_discovery_prose_tests(test_cases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -41,6 +41,7 @@ import pytest
|
||||
from pymongo.daemon import _spawn_daemon
|
||||
from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||
|
||||
try:
|
||||
from pymongo.pyopenssl_context import IS_PYOPENSSL
|
||||
@ -141,7 +142,7 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
|
||||
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
|
||||
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
|
||||
self.assertEqual(opts._kms_ssl_contexts, {})
|
||||
self.assertEqual(opts._kms_tls_options, None)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def test_init_spawn_args(self):
|
||||
@ -167,28 +168,36 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def test_init_kms_tls_options(self):
|
||||
# Error cases:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
|
||||
with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'):
|
||||
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1})
|
||||
MongoClient(auto_encryption_opts=opts)
|
||||
|
||||
tls_opts: Any
|
||||
for tls_opts in [
|
||||
{"kmip": {"tls": True, "tlsInsecure": True}},
|
||||
{"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}},
|
||||
{"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}},
|
||||
]:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"):
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
MongoClient(auto_encryption_opts=opts)
|
||||
opts = AutoEncryptionOpts(
|
||||
{}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}
|
||||
)
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}})
|
||||
MongoClient(auto_encryption_opts=opts)
|
||||
# Success cases:
|
||||
tls_opts: Any
|
||||
for tls_opts in [None, {}]:
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
|
||||
self.assertEqual(opts._kms_ssl_contexts, {})
|
||||
kms_tls_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
self.assertEqual(kms_tls_contexts, {})
|
||||
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
|
||||
ctx = opts._kms_ssl_contexts["kmip"]
|
||||
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
ctx = _kms_ssl_contexts["kmip"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
ctx = opts._kms_ssl_contexts["aws"]
|
||||
ctx = _kms_ssl_contexts["aws"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
opts = AutoEncryptionOpts(
|
||||
@ -196,7 +205,8 @@ class TestAutoEncryptionOpts(PyMongoTestCase):
|
||||
"k.d",
|
||||
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
|
||||
)
|
||||
ctx = opts._kms_ssl_contexts["kmip"]
|
||||
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
|
||||
ctx = _kms_ssl_contexts["kmip"]
|
||||
self.assertEqual(ctx.check_hostname, True)
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
|
||||
@ -2217,7 +2227,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest):
|
||||
encryption = self.create_client_encryption(
|
||||
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
|
||||
)
|
||||
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
|
||||
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
|
||||
if not hasattr(ctx, "check_ocsp_endpoint"):
|
||||
raise self.skipTest("OCSP not enabled")
|
||||
self.assertFalse(ctx.check_ocsp_endpoint)
|
||||
|
||||
@ -43,7 +43,7 @@ from urllib.parse import quote_plus
|
||||
from pymongo import MongoClient, ssl_support
|
||||
from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context
|
||||
from pymongo.ssl_support import HAVE_PYSSL, HAVE_SSL, _ssl, get_ssl_context
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_HAVE_PYOPENSSL = False
|
||||
@ -134,7 +134,7 @@ class TestClientSSL(PyMongoTestCase):
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
def test_use_pyopenssl_when_available(self):
|
||||
self.assertTrue(_ssl.IS_PYOPENSSL)
|
||||
self.assertTrue(HAVE_PYSSL)
|
||||
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL")
|
||||
def test_load_trusted_ca_certs(self):
|
||||
@ -177,7 +177,7 @@ class TestSSL(IntegrationTest):
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL:
|
||||
if not hasattr(ssl, "SSLContext") and not HAVE_PYSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
@ -309,13 +309,13 @@ class TestSSL(IntegrationTest):
|
||||
#
|
||||
# --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem
|
||||
# --sslCAFile=/path/to/pymongo/test/certificates/ca.pem
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, True, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, False, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, True, False, _IS_SYNC)
|
||||
self.assertFalse(ctx.check_hostname)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
self.assertTrue(ctx.check_hostname)
|
||||
|
||||
response = self.client.admin.command(HelloCompat.LEGACY_CMD)
|
||||
@ -376,9 +376,11 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
|
||||
@client_context.require_tlsCertificateKeyFile
|
||||
@client_context.require_sync
|
||||
@client_context.require_no_api_version
|
||||
@ignore_deprecations
|
||||
def test_tlsCRLFile_support(self):
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL:
|
||||
if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or HAVE_PYSSL:
|
||||
self.assertRaises(
|
||||
ConfigurationError,
|
||||
self.simple_client,
|
||||
@ -469,7 +471,7 @@ class TestSSL(IntegrationTest):
|
||||
)
|
||||
|
||||
def test_system_certs_config_error(self):
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False)
|
||||
ctx = get_ssl_context(None, None, None, None, True, True, False, _IS_SYNC)
|
||||
if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(
|
||||
ctx, "load_default_certs"
|
||||
):
|
||||
@ -500,11 +502,11 @@ class TestSSL(IntegrationTest):
|
||||
# Force the test on Windows, regardless of environment.
|
||||
ssl_support.HAVE_WINCERTSTORE = False
|
||||
try:
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where())
|
||||
finally:
|
||||
@ -521,11 +523,11 @@ class TestSSL(IntegrationTest):
|
||||
if not ssl_support.HAVE_WINCERTSTORE:
|
||||
raise SkipTest("Need wincertstore to test wincertstore.")
|
||||
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, CA_PEM)
|
||||
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False)
|
||||
ctx = get_ssl_context(None, None, None, None, False, False, False, _IS_SYNC)
|
||||
ssl_sock = ctx.wrap_socket(socket.socket())
|
||||
self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name)
|
||||
|
||||
@ -657,6 +659,14 @@ class TestSSL(IntegrationTest):
|
||||
) as client:
|
||||
self.assertTrue(client.admin.command("ping"))
|
||||
|
||||
@client_context.require_async
|
||||
@unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.")
|
||||
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
|
||||
def test_pyopenssl_ignored_in_async(self):
|
||||
client = MongoClient("mongodb://localhost:27017?tls=true&tlsAllowInvalidCertificates=true")
|
||||
client.admin.command("ping") # command doesn't matter, just needs it to connect
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -35,6 +35,7 @@ def check_ocsp(host: str, port: int, capath: str) -> None:
|
||||
False, # allow_invalid_certificates
|
||||
False, # allow_invalid_hostnames
|
||||
False,
|
||||
True, # is sync
|
||||
) # disable_ocsp_endpoint_check
|
||||
|
||||
# Ensure we're using pyOpenSSL.
|
||||
|
||||
@ -288,7 +288,8 @@ def process_files(
|
||||
if file in docstring_translate_files:
|
||||
lines = translate_docstrings(lines)
|
||||
if file in sync_test_files:
|
||||
translate_imports(lines)
|
||||
lines = translate_imports(lines)
|
||||
lines = process_ignores(lines)
|
||||
f.seek(0)
|
||||
f.writelines(lines)
|
||||
f.truncate()
|
||||
@ -390,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]:
|
||||
return [line for line in lines if line != "DOCSTRING_REMOVED"]
|
||||
|
||||
|
||||
def process_ignores(lines: list[str]) -> list[str]:
|
||||
for i in range(len(lines)):
|
||||
for k, v in replacements.items():
|
||||
if "unasync: off" in lines[i] and v in lines[i]:
|
||||
lines[i] = lines[i].replace(v, k)
|
||||
return lines
|
||||
|
||||
|
||||
def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None:
|
||||
unasync_files(
|
||||
files,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user