Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2025-03-25 20:14:27 -05:00
commit 47f5804d90
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
36 changed files with 1741 additions and 914 deletions

View File

@ -1257,12 +1257,12 @@ buildvariants:
# Stable api tests
- name: stable-api-require-v1-rhel8-python3.9-auth
tasks:
- name: .standalone .5.0 .noauth .nossl .sync_async
- name: .standalone .6.0 .noauth .nossl .sync_async
- name: .standalone .7.0 .noauth .nossl .sync_async
- name: .standalone .8.0 .noauth .nossl .sync_async
- name: .standalone .rapid .noauth .nossl .sync_async
- name: .standalone .latest .noauth .nossl .sync_async
- name: "!.replica_set .5.0 .noauth .nossl .sync_async"
- name: "!.replica_set .6.0 .noauth .nossl .sync_async"
- name: "!.replica_set .7.0 .noauth .nossl .sync_async"
- name: "!.replica_set .8.0 .noauth .nossl .sync_async"
- name: "!.replica_set .rapid .noauth .nossl .sync_async"
- name: "!.replica_set .latest .noauth .nossl .sync_async"
display_name: Stable API require v1 RHEL8 Python3.9 Auth
run_on:
- rhel87-small
@ -1290,12 +1290,12 @@ buildvariants:
tags: [versionedApi_tag]
- name: stable-api-require-v1-rhel8-python3.13-auth
tasks:
- name: .standalone .5.0 .noauth .nossl .sync_async
- name: .standalone .6.0 .noauth .nossl .sync_async
- name: .standalone .7.0 .noauth .nossl .sync_async
- name: .standalone .8.0 .noauth .nossl .sync_async
- name: .standalone .rapid .noauth .nossl .sync_async
- name: .standalone .latest .noauth .nossl .sync_async
- name: "!.replica_set .5.0 .noauth .nossl .sync_async"
- name: "!.replica_set .6.0 .noauth .nossl .sync_async"
- name: "!.replica_set .7.0 .noauth .nossl .sync_async"
- name: "!.replica_set .8.0 .noauth .nossl .sync_async"
- name: "!.replica_set .rapid .noauth .nossl .sync_async"
- name: "!.replica_set .latest .noauth .nossl .sync_async"
display_name: Stable API require v1 RHEL8 Python3.13 Auth
run_on:
- rhel87-small

View File

@ -546,7 +546,6 @@ def create_storage_engine_variants():
def create_stable_api_variants():
host = DEFAULT_HOST
tags = ["versionedApi_tag"]
tasks = [f".standalone .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0")]
variants = []
types = ["require v1", "accept v2"]
@ -560,11 +559,17 @@ def create_stable_api_variants():
expansions["REQUIRE_API_VERSION"] = "1"
# MONGODB_API_VERSION is the apiVersion to use in the test suite.
expansions["MONGODB_API_VERSION"] = "1"
tasks = [
f"!.replica_set .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0")
]
else:
# Test against a cluster with acceptApiVersion2 but without
# requireApiVersion, and don't automatically add apiVersion to
# clients created in the test suite.
expansions["ORCHESTRATION_FILE"] = "versioned-api-testing.json"
tasks = [
f".standalone .{v} .noauth .nossl .sync_async" for v in get_versions_from("5.0")
]
base_display_name = f"Stable API {test_type}"
display_name = get_display_name(base_display_name, host, python=python, **expansions)
variant = create_variant(

View File

@ -9,9 +9,16 @@ PyMongo 4.12 brings a number of changes including:
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
- Added index hinting support to the
:meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and
:meth:`~pymongo.collection.Collection.distinct` commands.
- Deprecated the ``hedge`` parameter for
:class:`~pymongo.read_preferences.PrimaryPreferred`,
:class:`~pymongo.read_preferences.Secondary`,
:class:`~pymongo.read_preferences.SecondaryPreferred`,
:class:`~pymongo.read_preferences.Nearest`. Support for ``hedge`` will be removed in PyMongo 5.0.
Issues Resolved
...............

View File

@ -87,7 +87,7 @@ from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:

View File

@ -44,6 +44,7 @@ from typing import (
AsyncContextManager,
AsyncGenerator,
Callable,
Collection,
Coroutine,
FrozenSet,
Generic,
@ -60,8 +61,8 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo.asynchronous import client_session, database
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
@ -113,11 +114,14 @@ from pymongo.typings import (
_DocumentTypeArg,
_Pipeline,
)
from pymongo.uri_parser import (
from pymongo.uri_parser_shared import (
SRV_SCHEME,
_check_options,
_handle_option_deprecations,
_handle_security_options,
_normalize_options,
_validate_uri,
split_hosts,
)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
@ -128,6 +132,7 @@ if TYPE_CHECKING:
from pymongo.asynchronous.bulk import _AsyncBulk
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
from pymongo.asynchronous.cursor import _ConnectionManager
from pymongo.asynchronous.encryption import _Encrypter
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.asynchronous.server import Server
from pymongo.read_concern import ReadConcern
@ -750,6 +755,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
port = self.PORT
if not isinstance(port, int):
raise TypeError(f"port must be an instance of int, not {type(port)}")
self._host = host
self._port = port
self._topology: Topology = None # type: ignore[assignment]
# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
@ -760,8 +768,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# Parse options passed as kwargs.
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
is_srv = False
username = None
password = None
dbase = None
@ -769,29 +779,22 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
fqdn = None
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
if len([h for h in host if "/" in h]) > 1:
if len([h for h in self._host if "/" in h]) > 1:
raise ConfigurationError("host must not contain multiple MongoDB URIs")
for entity in host:
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
if "/" in entity:
# Determine connection timeout from kwargs.
timeout = keyword_opts.get("connecttimeoutms")
if timeout is not None:
timeout = common.validate_timeout_or_none_or_zero(
keyword_opts.cased_key("connecttimeoutms"), timeout
)
res = uri_parser.parse_uri(
res = _validate_uri(
entity,
port,
validate=True,
warn=True,
normalize=False,
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
@ -799,7 +802,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(uri_parser.split_hosts(entity, port))
seeds.update(split_hosts(entity, self._port))
if not seeds:
raise ConfigurationError("need to specify at least one host")
@ -820,80 +823,179 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["tz_aware"] = tz_aware
keyword_opts["connect"] = connect
# Handle deprecated options in kwarg options.
keyword_opts = _handle_option_deprecations(keyword_opts)
# Validate kwarg options.
keyword_opts = common._CaseInsensitiveDictionary(
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
)
# Override connection string options with kwarg options.
opts.update(keyword_opts)
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
if srv_service_name is None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
# Handle security-option conflicts in combined options.
opts = _handle_security_options(opts)
# Normalize combined options.
opts = _normalize_options(opts)
_check_options(seeds, opts)
opts = self._normalize_and_validate_options(opts, seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
password = opts.get("password", password)
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._default_database_name = dbase
self._lock = _async_create_lock()
self._kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners
super().__init__(
options.codec_options,
options.read_preference,
options.write_concern,
options.read_concern,
self._encrypter: Optional[_Encrypter] = None
self._resolve_srv_info.update(
{
"is_srv": is_srv,
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
"monitor_class": monitor_class,
"condition_class": condition_class,
}
)
self._topology_settings = TopologySettings(
seeds=seeds,
replica_set_name=options.replica_set_name,
pool_class=pool_class,
pool_options=options.pool_options,
monitor_class=monitor_class,
condition_class=condition_class,
local_threshold_ms=options.local_threshold_ms,
server_selection_timeout=options.server_selection_timeout,
server_selector=options.server_selector,
heartbeat_frequency=options.heartbeat_frequency,
fqdn=fqdn,
direct_connection=options.direct_connection,
load_balanced=options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=options.server_monitoring_mode,
super().__init__(
self._options.codec_options,
self._options.read_preference,
self._options.write_concern,
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._init_background()
if not is_srv:
self._init_background()
if _IS_SYNC and connect:
self._get_topology() # type: ignore[unused-coroutine]
self._encrypter = None
async def _resolve_srv(self) -> None:
keyword_opts = self._resolve_srv_info["keyword_opts"]
seeds = set()
opts = common._CaseInsensitiveDictionary()
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
if "/" in entity:
# Determine connection timeout from kwargs.
timeout = keyword_opts.get("connecttimeoutms")
if timeout is not None:
timeout = common.validate_timeout_or_none_or_zero(
keyword_opts.cased_key("connecttimeoutms"), timeout
)
res = await uri_parser._parse_srv(
entity,
self._port,
validate=True,
warn=True,
normalize=False,
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
)
seeds.update(res["nodelist"])
opts = res["options"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
if _detect_external_db(hostname):
break
# Add options with named keyword arguments to the parsed kwarg options.
tz_aware = keyword_opts["tz_aware"]
connect = keyword_opts["connect"]
if tz_aware is None:
tz_aware = opts.get("tz_aware", False)
if connect is None:
# Default to connect=True unless on a FaaS system, which might use fork.
from pymongo.pool_options import _is_faas
connect = opts.get("connect", not _is_faas())
keyword_opts["tz_aware"] = tz_aware
keyword_opts["connect"] = connect
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
if srv_service_name is None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", self._resolve_srv_info["username"])
password = opts.get("password", self._resolve_srv_info["password"])
self._options = ClientOptions(
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
)
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
def _init_based_on_options(
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
) -> None:
self._event_listeners = self._options.pool_options._event_listeners
self._topology_settings = TopologySettings(
seeds=seeds,
replica_set_name=self._options.replica_set_name,
pool_class=self._resolve_srv_info["pool_class"],
pool_options=self._options.pool_options,
monitor_class=self._resolve_srv_info["monitor_class"],
condition_class=self._resolve_srv_info["condition_class"],
local_threshold_ms=self._options.local_threshold_ms,
server_selection_timeout=self._options.server_selection_timeout,
server_selector=self._options.server_selector,
heartbeat_frequency=self._options.heartbeat_frequency,
fqdn=self._resolve_srv_info["fqdn"],
direct_connection=self._options.direct_connection,
load_balanced=self._options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
)
if self._options.auto_encryption_opts:
from pymongo.asynchronous.encryption import _Encrypter
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
self._timeout = self._options.timeout
if _HAS_REGISTER_AT_FORK:
# Add this client to the list of weakly referenced items.
# This will be used later if we fork.
AsyncMongoClient._clients[self._topology._topology_id] = self
def _normalize_and_validate_options(
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
) -> common._CaseInsensitiveDictionary:
# Handle security-option conflicts in combined options.
opts = _handle_security_options(opts)
# Normalize combined options.
opts = _normalize_options(opts)
_check_options(seeds, opts)
return opts
def _validate_kwargs_and_update_opts(
self,
keyword_opts: common._CaseInsensitiveDictionary,
opts: common._CaseInsensitiveDictionary,
) -> common._CaseInsensitiveDictionary:
# Handle deprecated options in kwarg options.
keyword_opts = _handle_option_deprecations(keyword_opts)
# Validate kwarg options.
keyword_opts = common._CaseInsensitiveDictionary(
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
)
# Override connection string options with kwarg options.
opts.update(keyword_opts)
return opts
async def aconnect(self) -> None:
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
@ -901,6 +1003,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
if _HAS_REGISTER_AT_FORK:
# Add this client to the list of weakly referenced items.
# This will be used later if we fork.
AsyncMongoClient._clients[self._topology._topology_id] = self
# Seed the topology with the old one's pid so we can detect clients
# that are opened before a fork and used after.
self._topology._pid = old_pid
@ -1115,16 +1221,24 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
"""
return self._options
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
return (
tuple(sorted(self._resolve_srv_info["seeds"])),
self._options.replica_set_name,
self._resolve_srv_info["fqdn"],
self._resolve_srv_info["srv_service_name"],
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self._topology == other._topology
return self.eq_props() == other.eq_props()
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self) -> int:
return hash(self._topology)
return hash(self.eq_props())
def _repr_helper(self) -> str:
def option_repr(option: str, value: Any) -> str:
@ -1140,13 +1254,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return f"{option}={value!r}"
# Host first...
options = [
"host=%r"
% [
"%s:%d" % (host, port) if port is not None else host
for host, port in self._topology_settings.seeds
if self._topology is None:
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
else:
options = [
"host=%r"
% [
"%s:%d" % (host, port) if port is not None else host
for host, port in self._topology_settings.seeds
]
]
]
# ... then everything in self._constructor_args...
options.extend(
option_repr(key, self._options._options[key]) for key in self._constructor_args
@ -1552,6 +1669,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionchanged:: 3.6
End all server sessions created by this client.
"""
if self._topology is None:
return
session_ids = self._topology.pop_all_sessions()
if session_ids:
await self._end_sessions(session_ids)
@ -1582,6 +1701,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
launches the connection process in the background.
"""
if not self._opened:
if self._resolve_srv_info["is_srv"]:
await self._resolve_srv()
self._init_background()
await self._topology.open()
async with self._lock:
self._kill_cursors_executor.open()
@ -2511,6 +2633,7 @@ class _MongoClientErrorHandler:
self.completed_handshake,
self.service_id,
)
assert self.client._topology is not None
await self.client._topology.handle_error(self.server_address, err_ctx)
async def __aenter__(self) -> _MongoClientErrorHandler:

View File

@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Optional
from pymongo import common, periodic_executor
from pymongo._csot import MovingMinimum
from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.errors import NetworkTimeout, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _async_create_lock
@ -33,7 +34,6 @@ from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
@ -395,7 +395,7 @@ class SrvMonitor(MonitorBase):
# Don't poll right after creation, wait 60 seconds first
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
return
seedlist = self._get_seedlist()
seedlist = await self._get_seedlist()
if seedlist:
self._seedlist = seedlist
try:
@ -404,7 +404,7 @@ class SrvMonitor(MonitorBase):
# Topology was garbage-collected.
await self.close()
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
@ -415,7 +415,7 @@ class SrvMonitor(MonitorBase):
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception

View File

@ -0,0 +1,160 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
from __future__ import annotations
import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union
from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from dns import resolver
_IS_SYNC = False
def _have_dnspython() -> bool:
try:
import dns # noqa: F401
return True
except ImportError:
return False
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text: Union[str, bytes]) -> str:
if isinstance(text, bytes):
return text.decode()
return text
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
if _IS_SYNC:
from dns import resolver
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
else:
from dns import asyncresolver
if hasattr(asyncresolver, "resolve"):
# dnspython >= 2
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
raise ConfigurationError(
"Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections."
)
_INVALID_HOST_MSG = (
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
"Did you mean to use 'mongodb://'?"
)
class _SrvResolver:
def __init__(
self,
fqdn: str,
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
async def get_options(self) -> Optional[str]:
from dns import resolver
try:
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc)) from None
if len(results) > 1:
raise ConfigurationError("Only one TXT record is supported")
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
try:
results = await _resolve(
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc)) from None
return results
async def _get_srv_response_and_hosts(
self, encapsulate_errors: bool
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
results = await self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
for res in results
]
# Validate hosts
for node in nodes:
try:
nlist = node[0].lower().split(".")[1:][-self.__slen :]
except Exception:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
async def get_hosts(self) -> list[tuple[str, Any]]:
_, nodes = await self._get_srv_response_and_hosts(True)
return nodes
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
results, nodes = await self._get_srv_response_and_hosts(False)
rrset = results.rrset
ttl = rrset.ttl if rrset else 0
return nodes, ttl

View File

@ -0,0 +1,188 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI."""
from __future__ import annotations
from typing import Any, Optional
from urllib.parse import unquote_plus
from pymongo.asynchronous.srv_resolver import _SrvResolver
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.uri_parser_shared import (
_ALLOWED_TXT_OPTS,
DEFAULT_PORT,
SCHEME,
SCHEME_LEN,
SRV_SCHEME_LEN,
_check_options,
_validate_uri,
split_hosts,
split_options,
)
_IS_SYNC = False
async def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.
:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name
.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.
.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.
.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
result.update(
await _parse_srv(
uri,
default_port,
validate,
warn,
normalize,
connect_timeout,
srv_service_name,
srv_max_hosts,
)
)
return result
async def _parse_srv(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
else:
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, _ = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
_, _, hosts = host_part.rpartition("@")
else:
hosts = host_part
hosts = unquote_plus(hosts)
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
nodes = split_hosts(hosts, default_port=None)
fqdn, port = nodes[0]
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = await dns_resolver.get_hosts()
dns_options = await dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"options": options,
}

View File

@ -32,7 +32,7 @@ except ImportError:
from bson import int64
from pymongo.common import validate_is_mapping
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import _parse_kms_tls_options
from pymongo.uri_parser_shared import _parse_kms_tls_options
if TYPE_CHECKING:
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg

View File

@ -19,6 +19,7 @@
from __future__ import annotations
import warnings
from collections import abc
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence
@ -103,6 +104,11 @@ def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]:
if not isinstance(hedge, dict):
raise TypeError(f"hedge must be a dictionary, not {hedge!r}")
warnings.warn(
"The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0.",
DeprecationWarning,
stacklevel=4,
)
return hedge
@ -183,7 +189,9 @@ class _ServerMode:
@property
def hedge(self) -> Optional[_Hedge]:
"""The read preference ``hedge`` parameter.
"""**DEPRECATED** - The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0.
The read preference ``hedge`` parameter.
A dictionary that configures how the server will perform hedged reads.
It consists of the following keys:
@ -203,6 +211,12 @@ class _ServerMode:
.. versionadded:: 3.11
"""
if self.__hedge is not None:
warnings.warn(
"The read preference 'hedge' option is deprecated in PyMongo 4.12+ because hedged reads are deprecated in MongoDB version 8.0+. Support for 'hedge' will be removed in PyMongo 5.0.",
DeprecationWarning,
stacklevel=2,
)
return self.__hedge
@property
@ -312,7 +326,7 @@ class PrimaryPreferred(_ServerMode):
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` to use if the primary is not available.
:param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
@ -354,7 +368,7 @@ class Secondary(_ServerMode):
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
:param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
@ -397,7 +411,7 @@ class SecondaryPreferred(_ServerMode):
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
:param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.
@ -441,7 +455,7 @@ class Nearest(_ServerMode):
replication before it will no longer be selected for operations.
Default -1, meaning no maximum. If it is set, it must be at least
90 seconds.
:param hedge: The :attr:`~hedge` for this read preference.
:param hedge: **DEPRECATED** - The :attr:`~hedge` for this read preference.
.. versionchanged:: 3.11
Added ``hedge`` parameter.

View File

@ -86,7 +86,7 @@ from pymongo.synchronous.pool import (
_raise_connection_failure,
)
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.uri_parser_shared import parse_host
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:

View File

@ -42,6 +42,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
ContextManager,
FrozenSet,
Generator,
@ -59,7 +60,7 @@ from typing import (
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
from bson.timestamp import Timestamp
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.client_options import ClientOptions
from pymongo.errors import (
AutoReconnect,
@ -96,7 +97,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
@ -112,11 +113,14 @@ from pymongo.typings import (
_DocumentTypeArg,
_Pipeline,
)
from pymongo.uri_parser import (
from pymongo.uri_parser_shared import (
SRV_SCHEME,
_check_options,
_handle_option_deprecations,
_handle_security_options,
_normalize_options,
_validate_uri,
split_hosts,
)
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
@ -130,6 +134,7 @@ if TYPE_CHECKING:
from pymongo.synchronous.bulk import _Bulk
from pymongo.synchronous.client_session import ClientSession, _ServerSession
from pymongo.synchronous.cursor import _ConnectionManager
from pymongo.synchronous.encryption import _Encrypter
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.server import Server
@ -748,6 +753,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
port = self.PORT
if not isinstance(port, int):
raise TypeError(f"port must be an instance of int, not {type(port)}")
self._host = host
self._port = port
self._topology: Topology = None # type: ignore[assignment]
# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
@ -758,8 +766,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# Parse options passed as kwargs.
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
seeds = set()
is_srv = False
username = None
password = None
dbase = None
@ -767,29 +777,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
fqdn = None
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
if len([h for h in host if "/" in h]) > 1:
if len([h for h in self._host if "/" in h]) > 1:
raise ConfigurationError("host must not contain multiple MongoDB URIs")
for entity in host:
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
if "/" in entity:
# Determine connection timeout from kwargs.
timeout = keyword_opts.get("connecttimeoutms")
if timeout is not None:
timeout = common.validate_timeout_or_none_or_zero(
keyword_opts.cased_key("connecttimeoutms"), timeout
)
res = uri_parser.parse_uri(
res = _validate_uri(
entity,
port,
validate=True,
warn=True,
normalize=False,
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
@ -797,7 +800,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(uri_parser.split_hosts(entity, port))
seeds.update(split_hosts(entity, self._port))
if not seeds:
raise ConfigurationError("need to specify at least one host")
@ -818,80 +821,179 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
keyword_opts["tz_aware"] = tz_aware
keyword_opts["connect"] = connect
# Handle deprecated options in kwarg options.
keyword_opts = _handle_option_deprecations(keyword_opts)
# Validate kwarg options.
keyword_opts = common._CaseInsensitiveDictionary(
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
)
# Override connection string options with kwarg options.
opts.update(keyword_opts)
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
if srv_service_name is None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
# Handle security-option conflicts in combined options.
opts = _handle_security_options(opts)
# Normalize combined options.
opts = _normalize_options(opts)
_check_options(seeds, opts)
opts = self._normalize_and_validate_options(opts, seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
password = opts.get("password", password)
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
self._default_database_name = dbase
self._lock = _create_lock()
self._kill_cursors_queue: list = []
self._event_listeners = options.pool_options._event_listeners
super().__init__(
options.codec_options,
options.read_preference,
options.write_concern,
options.read_concern,
self._encrypter: Optional[_Encrypter] = None
self._resolve_srv_info.update(
{
"is_srv": is_srv,
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
"monitor_class": monitor_class,
"condition_class": condition_class,
}
)
self._topology_settings = TopologySettings(
seeds=seeds,
replica_set_name=options.replica_set_name,
pool_class=pool_class,
pool_options=options.pool_options,
monitor_class=monitor_class,
condition_class=condition_class,
local_threshold_ms=options.local_threshold_ms,
server_selection_timeout=options.server_selection_timeout,
server_selector=options.server_selector,
heartbeat_frequency=options.heartbeat_frequency,
fqdn=fqdn,
direct_connection=options.direct_connection,
load_balanced=options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=options.server_monitoring_mode,
super().__init__(
self._options.codec_options,
self._options.read_preference,
self._options.write_concern,
self._options.read_concern,
)
if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._opened = False
self._closed = False
self._init_background()
if not is_srv:
self._init_background()
if _IS_SYNC and connect:
self._get_topology() # type: ignore[unused-coroutine]
self._encrypter = None
def _resolve_srv(self) -> None:
keyword_opts = self._resolve_srv_info["keyword_opts"]
seeds = set()
opts = common._CaseInsensitiveDictionary()
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
if "/" in entity:
# Determine connection timeout from kwargs.
timeout = keyword_opts.get("connecttimeoutms")
if timeout is not None:
timeout = common.validate_timeout_or_none_or_zero(
keyword_opts.cased_key("connecttimeoutms"), timeout
)
res = uri_parser._parse_srv(
entity,
self._port,
validate=True,
warn=True,
normalize=False,
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
)
seeds.update(res["nodelist"])
opts = res["options"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
raise ConfigurationError("need to specify at least one host")
for hostname in [node[0] for node in seeds]:
if _detect_external_db(hostname):
break
# Add options with named keyword arguments to the parsed kwarg options.
tz_aware = keyword_opts["tz_aware"]
connect = keyword_opts["connect"]
if tz_aware is None:
tz_aware = opts.get("tz_aware", False)
if connect is None:
# Default to connect=True unless on a FaaS system, which might use fork.
from pymongo.pool_options import _is_faas
connect = opts.get("connect", not _is_faas())
keyword_opts["tz_aware"] = tz_aware
keyword_opts["connect"] = connect
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
if srv_service_name is None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
# Username and password passed as kwargs override user info in URI.
username = opts.get("username", self._resolve_srv_info["username"])
password = opts.get("password", self._resolve_srv_info["password"])
self._options = ClientOptions(
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
)
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
def _init_based_on_options(
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
) -> None:
self._event_listeners = self._options.pool_options._event_listeners
self._topology_settings = TopologySettings(
seeds=seeds,
replica_set_name=self._options.replica_set_name,
pool_class=self._resolve_srv_info["pool_class"],
pool_options=self._options.pool_options,
monitor_class=self._resolve_srv_info["monitor_class"],
condition_class=self._resolve_srv_info["condition_class"],
local_threshold_ms=self._options.local_threshold_ms,
server_selection_timeout=self._options.server_selection_timeout,
server_selector=self._options.server_selector,
heartbeat_frequency=self._options.heartbeat_frequency,
fqdn=self._resolve_srv_info["fqdn"],
direct_connection=self._options.direct_connection,
load_balanced=self._options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
)
if self._options.auto_encryption_opts:
from pymongo.synchronous.encryption import _Encrypter
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
self._timeout = self._options.timeout
if _HAS_REGISTER_AT_FORK:
# Add this client to the list of weakly referenced items.
# This will be used later if we fork.
MongoClient._clients[self._topology._topology_id] = self
def _normalize_and_validate_options(
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
) -> common._CaseInsensitiveDictionary:
# Handle security-option conflicts in combined options.
opts = _handle_security_options(opts)
# Normalize combined options.
opts = _normalize_options(opts)
_check_options(seeds, opts)
return opts
def _validate_kwargs_and_update_opts(
self,
keyword_opts: common._CaseInsensitiveDictionary,
opts: common._CaseInsensitiveDictionary,
) -> common._CaseInsensitiveDictionary:
# Handle deprecated options in kwarg options.
keyword_opts = _handle_option_deprecations(keyword_opts)
# Validate kwarg options.
keyword_opts = common._CaseInsensitiveDictionary(
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
)
# Override connection string options with kwarg options.
opts.update(keyword_opts)
return opts
def _connect(self) -> None:
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
@ -899,6 +1001,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _init_background(self, old_pid: Optional[int] = None) -> None:
self._topology = Topology(self._topology_settings)
if _HAS_REGISTER_AT_FORK:
# Add this client to the list of weakly referenced items.
# This will be used later if we fork.
MongoClient._clients[self._topology._topology_id] = self
# Seed the topology with the old one's pid so we can detect clients
# that are opened before a fork and used after.
self._topology._pid = old_pid
@ -1113,16 +1219,24 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
return self._options
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
return (
tuple(sorted(self._resolve_srv_info["seeds"])),
self._options.replica_set_name,
self._resolve_srv_info["fqdn"],
self._resolve_srv_info["srv_service_name"],
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self._topology == other._topology
return self.eq_props() == other.eq_props()
return NotImplemented
def __ne__(self, other: Any) -> bool:
return not self == other
def __hash__(self) -> int:
return hash(self._topology)
return hash(self.eq_props())
def _repr_helper(self) -> str:
def option_repr(option: str, value: Any) -> str:
@ -1138,13 +1252,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
return f"{option}={value!r}"
# Host first...
options = [
"host=%r"
% [
"%s:%d" % (host, port) if port is not None else host
for host, port in self._topology_settings.seeds
if self._topology is None:
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
else:
options = [
"host=%r"
% [
"%s:%d" % (host, port) if port is not None else host
for host, port in self._topology_settings.seeds
]
]
]
# ... then everything in self._constructor_args...
options.extend(
option_repr(key, self._options._options[key]) for key in self._constructor_args
@ -1546,6 +1663,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
.. versionchanged:: 3.6
End all server sessions created by this client.
"""
if self._topology is None:
return
session_ids = self._topology.pop_all_sessions()
if session_ids:
self._end_sessions(session_ids)
@ -1576,6 +1695,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
launches the connection process in the background.
"""
if not self._opened:
if self._resolve_srv_info["is_srv"]:
self._resolve_srv()
self._init_background()
self._topology.open()
with self._lock:
self._kill_cursors_executor.open()
@ -2497,6 +2619,7 @@ class _MongoClientErrorHandler:
self.completed_handshake,
self.service_id,
)
assert self.client._topology is not None
self.client._topology.handle_error(self.server_address, err_ctx)
def __enter__(self) -> _MongoClientErrorHandler:

View File

@ -33,7 +33,7 @@ from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.srv_resolver import _SrvResolver
from pymongo.synchronous.srv_resolver import _SrvResolver
if TYPE_CHECKING:
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext

View File

@ -25,6 +25,8 @@ from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from dns import resolver
_IS_SYNC = True
def _have_dnspython() -> bool:
try:
@ -45,13 +47,23 @@ def maybe_decode(text: Union[str, bytes]) -> str:
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
from dns import resolver
if _IS_SYNC:
from dns import resolver
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
if hasattr(resolver, "resolve"):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
else:
from dns import asyncresolver
if hasattr(asyncresolver, "resolve"):
# dnspython >= 2
return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
raise ConfigurationError(
"Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections."
)
_INVALID_HOST_MSG = (

View File

@ -0,0 +1,188 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI."""
from __future__ import annotations
from typing import Any, Optional
from urllib.parse import unquote_plus
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.synchronous.srv_resolver import _SrvResolver
from pymongo.uri_parser_shared import (
_ALLOWED_TXT_OPTS,
DEFAULT_PORT,
SCHEME,
SCHEME_LEN,
SRV_SCHEME_LEN,
_check_options,
_validate_uri,
split_hosts,
split_options,
)
_IS_SYNC = True
def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.
:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name
.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.
.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.
.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
result.update(
_parse_srv(
uri,
default_port,
validate,
warn,
normalize,
connect_timeout,
srv_service_name,
srv_max_hosts,
)
)
return result
def _parse_srv(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
else:
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, _ = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
_, _, hosts = host_part.rpartition("@")
else:
hosts = host_part
hosts = unquote_plus(hosts)
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
nodes = split_hosts(hosts, default_port=None)
fqdn, port = nodes[0]
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"options": options,
}

View File

@ -13,627 +13,32 @@
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI.
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
"""
"""Re-import of synchronous URI Parser API for compatibility."""
from __future__ import annotations
import re
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sized,
Union,
cast,
)
from urllib.parse import unquote_plus
from pymongo.client_options import _parse_ssl_options
from pymongo.common import (
INTERNAL_URI_OPTION_NAME_MAP,
SRV_SERVICE_NAME,
URI_OPTIONS_DEPRECATION_MAP,
_CaseInsensitiveDictionary,
get_validated_options,
)
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.srv_resolver import _have_dnspython, _SrvResolver
from pymongo.typings import _Address
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
SCHEME = "mongodb://"
SCHEME_LEN = len(SCHEME)
SRV_SCHEME = "mongodb+srv://"
SRV_SCHEME_LEN = len(SRV_SCHEME)
DEFAULT_PORT = 27017
def _unquoted_percent(s: str) -> bool:
"""Check for unescaped percent signs.
:param s: A string. `s` can have things like '%25', '%2525',
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
"""
for i in range(len(s)):
if s[i] == "%":
sub = s[i : i + 3]
# If unquoting yields the same string this means there was an
# unquoted %.
if unquote_plus(sub) == sub:
return True
return False
def parse_userinfo(userinfo: str) -> tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
Returns a 2-tuple containing the unescaped username followed
by the unescaped password.
:param userinfo: A string of the form <username>:<password>
"""
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
raise InvalidURI(
"Username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
user, _, passwd = userinfo.partition(":")
# No password is expected with GSSAPI authentication.
if not user:
raise InvalidURI("The empty string is not valid username")
return unquote_plus(user), unquote_plus(passwd)
def parse_ipv6_literal_host(
entity: str, default_port: Optional[int]
) -> tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
port is default_port if it wasn't specified in entity.
:param entity: A string that represents an IPv6 literal enclosed
in braces (e.g. '[::1]' or '[::1]:27017').
:param default_port: The port number to use when one wasn't
specified in entity.
"""
if entity.find("]") == -1:
raise ValueError(
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
)
i = entity.find("]:")
if i == -1:
return entity[1:-1], default_port
return entity[1:i], entity[i + 2 :]
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
"""Validates a host string
Returns a 2-tuple of host followed by port where port is default_port
if it wasn't specified in the string.
:param entity: A host or host:port string where host could be a
hostname or IP address.
:param default_port: The port number to use when one wasn't
specified in entity.
"""
host = entity
port: Optional[Union[str, int]] = default_port
if entity[0] == "[":
host, port = parse_ipv6_literal_host(entity, default_port)
elif entity.endswith(".sock"):
return entity, default_port
elif entity.find(":") != -1:
if entity.count(":") > 1:
raise ValueError(
"Reserved characters such as ':' must be "
"escaped according RFC 2396. An IPv6 "
"address literal must be enclosed in '[' "
"and ']' according to RFC 2732."
)
host, port = host.split(":", 1)
if isinstance(port, str):
if not port.isdigit():
# Special case check for mistakes like "mongodb://localhost:27017 ".
if all(c.isspace() or c.isdigit() for c in port):
for c in port:
if c.isspace():
raise ValueError(f"Port contains whitespace character: {c!r}")
# A non-digit port indicates that the URI is invalid, likely because the password
# or username were not escaped.
raise ValueError(
"Port contains non-digit characters. Hint: username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
if int(port) > 65535 or int(port) <= 0:
raise ValueError("Port must be an integer between 0 and 65535")
port = int(port)
# Normalize hostname to lowercase, since DNS is case-insensitive:
# https://tools.ietf.org/html/rfc4343
# This prevents useless rediscovery if "foo.com" is in the seed list but
# "FOO.com" is in the hello response.
return host.lower(), port
# Options whose values are implicitly determined by tlsInsecure.
_IMPLICIT_TLSINSECURE_OPTS = {
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
}
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
"""Helper method for split_options which creates the options dict.
Also handles the creation of a list for the URI tag_sets/
readpreferencetags portion, and the use of a unicode options string.
"""
options = _CaseInsensitiveDictionary()
for uriopt in opts.split(delim):
key, value = uriopt.split("=")
if key.lower() == "readpreferencetags":
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
if key.lower() == "authmechanismproperties":
val = value
else:
val = unquote_plus(value)
options[key] = val
return options
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Raise appropriate errors when conflicting TLS options are present in
the options dictionary.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Implicitly defined options must not be explicitly specified.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
if opt in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
)
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
if tlsallowinvalidcerts is not None:
if "tlsdisableocspendpointcheck" in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
)
if tlsallowinvalidcerts is True:
options["tlsdisableocspendpointcheck"] = True
# Handle co-occurence of CRL and OCSP-related options.
tlscrlfile = options.get("tlscrlfile")
if tlscrlfile is not None:
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
if options.get(opt) is True:
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
raise InvalidURI(err_msg % (opt,))
if "ssl" in options and "tls" in options:
def truth_value(val: Any) -> Any:
if val in ("true", "false"):
return val == "true"
if isinstance(val, bool):
return val
return val
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
err_msg = "Can not specify conflicting values for URI options %s and %s."
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
return options
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Issue appropriate warnings when deprecated options are present in the
options dictionary. Removes deprecated option key, value pairs if the
options dictionary is found to also have the renamed option.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
for optname in list(options):
if optname in URI_OPTIONS_DEPRECATION_MAP:
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
if mode == "renamed":
newoptname = message
if newoptname in options:
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
warnings.warn(
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
DeprecationWarning,
stacklevel=2,
)
options.pop(optname)
continue
warn_msg = "Option '%s' is deprecated, use '%s' instead."
warnings.warn(
warn_msg % (options.cased_key(optname), newoptname),
DeprecationWarning,
stacklevel=2,
)
elif mode == "removed":
warn_msg = "Option '%s' is deprecated. %s."
warnings.warn(
warn_msg % (options.cased_key(optname), message),
DeprecationWarning,
stacklevel=2,
)
return options
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Normalizes option names in the options dictionary by converting them to
their internally-used names.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Expand the tlsInsecure option.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
# Implicit options are logically the same as tlsInsecure.
options[opt] = tlsinsecure
for optname in list(options):
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
if intname is not None:
options[intname] = options.pop(optname)
return options
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
"""Validates and normalizes options passed in a MongoDB URI.
Returns a new dictionary of validated and normalized options. If warn is
False then errors will be thrown for invalid options, otherwise they will
be ignored and a warning will be issued.
:param opts: A dict of MongoDB URI options.
:param warn: If ``True`` then warnings will be logged and
invalid options will be ignored. Otherwise invalid options will
cause errors.
"""
return get_validated_options(opts, warn)
def split_options(
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
) -> MutableMapping[str, Any]:
"""Takes the options portion of a MongoDB URI, validates each option
and returns the options in a dictionary.
:param opt: A string representing MongoDB URI options.
:param validate: If ``True`` (the default), validate and normalize all
options.
:param warn: If ``False`` (the default), suppress all warnings raised
during validation of options.
:param normalize: If ``True`` (the default), renames all options to their
internally-used names.
"""
and_idx = opts.find("&")
semi_idx = opts.find(";")
try:
if and_idx >= 0 and semi_idx >= 0:
raise InvalidURI("Can not mix '&' and ';' for option separators")
elif and_idx >= 0:
options = _parse_options(opts, "&")
elif semi_idx >= 0:
options = _parse_options(opts, ";")
elif opts.find("=") != -1:
options = _parse_options(opts, None)
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs") from None
options = _handle_security_options(options)
options = _handle_option_deprecations(options)
if normalize:
options = _normalize_options(options)
if validate:
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
if options.get("authsource") == "":
raise InvalidURI("the authSource database cannot be an empty string")
return options
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
Returns a set of 2-tuples containing the host name (or IP) followed by
port number.
:param hosts: A string of the form host1[:port],host2[:port],...
:param default_port: The port number to use when one wasn't specified
for a host.
"""
nodes = []
for entity in hosts.split(","):
if not entity:
raise ConfigurationError("Empty host (or extra comma in host list)")
port = default_port
# Unix socket entities don't have ports
if entity.endswith(".sock"):
port = None
nodes.append(parse_host(entity, port))
return nodes
# Prohibited characters in database name. DB names also can't have ".", but for
# backward-compat we allow "db.collection" in URI.
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
_ALLOWED_TXT_OPTS = frozenset(
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
)
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
# Ensure directConnection was not True if there are multiple seeds.
if len(nodes) > 1 and options.get("directconnection"):
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
if options.get("loadbalanced"):
if len(nodes) > 1:
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
if options.get("directconnection"):
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
if options.get("replicaset"):
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
def parse_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.
Returns a dict of the form::
{
'nodelist': <list of (host, port) tuples>,
'username': <username> or None,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
to build nodelist and options.
:param uri: The MongoDB URI to parse.
:param default_port: The port number to use when one wasn't specified
for a host in the URI.
:param validate: If ``True`` (the default), validate and
normalize all options. Default: ``True``.
:param warn: When validating, if ``True`` then will warn
the user then ignore any invalid options or values. If ``False``,
validation will error when options are unsupported or values are
invalid. Default: ``False``.
:param normalize: If ``True``, convert names of URI options
to their internally-used names. Default: ``True``.
:param connect_timeout: The maximum time in milliseconds to
wait for a response from the DNS server.
:param srv_service_name: A custom SRV service name
.. versionchanged:: 4.6
The delimiting slash (``/``) between hosts and connection options is now optional.
For example, "mongodb://example.com?tls=true" is now a valid URI.
.. versionchanged:: 4.0
To better follow RFC 3986, unquoted percent signs ("%") are no longer
supported.
.. versionchanged:: 3.9
Added the ``normalize`` parameter.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
.. versionchanged:: 3.5
Return the original value of the ``readPreference`` MongoDB URI option
instead of the validated read preference mode.
.. versionchanged:: 3.1
``warn`` added so invalid options can be ignored.
"""
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
elif uri.startswith(SRV_SCHEME):
if not _have_dnspython():
python_path = sys.executable or "python"
raise ConfigurationError(
'The "dnspython" module must be '
"installed to use mongodb+srv:// URIs. "
"To fix this error install pymongo again:\n "
"%s -m pip install pymongo>=4.3" % (python_path)
)
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
else:
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP")
user = None
passwd = None
dbase = None
collection = None
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, dbase = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if dbase:
dbase = unquote_plus(dbase)
if "." in dbase:
dbase, collection = dbase.split(".", 1)
if _BAD_DB_CHARS.search(dbase):
raise InvalidURI('Bad database name "%s"' % dbase)
else:
dbase = None
if opts:
options.update(split_options(opts, validate, warn, normalize))
if srv_service_name is None:
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
if "@" in host_part:
userinfo, _, hosts = host_part.rpartition("@")
user, passwd = parse_userinfo(userinfo)
else:
hosts = host_part
if "/" in hosts:
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
hosts = unquote_plus(hosts)
fqdn = None
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
if options.get("directConnection"):
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
nodes = split_hosts(hosts, default_port=None)
if len(nodes) != 1:
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
fqdn, port = nodes[0]
if port is not None:
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
raise ConfigurationError(
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
)
for opt, val in parsed_dns_options.items():
if opt not in options:
options[opt] = val
if options.get("loadBalanced") and srv_max_hosts:
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
if options.get("replicaSet") and srv_max_hosts:
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
if "tls" not in options and "ssl" not in options:
options["tls"] = True if validate else "true"
elif not is_srv and options.get("srvServiceName") is not None:
raise ConfigurationError(
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
)
elif not is_srv and srv_max_hosts:
raise ConfigurationError(
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
)
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"username": user,
"password": passwd,
"database": dbase,
"collection": collection,
"options": options,
"fqdn": fqdn,
}
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}
if not isinstance(kms_tls_options, dict):
raise TypeError("kms_tls_options must be a dict")
contexts = {}
for provider, options in kms_tls_options.items():
if not isinstance(options, dict):
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
options.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(options)
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
if allow_invalid_hostnames:
raise ConfigurationError("Insecure TLS options prohibited")
for n in [
"tlsInsecure",
"tlsAllowInvalidCertificates",
"tlsAllowInvalidHostnames",
"tlsDisableCertificateRevocationCheck",
]:
if n in opts:
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
contexts[provider] = ssl_context
return contexts
from pymongo.errors import InvalidURI
from pymongo.synchronous.uri_parser import * # noqa: F403
from pymongo.synchronous.uri_parser import __doc__ as original_doc
from pymongo.uri_parser_shared import * # noqa: F403
__doc__ = original_doc
__all__ = [ # noqa: F405
"parse_userinfo",
"parse_ipv6_literal_host",
"parse_host",
"validate_options",
"split_options",
"split_hosts",
"parse_uri",
]
if __name__ == "__main__":
import pprint
try:
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203
except InvalidURI as exc:
print(exc) # noqa: T201
sys.exit(0)

View File

@ -0,0 +1,549 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Tools to parse and validate a MongoDB URI.
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
"""
from __future__ import annotations
import re
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sized,
Union,
cast,
)
from urllib.parse import unquote_plus
from pymongo.asynchronous.srv_resolver import _have_dnspython
from pymongo.client_options import _parse_ssl_options
from pymongo.common import (
INTERNAL_URI_OPTION_NAME_MAP,
URI_OPTIONS_DEPRECATION_MAP,
_CaseInsensitiveDictionary,
get_validated_options,
)
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.typings import _Address
if TYPE_CHECKING:
from pymongo.pyopenssl_context import SSLContext
SCHEME = "mongodb://"
SCHEME_LEN = len(SCHEME)
SRV_SCHEME = "mongodb+srv://"
SRV_SCHEME_LEN = len(SRV_SCHEME)
DEFAULT_PORT = 27017
def _unquoted_percent(s: str) -> bool:
"""Check for unescaped percent signs.
:param s: A string. `s` can have things like '%25', '%2525',
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
"""
for i in range(len(s)):
if s[i] == "%":
sub = s[i : i + 3]
# If unquoting yields the same string this means there was an
# unquoted %.
if unquote_plus(sub) == sub:
return True
return False
def parse_userinfo(userinfo: str) -> tuple[str, str]:
"""Validates the format of user information in a MongoDB URI.
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
"]", "@") as per RFC 3986 must be escaped.
Returns a 2-tuple containing the unescaped username followed
by the unescaped password.
:param userinfo: A string of the form <username>:<password>
"""
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
raise InvalidURI(
"Username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
user, _, passwd = userinfo.partition(":")
# No password is expected with GSSAPI authentication.
if not user:
raise InvalidURI("The empty string is not valid username")
return unquote_plus(user), unquote_plus(passwd)
def parse_ipv6_literal_host(
entity: str, default_port: Optional[int]
) -> tuple[str, Optional[Union[str, int]]]:
"""Validates an IPv6 literal host:port string.
Returns a 2-tuple of IPv6 literal followed by port where
port is default_port if it wasn't specified in entity.
:param entity: A string that represents an IPv6 literal enclosed
in braces (e.g. '[::1]' or '[::1]:27017').
:param default_port: The port number to use when one wasn't
specified in entity.
"""
if entity.find("]") == -1:
raise ValueError(
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
)
i = entity.find("]:")
if i == -1:
return entity[1:-1], default_port
return entity[1:i], entity[i + 2 :]
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
"""Validates a host string
Returns a 2-tuple of host followed by port where port is default_port
if it wasn't specified in the string.
:param entity: A host or host:port string where host could be a
hostname or IP address.
:param default_port: The port number to use when one wasn't
specified in entity.
"""
host = entity
port: Optional[Union[str, int]] = default_port
if entity[0] == "[":
host, port = parse_ipv6_literal_host(entity, default_port)
elif entity.endswith(".sock"):
return entity, default_port
elif entity.find(":") != -1:
if entity.count(":") > 1:
raise ValueError(
"Reserved characters such as ':' must be "
"escaped according RFC 2396. An IPv6 "
"address literal must be enclosed in '[' "
"and ']' according to RFC 2732."
)
host, port = host.split(":", 1)
if isinstance(port, str):
if not port.isdigit():
# Special case check for mistakes like "mongodb://localhost:27017 ".
if all(c.isspace() or c.isdigit() for c in port):
for c in port:
if c.isspace():
raise ValueError(f"Port contains whitespace character: {c!r}")
# A non-digit port indicates that the URI is invalid, likely because the password
# or username were not escaped.
raise ValueError(
"Port contains non-digit characters. Hint: username and password must be escaped according to "
"RFC 3986, use urllib.parse.quote_plus"
)
if int(port) > 65535 or int(port) <= 0:
raise ValueError("Port must be an integer between 0 and 65535")
port = int(port)
# Normalize hostname to lowercase, since DNS is case-insensitive:
# https://tools.ietf.org/html/rfc4343
# This prevents useless rediscovery if "foo.com" is in the seed list but
# "FOO.com" is in the hello response.
return host.lower(), port
# Options whose values are implicitly determined by tlsInsecure.
_IMPLICIT_TLSINSECURE_OPTS = {
"tlsallowinvalidcertificates",
"tlsallowinvalidhostnames",
"tlsdisableocspendpointcheck",
}
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
"""Helper method for split_options which creates the options dict.
Also handles the creation of a list for the URI tag_sets/
readpreferencetags portion, and the use of a unicode options string.
"""
options = _CaseInsensitiveDictionary()
for uriopt in opts.split(delim):
key, value = uriopt.split("=")
if key.lower() == "readpreferencetags":
options.setdefault(key, []).append(value)
else:
if key in options:
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
if key.lower() == "authmechanismproperties":
val = value
else:
val = unquote_plus(value)
options[key] = val
return options
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Raise appropriate errors when conflicting TLS options are present in
the options dictionary.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Implicitly defined options must not be explicitly specified.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
if opt in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
)
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
if tlsallowinvalidcerts is not None:
if "tlsdisableocspendpointcheck" in options:
err_msg = "URI options %s and %s cannot be specified simultaneously."
raise InvalidURI(
err_msg
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
)
if tlsallowinvalidcerts is True:
options["tlsdisableocspendpointcheck"] = True
# Handle co-occurence of CRL and OCSP-related options.
tlscrlfile = options.get("tlscrlfile")
if tlscrlfile is not None:
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
if options.get(opt) is True:
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
raise InvalidURI(err_msg % (opt,))
if "ssl" in options and "tls" in options:
def truth_value(val: Any) -> Any:
if val in ("true", "false"):
return val == "true"
if isinstance(val, bool):
return val
return val
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
err_msg = "Can not specify conflicting values for URI options %s and %s."
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
return options
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Issue appropriate warnings when deprecated options are present in the
options dictionary. Removes deprecated option key, value pairs if the
options dictionary is found to also have the renamed option.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
for optname in list(options):
if optname in URI_OPTIONS_DEPRECATION_MAP:
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
if mode == "renamed":
newoptname = message
if newoptname in options:
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
warnings.warn(
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
DeprecationWarning,
stacklevel=2,
)
options.pop(optname)
continue
warn_msg = "Option '%s' is deprecated, use '%s' instead."
warnings.warn(
warn_msg % (options.cased_key(optname), newoptname),
DeprecationWarning,
stacklevel=2,
)
elif mode == "removed":
warn_msg = "Option '%s' is deprecated. %s."
warnings.warn(
warn_msg % (options.cased_key(optname), message),
DeprecationWarning,
stacklevel=2,
)
return options
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
"""Normalizes option names in the options dictionary by converting them to
their internally-used names.
:param options: Instance of _CaseInsensitiveDictionary containing
MongoDB URI options.
"""
# Expand the tlsInsecure option.
tlsinsecure = options.get("tlsinsecure")
if tlsinsecure is not None:
for opt in _IMPLICIT_TLSINSECURE_OPTS:
# Implicit options are logically the same as tlsInsecure.
options[opt] = tlsinsecure
for optname in list(options):
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
if intname is not None:
options[intname] = options.pop(optname)
return options
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
"""Validates and normalizes options passed in a MongoDB URI.
Returns a new dictionary of validated and normalized options. If warn is
False then errors will be thrown for invalid options, otherwise they will
be ignored and a warning will be issued.
:param opts: A dict of MongoDB URI options.
:param warn: If ``True`` then warnings will be logged and
invalid options will be ignored. Otherwise invalid options will
cause errors.
"""
return get_validated_options(opts, warn)
def split_options(
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
) -> MutableMapping[str, Any]:
"""Takes the options portion of a MongoDB URI, validates each option
and returns the options in a dictionary.
:param opt: A string representing MongoDB URI options.
:param validate: If ``True`` (the default), validate and normalize all
options.
:param warn: If ``False`` (the default), suppress all warnings raised
during validation of options.
:param normalize: If ``True`` (the default), renames all options to their
internally-used names.
"""
and_idx = opts.find("&")
semi_idx = opts.find(";")
try:
if and_idx >= 0 and semi_idx >= 0:
raise InvalidURI("Can not mix '&' and ';' for option separators")
elif and_idx >= 0:
options = _parse_options(opts, "&")
elif semi_idx >= 0:
options = _parse_options(opts, ";")
elif opts.find("=") != -1:
options = _parse_options(opts, None)
else:
raise ValueError
except ValueError:
raise InvalidURI("MongoDB URI options are key=value pairs") from None
options = _handle_security_options(options)
options = _handle_option_deprecations(options)
if normalize:
options = _normalize_options(options)
if validate:
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
if options.get("authsource") == "":
raise InvalidURI("the authSource database cannot be an empty string")
return options
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
"""Takes a string of the form host1[:port],host2[:port]... and
splits it into (host, port) tuples. If [:port] isn't present the
default_port is used.
Returns a set of 2-tuples containing the host name (or IP) followed by
port number.
:param hosts: A string of the form host1[:port],host2[:port],...
:param default_port: The port number to use when one wasn't specified
for a host.
"""
nodes = []
for entity in hosts.split(","):
if not entity:
raise ConfigurationError("Empty host (or extra comma in host list)")
port = default_port
# Unix socket entities don't have ports
if entity.endswith(".sock"):
port = None
nodes.append(parse_host(entity, port))
return nodes
# Prohibited characters in database name. DB names also can't have ".", but for
# backward-compat we allow "db.collection" in URI.
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
_ALLOWED_TXT_OPTS = frozenset(
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
)
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
# Ensure directConnection was not True if there are multiple seeds.
if len(nodes) > 1 and options.get("directconnection"):
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
if options.get("loadbalanced"):
if len(nodes) > 1:
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
if options.get("directconnection"):
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
if options.get("replicaset"):
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
"""Parse KMS TLS connection options."""
if not kms_tls_options:
return {}
if not isinstance(kms_tls_options, dict):
raise TypeError("kms_tls_options must be a dict")
contexts = {}
for provider, options in kms_tls_options.items():
if not isinstance(options, dict):
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
options.setdefault("tls", True)
opts = _CaseInsensitiveDictionary(options)
opts = _handle_security_options(opts)
opts = _normalize_options(opts)
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
if ssl_context is None:
raise ConfigurationError("TLS is required for KMS providers")
if allow_invalid_hostnames:
raise ConfigurationError("Insecure TLS options prohibited")
for n in [
"tlsInsecure",
"tlsAllowInvalidCertificates",
"tlsAllowInvalidHostnames",
"tlsDisableCertificateRevocationCheck",
]:
if n in opts:
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
contexts[provider] = ssl_context
return contexts
def _validate_uri(
uri: str,
default_port: Optional[int] = DEFAULT_PORT,
validate: bool = True,
warn: bool = False,
normalize: bool = True,
srv_max_hosts: Optional[int] = None,
) -> dict[str, Any]:
if uri.startswith(SCHEME):
is_srv = False
scheme_free = uri[SCHEME_LEN:]
elif uri.startswith(SRV_SCHEME):
if not _have_dnspython():
python_path = sys.executable or "python"
raise ConfigurationError(
'The "dnspython" module must be '
"installed to use mongodb+srv:// URIs. "
"To fix this error install pymongo again:\n "
"%s -m pip install pymongo>=4.3" % (python_path)
)
is_srv = True
scheme_free = uri[SRV_SCHEME_LEN:]
else:
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
if not scheme_free:
raise InvalidURI("Must provide at least one hostname or IP")
user = None
passwd = None
dbase = None
collection = None
options = _CaseInsensitiveDictionary()
host_plus_db_part, _, opts = scheme_free.partition("?")
if "/" in host_plus_db_part:
host_part, _, dbase = host_plus_db_part.partition("/")
else:
host_part = host_plus_db_part
if dbase:
dbase = unquote_plus(dbase)
if "." in dbase:
dbase, collection = dbase.split(".", 1)
if _BAD_DB_CHARS.search(dbase):
raise InvalidURI('Bad database name "%s"' % dbase)
else:
dbase = None
if opts:
options.update(split_options(opts, validate, warn, normalize))
if "@" in host_part:
userinfo, _, hosts = host_part.rpartition("@")
user, passwd = parse_userinfo(userinfo)
else:
hosts = host_part
if "/" in hosts:
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
hosts = unquote_plus(hosts)
fqdn = None
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
if is_srv:
if options.get("directConnection"):
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
nodes = split_hosts(hosts, default_port=None)
if len(nodes) != 1:
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
fqdn, port = nodes[0]
if port is not None:
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
elif not is_srv and options.get("srvServiceName") is not None:
raise ConfigurationError(
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
)
elif not is_srv and srv_max_hosts:
raise ConfigurationError(
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
)
else:
nodes = split_hosts(hosts, default_port=default_port)
_check_options(nodes, options)
return {
"nodelist": nodes,
"username": user,
"password": passwd,
"database": dbase,
"collection": collection,
"options": options,
"fqdn": fqdn,
}

View File

@ -32,7 +32,7 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
try:
import ipaddress

View File

@ -32,7 +32,7 @@ import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
from pymongo.asynchronous.uri_parser import parse_uri
try:
import ipaddress
@ -1027,7 +1027,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
auth_mech = kwargs.get("authMechanism", "")
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
res = await parse_uri(uri)
if (
not res["username"]
and not res["password"]
@ -1058,7 +1058,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
auth_mech = kwargs.get("authMechanism", "")
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
# Only add the default username or password if one is not provided.
res = parse_uri(uri)
res = await parse_uri(uri)
if (
not res["username"]
and not res["password"]

View File

@ -47,7 +47,7 @@ from bson.son import SON
from pymongo import common, message
from pymongo.read_preferences import ReadPreference
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
if HAVE_SSL:
import ssl

View File

@ -512,13 +512,13 @@ class AsyncClientUnitTest(AsyncUnitTest):
async def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
# Patch the resolver.
from pymongo.srv_resolver import _resolve
from pymongo.asynchronous.srv_resolver import _resolve
patched_resolver = FunctionCallRecorder(_resolve)
pymongo.srv_resolver._resolve = patched_resolver
pymongo.asynchronous.srv_resolver._resolve = patched_resolver
def reset_resolver():
pymongo.srv_resolver._resolve = _resolve
pymongo.asynchronous.srv_resolver._resolve = _resolve
self.addCleanup(reset_resolver)
@ -607,7 +607,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
with self.assertRaisesRegex(ConfigurationError, expected):
AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type]
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
def test_detected_environment_logging(self, mock_get_hosts):
normal_hosts = [
"normal.host.com",
@ -629,7 +629,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
self.assertEqual(len(logs), 7)
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
async def test_detected_environment_warning(self, mock_get_hosts):
with self._caplog.at_level(logging.WARN):
normal_hosts = [
@ -933,6 +933,15 @@ class TestClient(AsyncIntegrationTest):
async with eval(the_repr) as client_two:
self.assertEqual(client_two, client)
async def test_repr_srv_host(self):
client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
# before srv resolution
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
await client.aconnect()
# after srv resolution
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
await client.close()
async def test_getters(self):
await async_wait_until(
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
@ -1911,28 +1920,37 @@ class TestClient(AsyncIntegrationTest):
srvServiceName="customname",
connect=False,
)
await client.aconnect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
await client.close()
client = AsyncMongoClient(
"mongodb+srv://user:password@test22.test.build.10gen.cc"
"/?srvServiceName=shouldbeoverriden",
srvServiceName="customname",
connect=False,
)
await client.aconnect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
await client.close()
client = AsyncMongoClient(
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
connect=False,
)
await client.aconnect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
await client.close()
async def test_srv_max_hosts_kwarg(self):
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
await client.aconnect()
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
await client.aconnect()
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
client = self.simple_client(
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
)
await client.aconnect()
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
@unittest.skipIf(

View File

@ -54,6 +54,7 @@ from bson import Timestamp, json_util
from pymongo import common, monitoring
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.asynchronous.uri_parser import parse_uri
from pymongo.errors import (
AutoReconnect,
ConfigurationError,
@ -66,7 +67,6 @@ from pymongo.helpers_shared import _check_command_response, _check_write_command
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.uri_parser import parse_uri
_IS_SYNC = False
@ -81,7 +81,7 @@ else:
async def create_mock_topology(uri, monitor_class=DummyMonitor):
parsed_uri = parse_uri(uri)
parsed_uri = await parse_uri(uri)
replica_set_name = None
direct_connection = None
load_balanced = None

View File

@ -31,9 +31,10 @@ from test.asynchronous import (
)
from test.utils_shared import async_wait_until
from pymongo.asynchronous.uri_parser import parse_uri
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import parse_uri, split_hosts
from pymongo.uri_parser_shared import split_hosts
_IS_SYNC = False
@ -109,7 +110,7 @@ def create_test(test_case):
hosts = frozenset(split_hosts(",".join(hosts)))
if seeds or num_seeds:
result = parse_uri(uri, validate=True)
result = await parse_uri(uri, validate=True)
if seeds is not None:
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
if num_seeds is not None:
@ -161,7 +162,7 @@ def create_test(test_case):
# and re-run these assertions.
else:
try:
parse_uri(uri)
await parse_uri(uri)
except (ConfigurationError, ValueError):
pass
else:
@ -185,35 +186,24 @@ create_tests(TestDNSSharded)
class TestParsingErrors(AsyncPyMongoTestCase):
async def test_invalid_host(self):
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb is not",
self.simple_client,
"mongodb+srv://mongodb",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb.com is not",
self.simple_client,
"mongodb+srv://mongodb.com",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://127.0.0.1",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://[::1]",
)
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
await client.aconnect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://[::1]")
await client.aconnect()
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
async def test_connect_case_insensitive(self):
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
await client.aconnect()
self.assertGreater(len(client.topology_description.server_descriptions()), 1)

View File

@ -35,6 +35,7 @@ from test.asynchronous import (
)
from test.utils_shared import (
OvertCommandListener,
_ignore_deprecations,
async_wait_until,
one,
)
@ -542,33 +543,44 @@ class TestMongosAndReadPreference(AsyncIntegrationTest):
for mode, cls in cases.items():
with self.assertRaises(TypeError):
cls(hedge=[]) # type: ignore
with _ignore_deprecations():
pref = cls(hedge={})
self.assertEqual(pref.document, {"mode": mode})
out = _maybe_add_read_preference({}, pref)
if cls == SecondaryPreferred:
# SecondaryPreferred without hedge doesn't add $readPreference.
self.assertEqual(out, {})
else:
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = cls(hedge={})
self.assertEqual(pref.document, {"mode": mode})
out = _maybe_add_read_preference({}, pref)
if cls == SecondaryPreferred:
# SecondaryPreferred without hedge doesn't add $readPreference.
self.assertEqual(out, {})
else:
hedge: dict[str, Any] = {"enabled": True}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge: dict[str, Any] = {"enabled": True}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False, "extra": "option"}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False, "extra": "option"}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
def test_read_preference_hedge_deprecated(self):
cases = {
"primaryPreferred": PrimaryPreferred,
"secondary": Secondary,
"secondaryPreferred": SecondaryPreferred,
"nearest": Nearest,
}
for _, cls in cases.items():
with self.assertRaises(DeprecationWarning):
cls(hedge={"enabled": True})
async def test_send_hedge(self):
cases = {
@ -582,7 +594,8 @@ class TestMongosAndReadPreference(AsyncIntegrationTest):
client = await self.async_rs_client(event_listeners=[listener])
await client.admin.command("ping")
for _mode, cls in cases.items():
pref = cls(hedge={"enabled": True})
with _ignore_deprecations():
pref = cls(hedge={"enabled": True})
coll = client.test.get_collection("test", read_preference=pref)
listener.reset()
await coll.find_one()

View File

@ -28,8 +28,8 @@ from test.asynchronous.utils import async_wait_until
import pymongo
from pymongo import common
from pymongo.asynchronous.srv_resolver import _have_dnspython
from pymongo.errors import ConfigurationError
from pymongo.srv_resolver import _have_dnspython
_IS_SYNC = False
@ -54,14 +54,16 @@ class SrvPollingKnobs:
def enable(self):
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
self.old_dns_resolver_response = (
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
)
if self.min_srv_rescan_interval is not None:
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
def mock_get_hosts_and_min_ttl(resolver, *args):
async def mock_get_hosts_and_min_ttl(resolver, *args):
assert self.old_dns_resolver_response is not None
nodes, ttl = self.old_dns_resolver_response(resolver)
nodes, ttl = await self.old_dns_resolver_response(resolver)
if self.nodelist_callback is not None:
nodes = self.nodelist_callback()
if self.ttl_time is not None:
@ -74,14 +76,14 @@ class SrvPollingKnobs:
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
self.old_dns_resolver_response
)
@ -134,7 +136,10 @@ class TestSrvPolling(AsyncPyMongoTestCase):
def predicate():
if set(expected_nodelist) == set(self.get_nodelist(client)):
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
return (
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
>= 1
)
return False
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
@ -144,7 +149,7 @@ class TestSrvPolling(AsyncPyMongoTestCase):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
1,
"resolver was never called",
)

View File

@ -32,7 +32,7 @@ except ImportError:
from pymongo import MongoClient
from pymongo.errors import OperationFailure
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
pytestmark = pytest.mark.auth_aws

View File

@ -49,7 +49,7 @@ from pymongo.synchronous.auth_oidc import (
OIDCCallbackResult,
_get_authenticator,
)
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
ROOT = Path(__file__).parent.parent.resolve()
TEST_PATH = ROOT / "auth" / "unified"

View File

@ -47,7 +47,7 @@ from bson.son import SON
from pymongo import common, message
from pymongo.read_preferences import ReadPreference
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
if HAVE_SSL:
import ssl

View File

@ -505,13 +505,13 @@ class ClientUnitTest(UnitTest):
def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
# Patch the resolver.
from pymongo.srv_resolver import _resolve
from pymongo.synchronous.srv_resolver import _resolve
patched_resolver = FunctionCallRecorder(_resolve)
pymongo.srv_resolver._resolve = patched_resolver
pymongo.synchronous.srv_resolver._resolve = patched_resolver
def reset_resolver():
pymongo.srv_resolver._resolve = _resolve
pymongo.synchronous.srv_resolver._resolve = _resolve
self.addCleanup(reset_resolver)
@ -600,7 +600,7 @@ class ClientUnitTest(UnitTest):
with self.assertRaisesRegex(ConfigurationError, expected):
MongoClient(**{typo: "standard"}) # type: ignore[arg-type]
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
def test_detected_environment_logging(self, mock_get_hosts):
normal_hosts = [
"normal.host.com",
@ -622,7 +622,7 @@ class ClientUnitTest(UnitTest):
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
self.assertEqual(len(logs), 7)
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
def test_detected_environment_warning(self, mock_get_hosts):
with self._caplog.at_level(logging.WARN):
normal_hosts = [
@ -908,6 +908,15 @@ class TestClient(IntegrationTest):
with eval(the_repr) as client_two:
self.assertEqual(client_two, client)
def test_repr_srv_host(self):
client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
# before srv resolution
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
client._connect()
# after srv resolution
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
client.close()
def test_getters(self):
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
@ -1868,28 +1877,37 @@ class TestClient(IntegrationTest):
srvServiceName="customname",
connect=False,
)
client._connect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
client.close()
client = MongoClient(
"mongodb+srv://user:password@test22.test.build.10gen.cc"
"/?srvServiceName=shouldbeoverriden",
srvServiceName="customname",
connect=False,
)
client._connect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
client.close()
client = MongoClient(
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
connect=False,
)
client._connect()
self.assertEqual(client._topology_settings.srv_service_name, "customname")
client.close()
def test_srv_max_hosts_kwarg(self):
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
client._connect()
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
client._connect()
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
client = self.simple_client(
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
)
client._connect()
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
@unittest.skipIf(

View File

@ -65,8 +65,8 @@ from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStarte
from pymongo.server_description import SERVER_TYPE, ServerDescription
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.synchronous.uri_parser import parse_uri
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.uri_parser import parse_uri
_IS_SYNC = True

View File

@ -33,7 +33,8 @@ from test.utils_shared import wait_until
from pymongo.common import validate_read_preference_tags
from pymongo.errors import ConfigurationError
from pymongo.uri_parser import parse_uri, split_hosts
from pymongo.synchronous.uri_parser import parse_uri
from pymongo.uri_parser_shared import split_hosts
_IS_SYNC = True
@ -183,35 +184,24 @@ create_tests(TestDNSSharded)
class TestParsingErrors(PyMongoTestCase):
def test_invalid_host(self):
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb is not",
self.simple_client,
"mongodb+srv://mongodb",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: mongodb.com is not",
self.simple_client,
"mongodb+srv://mongodb.com",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://127.0.0.1",
)
self.assertRaisesRegex(
ConfigurationError,
"Invalid URI host: an IP address is not",
self.simple_client,
"mongodb+srv://[::1]",
)
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
client = self.simple_client("mongodb+srv://mongodb")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
client = self.simple_client("mongodb+srv://mongodb.com")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://127.0.0.1")
client._connect()
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
client = self.simple_client("mongodb+srv://[::1]")
client._connect()
class TestCaseInsensitive(IntegrationTest):
def test_connect_case_insensitive(self):
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
client._connect()
self.assertGreater(len(client.topology_description.server_descriptions()), 1)

View File

@ -35,6 +35,7 @@ from test import (
)
from test.utils_shared import (
OvertCommandListener,
_ignore_deprecations,
one,
wait_until,
)
@ -522,33 +523,44 @@ class TestMongosAndReadPreference(IntegrationTest):
for mode, cls in cases.items():
with self.assertRaises(TypeError):
cls(hedge=[]) # type: ignore
with _ignore_deprecations():
pref = cls(hedge={})
self.assertEqual(pref.document, {"mode": mode})
out = _maybe_add_read_preference({}, pref)
if cls == SecondaryPreferred:
# SecondaryPreferred without hedge doesn't add $readPreference.
self.assertEqual(out, {})
else:
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
pref = cls(hedge={})
self.assertEqual(pref.document, {"mode": mode})
out = _maybe_add_read_preference({}, pref)
if cls == SecondaryPreferred:
# SecondaryPreferred without hedge doesn't add $readPreference.
self.assertEqual(out, {})
else:
hedge: dict[str, Any] = {"enabled": True}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge: dict[str, Any] = {"enabled": True}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False, "extra": "option"}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
hedge = {"enabled": False, "extra": "option"}
pref = cls(hedge=hedge)
self.assertEqual(pref.document, {"mode": mode, "hedge": hedge})
out = _maybe_add_read_preference({}, pref)
self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)]))
def test_read_preference_hedge_deprecated(self):
cases = {
"primaryPreferred": PrimaryPreferred,
"secondary": Secondary,
"secondaryPreferred": SecondaryPreferred,
"nearest": Nearest,
}
for _, cls in cases.items():
with self.assertRaises(DeprecationWarning):
cls(hedge={"enabled": True})
def test_send_hedge(self):
cases = {
@ -562,7 +574,8 @@ class TestMongosAndReadPreference(IntegrationTest):
client = self.rs_client(event_listeners=[listener])
client.admin.command("ping")
for _mode, cls in cases.items():
pref = cls(hedge={"enabled": True})
with _ignore_deprecations():
pref = cls(hedge={"enabled": True})
coll = client.test.get_collection("test", read_preference=pref)
listener.reset()
coll.find_one()

View File

@ -29,7 +29,7 @@ from test.utils import wait_until
import pymongo
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.srv_resolver import _have_dnspython
from pymongo.synchronous.srv_resolver import _have_dnspython
_IS_SYNC = True
@ -54,7 +54,9 @@ class SrvPollingKnobs:
def enable(self):
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
self.old_dns_resolver_response = (
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
)
if self.min_srv_rescan_interval is not None:
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
@ -74,14 +76,14 @@ class SrvPollingKnobs:
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
self.old_dns_resolver_response
)
@ -134,7 +136,10 @@ class TestSrvPolling(PyMongoTestCase):
def predicate():
if set(expected_nodelist) == set(self.get_nodelist(client)):
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
return (
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
>= 1
)
return False
wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
@ -144,7 +149,7 @@ class TestSrvPolling(PyMongoTestCase):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
1,
"resolver was never called",
)

View File

@ -28,8 +28,8 @@ from test import unittest
from bson.binary import JAVA_LEGACY
from pymongo import ReadPreference
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.uri_parser import (
parse_uri,
from pymongo.synchronous.uri_parser import parse_uri
from pymongo.uri_parser_shared import (
parse_userinfo,
split_hosts,
split_options,

View File

@ -29,7 +29,7 @@ from test.helpers import clear_warning_registry
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
from pymongo.compression_support import _have_snappy
from pymongo.uri_parser import parse_uri
from pymongo.synchronous.uri_parser import parse_uri
CONN_STRING_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test")

View File

@ -127,6 +127,7 @@ replacements = {
"async_create_barrier": "create_barrier",
"async_barrier_wait": "barrier_wait",
"async_joinall": "joinall",
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
}
docstring_replacements: dict[tuple[str, str], str] = {