PYTHON-3636 AsyncMongoClient should perform SRV resolution lazily (#2191)
Co-authored-by: Noah Stapp <noah@noahstapp.com> Co-authored-by: Shane Harvey <shane.harvey@mongodb.com>
This commit is contained in:
parent
38ceda4c09
commit
eea8a37257
@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including:
|
||||
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
|
||||
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
|
||||
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
|
||||
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
|
||||
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
|
||||
- Added index hinting support to the
|
||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and
|
||||
:meth:`~pymongo.collection.Collection.distinct` commands.
|
||||
|
||||
@ -87,7 +87,7 @@ from pymongo.read_concern import ReadConcern
|
||||
from pymongo.results import BulkWriteResult, DeleteResult
|
||||
from pymongo.ssl_support import get_ssl_context
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -44,6 +44,7 @@ from typing import (
|
||||
AsyncContextManager,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Collection,
|
||||
Coroutine,
|
||||
FrozenSet,
|
||||
Generic,
|
||||
@ -60,8 +61,8 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo.asynchronous import client_session, database
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous import client_session, database, uri_parser
|
||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||
@ -113,11 +114,14 @@ from pymongo.typings import (
|
||||
_DocumentTypeArg,
|
||||
_Pipeline,
|
||||
)
|
||||
from pymongo.uri_parser import (
|
||||
from pymongo.uri_parser_shared import (
|
||||
SRV_SCHEME,
|
||||
_check_options,
|
||||
_handle_option_deprecations,
|
||||
_handle_security_options,
|
||||
_normalize_options,
|
||||
_validate_uri,
|
||||
split_hosts,
|
||||
)
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||
|
||||
@ -128,6 +132,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.bulk import _AsyncBulk
|
||||
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
|
||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||
from pymongo.asynchronous.encryption import _Encrypter
|
||||
from pymongo.asynchronous.pool import AsyncConnection
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.read_concern import ReadConcern
|
||||
@ -750,6 +755,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
port = self.PORT
|
||||
if not isinstance(port, int):
|
||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._topology: Topology = None # type: ignore[assignment]
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -760,8 +768,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Parse options passed as kwargs.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
dbase = None
|
||||
@ -769,29 +779,22 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
fqdn = None
|
||||
srv_service_name = keyword_opts.get("srvservicename")
|
||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||
if len([h for h in host if "/" in h]) > 1:
|
||||
if len([h for h in self._host if "/" in h]) > 1:
|
||||
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
||||
for entity in host:
|
||||
for entity in self._host:
|
||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||
# it must be a URI,
|
||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||
if "/" in entity:
|
||||
# Determine connection timeout from kwargs.
|
||||
timeout = keyword_opts.get("connecttimeoutms")
|
||||
if timeout is not None:
|
||||
timeout = common.validate_timeout_or_none_or_zero(
|
||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||
)
|
||||
res = uri_parser.parse_uri(
|
||||
res = _validate_uri(
|
||||
entity,
|
||||
port,
|
||||
validate=True,
|
||||
warn=True,
|
||||
normalize=False,
|
||||
connect_timeout=timeout,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
@ -799,7 +802,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(uri_parser.split_hosts(entity, port))
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
@ -820,80 +823,179 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["tz_aware"] = tz_aware
|
||||
keyword_opts["connect"] = connect
|
||||
|
||||
# Handle deprecated options in kwarg options.
|
||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||
# Validate kwarg options.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(
|
||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||
)
|
||||
|
||||
# Override connection string options with kwarg options.
|
||||
opts.update(keyword_opts)
|
||||
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||
|
||||
if srv_service_name is None:
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
# Handle security-option conflicts in combined options.
|
||||
opts = _handle_security_options(opts)
|
||||
# Normalize combined options.
|
||||
opts = _normalize_options(opts)
|
||||
_check_options(seeds, opts)
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
password = opts.get("password", password)
|
||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _async_create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
|
||||
self._event_listeners = options.pool_options._event_listeners
|
||||
super().__init__(
|
||||
options.codec_options,
|
||||
options.read_preference,
|
||||
options.write_concern,
|
||||
options.read_concern,
|
||||
self._encrypter: Optional[_Encrypter] = None
|
||||
|
||||
self._resolve_srv_info.update(
|
||||
{
|
||||
"is_srv": is_srv,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
"monitor_class": monitor_class,
|
||||
"condition_class": condition_class,
|
||||
}
|
||||
)
|
||||
|
||||
self._topology_settings = TopologySettings(
|
||||
seeds=seeds,
|
||||
replica_set_name=options.replica_set_name,
|
||||
pool_class=pool_class,
|
||||
pool_options=options.pool_options,
|
||||
monitor_class=monitor_class,
|
||||
condition_class=condition_class,
|
||||
local_threshold_ms=options.local_threshold_ms,
|
||||
server_selection_timeout=options.server_selection_timeout,
|
||||
server_selector=options.server_selector,
|
||||
heartbeat_frequency=options.heartbeat_frequency,
|
||||
fqdn=fqdn,
|
||||
direct_connection=options.direct_connection,
|
||||
load_balanced=options.load_balanced,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=options.server_monitoring_mode,
|
||||
super().__init__(
|
||||
self._options.codec_options,
|
||||
self._options.read_preference,
|
||||
self._options.write_concern,
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._init_background()
|
||||
if not is_srv:
|
||||
self._init_background()
|
||||
|
||||
if _IS_SYNC and connect:
|
||||
self._get_topology() # type: ignore[unused-coroutine]
|
||||
|
||||
self._encrypter = None
|
||||
async def _resolve_srv(self) -> None:
|
||||
keyword_opts = self._resolve_srv_info["keyword_opts"]
|
||||
seeds = set()
|
||||
opts = common._CaseInsensitiveDictionary()
|
||||
srv_service_name = keyword_opts.get("srvservicename")
|
||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||
for entity in self._host:
|
||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||
# it must be a URI,
|
||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||
if "/" in entity:
|
||||
# Determine connection timeout from kwargs.
|
||||
timeout = keyword_opts.get("connecttimeoutms")
|
||||
if timeout is not None:
|
||||
timeout = common.validate_timeout_or_none_or_zero(
|
||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||
)
|
||||
res = await uri_parser._parse_srv(
|
||||
entity,
|
||||
self._port,
|
||||
validate=True,
|
||||
warn=True,
|
||||
normalize=False,
|
||||
connect_timeout=timeout,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
seeds.update(res["nodelist"])
|
||||
opts = res["options"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
|
||||
if not seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
# Add options with named keyword arguments to the parsed kwarg options.
|
||||
tz_aware = keyword_opts["tz_aware"]
|
||||
connect = keyword_opts["connect"]
|
||||
if tz_aware is None:
|
||||
tz_aware = opts.get("tz_aware", False)
|
||||
if connect is None:
|
||||
# Default to connect=True unless on a FaaS system, which might use fork.
|
||||
from pymongo.pool_options import _is_faas
|
||||
|
||||
connect = opts.get("connect", not _is_faas())
|
||||
keyword_opts["tz_aware"] = tz_aware
|
||||
keyword_opts["connect"] = connect
|
||||
|
||||
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||
|
||||
if srv_service_name is None:
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", self._resolve_srv_info["username"])
|
||||
password = opts.get("password", self._resolve_srv_info["password"])
|
||||
self._options = ClientOptions(
|
||||
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
|
||||
)
|
||||
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
def _init_based_on_options(
|
||||
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
|
||||
) -> None:
|
||||
self._event_listeners = self._options.pool_options._event_listeners
|
||||
self._topology_settings = TopologySettings(
|
||||
seeds=seeds,
|
||||
replica_set_name=self._options.replica_set_name,
|
||||
pool_class=self._resolve_srv_info["pool_class"],
|
||||
pool_options=self._options.pool_options,
|
||||
monitor_class=self._resolve_srv_info["monitor_class"],
|
||||
condition_class=self._resolve_srv_info["condition_class"],
|
||||
local_threshold_ms=self._options.local_threshold_ms,
|
||||
server_selection_timeout=self._options.server_selection_timeout,
|
||||
server_selector=self._options.server_selector,
|
||||
heartbeat_frequency=self._options.heartbeat_frequency,
|
||||
fqdn=self._resolve_srv_info["fqdn"],
|
||||
direct_connection=self._options.direct_connection,
|
||||
load_balanced=self._options.load_balanced,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.asynchronous.encryption import _Encrypter
|
||||
|
||||
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
||||
self._timeout = self._options.timeout
|
||||
|
||||
if _HAS_REGISTER_AT_FORK:
|
||||
# Add this client to the list of weakly referenced items.
|
||||
# This will be used later if we fork.
|
||||
AsyncMongoClient._clients[self._topology._topology_id] = self
|
||||
def _normalize_and_validate_options(
|
||||
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
|
||||
) -> common._CaseInsensitiveDictionary:
|
||||
# Handle security-option conflicts in combined options.
|
||||
opts = _handle_security_options(opts)
|
||||
# Normalize combined options.
|
||||
opts = _normalize_options(opts)
|
||||
_check_options(seeds, opts)
|
||||
return opts
|
||||
|
||||
def _validate_kwargs_and_update_opts(
|
||||
self,
|
||||
keyword_opts: common._CaseInsensitiveDictionary,
|
||||
opts: common._CaseInsensitiveDictionary,
|
||||
) -> common._CaseInsensitiveDictionary:
|
||||
# Handle deprecated options in kwarg options.
|
||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||
# Validate kwarg options.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(
|
||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||
)
|
||||
# Override connection string options with kwarg options.
|
||||
opts.update(keyword_opts)
|
||||
return opts
|
||||
|
||||
async def aconnect(self) -> None:
|
||||
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
|
||||
@ -901,6 +1003,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||
self._topology = Topology(self._topology_settings)
|
||||
if _HAS_REGISTER_AT_FORK:
|
||||
# Add this client to the list of weakly referenced items.
|
||||
# This will be used later if we fork.
|
||||
AsyncMongoClient._clients[self._topology._topology_id] = self
|
||||
# Seed the topology with the old one's pid so we can detect clients
|
||||
# that are opened before a fork and used after.
|
||||
self._topology._pid = old_pid
|
||||
@ -1115,16 +1221,24 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return self._options
|
||||
|
||||
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
|
||||
return (
|
||||
tuple(sorted(self._resolve_srv_info["seeds"])),
|
||||
self._options.replica_set_name,
|
||||
self._resolve_srv_info["fqdn"],
|
||||
self._resolve_srv_info["srv_service_name"],
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self._topology == other._topology
|
||||
return self.eq_props() == other.eq_props()
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._topology)
|
||||
return hash(self.eq_props())
|
||||
|
||||
def _repr_helper(self) -> str:
|
||||
def option_repr(option: str, value: Any) -> str:
|
||||
@ -1140,13 +1254,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return f"{option}={value!r}"
|
||||
|
||||
# Host first...
|
||||
options = [
|
||||
"host=%r"
|
||||
% [
|
||||
"%s:%d" % (host, port) if port is not None else host
|
||||
for host, port in self._topology_settings.seeds
|
||||
if self._topology is None:
|
||||
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
|
||||
else:
|
||||
options = [
|
||||
"host=%r"
|
||||
% [
|
||||
"%s:%d" % (host, port) if port is not None else host
|
||||
for host, port in self._topology_settings.seeds
|
||||
]
|
||||
]
|
||||
]
|
||||
# ... then everything in self._constructor_args...
|
||||
options.extend(
|
||||
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
||||
@ -1552,6 +1669,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionchanged:: 3.6
|
||||
End all server sessions created by this client.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return
|
||||
session_ids = self._topology.pop_all_sessions()
|
||||
if session_ids:
|
||||
await self._end_sessions(session_ids)
|
||||
@ -1582,6 +1701,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
launches the connection process in the background.
|
||||
"""
|
||||
if not self._opened:
|
||||
if self._resolve_srv_info["is_srv"]:
|
||||
await self._resolve_srv()
|
||||
self._init_background()
|
||||
await self._topology.open()
|
||||
async with self._lock:
|
||||
self._kill_cursors_executor.open()
|
||||
@ -2511,6 +2633,7 @@ class _MongoClientErrorHandler:
|
||||
self.completed_handshake,
|
||||
self.service_id,
|
||||
)
|
||||
assert self.client._topology is not None
|
||||
await self.client._topology.handle_error(self.server_address, err_ctx)
|
||||
|
||||
async def __aenter__(self) -> _MongoClientErrorHandler:
|
||||
|
||||
@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo._csot import MovingMinimum
|
||||
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.errors import NetworkTimeout, _OperationCancelled
|
||||
from pymongo.hello import Hello
|
||||
from pymongo.lock import _async_create_lock
|
||||
@ -33,7 +34,6 @@ from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
||||
@ -395,7 +395,7 @@ class SrvMonitor(MonitorBase):
|
||||
# Don't poll right after creation, wait 60 seconds first
|
||||
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
||||
return
|
||||
seedlist = self._get_seedlist()
|
||||
seedlist = await self._get_seedlist()
|
||||
if seedlist:
|
||||
self._seedlist = seedlist
|
||||
try:
|
||||
@ -404,7 +404,7 @@ class SrvMonitor(MonitorBase):
|
||||
# Topology was garbage-collected.
|
||||
await self.close()
|
||||
|
||||
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||
"""Poll SRV records for a seedlist.
|
||||
|
||||
Returns a list of ServerDescriptions.
|
||||
@ -415,7 +415,7 @@ class SrvMonitor(MonitorBase):
|
||||
self._settings.pool_options.connect_timeout,
|
||||
self._settings.srv_service_name,
|
||||
)
|
||||
seedlist, ttl = resolver.get_hosts_and_min_ttl()
|
||||
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
|
||||
160
pymongo/asynchronous/srv_resolver.py
Normal file
160
pymongo/asynchronous/srv_resolver.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
import dns # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text: Union[str, bytes]) -> str:
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
|
||||
|
||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||
if _IS_SYNC:
|
||||
from dns import resolver
|
||||
|
||||
if hasattr(resolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
else:
|
||||
from dns import asyncresolver
|
||||
|
||||
if hasattr(asyncresolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
|
||||
raise ConfigurationError(
|
||||
"Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections."
|
||||
)
|
||||
|
||||
|
||||
_INVALID_HOST_MSG = (
|
||||
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
|
||||
"Did you mean to use 'mongodb://'?"
|
||||
)
|
||||
|
||||
|
||||
class _SrvResolver:
|
||||
def __init__(
|
||||
self,
|
||||
fqdn: str,
|
||||
connect_timeout: Optional[float],
|
||||
srv_service_name: str,
|
||||
srv_max_hosts: int = 0,
|
||||
):
|
||||
self.__fqdn = fqdn
|
||||
self.__srv = srv_service_name
|
||||
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
|
||||
self.__srv_max_hosts = srv_max_hosts or 0
|
||||
# Validate the fully qualified domain name.
|
||||
try:
|
||||
ipaddress.ip_address(fqdn)
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
except Exception:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||
|
||||
async def get_options(self) -> Optional[str]:
|
||||
from dns import resolver
|
||||
|
||||
try:
|
||||
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError("Only one TXT record is supported")
|
||||
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
|
||||
|
||||
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
|
||||
try:
|
||||
results = await _resolve(
|
||||
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
if not encapsulate_errors:
|
||||
# Raise the original error.
|
||||
raise
|
||||
# Else, raise all errors as ConfigurationError.
|
||||
raise ConfigurationError(str(exc)) from None
|
||||
return results
|
||||
|
||||
async def _get_srv_response_and_hosts(
|
||||
self, encapsulate_errors: bool
|
||||
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
|
||||
results = await self._resolve_uri(encapsulate_errors)
|
||||
|
||||
# Construct address tuples
|
||||
nodes = [
|
||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
|
||||
for res in results
|
||||
]
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
try:
|
||||
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||
except Exception:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||
if self.__plist != nlist:
|
||||
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
|
||||
if self.__srv_max_hosts:
|
||||
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
|
||||
return results, nodes
|
||||
|
||||
async def get_hosts(self) -> list[tuple[str, Any]]:
|
||||
_, nodes = await self._get_srv_response_and_hosts(True)
|
||||
return nodes
|
||||
|
||||
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
|
||||
results, nodes = await self._get_srv_response_and_hosts(False)
|
||||
rrset = results.rrset
|
||||
ttl = rrset.ttl if rrset else 0
|
||||
return nodes, ttl
|
||||
188
pymongo/asynchronous/uri_parser.py
Normal file
188
pymongo/asynchronous/uri_parser.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.uri_parser_shared import (
|
||||
_ALLOWED_TXT_OPTS,
|
||||
DEFAULT_PORT,
|
||||
SCHEME,
|
||||
SCHEME_LEN,
|
||||
SRV_SCHEME_LEN,
|
||||
_check_options,
|
||||
_validate_uri,
|
||||
split_hosts,
|
||||
split_options,
|
||||
)
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
async def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
|
||||
result.update(
|
||||
await _parse_srv(
|
||||
uri,
|
||||
default_port,
|
||||
validate,
|
||||
warn,
|
||||
normalize,
|
||||
connect_timeout,
|
||||
srv_service_name,
|
||||
srv_max_hosts,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def _parse_srv(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
else:
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, _ = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
_, _, hosts = host_part.rpartition("@")
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
fqdn, port = nodes[0]
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = await dns_resolver.get_hosts()
|
||||
dns_options = await dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"options": options,
|
||||
}
|
||||
@ -32,7 +32,7 @@ except ImportError:
|
||||
from bson import int64
|
||||
from pymongo.common import validate_is_mapping
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import _parse_kms_tls_options
|
||||
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
||||
|
||||
@ -86,7 +86,7 @@ from pymongo.synchronous.pool import (
|
||||
_raise_connection_failure,
|
||||
)
|
||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||
from pymongo.uri_parser import parse_host
|
||||
from pymongo.uri_parser_shared import parse_host
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -42,6 +42,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
ContextManager,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
@ -59,7 +60,7 @@ from typing import (
|
||||
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||
from bson.timestamp import Timestamp
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.client_options import ClientOptions
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
@ -96,7 +97,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.results import ClientBulkWriteResult
|
||||
from pymongo.server_selectors import writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous import client_session, database
|
||||
from pymongo.synchronous import client_session, database, uri_parser
|
||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||
@ -112,11 +113,14 @@ from pymongo.typings import (
|
||||
_DocumentTypeArg,
|
||||
_Pipeline,
|
||||
)
|
||||
from pymongo.uri_parser import (
|
||||
from pymongo.uri_parser_shared import (
|
||||
SRV_SCHEME,
|
||||
_check_options,
|
||||
_handle_option_deprecations,
|
||||
_handle_security_options,
|
||||
_normalize_options,
|
||||
_validate_uri,
|
||||
split_hosts,
|
||||
)
|
||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||
|
||||
@ -130,6 +134,7 @@ if TYPE_CHECKING:
|
||||
from pymongo.synchronous.bulk import _Bulk
|
||||
from pymongo.synchronous.client_session import ClientSession, _ServerSession
|
||||
from pymongo.synchronous.cursor import _ConnectionManager
|
||||
from pymongo.synchronous.encryption import _Encrypter
|
||||
from pymongo.synchronous.pool import Connection
|
||||
from pymongo.synchronous.server import Server
|
||||
|
||||
@ -748,6 +753,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
port = self.PORT
|
||||
if not isinstance(port, int):
|
||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._topology: Topology = None # type: ignore[assignment]
|
||||
|
||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||
# customization of PyMongo, e.g. Motor.
|
||||
@ -758,8 +766,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# Parse options passed as kwargs.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
||||
keyword_opts["document_class"] = doc_class
|
||||
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||
|
||||
seeds = set()
|
||||
is_srv = False
|
||||
username = None
|
||||
password = None
|
||||
dbase = None
|
||||
@ -767,29 +777,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
fqdn = None
|
||||
srv_service_name = keyword_opts.get("srvservicename")
|
||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||
if len([h for h in host if "/" in h]) > 1:
|
||||
if len([h for h in self._host if "/" in h]) > 1:
|
||||
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
||||
for entity in host:
|
||||
for entity in self._host:
|
||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||
# it must be a URI,
|
||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||
if "/" in entity:
|
||||
# Determine connection timeout from kwargs.
|
||||
timeout = keyword_opts.get("connecttimeoutms")
|
||||
if timeout is not None:
|
||||
timeout = common.validate_timeout_or_none_or_zero(
|
||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||
)
|
||||
res = uri_parser.parse_uri(
|
||||
res = _validate_uri(
|
||||
entity,
|
||||
port,
|
||||
validate=True,
|
||||
warn=True,
|
||||
normalize=False,
|
||||
connect_timeout=timeout,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
is_srv = entity.startswith(SRV_SCHEME)
|
||||
seeds.update(res["nodelist"])
|
||||
username = res["username"] or username
|
||||
password = res["password"] or password
|
||||
@ -797,7 +800,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(uri_parser.split_hosts(entity, port))
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
if not seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
@ -818,80 +821,179 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
keyword_opts["tz_aware"] = tz_aware
|
||||
keyword_opts["connect"] = connect
|
||||
|
||||
# Handle deprecated options in kwarg options.
|
||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||
# Validate kwarg options.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(
|
||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||
)
|
||||
|
||||
# Override connection string options with kwarg options.
|
||||
opts.update(keyword_opts)
|
||||
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||
|
||||
if srv_service_name is None:
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
# Handle security-option conflicts in combined options.
|
||||
opts = _handle_security_options(opts)
|
||||
# Normalize combined options.
|
||||
opts = _normalize_options(opts)
|
||||
_check_options(seeds, opts)
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", username)
|
||||
password = opts.get("password", password)
|
||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||
|
||||
self._default_database_name = dbase
|
||||
self._lock = _create_lock()
|
||||
self._kill_cursors_queue: list = []
|
||||
|
||||
self._event_listeners = options.pool_options._event_listeners
|
||||
super().__init__(
|
||||
options.codec_options,
|
||||
options.read_preference,
|
||||
options.write_concern,
|
||||
options.read_concern,
|
||||
self._encrypter: Optional[_Encrypter] = None
|
||||
|
||||
self._resolve_srv_info.update(
|
||||
{
|
||||
"is_srv": is_srv,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"dbase": dbase,
|
||||
"seeds": seeds,
|
||||
"fqdn": fqdn,
|
||||
"srv_service_name": srv_service_name,
|
||||
"pool_class": pool_class,
|
||||
"monitor_class": monitor_class,
|
||||
"condition_class": condition_class,
|
||||
}
|
||||
)
|
||||
|
||||
self._topology_settings = TopologySettings(
|
||||
seeds=seeds,
|
||||
replica_set_name=options.replica_set_name,
|
||||
pool_class=pool_class,
|
||||
pool_options=options.pool_options,
|
||||
monitor_class=monitor_class,
|
||||
condition_class=condition_class,
|
||||
local_threshold_ms=options.local_threshold_ms,
|
||||
server_selection_timeout=options.server_selection_timeout,
|
||||
server_selector=options.server_selector,
|
||||
heartbeat_frequency=options.heartbeat_frequency,
|
||||
fqdn=fqdn,
|
||||
direct_connection=options.direct_connection,
|
||||
load_balanced=options.load_balanced,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=options.server_monitoring_mode,
|
||||
super().__init__(
|
||||
self._options.codec_options,
|
||||
self._options.read_preference,
|
||||
self._options.write_concern,
|
||||
self._options.read_concern,
|
||||
)
|
||||
|
||||
if not is_srv:
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
self._opened = False
|
||||
self._closed = False
|
||||
self._init_background()
|
||||
if not is_srv:
|
||||
self._init_background()
|
||||
|
||||
if _IS_SYNC and connect:
|
||||
self._get_topology() # type: ignore[unused-coroutine]
|
||||
|
||||
self._encrypter = None
|
||||
def _resolve_srv(self) -> None:
|
||||
keyword_opts = self._resolve_srv_info["keyword_opts"]
|
||||
seeds = set()
|
||||
opts = common._CaseInsensitiveDictionary()
|
||||
srv_service_name = keyword_opts.get("srvservicename")
|
||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||
for entity in self._host:
|
||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||
# it must be a URI,
|
||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||
if "/" in entity:
|
||||
# Determine connection timeout from kwargs.
|
||||
timeout = keyword_opts.get("connecttimeoutms")
|
||||
if timeout is not None:
|
||||
timeout = common.validate_timeout_or_none_or_zero(
|
||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||
)
|
||||
res = uri_parser._parse_srv(
|
||||
entity,
|
||||
self._port,
|
||||
validate=True,
|
||||
warn=True,
|
||||
normalize=False,
|
||||
connect_timeout=timeout,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
)
|
||||
seeds.update(res["nodelist"])
|
||||
opts = res["options"]
|
||||
else:
|
||||
seeds.update(split_hosts(entity, self._port))
|
||||
|
||||
if not seeds:
|
||||
raise ConfigurationError("need to specify at least one host")
|
||||
|
||||
for hostname in [node[0] for node in seeds]:
|
||||
if _detect_external_db(hostname):
|
||||
break
|
||||
|
||||
# Add options with named keyword arguments to the parsed kwarg options.
|
||||
tz_aware = keyword_opts["tz_aware"]
|
||||
connect = keyword_opts["connect"]
|
||||
if tz_aware is None:
|
||||
tz_aware = opts.get("tz_aware", False)
|
||||
if connect is None:
|
||||
# Default to connect=True unless on a FaaS system, which might use fork.
|
||||
from pymongo.pool_options import _is_faas
|
||||
|
||||
connect = opts.get("connect", not _is_faas())
|
||||
keyword_opts["tz_aware"] = tz_aware
|
||||
keyword_opts["connect"] = connect
|
||||
|
||||
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||
|
||||
if srv_service_name is None:
|
||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||
|
||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||
opts = self._normalize_and_validate_options(opts, seeds)
|
||||
|
||||
# Username and password passed as kwargs override user info in URI.
|
||||
username = opts.get("username", self._resolve_srv_info["username"])
|
||||
password = opts.get("password", self._resolve_srv_info["password"])
|
||||
self._options = ClientOptions(
|
||||
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
|
||||
)
|
||||
|
||||
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||
|
||||
def _init_based_on_options(
|
||||
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
|
||||
) -> None:
|
||||
self._event_listeners = self._options.pool_options._event_listeners
|
||||
self._topology_settings = TopologySettings(
|
||||
seeds=seeds,
|
||||
replica_set_name=self._options.replica_set_name,
|
||||
pool_class=self._resolve_srv_info["pool_class"],
|
||||
pool_options=self._options.pool_options,
|
||||
monitor_class=self._resolve_srv_info["monitor_class"],
|
||||
condition_class=self._resolve_srv_info["condition_class"],
|
||||
local_threshold_ms=self._options.local_threshold_ms,
|
||||
server_selection_timeout=self._options.server_selection_timeout,
|
||||
server_selector=self._options.server_selector,
|
||||
heartbeat_frequency=self._options.heartbeat_frequency,
|
||||
fqdn=self._resolve_srv_info["fqdn"],
|
||||
direct_connection=self._options.direct_connection,
|
||||
load_balanced=self._options.load_balanced,
|
||||
srv_service_name=srv_service_name,
|
||||
srv_max_hosts=srv_max_hosts,
|
||||
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||
)
|
||||
if self._options.auto_encryption_opts:
|
||||
from pymongo.synchronous.encryption import _Encrypter
|
||||
|
||||
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
||||
self._timeout = self._options.timeout
|
||||
|
||||
if _HAS_REGISTER_AT_FORK:
|
||||
# Add this client to the list of weakly referenced items.
|
||||
# This will be used later if we fork.
|
||||
MongoClient._clients[self._topology._topology_id] = self
|
||||
def _normalize_and_validate_options(
|
||||
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
|
||||
) -> common._CaseInsensitiveDictionary:
|
||||
# Handle security-option conflicts in combined options.
|
||||
opts = _handle_security_options(opts)
|
||||
# Normalize combined options.
|
||||
opts = _normalize_options(opts)
|
||||
_check_options(seeds, opts)
|
||||
return opts
|
||||
|
||||
def _validate_kwargs_and_update_opts(
|
||||
self,
|
||||
keyword_opts: common._CaseInsensitiveDictionary,
|
||||
opts: common._CaseInsensitiveDictionary,
|
||||
) -> common._CaseInsensitiveDictionary:
|
||||
# Handle deprecated options in kwarg options.
|
||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||
# Validate kwarg options.
|
||||
keyword_opts = common._CaseInsensitiveDictionary(
|
||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||
)
|
||||
# Override connection string options with kwarg options.
|
||||
opts.update(keyword_opts)
|
||||
return opts
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
|
||||
@ -899,6 +1001,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
|
||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||
self._topology = Topology(self._topology_settings)
|
||||
if _HAS_REGISTER_AT_FORK:
|
||||
# Add this client to the list of weakly referenced items.
|
||||
# This will be used later if we fork.
|
||||
MongoClient._clients[self._topology._topology_id] = self
|
||||
# Seed the topology with the old one's pid so we can detect clients
|
||||
# that are opened before a fork and used after.
|
||||
self._topology._pid = old_pid
|
||||
@ -1113,16 +1219,24 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return self._options
|
||||
|
||||
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
|
||||
return (
|
||||
tuple(sorted(self._resolve_srv_info["seeds"])),
|
||||
self._options.replica_set_name,
|
||||
self._resolve_srv_info["fqdn"],
|
||||
self._resolve_srv_info["srv_service_name"],
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self._topology == other._topology
|
||||
return self.eq_props() == other.eq_props()
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._topology)
|
||||
return hash(self.eq_props())
|
||||
|
||||
def _repr_helper(self) -> str:
|
||||
def option_repr(option: str, value: Any) -> str:
|
||||
@ -1138,13 +1252,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return f"{option}={value!r}"
|
||||
|
||||
# Host first...
|
||||
options = [
|
||||
"host=%r"
|
||||
% [
|
||||
"%s:%d" % (host, port) if port is not None else host
|
||||
for host, port in self._topology_settings.seeds
|
||||
if self._topology is None:
|
||||
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
|
||||
else:
|
||||
options = [
|
||||
"host=%r"
|
||||
% [
|
||||
"%s:%d" % (host, port) if port is not None else host
|
||||
for host, port in self._topology_settings.seeds
|
||||
]
|
||||
]
|
||||
]
|
||||
# ... then everything in self._constructor_args...
|
||||
options.extend(
|
||||
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
||||
@ -1546,6 +1663,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
.. versionchanged:: 3.6
|
||||
End all server sessions created by this client.
|
||||
"""
|
||||
if self._topology is None:
|
||||
return
|
||||
session_ids = self._topology.pop_all_sessions()
|
||||
if session_ids:
|
||||
self._end_sessions(session_ids)
|
||||
@ -1576,6 +1695,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
launches the connection process in the background.
|
||||
"""
|
||||
if not self._opened:
|
||||
if self._resolve_srv_info["is_srv"]:
|
||||
self._resolve_srv()
|
||||
self._init_background()
|
||||
self._topology.open()
|
||||
with self._lock:
|
||||
self._kill_cursors_executor.open()
|
||||
@ -2497,6 +2619,7 @@ class _MongoClientErrorHandler:
|
||||
self.completed_handshake,
|
||||
self.service_id,
|
||||
)
|
||||
assert self.client._topology is not None
|
||||
self.client._topology.handle_error(self.server_address, err_ctx)
|
||||
|
||||
def __enter__(self) -> _MongoClientErrorHandler:
|
||||
|
||||
@ -33,7 +33,7 @@ from pymongo.periodic_executor import _shutdown_executors
|
||||
from pymongo.pool_options import _is_faas
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||
|
||||
@ -25,6 +25,8 @@ from pymongo.errors import ConfigurationError
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def _have_dnspython() -> bool:
|
||||
try:
|
||||
@ -45,13 +47,23 @@ def maybe_decode(text: Union[str, bytes]) -> str:
|
||||
|
||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||
from dns import resolver
|
||||
if _IS_SYNC:
|
||||
from dns import resolver
|
||||
|
||||
if hasattr(resolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
if hasattr(resolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return resolver.resolve(*args, **kwargs)
|
||||
# dnspython 1.X
|
||||
return resolver.query(*args, **kwargs)
|
||||
else:
|
||||
from dns import asyncresolver
|
||||
|
||||
if hasattr(asyncresolver, "resolve"):
|
||||
# dnspython >= 2
|
||||
return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
|
||||
raise ConfigurationError(
|
||||
"Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections."
|
||||
)
|
||||
|
||||
|
||||
_INVALID_HOST_MSG = (
|
||||
188
pymongo/synchronous/uri_parser.py
Normal file
188
pymongo/synchronous/uri_parser.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||
from pymongo.uri_parser_shared import (
|
||||
_ALLOWED_TXT_OPTS,
|
||||
DEFAULT_PORT,
|
||||
SCHEME,
|
||||
SCHEME_LEN,
|
||||
SRV_SCHEME_LEN,
|
||||
_check_options,
|
||||
_validate_uri,
|
||||
split_hosts,
|
||||
split_options,
|
||||
)
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
|
||||
result.update(
|
||||
_parse_srv(
|
||||
uri,
|
||||
default_port,
|
||||
validate,
|
||||
warn,
|
||||
normalize,
|
||||
connect_timeout,
|
||||
srv_service_name,
|
||||
srv_max_hosts,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_srv(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
else:
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, _ = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
_, _, hosts = host_part.rpartition("@")
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
fqdn, port = nodes[0]
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"options": options,
|
||||
}
|
||||
@ -13,627 +13,32 @@
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI.
|
||||
|
||||
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
|
||||
"""
|
||||
"""Re-import of synchronous URI Parser API for compatibility."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sized,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.client_options import _parse_ssl_options
|
||||
from pymongo.common import (
|
||||
INTERNAL_URI_OPTION_NAME_MAP,
|
||||
SRV_SERVICE_NAME,
|
||||
URI_OPTIONS_DEPRECATION_MAP,
|
||||
_CaseInsensitiveDictionary,
|
||||
get_validated_options,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.srv_resolver import _have_dnspython, _SrvResolver
|
||||
from pymongo.typings import _Address
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
SCHEME = "mongodb://"
|
||||
SCHEME_LEN = len(SCHEME)
|
||||
SRV_SCHEME = "mongodb+srv://"
|
||||
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
|
||||
def _unquoted_percent(s: str) -> bool:
|
||||
"""Check for unescaped percent signs.
|
||||
|
||||
:param s: A string. `s` can have things like '%25', '%2525',
|
||||
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
||||
"""
|
||||
for i in range(len(s)):
|
||||
if s[i] == "%":
|
||||
sub = s[i : i + 3]
|
||||
# If unquoting yields the same string this means there was an
|
||||
# unquoted %.
|
||||
if unquote_plus(sub) == sub:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||
"""Validates the format of user information in a MongoDB URI.
|
||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||
"]", "@") as per RFC 3986 must be escaped.
|
||||
|
||||
Returns a 2-tuple containing the unescaped username followed
|
||||
by the unescaped password.
|
||||
|
||||
:param userinfo: A string of the form <username>:<password>
|
||||
"""
|
||||
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
||||
raise InvalidURI(
|
||||
"Username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
|
||||
user, _, passwd = userinfo.partition(":")
|
||||
# No password is expected with GSSAPI authentication.
|
||||
if not user:
|
||||
raise InvalidURI("The empty string is not valid username")
|
||||
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
|
||||
def parse_ipv6_literal_host(
|
||||
entity: str, default_port: Optional[int]
|
||||
) -> tuple[str, Optional[Union[str, int]]]:
|
||||
"""Validates an IPv6 literal host:port string.
|
||||
|
||||
Returns a 2-tuple of IPv6 literal followed by port where
|
||||
port is default_port if it wasn't specified in entity.
|
||||
|
||||
:param entity: A string that represents an IPv6 literal enclosed
|
||||
in braces (e.g. '[::1]' or '[::1]:27017').
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
if entity.find("]") == -1:
|
||||
raise ValueError(
|
||||
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
||||
)
|
||||
i = entity.find("]:")
|
||||
if i == -1:
|
||||
return entity[1:-1], default_port
|
||||
return entity[1:i], entity[i + 2 :]
|
||||
|
||||
|
||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
||||
"""Validates a host string
|
||||
|
||||
Returns a 2-tuple of host followed by port where port is default_port
|
||||
if it wasn't specified in the string.
|
||||
|
||||
:param entity: A host or host:port string where host could be a
|
||||
hostname or IP address.
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
host = entity
|
||||
port: Optional[Union[str, int]] = default_port
|
||||
if entity[0] == "[":
|
||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||
elif entity.endswith(".sock"):
|
||||
return entity, default_port
|
||||
elif entity.find(":") != -1:
|
||||
if entity.count(":") > 1:
|
||||
raise ValueError(
|
||||
"Reserved characters such as ':' must be "
|
||||
"escaped according RFC 2396. An IPv6 "
|
||||
"address literal must be enclosed in '[' "
|
||||
"and ']' according to RFC 2732."
|
||||
)
|
||||
host, port = host.split(":", 1)
|
||||
if isinstance(port, str):
|
||||
if not port.isdigit():
|
||||
# Special case check for mistakes like "mongodb://localhost:27017 ".
|
||||
if all(c.isspace() or c.isdigit() for c in port):
|
||||
for c in port:
|
||||
if c.isspace():
|
||||
raise ValueError(f"Port contains whitespace character: {c!r}")
|
||||
|
||||
# A non-digit port indicates that the URI is invalid, likely because the password
|
||||
# or username were not escaped.
|
||||
raise ValueError(
|
||||
"Port contains non-digit characters. Hint: username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
if int(port) > 65535 or int(port) <= 0:
|
||||
raise ValueError("Port must be an integer between 0 and 65535")
|
||||
port = int(port)
|
||||
|
||||
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
||||
# https://tools.ietf.org/html/rfc4343
|
||||
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
||||
# "FOO.com" is in the hello response.
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
# Options whose values are implicitly determined by tlsInsecure.
|
||||
_IMPLICIT_TLSINSECURE_OPTS = {
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
}
|
||||
|
||||
|
||||
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
||||
"""Helper method for split_options which creates the options dict.
|
||||
Also handles the creation of a list for the URI tag_sets/
|
||||
readpreferencetags portion, and the use of a unicode options string.
|
||||
"""
|
||||
options = _CaseInsensitiveDictionary()
|
||||
for uriopt in opts.split(delim):
|
||||
key, value = uriopt.split("=")
|
||||
if key.lower() == "readpreferencetags":
|
||||
options.setdefault(key, []).append(value)
|
||||
else:
|
||||
if key in options:
|
||||
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
||||
if key.lower() == "authmechanismproperties":
|
||||
val = value
|
||||
else:
|
||||
val = unquote_plus(value)
|
||||
options[key] = val
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Raise appropriate errors when conflicting TLS options are present in
|
||||
the options dictionary.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Implicitly defined options must not be explicitly specified.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
if opt in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
||||
)
|
||||
|
||||
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
||||
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
||||
if tlsallowinvalidcerts is not None:
|
||||
if "tlsdisableocspendpointcheck" in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg
|
||||
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
||||
)
|
||||
if tlsallowinvalidcerts is True:
|
||||
options["tlsdisableocspendpointcheck"] = True
|
||||
|
||||
# Handle co-occurence of CRL and OCSP-related options.
|
||||
tlscrlfile = options.get("tlscrlfile")
|
||||
if tlscrlfile is not None:
|
||||
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
||||
if options.get(opt) is True:
|
||||
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
||||
raise InvalidURI(err_msg % (opt,))
|
||||
|
||||
if "ssl" in options and "tls" in options:
|
||||
|
||||
def truth_value(val: Any) -> Any:
|
||||
if val in ("true", "false"):
|
||||
return val == "true"
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
return val
|
||||
|
||||
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
||||
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
||||
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Issue appropriate warnings when deprecated options are present in the
|
||||
options dictionary. Removes deprecated option key, value pairs if the
|
||||
options dictionary is found to also have the renamed option.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
for optname in list(options):
|
||||
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
||||
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
||||
if mode == "renamed":
|
||||
newoptname = message
|
||||
if newoptname in options:
|
||||
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
options.pop(optname)
|
||||
continue
|
||||
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), newoptname),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif mode == "removed":
|
||||
warn_msg = "Option '%s' is deprecated. %s."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), message),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Normalizes option names in the options dictionary by converting them to
|
||||
their internally-used names.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Expand the tlsInsecure option.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
# Implicit options are logically the same as tlsInsecure.
|
||||
options[opt] = tlsinsecure
|
||||
|
||||
for optname in list(options):
|
||||
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
||||
if intname is not None:
|
||||
options[intname] = options.pop(optname)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||
"""Validates and normalizes options passed in a MongoDB URI.
|
||||
|
||||
Returns a new dictionary of validated and normalized options. If warn is
|
||||
False then errors will be thrown for invalid options, otherwise they will
|
||||
be ignored and a warning will be issued.
|
||||
|
||||
:param opts: A dict of MongoDB URI options.
|
||||
:param warn: If ``True`` then warnings will be logged and
|
||||
invalid options will be ignored. Otherwise invalid options will
|
||||
cause errors.
|
||||
"""
|
||||
return get_validated_options(opts, warn)
|
||||
|
||||
|
||||
def split_options(
|
||||
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Takes the options portion of a MongoDB URI, validates each option
|
||||
and returns the options in a dictionary.
|
||||
|
||||
:param opt: A string representing MongoDB URI options.
|
||||
:param validate: If ``True`` (the default), validate and normalize all
|
||||
options.
|
||||
:param warn: If ``False`` (the default), suppress all warnings raised
|
||||
during validation of options.
|
||||
:param normalize: If ``True`` (the default), renames all options to their
|
||||
internally-used names.
|
||||
"""
|
||||
and_idx = opts.find("&")
|
||||
semi_idx = opts.find(";")
|
||||
try:
|
||||
if and_idx >= 0 and semi_idx >= 0:
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators")
|
||||
elif and_idx >= 0:
|
||||
options = _parse_options(opts, "&")
|
||||
elif semi_idx >= 0:
|
||||
options = _parse_options(opts, ";")
|
||||
elif opts.find("=") != -1:
|
||||
options = _parse_options(opts, None)
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs") from None
|
||||
|
||||
options = _handle_security_options(options)
|
||||
|
||||
options = _handle_option_deprecations(options)
|
||||
|
||||
if normalize:
|
||||
options = _normalize_options(options)
|
||||
|
||||
if validate:
|
||||
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
||||
if options.get("authsource") == "":
|
||||
raise InvalidURI("the authSource database cannot be an empty string")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||
splits it into (host, port) tuples. If [:port] isn't present the
|
||||
default_port is used.
|
||||
|
||||
Returns a set of 2-tuples containing the host name (or IP) followed by
|
||||
port number.
|
||||
|
||||
:param hosts: A string of the form host1[:port],host2[:port],...
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host.
|
||||
"""
|
||||
nodes = []
|
||||
for entity in hosts.split(","):
|
||||
if not entity:
|
||||
raise ConfigurationError("Empty host (or extra comma in host list)")
|
||||
port = default_port
|
||||
# Unix socket entities don't have ports
|
||||
if entity.endswith(".sock"):
|
||||
port = None
|
||||
nodes.append(parse_host(entity, port))
|
||||
return nodes
|
||||
|
||||
|
||||
# Prohibited characters in database name. DB names also can't have ".", but for
|
||||
# backward-compat we allow "db.collection" in URI.
|
||||
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
||||
|
||||
_ALLOWED_TXT_OPTS = frozenset(
|
||||
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
||||
)
|
||||
|
||||
|
||||
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||
# Ensure directConnection was not True if there are multiple seeds.
|
||||
if len(nodes) > 1 and options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
||||
|
||||
if options.get("loadbalanced"):
|
||||
if len(nodes) > 1:
|
||||
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
||||
if options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
||||
if options.get("replicaset"):
|
||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||
|
||||
|
||||
def parse_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
connect_timeout: Optional[float] = None,
|
||||
srv_service_name: Optional[str] = None,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Parse and validate a MongoDB URI.
|
||||
|
||||
Returns a dict of the form::
|
||||
|
||||
{
|
||||
'nodelist': <list of (host, port) tuples>,
|
||||
'username': <username> or None,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
to build nodelist and options.
|
||||
|
||||
:param uri: The MongoDB URI to parse.
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host in the URI.
|
||||
:param validate: If ``True`` (the default), validate and
|
||||
normalize all options. Default: ``True``.
|
||||
:param warn: When validating, if ``True`` then will warn
|
||||
the user then ignore any invalid options or values. If ``False``,
|
||||
validation will error when options are unsupported or values are
|
||||
invalid. Default: ``False``.
|
||||
:param normalize: If ``True``, convert names of URI options
|
||||
to their internally-used names. Default: ``True``.
|
||||
:param connect_timeout: The maximum time in milliseconds to
|
||||
wait for a response from the DNS server.
|
||||
:param srv_service_name: A custom SRV service name
|
||||
|
||||
.. versionchanged:: 4.6
|
||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||
supported.
|
||||
|
||||
.. versionchanged:: 3.9
|
||||
Added the ``normalize`` parameter.
|
||||
|
||||
.. versionchanged:: 3.6
|
||||
Added support for mongodb+srv:// URIs.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
Return the original value of the ``readPreference`` MongoDB URI option
|
||||
instead of the validated read preference mode.
|
||||
|
||||
.. versionchanged:: 3.1
|
||||
``warn`` added so invalid options can be ignored.
|
||||
"""
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
elif uri.startswith(SRV_SCHEME):
|
||||
if not _have_dnspython():
|
||||
python_path = sys.executable or "python"
|
||||
raise ConfigurationError(
|
||||
'The "dnspython" module must be '
|
||||
"installed to use mongodb+srv:// URIs. "
|
||||
"To fix this error install pymongo again:\n "
|
||||
"%s -m pip install pymongo>=4.3" % (python_path)
|
||||
)
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
else:
|
||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||
|
||||
if not scheme_free:
|
||||
raise InvalidURI("Must provide at least one hostname or IP")
|
||||
|
||||
user = None
|
||||
passwd = None
|
||||
dbase = None
|
||||
collection = None
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, dbase = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if dbase:
|
||||
dbase = unquote_plus(dbase)
|
||||
if "." in dbase:
|
||||
dbase, collection = dbase.split(".", 1)
|
||||
if _BAD_DB_CHARS.search(dbase):
|
||||
raise InvalidURI('Bad database name "%s"' % dbase)
|
||||
else:
|
||||
dbase = None
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if srv_service_name is None:
|
||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||
if "@" in host_part:
|
||||
userinfo, _, hosts = host_part.rpartition("@")
|
||||
user, passwd = parse_userinfo(userinfo)
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
if "/" in hosts:
|
||||
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
fqdn = None
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
if options.get("directConnection"):
|
||||
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
if len(nodes) != 1:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
||||
fqdn, port = nodes[0]
|
||||
if port is not None:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
||||
|
||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||
# argument overrides the same option passed in the connection string.
|
||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||
raise ConfigurationError(
|
||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||
)
|
||||
for opt, val in parsed_dns_options.items():
|
||||
if opt not in options:
|
||||
options[opt] = val
|
||||
if options.get("loadBalanced") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||
if options.get("replicaSet") and srv_max_hosts:
|
||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||
if "tls" not in options and "ssl" not in options:
|
||||
options["tls"] = True if validate else "true"
|
||||
elif not is_srv and options.get("srvServiceName") is not None:
|
||||
raise ConfigurationError(
|
||||
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
elif not is_srv and srv_max_hosts:
|
||||
raise ConfigurationError(
|
||||
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"username": user,
|
||||
"password": passwd,
|
||||
"database": dbase,
|
||||
"collection": collection,
|
||||
"options": options,
|
||||
"fqdn": fqdn,
|
||||
}
|
||||
|
||||
|
||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||
"""Parse KMS TLS connection options."""
|
||||
if not kms_tls_options:
|
||||
return {}
|
||||
if not isinstance(kms_tls_options, dict):
|
||||
raise TypeError("kms_tls_options must be a dict")
|
||||
contexts = {}
|
||||
for provider, options in kms_tls_options.items():
|
||||
if not isinstance(options, dict):
|
||||
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
||||
options.setdefault("tls", True)
|
||||
opts = _CaseInsensitiveDictionary(options)
|
||||
opts = _handle_security_options(opts)
|
||||
opts = _normalize_options(opts)
|
||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||
if ssl_context is None:
|
||||
raise ConfigurationError("TLS is required for KMS providers")
|
||||
if allow_invalid_hostnames:
|
||||
raise ConfigurationError("Insecure TLS options prohibited")
|
||||
|
||||
for n in [
|
||||
"tlsInsecure",
|
||||
"tlsAllowInvalidCertificates",
|
||||
"tlsAllowInvalidHostnames",
|
||||
"tlsDisableCertificateRevocationCheck",
|
||||
]:
|
||||
if n in opts:
|
||||
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
||||
contexts[provider] = ssl_context
|
||||
return contexts
|
||||
|
||||
from pymongo.errors import InvalidURI
|
||||
from pymongo.synchronous.uri_parser import * # noqa: F403
|
||||
from pymongo.synchronous.uri_parser import __doc__ as original_doc
|
||||
from pymongo.uri_parser_shared import * # noqa: F403
|
||||
|
||||
__doc__ = original_doc
|
||||
__all__ = [ # noqa: F405
|
||||
"parse_userinfo",
|
||||
"parse_ipv6_literal_host",
|
||||
"parse_host",
|
||||
"validate_options",
|
||||
"split_options",
|
||||
"split_hosts",
|
||||
"parse_uri",
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pprint
|
||||
|
||||
try:
|
||||
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
|
||||
pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203
|
||||
except InvalidURI as exc:
|
||||
print(exc) # noqa: T201
|
||||
sys.exit(0)
|
||||
|
||||
549
pymongo/uri_parser_shared.py
Normal file
549
pymongo/uri_parser_shared.py
Normal file
@ -0,0 +1,549 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
|
||||
"""Tools to parse and validate a MongoDB URI.
|
||||
|
||||
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sized,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from pymongo.asynchronous.srv_resolver import _have_dnspython
|
||||
from pymongo.client_options import _parse_ssl_options
|
||||
from pymongo.common import (
|
||||
INTERNAL_URI_OPTION_NAME_MAP,
|
||||
URI_OPTIONS_DEPRECATION_MAP,
|
||||
_CaseInsensitiveDictionary,
|
||||
get_validated_options,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.typings import _Address
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.pyopenssl_context import SSLContext
|
||||
|
||||
SCHEME = "mongodb://"
|
||||
SCHEME_LEN = len(SCHEME)
|
||||
SRV_SCHEME = "mongodb+srv://"
|
||||
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
||||
DEFAULT_PORT = 27017
|
||||
|
||||
|
||||
def _unquoted_percent(s: str) -> bool:
|
||||
"""Check for unescaped percent signs.
|
||||
|
||||
:param s: A string. `s` can have things like '%25', '%2525',
|
||||
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
||||
"""
|
||||
for i in range(len(s)):
|
||||
if s[i] == "%":
|
||||
sub = s[i : i + 3]
|
||||
# If unquoting yields the same string this means there was an
|
||||
# unquoted %.
|
||||
if unquote_plus(sub) == sub:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||
"""Validates the format of user information in a MongoDB URI.
|
||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||
"]", "@") as per RFC 3986 must be escaped.
|
||||
|
||||
Returns a 2-tuple containing the unescaped username followed
|
||||
by the unescaped password.
|
||||
|
||||
:param userinfo: A string of the form <username>:<password>
|
||||
"""
|
||||
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
||||
raise InvalidURI(
|
||||
"Username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
|
||||
user, _, passwd = userinfo.partition(":")
|
||||
# No password is expected with GSSAPI authentication.
|
||||
if not user:
|
||||
raise InvalidURI("The empty string is not valid username")
|
||||
|
||||
return unquote_plus(user), unquote_plus(passwd)
|
||||
|
||||
|
||||
def parse_ipv6_literal_host(
|
||||
entity: str, default_port: Optional[int]
|
||||
) -> tuple[str, Optional[Union[str, int]]]:
|
||||
"""Validates an IPv6 literal host:port string.
|
||||
|
||||
Returns a 2-tuple of IPv6 literal followed by port where
|
||||
port is default_port if it wasn't specified in entity.
|
||||
|
||||
:param entity: A string that represents an IPv6 literal enclosed
|
||||
in braces (e.g. '[::1]' or '[::1]:27017').
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
if entity.find("]") == -1:
|
||||
raise ValueError(
|
||||
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
||||
)
|
||||
i = entity.find("]:")
|
||||
if i == -1:
|
||||
return entity[1:-1], default_port
|
||||
return entity[1:i], entity[i + 2 :]
|
||||
|
||||
|
||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
||||
"""Validates a host string
|
||||
|
||||
Returns a 2-tuple of host followed by port where port is default_port
|
||||
if it wasn't specified in the string.
|
||||
|
||||
:param entity: A host or host:port string where host could be a
|
||||
hostname or IP address.
|
||||
:param default_port: The port number to use when one wasn't
|
||||
specified in entity.
|
||||
"""
|
||||
host = entity
|
||||
port: Optional[Union[str, int]] = default_port
|
||||
if entity[0] == "[":
|
||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||
elif entity.endswith(".sock"):
|
||||
return entity, default_port
|
||||
elif entity.find(":") != -1:
|
||||
if entity.count(":") > 1:
|
||||
raise ValueError(
|
||||
"Reserved characters such as ':' must be "
|
||||
"escaped according RFC 2396. An IPv6 "
|
||||
"address literal must be enclosed in '[' "
|
||||
"and ']' according to RFC 2732."
|
||||
)
|
||||
host, port = host.split(":", 1)
|
||||
if isinstance(port, str):
|
||||
if not port.isdigit():
|
||||
# Special case check for mistakes like "mongodb://localhost:27017 ".
|
||||
if all(c.isspace() or c.isdigit() for c in port):
|
||||
for c in port:
|
||||
if c.isspace():
|
||||
raise ValueError(f"Port contains whitespace character: {c!r}")
|
||||
|
||||
# A non-digit port indicates that the URI is invalid, likely because the password
|
||||
# or username were not escaped.
|
||||
raise ValueError(
|
||||
"Port contains non-digit characters. Hint: username and password must be escaped according to "
|
||||
"RFC 3986, use urllib.parse.quote_plus"
|
||||
)
|
||||
if int(port) > 65535 or int(port) <= 0:
|
||||
raise ValueError("Port must be an integer between 0 and 65535")
|
||||
port = int(port)
|
||||
|
||||
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
||||
# https://tools.ietf.org/html/rfc4343
|
||||
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
||||
# "FOO.com" is in the hello response.
|
||||
return host.lower(), port
|
||||
|
||||
|
||||
# Options whose values are implicitly determined by tlsInsecure.
|
||||
_IMPLICIT_TLSINSECURE_OPTS = {
|
||||
"tlsallowinvalidcertificates",
|
||||
"tlsallowinvalidhostnames",
|
||||
"tlsdisableocspendpointcheck",
|
||||
}
|
||||
|
||||
|
||||
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
||||
"""Helper method for split_options which creates the options dict.
|
||||
Also handles the creation of a list for the URI tag_sets/
|
||||
readpreferencetags portion, and the use of a unicode options string.
|
||||
"""
|
||||
options = _CaseInsensitiveDictionary()
|
||||
for uriopt in opts.split(delim):
|
||||
key, value = uriopt.split("=")
|
||||
if key.lower() == "readpreferencetags":
|
||||
options.setdefault(key, []).append(value)
|
||||
else:
|
||||
if key in options:
|
||||
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
||||
if key.lower() == "authmechanismproperties":
|
||||
val = value
|
||||
else:
|
||||
val = unquote_plus(value)
|
||||
options[key] = val
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Raise appropriate errors when conflicting TLS options are present in
|
||||
the options dictionary.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Implicitly defined options must not be explicitly specified.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
if opt in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
||||
)
|
||||
|
||||
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
||||
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
||||
if tlsallowinvalidcerts is not None:
|
||||
if "tlsdisableocspendpointcheck" in options:
|
||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||
raise InvalidURI(
|
||||
err_msg
|
||||
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
||||
)
|
||||
if tlsallowinvalidcerts is True:
|
||||
options["tlsdisableocspendpointcheck"] = True
|
||||
|
||||
# Handle co-occurence of CRL and OCSP-related options.
|
||||
tlscrlfile = options.get("tlscrlfile")
|
||||
if tlscrlfile is not None:
|
||||
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
||||
if options.get(opt) is True:
|
||||
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
||||
raise InvalidURI(err_msg % (opt,))
|
||||
|
||||
if "ssl" in options and "tls" in options:
|
||||
|
||||
def truth_value(val: Any) -> Any:
|
||||
if val in ("true", "false"):
|
||||
return val == "true"
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
return val
|
||||
|
||||
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
||||
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
||||
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Issue appropriate warnings when deprecated options are present in the
|
||||
options dictionary. Removes deprecated option key, value pairs if the
|
||||
options dictionary is found to also have the renamed option.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
for optname in list(options):
|
||||
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
||||
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
||||
if mode == "renamed":
|
||||
newoptname = message
|
||||
if newoptname in options:
|
||||
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
options.pop(optname)
|
||||
continue
|
||||
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), newoptname),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif mode == "removed":
|
||||
warn_msg = "Option '%s' is deprecated. %s."
|
||||
warnings.warn(
|
||||
warn_msg % (options.cased_key(optname), message),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||
"""Normalizes option names in the options dictionary by converting them to
|
||||
their internally-used names.
|
||||
|
||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||
MongoDB URI options.
|
||||
"""
|
||||
# Expand the tlsInsecure option.
|
||||
tlsinsecure = options.get("tlsinsecure")
|
||||
if tlsinsecure is not None:
|
||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||
# Implicit options are logically the same as tlsInsecure.
|
||||
options[opt] = tlsinsecure
|
||||
|
||||
for optname in list(options):
|
||||
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
||||
if intname is not None:
|
||||
options[intname] = options.pop(optname)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||
"""Validates and normalizes options passed in a MongoDB URI.
|
||||
|
||||
Returns a new dictionary of validated and normalized options. If warn is
|
||||
False then errors will be thrown for invalid options, otherwise they will
|
||||
be ignored and a warning will be issued.
|
||||
|
||||
:param opts: A dict of MongoDB URI options.
|
||||
:param warn: If ``True`` then warnings will be logged and
|
||||
invalid options will be ignored. Otherwise invalid options will
|
||||
cause errors.
|
||||
"""
|
||||
return get_validated_options(opts, warn)
|
||||
|
||||
|
||||
def split_options(
|
||||
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
||||
) -> MutableMapping[str, Any]:
|
||||
"""Takes the options portion of a MongoDB URI, validates each option
|
||||
and returns the options in a dictionary.
|
||||
|
||||
:param opt: A string representing MongoDB URI options.
|
||||
:param validate: If ``True`` (the default), validate and normalize all
|
||||
options.
|
||||
:param warn: If ``False`` (the default), suppress all warnings raised
|
||||
during validation of options.
|
||||
:param normalize: If ``True`` (the default), renames all options to their
|
||||
internally-used names.
|
||||
"""
|
||||
and_idx = opts.find("&")
|
||||
semi_idx = opts.find(";")
|
||||
try:
|
||||
if and_idx >= 0 and semi_idx >= 0:
|
||||
raise InvalidURI("Can not mix '&' and ';' for option separators")
|
||||
elif and_idx >= 0:
|
||||
options = _parse_options(opts, "&")
|
||||
elif semi_idx >= 0:
|
||||
options = _parse_options(opts, ";")
|
||||
elif opts.find("=") != -1:
|
||||
options = _parse_options(opts, None)
|
||||
else:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise InvalidURI("MongoDB URI options are key=value pairs") from None
|
||||
|
||||
options = _handle_security_options(options)
|
||||
|
||||
options = _handle_option_deprecations(options)
|
||||
|
||||
if normalize:
|
||||
options = _normalize_options(options)
|
||||
|
||||
if validate:
|
||||
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
||||
if options.get("authsource") == "":
|
||||
raise InvalidURI("the authSource database cannot be an empty string")
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||
splits it into (host, port) tuples. If [:port] isn't present the
|
||||
default_port is used.
|
||||
|
||||
Returns a set of 2-tuples containing the host name (or IP) followed by
|
||||
port number.
|
||||
|
||||
:param hosts: A string of the form host1[:port],host2[:port],...
|
||||
:param default_port: The port number to use when one wasn't specified
|
||||
for a host.
|
||||
"""
|
||||
nodes = []
|
||||
for entity in hosts.split(","):
|
||||
if not entity:
|
||||
raise ConfigurationError("Empty host (or extra comma in host list)")
|
||||
port = default_port
|
||||
# Unix socket entities don't have ports
|
||||
if entity.endswith(".sock"):
|
||||
port = None
|
||||
nodes.append(parse_host(entity, port))
|
||||
return nodes
|
||||
|
||||
|
||||
# Prohibited characters in database name. DB names also can't have ".", but for
|
||||
# backward-compat we allow "db.collection" in URI.
|
||||
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
||||
|
||||
_ALLOWED_TXT_OPTS = frozenset(
|
||||
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
||||
)
|
||||
|
||||
|
||||
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||
# Ensure directConnection was not True if there are multiple seeds.
|
||||
if len(nodes) > 1 and options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
||||
|
||||
if options.get("loadbalanced"):
|
||||
if len(nodes) > 1:
|
||||
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
||||
if options.get("directconnection"):
|
||||
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
||||
if options.get("replicaset"):
|
||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||
|
||||
|
||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||
"""Parse KMS TLS connection options."""
|
||||
if not kms_tls_options:
|
||||
return {}
|
||||
if not isinstance(kms_tls_options, dict):
|
||||
raise TypeError("kms_tls_options must be a dict")
|
||||
contexts = {}
|
||||
for provider, options in kms_tls_options.items():
|
||||
if not isinstance(options, dict):
|
||||
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
||||
options.setdefault("tls", True)
|
||||
opts = _CaseInsensitiveDictionary(options)
|
||||
opts = _handle_security_options(opts)
|
||||
opts = _normalize_options(opts)
|
||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||
if ssl_context is None:
|
||||
raise ConfigurationError("TLS is required for KMS providers")
|
||||
if allow_invalid_hostnames:
|
||||
raise ConfigurationError("Insecure TLS options prohibited")
|
||||
|
||||
for n in [
|
||||
"tlsInsecure",
|
||||
"tlsAllowInvalidCertificates",
|
||||
"tlsAllowInvalidHostnames",
|
||||
"tlsDisableCertificateRevocationCheck",
|
||||
]:
|
||||
if n in opts:
|
||||
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
||||
contexts[provider] = ssl_context
|
||||
return contexts
|
||||
|
||||
|
||||
def _validate_uri(
|
||||
uri: str,
|
||||
default_port: Optional[int] = DEFAULT_PORT,
|
||||
validate: bool = True,
|
||||
warn: bool = False,
|
||||
normalize: bool = True,
|
||||
srv_max_hosts: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
if uri.startswith(SCHEME):
|
||||
is_srv = False
|
||||
scheme_free = uri[SCHEME_LEN:]
|
||||
elif uri.startswith(SRV_SCHEME):
|
||||
if not _have_dnspython():
|
||||
python_path = sys.executable or "python"
|
||||
raise ConfigurationError(
|
||||
'The "dnspython" module must be '
|
||||
"installed to use mongodb+srv:// URIs. "
|
||||
"To fix this error install pymongo again:\n "
|
||||
"%s -m pip install pymongo>=4.3" % (python_path)
|
||||
)
|
||||
is_srv = True
|
||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||
else:
|
||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||
|
||||
if not scheme_free:
|
||||
raise InvalidURI("Must provide at least one hostname or IP")
|
||||
|
||||
user = None
|
||||
passwd = None
|
||||
dbase = None
|
||||
collection = None
|
||||
options = _CaseInsensitiveDictionary()
|
||||
|
||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||
if "/" in host_plus_db_part:
|
||||
host_part, _, dbase = host_plus_db_part.partition("/")
|
||||
else:
|
||||
host_part = host_plus_db_part
|
||||
|
||||
if dbase:
|
||||
dbase = unquote_plus(dbase)
|
||||
if "." in dbase:
|
||||
dbase, collection = dbase.split(".", 1)
|
||||
if _BAD_DB_CHARS.search(dbase):
|
||||
raise InvalidURI('Bad database name "%s"' % dbase)
|
||||
else:
|
||||
dbase = None
|
||||
|
||||
if opts:
|
||||
options.update(split_options(opts, validate, warn, normalize))
|
||||
if "@" in host_part:
|
||||
userinfo, _, hosts = host_part.rpartition("@")
|
||||
user, passwd = parse_userinfo(userinfo)
|
||||
else:
|
||||
hosts = host_part
|
||||
|
||||
if "/" in hosts:
|
||||
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
fqdn = None
|
||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||
if is_srv:
|
||||
if options.get("directConnection"):
|
||||
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
if len(nodes) != 1:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
||||
fqdn, port = nodes[0]
|
||||
if port is not None:
|
||||
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
||||
elif not is_srv and options.get("srvServiceName") is not None:
|
||||
raise ConfigurationError(
|
||||
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
elif not is_srv and srv_max_hosts:
|
||||
raise ConfigurationError(
|
||||
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
||||
)
|
||||
else:
|
||||
nodes = split_hosts(hosts, default_port=default_port)
|
||||
|
||||
_check_options(nodes, options)
|
||||
|
||||
return {
|
||||
"nodelist": nodes,
|
||||
"username": user,
|
||||
"password": passwd,
|
||||
"database": dbase,
|
||||
"collection": collection,
|
||||
"options": options,
|
||||
"fqdn": fqdn,
|
||||
}
|
||||
@ -32,7 +32,7 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
@ -32,7 +32,7 @@ import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.asynchronous.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
@ -1027,7 +1027,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
res = await parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
@ -1058,7 +1058,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
|
||||
auth_mech = kwargs.get("authMechanism", "")
|
||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||
# Only add the default username or password if one is not provided.
|
||||
res = parse_uri(uri)
|
||||
res = await parse_uri(uri)
|
||||
if (
|
||||
not res["username"]
|
||||
and not res["password"]
|
||||
|
||||
@ -47,7 +47,7 @@ from bson.son import SON
|
||||
from pymongo import common, message
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
@ -512,13 +512,13 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
|
||||
async def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
||||
# Patch the resolver.
|
||||
from pymongo.srv_resolver import _resolve
|
||||
from pymongo.asynchronous.srv_resolver import _resolve
|
||||
|
||||
patched_resolver = FunctionCallRecorder(_resolve)
|
||||
pymongo.srv_resolver._resolve = patched_resolver
|
||||
pymongo.asynchronous.srv_resolver._resolve = patched_resolver
|
||||
|
||||
def reset_resolver():
|
||||
pymongo.srv_resolver._resolve = _resolve
|
||||
pymongo.asynchronous.srv_resolver._resolve = _resolve
|
||||
|
||||
self.addCleanup(reset_resolver)
|
||||
|
||||
@ -607,7 +607,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
with self.assertRaisesRegex(ConfigurationError, expected):
|
||||
AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
||||
|
||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
||||
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
|
||||
def test_detected_environment_logging(self, mock_get_hosts):
|
||||
normal_hosts = [
|
||||
"normal.host.com",
|
||||
@ -629,7 +629,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
||||
self.assertEqual(len(logs), 7)
|
||||
|
||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
||||
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
|
||||
async def test_detected_environment_warning(self, mock_get_hosts):
|
||||
with self._caplog.at_level(logging.WARN):
|
||||
normal_hosts = [
|
||||
@ -933,6 +933,15 @@ class TestClient(AsyncIntegrationTest):
|
||||
async with eval(the_repr) as client_two:
|
||||
self.assertEqual(client_two, client)
|
||||
|
||||
async def test_repr_srv_host(self):
|
||||
client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
|
||||
# before srv resolution
|
||||
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
|
||||
await client.aconnect()
|
||||
# after srv resolution
|
||||
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
|
||||
await client.close()
|
||||
|
||||
async def test_getters(self):
|
||||
await async_wait_until(
|
||||
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
||||
@ -1911,28 +1920,37 @@ class TestClient(AsyncIntegrationTest):
|
||||
srvServiceName="customname",
|
||||
connect=False,
|
||||
)
|
||||
await client.aconnect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
await client.close()
|
||||
client = AsyncMongoClient(
|
||||
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
||||
"/?srvServiceName=shouldbeoverriden",
|
||||
srvServiceName="customname",
|
||||
connect=False,
|
||||
)
|
||||
await client.aconnect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
await client.close()
|
||||
client = AsyncMongoClient(
|
||||
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
||||
connect=False,
|
||||
)
|
||||
await client.aconnect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
await client.close()
|
||||
|
||||
async def test_srv_max_hosts_kwarg(self):
|
||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
||||
await client.aconnect()
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
||||
await client.aconnect()
|
||||
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
||||
client = self.simple_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
||||
)
|
||||
await client.aconnect()
|
||||
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
|
||||
@ -54,6 +54,7 @@ from bson import Timestamp, json_util
|
||||
from pymongo import common, monitoring
|
||||
from pymongo.asynchronous.settings import TopologySettings
|
||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.asynchronous.uri_parser import parse_uri
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
ConfigurationError,
|
||||
@ -66,7 +67,6 @@ from pymongo.helpers_shared import _check_command_response, _check_write_command
|
||||
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
|
||||
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
async def create_mock_topology(uri, monitor_class=DummyMonitor):
|
||||
parsed_uri = parse_uri(uri)
|
||||
parsed_uri = await parse_uri(uri)
|
||||
replica_set_name = None
|
||||
direct_connection = None
|
||||
load_balanced = None
|
||||
|
||||
@ -31,9 +31,10 @@ from test.asynchronous import (
|
||||
)
|
||||
from test.utils_shared import async_wait_until
|
||||
|
||||
from pymongo.asynchronous.uri_parser import parse_uri
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
from pymongo.uri_parser_shared import split_hosts
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -109,7 +110,7 @@ def create_test(test_case):
|
||||
hosts = frozenset(split_hosts(",".join(hosts)))
|
||||
|
||||
if seeds or num_seeds:
|
||||
result = parse_uri(uri, validate=True)
|
||||
result = await parse_uri(uri, validate=True)
|
||||
if seeds is not None:
|
||||
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
|
||||
if num_seeds is not None:
|
||||
@ -161,7 +162,7 @@ def create_test(test_case):
|
||||
# and re-run these assertions.
|
||||
else:
|
||||
try:
|
||||
parse_uri(uri)
|
||||
await parse_uri(uri)
|
||||
except (ConfigurationError, ValueError):
|
||||
pass
|
||||
else:
|
||||
@ -185,35 +186,24 @@ create_tests(TestDNSSharded)
|
||||
|
||||
class TestParsingErrors(AsyncPyMongoTestCase):
|
||||
async def test_invalid_host(self):
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb.com is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb.com",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://127.0.0.1",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://[::1]",
|
||||
)
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb")
|
||||
await client.aconnect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||
await client.aconnect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||
await client.aconnect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://[::1]")
|
||||
await client.aconnect()
|
||||
|
||||
|
||||
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
|
||||
async def test_connect_case_insensitive(self):
|
||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
await client.aconnect()
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
|
||||
@ -28,8 +28,8 @@ from test.asynchronous.utils import async_wait_until
|
||||
|
||||
import pymongo
|
||||
from pymongo import common
|
||||
from pymongo.asynchronous.srv_resolver import _have_dnspython
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.srv_resolver import _have_dnspython
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -54,14 +54,16 @@ class SrvPollingKnobs:
|
||||
|
||||
def enable(self):
|
||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
self.old_dns_resolver_response = (
|
||||
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
)
|
||||
|
||||
if self.min_srv_rescan_interval is not None:
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||
|
||||
def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||
async def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||
assert self.old_dns_resolver_response is not None
|
||||
nodes, ttl = self.old_dns_resolver_response(resolver)
|
||||
nodes, ttl = await self.old_dns_resolver_response(resolver)
|
||||
if self.nodelist_callback is not None:
|
||||
nodes = self.nodelist_callback()
|
||||
if self.ttl_time is not None:
|
||||
@ -74,14 +76,14 @@ class SrvPollingKnobs:
|
||||
else:
|
||||
patch_func = mock_get_hosts_and_min_ttl
|
||||
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
def disable(self):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
self.old_dns_resolver_response
|
||||
)
|
||||
|
||||
@ -134,7 +136,10 @@ class TestSrvPolling(AsyncPyMongoTestCase):
|
||||
|
||||
def predicate():
|
||||
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
||||
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
|
||||
return (
|
||||
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
|
||||
>= 1
|
||||
)
|
||||
return False
|
||||
|
||||
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
||||
@ -144,7 +149,7 @@ class TestSrvPolling(AsyncPyMongoTestCase):
|
||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||
self.assertGreaterEqual(
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
1,
|
||||
"resolver was never called",
|
||||
)
|
||||
|
||||
@ -32,7 +32,7 @@ except ImportError:
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
pytestmark = pytest.mark.auth_aws
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from pymongo.synchronous.auth_oidc import (
|
||||
OIDCCallbackResult,
|
||||
_get_authenticator,
|
||||
)
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
ROOT = Path(__file__).parent.parent.resolve()
|
||||
TEST_PATH = ROOT / "auth" / "unified"
|
||||
|
||||
@ -47,7 +47,7 @@ from bson.son import SON
|
||||
from pymongo import common, message
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
if HAVE_SSL:
|
||||
import ssl
|
||||
|
||||
@ -505,13 +505,13 @@ class ClientUnitTest(UnitTest):
|
||||
|
||||
def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
||||
# Patch the resolver.
|
||||
from pymongo.srv_resolver import _resolve
|
||||
from pymongo.synchronous.srv_resolver import _resolve
|
||||
|
||||
patched_resolver = FunctionCallRecorder(_resolve)
|
||||
pymongo.srv_resolver._resolve = patched_resolver
|
||||
pymongo.synchronous.srv_resolver._resolve = patched_resolver
|
||||
|
||||
def reset_resolver():
|
||||
pymongo.srv_resolver._resolve = _resolve
|
||||
pymongo.synchronous.srv_resolver._resolve = _resolve
|
||||
|
||||
self.addCleanup(reset_resolver)
|
||||
|
||||
@ -600,7 +600,7 @@ class ClientUnitTest(UnitTest):
|
||||
with self.assertRaisesRegex(ConfigurationError, expected):
|
||||
MongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
||||
|
||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
||||
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
|
||||
def test_detected_environment_logging(self, mock_get_hosts):
|
||||
normal_hosts = [
|
||||
"normal.host.com",
|
||||
@ -622,7 +622,7 @@ class ClientUnitTest(UnitTest):
|
||||
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
||||
self.assertEqual(len(logs), 7)
|
||||
|
||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
||||
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
|
||||
def test_detected_environment_warning(self, mock_get_hosts):
|
||||
with self._caplog.at_level(logging.WARN):
|
||||
normal_hosts = [
|
||||
@ -908,6 +908,15 @@ class TestClient(IntegrationTest):
|
||||
with eval(the_repr) as client_two:
|
||||
self.assertEqual(client_two, client)
|
||||
|
||||
def test_repr_srv_host(self):
|
||||
client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
|
||||
# before srv resolution
|
||||
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
|
||||
client._connect()
|
||||
# after srv resolution
|
||||
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
|
||||
client.close()
|
||||
|
||||
def test_getters(self):
|
||||
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
||||
|
||||
@ -1868,28 +1877,37 @@ class TestClient(IntegrationTest):
|
||||
srvServiceName="customname",
|
||||
connect=False,
|
||||
)
|
||||
client._connect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
client.close()
|
||||
client = MongoClient(
|
||||
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
||||
"/?srvServiceName=shouldbeoverriden",
|
||||
srvServiceName="customname",
|
||||
connect=False,
|
||||
)
|
||||
client._connect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
client.close()
|
||||
client = MongoClient(
|
||||
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
||||
connect=False,
|
||||
)
|
||||
client._connect()
|
||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||
client.close()
|
||||
|
||||
def test_srv_max_hosts_kwarg(self):
|
||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
||||
client._connect()
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
||||
client._connect()
|
||||
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
||||
client = self.simple_client(
|
||||
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
||||
)
|
||||
client._connect()
|
||||
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
||||
|
||||
@unittest.skipIf(
|
||||
|
||||
@ -65,8 +65,8 @@ from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStarte
|
||||
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
||||
from pymongo.synchronous.settings import TopologySettings
|
||||
from pymongo.synchronous.topology import Topology, _ErrorContext
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
@ -33,7 +33,8 @@ from test.utils_shared import wait_until
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
from pymongo.uri_parser_shared import split_hosts
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -183,35 +184,24 @@ create_tests(TestDNSSharded)
|
||||
|
||||
class TestParsingErrors(PyMongoTestCase):
|
||||
def test_invalid_host(self):
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: mongodb.com is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://mongodb.com",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://127.0.0.1",
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
ConfigurationError,
|
||||
"Invalid URI host: an IP address is not",
|
||||
self.simple_client,
|
||||
"mongodb+srv://[::1]",
|
||||
)
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb")
|
||||
client._connect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||
client._connect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||
client._connect()
|
||||
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||
client = self.simple_client("mongodb+srv://[::1]")
|
||||
client._connect()
|
||||
|
||||
|
||||
class TestCaseInsensitive(IntegrationTest):
|
||||
def test_connect_case_insensitive(self):
|
||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||
client._connect()
|
||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from test.utils import wait_until
|
||||
import pymongo
|
||||
from pymongo import common
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.srv_resolver import _have_dnspython
|
||||
from pymongo.synchronous.srv_resolver import _have_dnspython
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -54,7 +54,9 @@ class SrvPollingKnobs:
|
||||
|
||||
def enable(self):
|
||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
self.old_dns_resolver_response = (
|
||||
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
)
|
||||
|
||||
if self.min_srv_rescan_interval is not None:
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||
@ -74,14 +76,14 @@ class SrvPollingKnobs:
|
||||
else:
|
||||
patch_func = mock_get_hosts_and_min_ttl
|
||||
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
def disable(self):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||
self.old_dns_resolver_response
|
||||
)
|
||||
|
||||
@ -134,7 +136,10 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
|
||||
def predicate():
|
||||
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
||||
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
|
||||
return (
|
||||
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
|
||||
>= 1
|
||||
)
|
||||
return False
|
||||
|
||||
wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
||||
@ -144,7 +149,7 @@ class TestSrvPolling(PyMongoTestCase):
|
||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||
self.assertGreaterEqual(
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||
1,
|
||||
"resolver was never called",
|
||||
)
|
||||
|
||||
@ -28,8 +28,8 @@ from test import unittest
|
||||
from bson.binary import JAVA_LEGACY
|
||||
from pymongo import ReadPreference
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.uri_parser import (
|
||||
parse_uri,
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
from pymongo.uri_parser_shared import (
|
||||
parse_userinfo,
|
||||
split_hosts,
|
||||
split_options,
|
||||
|
||||
@ -29,7 +29,7 @@ from test.helpers import clear_warning_registry
|
||||
|
||||
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
|
||||
from pymongo.compression_support import _have_snappy
|
||||
from pymongo.uri_parser import parse_uri
|
||||
from pymongo.synchronous.uri_parser import parse_uri
|
||||
|
||||
CONN_STRING_TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test")
|
||||
|
||||
@ -127,6 +127,7 @@ replacements = {
|
||||
"async_create_barrier": "create_barrier",
|
||||
"async_barrier_wait": "barrier_wait",
|
||||
"async_joinall": "joinall",
|
||||
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
|
||||
}
|
||||
|
||||
docstring_replacements: dict[tuple[str, str], str] = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user