PYTHON-3636 AsyncMongoClient should perform SRV resolution lazily (#2191)
Co-authored-by: Noah Stapp <noah@noahstapp.com> Co-authored-by: Shane Harvey <shane.harvey@mongodb.com>
This commit is contained in:
parent
38ceda4c09
commit
eea8a37257
@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including:
|
|||||||
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
|
- Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to
|
||||||
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
|
:class:`~pymongo.encryption_options.AutoEncryptionOpts`.
|
||||||
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
|
- Support for $lookup in CSFLE and QE supported on MongoDB 8.1+.
|
||||||
|
- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation.
|
||||||
|
To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected.
|
||||||
- Added index hinting support to the
|
- Added index hinting support to the
|
||||||
:meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and
|
:meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and
|
||||||
:meth:`~pymongo.collection.Collection.distinct` commands.
|
:meth:`~pymongo.collection.Collection.distinct` commands.
|
||||||
|
|||||||
@ -87,7 +87,7 @@ from pymongo.read_concern import ReadConcern
|
|||||||
from pymongo.results import BulkWriteResult, DeleteResult
|
from pymongo.results import BulkWriteResult, DeleteResult
|
||||||
from pymongo.ssl_support import get_ssl_context
|
from pymongo.ssl_support import get_ssl_context
|
||||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||||
from pymongo.uri_parser import parse_host
|
from pymongo.uri_parser_shared import parse_host
|
||||||
from pymongo.write_concern import WriteConcern
|
from pymongo.write_concern import WriteConcern
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -44,6 +44,7 @@ from typing import (
|
|||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Generic,
|
Generic,
|
||||||
@ -60,8 +61,8 @@ from typing import (
|
|||||||
|
|
||||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||||
from bson.timestamp import Timestamp
|
from bson.timestamp import Timestamp
|
||||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||||
from pymongo.asynchronous import client_session, database
|
from pymongo.asynchronous import client_session, database, uri_parser
|
||||||
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
|
||||||
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
|
||||||
from pymongo.asynchronous.client_session import _EmptyServerSession
|
from pymongo.asynchronous.client_session import _EmptyServerSession
|
||||||
@ -113,11 +114,14 @@ from pymongo.typings import (
|
|||||||
_DocumentTypeArg,
|
_DocumentTypeArg,
|
||||||
_Pipeline,
|
_Pipeline,
|
||||||
)
|
)
|
||||||
from pymongo.uri_parser import (
|
from pymongo.uri_parser_shared import (
|
||||||
|
SRV_SCHEME,
|
||||||
_check_options,
|
_check_options,
|
||||||
_handle_option_deprecations,
|
_handle_option_deprecations,
|
||||||
_handle_security_options,
|
_handle_security_options,
|
||||||
_normalize_options,
|
_normalize_options,
|
||||||
|
_validate_uri,
|
||||||
|
split_hosts,
|
||||||
)
|
)
|
||||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||||
|
|
||||||
@ -128,6 +132,7 @@ if TYPE_CHECKING:
|
|||||||
from pymongo.asynchronous.bulk import _AsyncBulk
|
from pymongo.asynchronous.bulk import _AsyncBulk
|
||||||
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
|
from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession
|
||||||
from pymongo.asynchronous.cursor import _ConnectionManager
|
from pymongo.asynchronous.cursor import _ConnectionManager
|
||||||
|
from pymongo.asynchronous.encryption import _Encrypter
|
||||||
from pymongo.asynchronous.pool import AsyncConnection
|
from pymongo.asynchronous.pool import AsyncConnection
|
||||||
from pymongo.asynchronous.server import Server
|
from pymongo.asynchronous.server import Server
|
||||||
from pymongo.read_concern import ReadConcern
|
from pymongo.read_concern import ReadConcern
|
||||||
@ -750,6 +755,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
port = self.PORT
|
port = self.PORT
|
||||||
if not isinstance(port, int):
|
if not isinstance(port, int):
|
||||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._topology: Topology = None # type: ignore[assignment]
|
||||||
|
|
||||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||||
# customization of PyMongo, e.g. Motor.
|
# customization of PyMongo, e.g. Motor.
|
||||||
@ -760,8 +768,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
# Parse options passed as kwargs.
|
# Parse options passed as kwargs.
|
||||||
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
||||||
keyword_opts["document_class"] = doc_class
|
keyword_opts["document_class"] = doc_class
|
||||||
|
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||||
|
|
||||||
seeds = set()
|
seeds = set()
|
||||||
|
is_srv = False
|
||||||
username = None
|
username = None
|
||||||
password = None
|
password = None
|
||||||
dbase = None
|
dbase = None
|
||||||
@ -769,29 +779,22 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
fqdn = None
|
fqdn = None
|
||||||
srv_service_name = keyword_opts.get("srvservicename")
|
srv_service_name = keyword_opts.get("srvservicename")
|
||||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||||
if len([h for h in host if "/" in h]) > 1:
|
if len([h for h in self._host if "/" in h]) > 1:
|
||||||
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
||||||
for entity in host:
|
for entity in self._host:
|
||||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||||
# it must be a URI,
|
# it must be a URI,
|
||||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||||
if "/" in entity:
|
if "/" in entity:
|
||||||
# Determine connection timeout from kwargs.
|
res = _validate_uri(
|
||||||
timeout = keyword_opts.get("connecttimeoutms")
|
|
||||||
if timeout is not None:
|
|
||||||
timeout = common.validate_timeout_or_none_or_zero(
|
|
||||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
|
||||||
)
|
|
||||||
res = uri_parser.parse_uri(
|
|
||||||
entity,
|
entity,
|
||||||
port,
|
port,
|
||||||
validate=True,
|
validate=True,
|
||||||
warn=True,
|
warn=True,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
connect_timeout=timeout,
|
|
||||||
srv_service_name=srv_service_name,
|
|
||||||
srv_max_hosts=srv_max_hosts,
|
srv_max_hosts=srv_max_hosts,
|
||||||
)
|
)
|
||||||
|
is_srv = entity.startswith(SRV_SCHEME)
|
||||||
seeds.update(res["nodelist"])
|
seeds.update(res["nodelist"])
|
||||||
username = res["username"] or username
|
username = res["username"] or username
|
||||||
password = res["password"] or password
|
password = res["password"] or password
|
||||||
@ -799,7 +802,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
opts = res["options"]
|
opts = res["options"]
|
||||||
fqdn = res["fqdn"]
|
fqdn = res["fqdn"]
|
||||||
else:
|
else:
|
||||||
seeds.update(uri_parser.split_hosts(entity, port))
|
seeds.update(split_hosts(entity, self._port))
|
||||||
if not seeds:
|
if not seeds:
|
||||||
raise ConfigurationError("need to specify at least one host")
|
raise ConfigurationError("need to specify at least one host")
|
||||||
|
|
||||||
@ -820,80 +823,179 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
keyword_opts["tz_aware"] = tz_aware
|
keyword_opts["tz_aware"] = tz_aware
|
||||||
keyword_opts["connect"] = connect
|
keyword_opts["connect"] = connect
|
||||||
|
|
||||||
# Handle deprecated options in kwarg options.
|
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
|
||||||
# Validate kwarg options.
|
|
||||||
keyword_opts = common._CaseInsensitiveDictionary(
|
|
||||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override connection string options with kwarg options.
|
|
||||||
opts.update(keyword_opts)
|
|
||||||
|
|
||||||
if srv_service_name is None:
|
if srv_service_name is None:
|
||||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||||
|
|
||||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||||
# Handle security-option conflicts in combined options.
|
opts = self._normalize_and_validate_options(opts, seeds)
|
||||||
opts = _handle_security_options(opts)
|
|
||||||
# Normalize combined options.
|
|
||||||
opts = _normalize_options(opts)
|
|
||||||
_check_options(seeds, opts)
|
|
||||||
|
|
||||||
# Username and password passed as kwargs override user info in URI.
|
# Username and password passed as kwargs override user info in URI.
|
||||||
username = opts.get("username", username)
|
username = opts.get("username", username)
|
||||||
password = opts.get("password", password)
|
password = opts.get("password", password)
|
||||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||||
|
|
||||||
self._default_database_name = dbase
|
self._default_database_name = dbase
|
||||||
self._lock = _async_create_lock()
|
self._lock = _async_create_lock()
|
||||||
self._kill_cursors_queue: list = []
|
self._kill_cursors_queue: list = []
|
||||||
|
|
||||||
self._event_listeners = options.pool_options._event_listeners
|
self._encrypter: Optional[_Encrypter] = None
|
||||||
super().__init__(
|
|
||||||
options.codec_options,
|
self._resolve_srv_info.update(
|
||||||
options.read_preference,
|
{
|
||||||
options.write_concern,
|
"is_srv": is_srv,
|
||||||
options.read_concern,
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"dbase": dbase,
|
||||||
|
"seeds": seeds,
|
||||||
|
"fqdn": fqdn,
|
||||||
|
"srv_service_name": srv_service_name,
|
||||||
|
"pool_class": pool_class,
|
||||||
|
"monitor_class": monitor_class,
|
||||||
|
"condition_class": condition_class,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self._topology_settings = TopologySettings(
|
super().__init__(
|
||||||
seeds=seeds,
|
self._options.codec_options,
|
||||||
replica_set_name=options.replica_set_name,
|
self._options.read_preference,
|
||||||
pool_class=pool_class,
|
self._options.write_concern,
|
||||||
pool_options=options.pool_options,
|
self._options.read_concern,
|
||||||
monitor_class=monitor_class,
|
|
||||||
condition_class=condition_class,
|
|
||||||
local_threshold_ms=options.local_threshold_ms,
|
|
||||||
server_selection_timeout=options.server_selection_timeout,
|
|
||||||
server_selector=options.server_selector,
|
|
||||||
heartbeat_frequency=options.heartbeat_frequency,
|
|
||||||
fqdn=fqdn,
|
|
||||||
direct_connection=options.direct_connection,
|
|
||||||
load_balanced=options.load_balanced,
|
|
||||||
srv_service_name=srv_service_name,
|
|
||||||
srv_max_hosts=srv_max_hosts,
|
|
||||||
server_monitoring_mode=options.server_monitoring_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_srv:
|
||||||
|
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||||
|
|
||||||
self._opened = False
|
self._opened = False
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._init_background()
|
if not is_srv:
|
||||||
|
self._init_background()
|
||||||
|
|
||||||
if _IS_SYNC and connect:
|
if _IS_SYNC and connect:
|
||||||
self._get_topology() # type: ignore[unused-coroutine]
|
self._get_topology() # type: ignore[unused-coroutine]
|
||||||
|
|
||||||
self._encrypter = None
|
async def _resolve_srv(self) -> None:
|
||||||
|
keyword_opts = self._resolve_srv_info["keyword_opts"]
|
||||||
|
seeds = set()
|
||||||
|
opts = common._CaseInsensitiveDictionary()
|
||||||
|
srv_service_name = keyword_opts.get("srvservicename")
|
||||||
|
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||||
|
for entity in self._host:
|
||||||
|
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||||
|
# it must be a URI,
|
||||||
|
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||||
|
if "/" in entity:
|
||||||
|
# Determine connection timeout from kwargs.
|
||||||
|
timeout = keyword_opts.get("connecttimeoutms")
|
||||||
|
if timeout is not None:
|
||||||
|
timeout = common.validate_timeout_or_none_or_zero(
|
||||||
|
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||||
|
)
|
||||||
|
res = await uri_parser._parse_srv(
|
||||||
|
entity,
|
||||||
|
self._port,
|
||||||
|
validate=True,
|
||||||
|
warn=True,
|
||||||
|
normalize=False,
|
||||||
|
connect_timeout=timeout,
|
||||||
|
srv_service_name=srv_service_name,
|
||||||
|
srv_max_hosts=srv_max_hosts,
|
||||||
|
)
|
||||||
|
seeds.update(res["nodelist"])
|
||||||
|
opts = res["options"]
|
||||||
|
else:
|
||||||
|
seeds.update(split_hosts(entity, self._port))
|
||||||
|
|
||||||
|
if not seeds:
|
||||||
|
raise ConfigurationError("need to specify at least one host")
|
||||||
|
|
||||||
|
for hostname in [node[0] for node in seeds]:
|
||||||
|
if _detect_external_db(hostname):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add options with named keyword arguments to the parsed kwarg options.
|
||||||
|
tz_aware = keyword_opts["tz_aware"]
|
||||||
|
connect = keyword_opts["connect"]
|
||||||
|
if tz_aware is None:
|
||||||
|
tz_aware = opts.get("tz_aware", False)
|
||||||
|
if connect is None:
|
||||||
|
# Default to connect=True unless on a FaaS system, which might use fork.
|
||||||
|
from pymongo.pool_options import _is_faas
|
||||||
|
|
||||||
|
connect = opts.get("connect", not _is_faas())
|
||||||
|
keyword_opts["tz_aware"] = tz_aware
|
||||||
|
keyword_opts["connect"] = connect
|
||||||
|
|
||||||
|
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||||
|
|
||||||
|
if srv_service_name is None:
|
||||||
|
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||||
|
|
||||||
|
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||||
|
opts = self._normalize_and_validate_options(opts, seeds)
|
||||||
|
|
||||||
|
# Username and password passed as kwargs override user info in URI.
|
||||||
|
username = opts.get("username", self._resolve_srv_info["username"])
|
||||||
|
password = opts.get("password", self._resolve_srv_info["password"])
|
||||||
|
self._options = ClientOptions(
|
||||||
|
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||||
|
|
||||||
|
def _init_based_on_options(
|
||||||
|
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
|
||||||
|
) -> None:
|
||||||
|
self._event_listeners = self._options.pool_options._event_listeners
|
||||||
|
self._topology_settings = TopologySettings(
|
||||||
|
seeds=seeds,
|
||||||
|
replica_set_name=self._options.replica_set_name,
|
||||||
|
pool_class=self._resolve_srv_info["pool_class"],
|
||||||
|
pool_options=self._options.pool_options,
|
||||||
|
monitor_class=self._resolve_srv_info["monitor_class"],
|
||||||
|
condition_class=self._resolve_srv_info["condition_class"],
|
||||||
|
local_threshold_ms=self._options.local_threshold_ms,
|
||||||
|
server_selection_timeout=self._options.server_selection_timeout,
|
||||||
|
server_selector=self._options.server_selector,
|
||||||
|
heartbeat_frequency=self._options.heartbeat_frequency,
|
||||||
|
fqdn=self._resolve_srv_info["fqdn"],
|
||||||
|
direct_connection=self._options.direct_connection,
|
||||||
|
load_balanced=self._options.load_balanced,
|
||||||
|
srv_service_name=srv_service_name,
|
||||||
|
srv_max_hosts=srv_max_hosts,
|
||||||
|
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||||
|
)
|
||||||
if self._options.auto_encryption_opts:
|
if self._options.auto_encryption_opts:
|
||||||
from pymongo.asynchronous.encryption import _Encrypter
|
from pymongo.asynchronous.encryption import _Encrypter
|
||||||
|
|
||||||
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
||||||
self._timeout = self._options.timeout
|
self._timeout = self._options.timeout
|
||||||
|
|
||||||
if _HAS_REGISTER_AT_FORK:
|
def _normalize_and_validate_options(
|
||||||
# Add this client to the list of weakly referenced items.
|
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
|
||||||
# This will be used later if we fork.
|
) -> common._CaseInsensitiveDictionary:
|
||||||
AsyncMongoClient._clients[self._topology._topology_id] = self
|
# Handle security-option conflicts in combined options.
|
||||||
|
opts = _handle_security_options(opts)
|
||||||
|
# Normalize combined options.
|
||||||
|
opts = _normalize_options(opts)
|
||||||
|
_check_options(seeds, opts)
|
||||||
|
return opts
|
||||||
|
|
||||||
|
def _validate_kwargs_and_update_opts(
|
||||||
|
self,
|
||||||
|
keyword_opts: common._CaseInsensitiveDictionary,
|
||||||
|
opts: common._CaseInsensitiveDictionary,
|
||||||
|
) -> common._CaseInsensitiveDictionary:
|
||||||
|
# Handle deprecated options in kwarg options.
|
||||||
|
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||||
|
# Validate kwarg options.
|
||||||
|
keyword_opts = common._CaseInsensitiveDictionary(
|
||||||
|
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||||
|
)
|
||||||
|
# Override connection string options with kwarg options.
|
||||||
|
opts.update(keyword_opts)
|
||||||
|
return opts
|
||||||
|
|
||||||
async def aconnect(self) -> None:
|
async def aconnect(self) -> None:
|
||||||
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
|
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
|
||||||
@ -901,6 +1003,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
|
|
||||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||||
self._topology = Topology(self._topology_settings)
|
self._topology = Topology(self._topology_settings)
|
||||||
|
if _HAS_REGISTER_AT_FORK:
|
||||||
|
# Add this client to the list of weakly referenced items.
|
||||||
|
# This will be used later if we fork.
|
||||||
|
AsyncMongoClient._clients[self._topology._topology_id] = self
|
||||||
# Seed the topology with the old one's pid so we can detect clients
|
# Seed the topology with the old one's pid so we can detect clients
|
||||||
# that are opened before a fork and used after.
|
# that are opened before a fork and used after.
|
||||||
self._topology._pid = old_pid
|
self._topology._pid = old_pid
|
||||||
@ -1115,16 +1221,24 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
"""
|
"""
|
||||||
return self._options
|
return self._options
|
||||||
|
|
||||||
|
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
|
||||||
|
return (
|
||||||
|
tuple(sorted(self._resolve_srv_info["seeds"])),
|
||||||
|
self._options.replica_set_name,
|
||||||
|
self._resolve_srv_info["fqdn"],
|
||||||
|
self._resolve_srv_info["srv_service_name"],
|
||||||
|
)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self._topology == other._topology
|
return self.eq_props() == other.eq_props()
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __ne__(self, other: Any) -> bool:
|
def __ne__(self, other: Any) -> bool:
|
||||||
return not self == other
|
return not self == other
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self._topology)
|
return hash(self.eq_props())
|
||||||
|
|
||||||
def _repr_helper(self) -> str:
|
def _repr_helper(self) -> str:
|
||||||
def option_repr(option: str, value: Any) -> str:
|
def option_repr(option: str, value: Any) -> str:
|
||||||
@ -1140,13 +1254,16 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
return f"{option}={value!r}"
|
return f"{option}={value!r}"
|
||||||
|
|
||||||
# Host first...
|
# Host first...
|
||||||
options = [
|
if self._topology is None:
|
||||||
"host=%r"
|
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
|
||||||
% [
|
else:
|
||||||
"%s:%d" % (host, port) if port is not None else host
|
options = [
|
||||||
for host, port in self._topology_settings.seeds
|
"host=%r"
|
||||||
|
% [
|
||||||
|
"%s:%d" % (host, port) if port is not None else host
|
||||||
|
for host, port in self._topology_settings.seeds
|
||||||
|
]
|
||||||
]
|
]
|
||||||
]
|
|
||||||
# ... then everything in self._constructor_args...
|
# ... then everything in self._constructor_args...
|
||||||
options.extend(
|
options.extend(
|
||||||
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
||||||
@ -1552,6 +1669,8 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
.. versionchanged:: 3.6
|
.. versionchanged:: 3.6
|
||||||
End all server sessions created by this client.
|
End all server sessions created by this client.
|
||||||
"""
|
"""
|
||||||
|
if self._topology is None:
|
||||||
|
return
|
||||||
session_ids = self._topology.pop_all_sessions()
|
session_ids = self._topology.pop_all_sessions()
|
||||||
if session_ids:
|
if session_ids:
|
||||||
await self._end_sessions(session_ids)
|
await self._end_sessions(session_ids)
|
||||||
@ -1582,6 +1701,9 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
launches the connection process in the background.
|
launches the connection process in the background.
|
||||||
"""
|
"""
|
||||||
if not self._opened:
|
if not self._opened:
|
||||||
|
if self._resolve_srv_info["is_srv"]:
|
||||||
|
await self._resolve_srv()
|
||||||
|
self._init_background()
|
||||||
await self._topology.open()
|
await self._topology.open()
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._kill_cursors_executor.open()
|
self._kill_cursors_executor.open()
|
||||||
@ -2511,6 +2633,7 @@ class _MongoClientErrorHandler:
|
|||||||
self.completed_handshake,
|
self.completed_handshake,
|
||||||
self.service_id,
|
self.service_id,
|
||||||
)
|
)
|
||||||
|
assert self.client._topology is not None
|
||||||
await self.client._topology.handle_error(self.server_address, err_ctx)
|
await self.client._topology.handle_error(self.server_address, err_ctx)
|
||||||
|
|
||||||
async def __aenter__(self) -> _MongoClientErrorHandler:
|
async def __aenter__(self) -> _MongoClientErrorHandler:
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
|
|
||||||
from pymongo import common, periodic_executor
|
from pymongo import common, periodic_executor
|
||||||
from pymongo._csot import MovingMinimum
|
from pymongo._csot import MovingMinimum
|
||||||
|
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||||
from pymongo.errors import NetworkTimeout, _OperationCancelled
|
from pymongo.errors import NetworkTimeout, _OperationCancelled
|
||||||
from pymongo.hello import Hello
|
from pymongo.hello import Hello
|
||||||
from pymongo.lock import _async_create_lock
|
from pymongo.lock import _async_create_lock
|
||||||
@ -33,7 +34,6 @@ from pymongo.periodic_executor import _shutdown_executors
|
|||||||
from pymongo.pool_options import _is_faas
|
from pymongo.pool_options import _is_faas
|
||||||
from pymongo.read_preferences import MovingAverage
|
from pymongo.read_preferences import MovingAverage
|
||||||
from pymongo.server_description import ServerDescription
|
from pymongo.server_description import ServerDescription
|
||||||
from pymongo.srv_resolver import _SrvResolver
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext
|
||||||
@ -395,7 +395,7 @@ class SrvMonitor(MonitorBase):
|
|||||||
# Don't poll right after creation, wait 60 seconds first
|
# Don't poll right after creation, wait 60 seconds first
|
||||||
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL:
|
||||||
return
|
return
|
||||||
seedlist = self._get_seedlist()
|
seedlist = await self._get_seedlist()
|
||||||
if seedlist:
|
if seedlist:
|
||||||
self._seedlist = seedlist
|
self._seedlist = seedlist
|
||||||
try:
|
try:
|
||||||
@ -404,7 +404,7 @@ class SrvMonitor(MonitorBase):
|
|||||||
# Topology was garbage-collected.
|
# Topology was garbage-collected.
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
|
||||||
"""Poll SRV records for a seedlist.
|
"""Poll SRV records for a seedlist.
|
||||||
|
|
||||||
Returns a list of ServerDescriptions.
|
Returns a list of ServerDescriptions.
|
||||||
@ -415,7 +415,7 @@ class SrvMonitor(MonitorBase):
|
|||||||
self._settings.pool_options.connect_timeout,
|
self._settings.pool_options.connect_timeout,
|
||||||
self._settings.srv_service_name,
|
self._settings.srv_service_name,
|
||||||
)
|
)
|
||||||
seedlist, ttl = resolver.get_hosts_and_min_ttl()
|
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
|
||||||
if len(seedlist) == 0:
|
if len(seedlist) == 0:
|
||||||
# As per the spec: this should be treated as a failure.
|
# As per the spec: this should be treated as a failure.
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|||||||
160
pymongo/asynchronous/srv_resolver.py
Normal file
160
pymongo/asynchronous/srv_resolver.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# Copyright 2019-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||||
|
# may not use this file except in compliance with the License. You
|
||||||
|
# may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import random
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
from pymongo.common import CONNECT_TIMEOUT
|
||||||
|
from pymongo.errors import ConfigurationError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from dns import resolver
|
||||||
|
|
||||||
|
_IS_SYNC = False
|
||||||
|
|
||||||
|
|
||||||
|
def _have_dnspython() -> bool:
|
||||||
|
try:
|
||||||
|
import dns # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# dnspython can return bytes or str from various parts
|
||||||
|
# of its API depending on version. We always want str.
|
||||||
|
def maybe_decode(text: Union[str, bytes]) -> str:
|
||||||
|
if isinstance(text, bytes):
|
||||||
|
return text.decode()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||||
|
async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||||
|
if _IS_SYNC:
|
||||||
|
from dns import resolver
|
||||||
|
|
||||||
|
if hasattr(resolver, "resolve"):
|
||||||
|
# dnspython >= 2
|
||||||
|
return resolver.resolve(*args, **kwargs)
|
||||||
|
# dnspython 1.X
|
||||||
|
return resolver.query(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
from dns import asyncresolver
|
||||||
|
|
||||||
|
if hasattr(asyncresolver, "resolve"):
|
||||||
|
# dnspython >= 2
|
||||||
|
return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_INVALID_HOST_MSG = (
|
||||||
|
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
|
||||||
|
"Did you mean to use 'mongodb://'?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _SrvResolver:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fqdn: str,
|
||||||
|
connect_timeout: Optional[float],
|
||||||
|
srv_service_name: str,
|
||||||
|
srv_max_hosts: int = 0,
|
||||||
|
):
|
||||||
|
self.__fqdn = fqdn
|
||||||
|
self.__srv = srv_service_name
|
||||||
|
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
|
||||||
|
self.__srv_max_hosts = srv_max_hosts or 0
|
||||||
|
# Validate the fully qualified domain name.
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(fqdn)
|
||||||
|
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.__plist = self.__fqdn.split(".")[1:]
|
||||||
|
except Exception:
|
||||||
|
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None
|
||||||
|
self.__slen = len(self.__plist)
|
||||||
|
if self.__slen < 2:
|
||||||
|
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
|
||||||
|
|
||||||
|
async def get_options(self) -> Optional[str]:
|
||||||
|
from dns import resolver
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout)
|
||||||
|
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||||
|
# No TXT records
|
||||||
|
return None
|
||||||
|
except Exception as exc:
|
||||||
|
raise ConfigurationError(str(exc)) from None
|
||||||
|
if len(results) > 1:
|
||||||
|
raise ConfigurationError("Only one TXT record is supported")
|
||||||
|
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer:
|
||||||
|
try:
|
||||||
|
results = await _resolve(
|
||||||
|
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if not encapsulate_errors:
|
||||||
|
# Raise the original error.
|
||||||
|
raise
|
||||||
|
# Else, raise all errors as ConfigurationError.
|
||||||
|
raise ConfigurationError(str(exc)) from None
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _get_srv_response_and_hosts(
|
||||||
|
self, encapsulate_errors: bool
|
||||||
|
) -> tuple[resolver.Answer, list[tuple[str, Any]]]:
|
||||||
|
results = await self._resolve_uri(encapsulate_errors)
|
||||||
|
|
||||||
|
# Construct address tuples
|
||||||
|
nodes = [
|
||||||
|
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined]
|
||||||
|
for res in results
|
||||||
|
]
|
||||||
|
|
||||||
|
# Validate hosts
|
||||||
|
for node in nodes:
|
||||||
|
try:
|
||||||
|
nlist = node[0].lower().split(".")[1:][-self.__slen :]
|
||||||
|
except Exception:
|
||||||
|
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None
|
||||||
|
if self.__plist != nlist:
|
||||||
|
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
|
||||||
|
if self.__srv_max_hosts:
|
||||||
|
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
|
||||||
|
return results, nodes
|
||||||
|
|
||||||
|
async def get_hosts(self) -> list[tuple[str, Any]]:
|
||||||
|
_, nodes = await self._get_srv_response_and_hosts(True)
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]:
|
||||||
|
results, nodes = await self._get_srv_response_and_hosts(False)
|
||||||
|
rrset = results.rrset
|
||||||
|
ttl = rrset.ttl if rrset else 0
|
||||||
|
return nodes, ttl
|
||||||
188
pymongo/asynchronous/uri_parser.py
Normal file
188
pymongo/asynchronous/uri_parser.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright 2011-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||||
|
# may not use this file except in compliance with the License. You
|
||||||
|
# may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""Tools to parse and validate a MongoDB URI."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from urllib.parse import unquote_plus
|
||||||
|
|
||||||
|
from pymongo.asynchronous.srv_resolver import _SrvResolver
|
||||||
|
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
|
||||||
|
from pymongo.errors import ConfigurationError, InvalidURI
|
||||||
|
from pymongo.uri_parser_shared import (
|
||||||
|
_ALLOWED_TXT_OPTS,
|
||||||
|
DEFAULT_PORT,
|
||||||
|
SCHEME,
|
||||||
|
SCHEME_LEN,
|
||||||
|
SRV_SCHEME_LEN,
|
||||||
|
_check_options,
|
||||||
|
_validate_uri,
|
||||||
|
split_hosts,
|
||||||
|
split_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
_IS_SYNC = False
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_uri(
|
||||||
|
uri: str,
|
||||||
|
default_port: Optional[int] = DEFAULT_PORT,
|
||||||
|
validate: bool = True,
|
||||||
|
warn: bool = False,
|
||||||
|
normalize: bool = True,
|
||||||
|
connect_timeout: Optional[float] = None,
|
||||||
|
srv_service_name: Optional[str] = None,
|
||||||
|
srv_max_hosts: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Parse and validate a MongoDB URI.
|
||||||
|
|
||||||
|
Returns a dict of the form::
|
||||||
|
|
||||||
|
{
|
||||||
|
'nodelist': <list of (host, port) tuples>,
|
||||||
|
'username': <username> or None,
|
||||||
|
'password': <password> or None,
|
||||||
|
'database': <database name> or None,
|
||||||
|
'collection': <collection name> or None,
|
||||||
|
'options': <dict of MongoDB URI options>,
|
||||||
|
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||||
|
}
|
||||||
|
|
||||||
|
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||||
|
to build nodelist and options.
|
||||||
|
|
||||||
|
:param uri: The MongoDB URI to parse.
|
||||||
|
:param default_port: The port number to use when one wasn't specified
|
||||||
|
for a host in the URI.
|
||||||
|
:param validate: If ``True`` (the default), validate and
|
||||||
|
normalize all options. Default: ``True``.
|
||||||
|
:param warn: When validating, if ``True`` then will warn
|
||||||
|
the user then ignore any invalid options or values. If ``False``,
|
||||||
|
validation will error when options are unsupported or values are
|
||||||
|
invalid. Default: ``False``.
|
||||||
|
:param normalize: If ``True``, convert names of URI options
|
||||||
|
to their internally-used names. Default: ``True``.
|
||||||
|
:param connect_timeout: The maximum time in milliseconds to
|
||||||
|
wait for a response from the DNS server.
|
||||||
|
:param srv_service_name: A custom SRV service name
|
||||||
|
|
||||||
|
.. versionchanged:: 4.6
|
||||||
|
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||||
|
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.0
|
||||||
|
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||||
|
supported.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.9
|
||||||
|
Added the ``normalize`` parameter.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.6
|
||||||
|
Added support for mongodb+srv:// URIs.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.5
|
||||||
|
Return the original value of the ``readPreference`` MongoDB URI option
|
||||||
|
instead of the validated read preference mode.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.1
|
||||||
|
``warn`` added so invalid options can be ignored.
|
||||||
|
"""
|
||||||
|
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
|
||||||
|
result.update(
|
||||||
|
await _parse_srv(
|
||||||
|
uri,
|
||||||
|
default_port,
|
||||||
|
validate,
|
||||||
|
warn,
|
||||||
|
normalize,
|
||||||
|
connect_timeout,
|
||||||
|
srv_service_name,
|
||||||
|
srv_max_hosts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_srv(
|
||||||
|
uri: str,
|
||||||
|
default_port: Optional[int] = DEFAULT_PORT,
|
||||||
|
validate: bool = True,
|
||||||
|
warn: bool = False,
|
||||||
|
normalize: bool = True,
|
||||||
|
connect_timeout: Optional[float] = None,
|
||||||
|
srv_service_name: Optional[str] = None,
|
||||||
|
srv_max_hosts: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if uri.startswith(SCHEME):
|
||||||
|
is_srv = False
|
||||||
|
scheme_free = uri[SCHEME_LEN:]
|
||||||
|
else:
|
||||||
|
is_srv = True
|
||||||
|
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||||
|
|
||||||
|
options = _CaseInsensitiveDictionary()
|
||||||
|
|
||||||
|
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||||
|
if "/" in host_plus_db_part:
|
||||||
|
host_part, _, _ = host_plus_db_part.partition("/")
|
||||||
|
else:
|
||||||
|
host_part = host_plus_db_part
|
||||||
|
|
||||||
|
if opts:
|
||||||
|
options.update(split_options(opts, validate, warn, normalize))
|
||||||
|
if srv_service_name is None:
|
||||||
|
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||||
|
if "@" in host_part:
|
||||||
|
_, _, hosts = host_part.rpartition("@")
|
||||||
|
else:
|
||||||
|
hosts = host_part
|
||||||
|
|
||||||
|
hosts = unquote_plus(hosts)
|
||||||
|
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||||
|
if is_srv:
|
||||||
|
nodes = split_hosts(hosts, default_port=None)
|
||||||
|
fqdn, port = nodes[0]
|
||||||
|
|
||||||
|
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||||
|
# argument overrides the same option passed in the connection string.
|
||||||
|
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||||
|
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||||
|
nodes = await dns_resolver.get_hosts()
|
||||||
|
dns_options = await dns_resolver.get_options()
|
||||||
|
if dns_options:
|
||||||
|
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||||
|
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||||
|
)
|
||||||
|
for opt, val in parsed_dns_options.items():
|
||||||
|
if opt not in options:
|
||||||
|
options[opt] = val
|
||||||
|
if options.get("loadBalanced") and srv_max_hosts:
|
||||||
|
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||||
|
if options.get("replicaSet") and srv_max_hosts:
|
||||||
|
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||||
|
if "tls" not in options and "ssl" not in options:
|
||||||
|
options["tls"] = True if validate else "true"
|
||||||
|
else:
|
||||||
|
nodes = split_hosts(hosts, default_port=default_port)
|
||||||
|
|
||||||
|
_check_options(nodes, options)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodelist": nodes,
|
||||||
|
"options": options,
|
||||||
|
}
|
||||||
@ -32,7 +32,7 @@ except ImportError:
|
|||||||
from bson import int64
|
from bson import int64
|
||||||
from pymongo.common import validate_is_mapping
|
from pymongo.common import validate_is_mapping
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.uri_parser import _parse_kms_tls_options
|
from pymongo.uri_parser_shared import _parse_kms_tls_options
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
|
||||||
|
|||||||
@ -86,7 +86,7 @@ from pymongo.synchronous.pool import (
|
|||||||
_raise_connection_failure,
|
_raise_connection_failure,
|
||||||
)
|
)
|
||||||
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
from pymongo.typings import _DocumentType, _DocumentTypeArg
|
||||||
from pymongo.uri_parser import parse_host
|
from pymongo.uri_parser_shared import parse_host
|
||||||
from pymongo.write_concern import WriteConcern
|
from pymongo.write_concern import WriteConcern
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from typing import (
|
|||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Generator,
|
Generator,
|
||||||
@ -59,7 +60,7 @@ from typing import (
|
|||||||
|
|
||||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry
|
||||||
from bson.timestamp import Timestamp
|
from bson.timestamp import Timestamp
|
||||||
from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser
|
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||||
from pymongo.client_options import ClientOptions
|
from pymongo.client_options import ClientOptions
|
||||||
from pymongo.errors import (
|
from pymongo.errors import (
|
||||||
AutoReconnect,
|
AutoReconnect,
|
||||||
@ -96,7 +97,7 @@ from pymongo.read_preferences import ReadPreference, _ServerMode
|
|||||||
from pymongo.results import ClientBulkWriteResult
|
from pymongo.results import ClientBulkWriteResult
|
||||||
from pymongo.server_selectors import writable_server_selector
|
from pymongo.server_selectors import writable_server_selector
|
||||||
from pymongo.server_type import SERVER_TYPE
|
from pymongo.server_type import SERVER_TYPE
|
||||||
from pymongo.synchronous import client_session, database
|
from pymongo.synchronous import client_session, database, uri_parser
|
||||||
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
|
||||||
from pymongo.synchronous.client_bulk import _ClientBulk
|
from pymongo.synchronous.client_bulk import _ClientBulk
|
||||||
from pymongo.synchronous.client_session import _EmptyServerSession
|
from pymongo.synchronous.client_session import _EmptyServerSession
|
||||||
@ -112,11 +113,14 @@ from pymongo.typings import (
|
|||||||
_DocumentTypeArg,
|
_DocumentTypeArg,
|
||||||
_Pipeline,
|
_Pipeline,
|
||||||
)
|
)
|
||||||
from pymongo.uri_parser import (
|
from pymongo.uri_parser_shared import (
|
||||||
|
SRV_SCHEME,
|
||||||
_check_options,
|
_check_options,
|
||||||
_handle_option_deprecations,
|
_handle_option_deprecations,
|
||||||
_handle_security_options,
|
_handle_security_options,
|
||||||
_normalize_options,
|
_normalize_options,
|
||||||
|
_validate_uri,
|
||||||
|
split_hosts,
|
||||||
)
|
)
|
||||||
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern
|
||||||
|
|
||||||
@ -130,6 +134,7 @@ if TYPE_CHECKING:
|
|||||||
from pymongo.synchronous.bulk import _Bulk
|
from pymongo.synchronous.bulk import _Bulk
|
||||||
from pymongo.synchronous.client_session import ClientSession, _ServerSession
|
from pymongo.synchronous.client_session import ClientSession, _ServerSession
|
||||||
from pymongo.synchronous.cursor import _ConnectionManager
|
from pymongo.synchronous.cursor import _ConnectionManager
|
||||||
|
from pymongo.synchronous.encryption import _Encrypter
|
||||||
from pymongo.synchronous.pool import Connection
|
from pymongo.synchronous.pool import Connection
|
||||||
from pymongo.synchronous.server import Server
|
from pymongo.synchronous.server import Server
|
||||||
|
|
||||||
@ -748,6 +753,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
port = self.PORT
|
port = self.PORT
|
||||||
if not isinstance(port, int):
|
if not isinstance(port, int):
|
||||||
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
raise TypeError(f"port must be an instance of int, not {type(port)}")
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._topology: Topology = None # type: ignore[assignment]
|
||||||
|
|
||||||
# _pool_class, _monitor_class, and _condition_class are for deep
|
# _pool_class, _monitor_class, and _condition_class are for deep
|
||||||
# customization of PyMongo, e.g. Motor.
|
# customization of PyMongo, e.g. Motor.
|
||||||
@ -758,8 +766,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
# Parse options passed as kwargs.
|
# Parse options passed as kwargs.
|
||||||
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
keyword_opts = common._CaseInsensitiveDictionary(kwargs)
|
||||||
keyword_opts["document_class"] = doc_class
|
keyword_opts["document_class"] = doc_class
|
||||||
|
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
|
||||||
|
|
||||||
seeds = set()
|
seeds = set()
|
||||||
|
is_srv = False
|
||||||
username = None
|
username = None
|
||||||
password = None
|
password = None
|
||||||
dbase = None
|
dbase = None
|
||||||
@ -767,29 +777,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
fqdn = None
|
fqdn = None
|
||||||
srv_service_name = keyword_opts.get("srvservicename")
|
srv_service_name = keyword_opts.get("srvservicename")
|
||||||
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||||
if len([h for h in host if "/" in h]) > 1:
|
if len([h for h in self._host if "/" in h]) > 1:
|
||||||
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
raise ConfigurationError("host must not contain multiple MongoDB URIs")
|
||||||
for entity in host:
|
for entity in self._host:
|
||||||
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||||
# it must be a URI,
|
# it must be a URI,
|
||||||
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||||
if "/" in entity:
|
if "/" in entity:
|
||||||
# Determine connection timeout from kwargs.
|
res = _validate_uri(
|
||||||
timeout = keyword_opts.get("connecttimeoutms")
|
|
||||||
if timeout is not None:
|
|
||||||
timeout = common.validate_timeout_or_none_or_zero(
|
|
||||||
keyword_opts.cased_key("connecttimeoutms"), timeout
|
|
||||||
)
|
|
||||||
res = uri_parser.parse_uri(
|
|
||||||
entity,
|
entity,
|
||||||
port,
|
port,
|
||||||
validate=True,
|
validate=True,
|
||||||
warn=True,
|
warn=True,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
connect_timeout=timeout,
|
|
||||||
srv_service_name=srv_service_name,
|
|
||||||
srv_max_hosts=srv_max_hosts,
|
srv_max_hosts=srv_max_hosts,
|
||||||
)
|
)
|
||||||
|
is_srv = entity.startswith(SRV_SCHEME)
|
||||||
seeds.update(res["nodelist"])
|
seeds.update(res["nodelist"])
|
||||||
username = res["username"] or username
|
username = res["username"] or username
|
||||||
password = res["password"] or password
|
password = res["password"] or password
|
||||||
@ -797,7 +800,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
opts = res["options"]
|
opts = res["options"]
|
||||||
fqdn = res["fqdn"]
|
fqdn = res["fqdn"]
|
||||||
else:
|
else:
|
||||||
seeds.update(uri_parser.split_hosts(entity, port))
|
seeds.update(split_hosts(entity, self._port))
|
||||||
if not seeds:
|
if not seeds:
|
||||||
raise ConfigurationError("need to specify at least one host")
|
raise ConfigurationError("need to specify at least one host")
|
||||||
|
|
||||||
@ -818,80 +821,179 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
keyword_opts["tz_aware"] = tz_aware
|
keyword_opts["tz_aware"] = tz_aware
|
||||||
keyword_opts["connect"] = connect
|
keyword_opts["connect"] = connect
|
||||||
|
|
||||||
# Handle deprecated options in kwarg options.
|
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||||
keyword_opts = _handle_option_deprecations(keyword_opts)
|
|
||||||
# Validate kwarg options.
|
|
||||||
keyword_opts = common._CaseInsensitiveDictionary(
|
|
||||||
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override connection string options with kwarg options.
|
|
||||||
opts.update(keyword_opts)
|
|
||||||
|
|
||||||
if srv_service_name is None:
|
if srv_service_name is None:
|
||||||
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||||
|
|
||||||
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||||
# Handle security-option conflicts in combined options.
|
opts = self._normalize_and_validate_options(opts, seeds)
|
||||||
opts = _handle_security_options(opts)
|
|
||||||
# Normalize combined options.
|
|
||||||
opts = _normalize_options(opts)
|
|
||||||
_check_options(seeds, opts)
|
|
||||||
|
|
||||||
# Username and password passed as kwargs override user info in URI.
|
# Username and password passed as kwargs override user info in URI.
|
||||||
username = opts.get("username", username)
|
username = opts.get("username", username)
|
||||||
password = opts.get("password", password)
|
password = opts.get("password", password)
|
||||||
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC)
|
||||||
|
|
||||||
self._default_database_name = dbase
|
self._default_database_name = dbase
|
||||||
self._lock = _create_lock()
|
self._lock = _create_lock()
|
||||||
self._kill_cursors_queue: list = []
|
self._kill_cursors_queue: list = []
|
||||||
|
|
||||||
self._event_listeners = options.pool_options._event_listeners
|
self._encrypter: Optional[_Encrypter] = None
|
||||||
super().__init__(
|
|
||||||
options.codec_options,
|
self._resolve_srv_info.update(
|
||||||
options.read_preference,
|
{
|
||||||
options.write_concern,
|
"is_srv": is_srv,
|
||||||
options.read_concern,
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"dbase": dbase,
|
||||||
|
"seeds": seeds,
|
||||||
|
"fqdn": fqdn,
|
||||||
|
"srv_service_name": srv_service_name,
|
||||||
|
"pool_class": pool_class,
|
||||||
|
"monitor_class": monitor_class,
|
||||||
|
"condition_class": condition_class,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self._topology_settings = TopologySettings(
|
super().__init__(
|
||||||
seeds=seeds,
|
self._options.codec_options,
|
||||||
replica_set_name=options.replica_set_name,
|
self._options.read_preference,
|
||||||
pool_class=pool_class,
|
self._options.write_concern,
|
||||||
pool_options=options.pool_options,
|
self._options.read_concern,
|
||||||
monitor_class=monitor_class,
|
|
||||||
condition_class=condition_class,
|
|
||||||
local_threshold_ms=options.local_threshold_ms,
|
|
||||||
server_selection_timeout=options.server_selection_timeout,
|
|
||||||
server_selector=options.server_selector,
|
|
||||||
heartbeat_frequency=options.heartbeat_frequency,
|
|
||||||
fqdn=fqdn,
|
|
||||||
direct_connection=options.direct_connection,
|
|
||||||
load_balanced=options.load_balanced,
|
|
||||||
srv_service_name=srv_service_name,
|
|
||||||
srv_max_hosts=srv_max_hosts,
|
|
||||||
server_monitoring_mode=options.server_monitoring_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_srv:
|
||||||
|
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||||
|
|
||||||
self._opened = False
|
self._opened = False
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._init_background()
|
if not is_srv:
|
||||||
|
self._init_background()
|
||||||
|
|
||||||
if _IS_SYNC and connect:
|
if _IS_SYNC and connect:
|
||||||
self._get_topology() # type: ignore[unused-coroutine]
|
self._get_topology() # type: ignore[unused-coroutine]
|
||||||
|
|
||||||
self._encrypter = None
|
def _resolve_srv(self) -> None:
|
||||||
|
keyword_opts = self._resolve_srv_info["keyword_opts"]
|
||||||
|
seeds = set()
|
||||||
|
opts = common._CaseInsensitiveDictionary()
|
||||||
|
srv_service_name = keyword_opts.get("srvservicename")
|
||||||
|
srv_max_hosts = keyword_opts.get("srvmaxhosts")
|
||||||
|
for entity in self._host:
|
||||||
|
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
|
||||||
|
# it must be a URI,
|
||||||
|
# https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names
|
||||||
|
if "/" in entity:
|
||||||
|
# Determine connection timeout from kwargs.
|
||||||
|
timeout = keyword_opts.get("connecttimeoutms")
|
||||||
|
if timeout is not None:
|
||||||
|
timeout = common.validate_timeout_or_none_or_zero(
|
||||||
|
keyword_opts.cased_key("connecttimeoutms"), timeout
|
||||||
|
)
|
||||||
|
res = uri_parser._parse_srv(
|
||||||
|
entity,
|
||||||
|
self._port,
|
||||||
|
validate=True,
|
||||||
|
warn=True,
|
||||||
|
normalize=False,
|
||||||
|
connect_timeout=timeout,
|
||||||
|
srv_service_name=srv_service_name,
|
||||||
|
srv_max_hosts=srv_max_hosts,
|
||||||
|
)
|
||||||
|
seeds.update(res["nodelist"])
|
||||||
|
opts = res["options"]
|
||||||
|
else:
|
||||||
|
seeds.update(split_hosts(entity, self._port))
|
||||||
|
|
||||||
|
if not seeds:
|
||||||
|
raise ConfigurationError("need to specify at least one host")
|
||||||
|
|
||||||
|
for hostname in [node[0] for node in seeds]:
|
||||||
|
if _detect_external_db(hostname):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add options with named keyword arguments to the parsed kwarg options.
|
||||||
|
tz_aware = keyword_opts["tz_aware"]
|
||||||
|
connect = keyword_opts["connect"]
|
||||||
|
if tz_aware is None:
|
||||||
|
tz_aware = opts.get("tz_aware", False)
|
||||||
|
if connect is None:
|
||||||
|
# Default to connect=True unless on a FaaS system, which might use fork.
|
||||||
|
from pymongo.pool_options import _is_faas
|
||||||
|
|
||||||
|
connect = opts.get("connect", not _is_faas())
|
||||||
|
keyword_opts["tz_aware"] = tz_aware
|
||||||
|
keyword_opts["connect"] = connect
|
||||||
|
|
||||||
|
opts = self._validate_kwargs_and_update_opts(keyword_opts, opts)
|
||||||
|
|
||||||
|
if srv_service_name is None:
|
||||||
|
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
|
||||||
|
|
||||||
|
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
|
||||||
|
opts = self._normalize_and_validate_options(opts, seeds)
|
||||||
|
|
||||||
|
# Username and password passed as kwargs override user info in URI.
|
||||||
|
username = opts.get("username", self._resolve_srv_info["username"])
|
||||||
|
password = opts.get("password", self._resolve_srv_info["password"])
|
||||||
|
self._options = ClientOptions(
|
||||||
|
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
|
||||||
|
|
||||||
|
def _init_based_on_options(
|
||||||
|
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
|
||||||
|
) -> None:
|
||||||
|
self._event_listeners = self._options.pool_options._event_listeners
|
||||||
|
self._topology_settings = TopologySettings(
|
||||||
|
seeds=seeds,
|
||||||
|
replica_set_name=self._options.replica_set_name,
|
||||||
|
pool_class=self._resolve_srv_info["pool_class"],
|
||||||
|
pool_options=self._options.pool_options,
|
||||||
|
monitor_class=self._resolve_srv_info["monitor_class"],
|
||||||
|
condition_class=self._resolve_srv_info["condition_class"],
|
||||||
|
local_threshold_ms=self._options.local_threshold_ms,
|
||||||
|
server_selection_timeout=self._options.server_selection_timeout,
|
||||||
|
server_selector=self._options.server_selector,
|
||||||
|
heartbeat_frequency=self._options.heartbeat_frequency,
|
||||||
|
fqdn=self._resolve_srv_info["fqdn"],
|
||||||
|
direct_connection=self._options.direct_connection,
|
||||||
|
load_balanced=self._options.load_balanced,
|
||||||
|
srv_service_name=srv_service_name,
|
||||||
|
srv_max_hosts=srv_max_hosts,
|
||||||
|
server_monitoring_mode=self._options.server_monitoring_mode,
|
||||||
|
)
|
||||||
if self._options.auto_encryption_opts:
|
if self._options.auto_encryption_opts:
|
||||||
from pymongo.synchronous.encryption import _Encrypter
|
from pymongo.synchronous.encryption import _Encrypter
|
||||||
|
|
||||||
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
self._encrypter = _Encrypter(self, self._options.auto_encryption_opts)
|
||||||
self._timeout = self._options.timeout
|
self._timeout = self._options.timeout
|
||||||
|
|
||||||
if _HAS_REGISTER_AT_FORK:
|
def _normalize_and_validate_options(
|
||||||
# Add this client to the list of weakly referenced items.
|
self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]]
|
||||||
# This will be used later if we fork.
|
) -> common._CaseInsensitiveDictionary:
|
||||||
MongoClient._clients[self._topology._topology_id] = self
|
# Handle security-option conflicts in combined options.
|
||||||
|
opts = _handle_security_options(opts)
|
||||||
|
# Normalize combined options.
|
||||||
|
opts = _normalize_options(opts)
|
||||||
|
_check_options(seeds, opts)
|
||||||
|
return opts
|
||||||
|
|
||||||
|
def _validate_kwargs_and_update_opts(
|
||||||
|
self,
|
||||||
|
keyword_opts: common._CaseInsensitiveDictionary,
|
||||||
|
opts: common._CaseInsensitiveDictionary,
|
||||||
|
) -> common._CaseInsensitiveDictionary:
|
||||||
|
# Handle deprecated options in kwarg options.
|
||||||
|
keyword_opts = _handle_option_deprecations(keyword_opts)
|
||||||
|
# Validate kwarg options.
|
||||||
|
keyword_opts = common._CaseInsensitiveDictionary(
|
||||||
|
dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())
|
||||||
|
)
|
||||||
|
# Override connection string options with kwarg options.
|
||||||
|
opts.update(keyword_opts)
|
||||||
|
return opts
|
||||||
|
|
||||||
def _connect(self) -> None:
|
def _connect(self) -> None:
|
||||||
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
|
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
|
||||||
@ -899,6 +1001,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
|
|
||||||
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
def _init_background(self, old_pid: Optional[int] = None) -> None:
|
||||||
self._topology = Topology(self._topology_settings)
|
self._topology = Topology(self._topology_settings)
|
||||||
|
if _HAS_REGISTER_AT_FORK:
|
||||||
|
# Add this client to the list of weakly referenced items.
|
||||||
|
# This will be used later if we fork.
|
||||||
|
MongoClient._clients[self._topology._topology_id] = self
|
||||||
# Seed the topology with the old one's pid so we can detect clients
|
# Seed the topology with the old one's pid so we can detect clients
|
||||||
# that are opened before a fork and used after.
|
# that are opened before a fork and used after.
|
||||||
self._topology._pid = old_pid
|
self._topology._pid = old_pid
|
||||||
@ -1113,16 +1219,24 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
"""
|
"""
|
||||||
return self._options
|
return self._options
|
||||||
|
|
||||||
|
def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]:
|
||||||
|
return (
|
||||||
|
tuple(sorted(self._resolve_srv_info["seeds"])),
|
||||||
|
self._options.replica_set_name,
|
||||||
|
self._resolve_srv_info["fqdn"],
|
||||||
|
self._resolve_srv_info["srv_service_name"],
|
||||||
|
)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self._topology == other._topology
|
return self.eq_props() == other.eq_props()
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __ne__(self, other: Any) -> bool:
|
def __ne__(self, other: Any) -> bool:
|
||||||
return not self == other
|
return not self == other
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self._topology)
|
return hash(self.eq_props())
|
||||||
|
|
||||||
def _repr_helper(self) -> str:
|
def _repr_helper(self) -> str:
|
||||||
def option_repr(option: str, value: Any) -> str:
|
def option_repr(option: str, value: Any) -> str:
|
||||||
@ -1138,13 +1252,16 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
return f"{option}={value!r}"
|
return f"{option}={value!r}"
|
||||||
|
|
||||||
# Host first...
|
# Host first...
|
||||||
options = [
|
if self._topology is None:
|
||||||
"host=%r"
|
options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"]
|
||||||
% [
|
else:
|
||||||
"%s:%d" % (host, port) if port is not None else host
|
options = [
|
||||||
for host, port in self._topology_settings.seeds
|
"host=%r"
|
||||||
|
% [
|
||||||
|
"%s:%d" % (host, port) if port is not None else host
|
||||||
|
for host, port in self._topology_settings.seeds
|
||||||
|
]
|
||||||
]
|
]
|
||||||
]
|
|
||||||
# ... then everything in self._constructor_args...
|
# ... then everything in self._constructor_args...
|
||||||
options.extend(
|
options.extend(
|
||||||
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
option_repr(key, self._options._options[key]) for key in self._constructor_args
|
||||||
@ -1546,6 +1663,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
.. versionchanged:: 3.6
|
.. versionchanged:: 3.6
|
||||||
End all server sessions created by this client.
|
End all server sessions created by this client.
|
||||||
"""
|
"""
|
||||||
|
if self._topology is None:
|
||||||
|
return
|
||||||
session_ids = self._topology.pop_all_sessions()
|
session_ids = self._topology.pop_all_sessions()
|
||||||
if session_ids:
|
if session_ids:
|
||||||
self._end_sessions(session_ids)
|
self._end_sessions(session_ids)
|
||||||
@ -1576,6 +1695,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
|||||||
launches the connection process in the background.
|
launches the connection process in the background.
|
||||||
"""
|
"""
|
||||||
if not self._opened:
|
if not self._opened:
|
||||||
|
if self._resolve_srv_info["is_srv"]:
|
||||||
|
self._resolve_srv()
|
||||||
|
self._init_background()
|
||||||
self._topology.open()
|
self._topology.open()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._kill_cursors_executor.open()
|
self._kill_cursors_executor.open()
|
||||||
@ -2497,6 +2619,7 @@ class _MongoClientErrorHandler:
|
|||||||
self.completed_handshake,
|
self.completed_handshake,
|
||||||
self.service_id,
|
self.service_id,
|
||||||
)
|
)
|
||||||
|
assert self.client._topology is not None
|
||||||
self.client._topology.handle_error(self.server_address, err_ctx)
|
self.client._topology.handle_error(self.server_address, err_ctx)
|
||||||
|
|
||||||
def __enter__(self) -> _MongoClientErrorHandler:
|
def __enter__(self) -> _MongoClientErrorHandler:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from pymongo.periodic_executor import _shutdown_executors
|
|||||||
from pymongo.pool_options import _is_faas
|
from pymongo.pool_options import _is_faas
|
||||||
from pymongo.read_preferences import MovingAverage
|
from pymongo.read_preferences import MovingAverage
|
||||||
from pymongo.server_description import ServerDescription
|
from pymongo.server_description import ServerDescription
|
||||||
from pymongo.srv_resolver import _SrvResolver
|
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
from pymongo.synchronous.pool import Connection, Pool, _CancellationContext
|
||||||
|
|||||||
@ -25,6 +25,8 @@ from pymongo.errors import ConfigurationError
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dns import resolver
|
from dns import resolver
|
||||||
|
|
||||||
|
_IS_SYNC = True
|
||||||
|
|
||||||
|
|
||||||
def _have_dnspython() -> bool:
|
def _have_dnspython() -> bool:
|
||||||
try:
|
try:
|
||||||
@ -45,13 +47,23 @@ def maybe_decode(text: Union[str, bytes]) -> str:
|
|||||||
|
|
||||||
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
|
||||||
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer:
|
||||||
from dns import resolver
|
if _IS_SYNC:
|
||||||
|
from dns import resolver
|
||||||
|
|
||||||
if hasattr(resolver, "resolve"):
|
if hasattr(resolver, "resolve"):
|
||||||
# dnspython >= 2
|
# dnspython >= 2
|
||||||
return resolver.resolve(*args, **kwargs)
|
return resolver.resolve(*args, **kwargs)
|
||||||
# dnspython 1.X
|
# dnspython 1.X
|
||||||
return resolver.query(*args, **kwargs)
|
return resolver.query(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
from dns import asyncresolver
|
||||||
|
|
||||||
|
if hasattr(asyncresolver, "resolve"):
|
||||||
|
# dnspython >= 2
|
||||||
|
return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value]
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_INVALID_HOST_MSG = (
|
_INVALID_HOST_MSG = (
|
||||||
188
pymongo/synchronous/uri_parser.py
Normal file
188
pymongo/synchronous/uri_parser.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright 2011-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||||
|
# may not use this file except in compliance with the License. You
|
||||||
|
# may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""Tools to parse and validate a MongoDB URI."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from urllib.parse import unquote_plus
|
||||||
|
|
||||||
|
from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary
|
||||||
|
from pymongo.errors import ConfigurationError, InvalidURI
|
||||||
|
from pymongo.synchronous.srv_resolver import _SrvResolver
|
||||||
|
from pymongo.uri_parser_shared import (
|
||||||
|
_ALLOWED_TXT_OPTS,
|
||||||
|
DEFAULT_PORT,
|
||||||
|
SCHEME,
|
||||||
|
SCHEME_LEN,
|
||||||
|
SRV_SCHEME_LEN,
|
||||||
|
_check_options,
|
||||||
|
_validate_uri,
|
||||||
|
split_hosts,
|
||||||
|
split_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
_IS_SYNC = True
|
||||||
|
|
||||||
|
|
||||||
|
def parse_uri(
|
||||||
|
uri: str,
|
||||||
|
default_port: Optional[int] = DEFAULT_PORT,
|
||||||
|
validate: bool = True,
|
||||||
|
warn: bool = False,
|
||||||
|
normalize: bool = True,
|
||||||
|
connect_timeout: Optional[float] = None,
|
||||||
|
srv_service_name: Optional[str] = None,
|
||||||
|
srv_max_hosts: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Parse and validate a MongoDB URI.
|
||||||
|
|
||||||
|
Returns a dict of the form::
|
||||||
|
|
||||||
|
{
|
||||||
|
'nodelist': <list of (host, port) tuples>,
|
||||||
|
'username': <username> or None,
|
||||||
|
'password': <password> or None,
|
||||||
|
'database': <database name> or None,
|
||||||
|
'collection': <collection name> or None,
|
||||||
|
'options': <dict of MongoDB URI options>,
|
||||||
|
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||||
|
}
|
||||||
|
|
||||||
|
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||||
|
to build nodelist and options.
|
||||||
|
|
||||||
|
:param uri: The MongoDB URI to parse.
|
||||||
|
:param default_port: The port number to use when one wasn't specified
|
||||||
|
for a host in the URI.
|
||||||
|
:param validate: If ``True`` (the default), validate and
|
||||||
|
normalize all options. Default: ``True``.
|
||||||
|
:param warn: When validating, if ``True`` then will warn
|
||||||
|
the user then ignore any invalid options or values. If ``False``,
|
||||||
|
validation will error when options are unsupported or values are
|
||||||
|
invalid. Default: ``False``.
|
||||||
|
:param normalize: If ``True``, convert names of URI options
|
||||||
|
to their internally-used names. Default: ``True``.
|
||||||
|
:param connect_timeout: The maximum time in milliseconds to
|
||||||
|
wait for a response from the DNS server.
|
||||||
|
:param srv_service_name: A custom SRV service name
|
||||||
|
|
||||||
|
.. versionchanged:: 4.6
|
||||||
|
The delimiting slash (``/``) between hosts and connection options is now optional.
|
||||||
|
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.0
|
||||||
|
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
||||||
|
supported.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.9
|
||||||
|
Added the ``normalize`` parameter.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.6
|
||||||
|
Added support for mongodb+srv:// URIs.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.5
|
||||||
|
Return the original value of the ``readPreference`` MongoDB URI option
|
||||||
|
instead of the validated read preference mode.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.1
|
||||||
|
``warn`` added so invalid options can be ignored.
|
||||||
|
"""
|
||||||
|
result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts)
|
||||||
|
result.update(
|
||||||
|
_parse_srv(
|
||||||
|
uri,
|
||||||
|
default_port,
|
||||||
|
validate,
|
||||||
|
warn,
|
||||||
|
normalize,
|
||||||
|
connect_timeout,
|
||||||
|
srv_service_name,
|
||||||
|
srv_max_hosts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_srv(
|
||||||
|
uri: str,
|
||||||
|
default_port: Optional[int] = DEFAULT_PORT,
|
||||||
|
validate: bool = True,
|
||||||
|
warn: bool = False,
|
||||||
|
normalize: bool = True,
|
||||||
|
connect_timeout: Optional[float] = None,
|
||||||
|
srv_service_name: Optional[str] = None,
|
||||||
|
srv_max_hosts: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if uri.startswith(SCHEME):
|
||||||
|
is_srv = False
|
||||||
|
scheme_free = uri[SCHEME_LEN:]
|
||||||
|
else:
|
||||||
|
is_srv = True
|
||||||
|
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||||
|
|
||||||
|
options = _CaseInsensitiveDictionary()
|
||||||
|
|
||||||
|
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||||
|
if "/" in host_plus_db_part:
|
||||||
|
host_part, _, _ = host_plus_db_part.partition("/")
|
||||||
|
else:
|
||||||
|
host_part = host_plus_db_part
|
||||||
|
|
||||||
|
if opts:
|
||||||
|
options.update(split_options(opts, validate, warn, normalize))
|
||||||
|
if srv_service_name is None:
|
||||||
|
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
||||||
|
if "@" in host_part:
|
||||||
|
_, _, hosts = host_part.rpartition("@")
|
||||||
|
else:
|
||||||
|
hosts = host_part
|
||||||
|
|
||||||
|
hosts = unquote_plus(hosts)
|
||||||
|
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||||
|
if is_srv:
|
||||||
|
nodes = split_hosts(hosts, default_port=None)
|
||||||
|
fqdn, port = nodes[0]
|
||||||
|
|
||||||
|
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
||||||
|
# argument overrides the same option passed in the connection string.
|
||||||
|
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
||||||
|
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
||||||
|
nodes = dns_resolver.get_hosts()
|
||||||
|
dns_options = dns_resolver.get_options()
|
||||||
|
if dns_options:
|
||||||
|
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
||||||
|
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
||||||
|
)
|
||||||
|
for opt, val in parsed_dns_options.items():
|
||||||
|
if opt not in options:
|
||||||
|
options[opt] = val
|
||||||
|
if options.get("loadBalanced") and srv_max_hosts:
|
||||||
|
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
||||||
|
if options.get("replicaSet") and srv_max_hosts:
|
||||||
|
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
||||||
|
if "tls" not in options and "ssl" not in options:
|
||||||
|
options["tls"] = True if validate else "true"
|
||||||
|
else:
|
||||||
|
nodes = split_hosts(hosts, default_port=default_port)
|
||||||
|
|
||||||
|
_check_options(nodes, options)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodelist": nodes,
|
||||||
|
"options": options,
|
||||||
|
}
|
||||||
@ -13,627 +13,32 @@
|
|||||||
# permissions and limitations under the License.
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
"""Tools to parse and validate a MongoDB URI.
|
"""Re-import of synchronous URI Parser API for compatibility."""
|
||||||
|
|
||||||
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Mapping,
|
|
||||||
MutableMapping,
|
|
||||||
Optional,
|
|
||||||
Sized,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
from urllib.parse import unquote_plus
|
|
||||||
|
|
||||||
from pymongo.client_options import _parse_ssl_options
|
|
||||||
from pymongo.common import (
|
|
||||||
INTERNAL_URI_OPTION_NAME_MAP,
|
|
||||||
SRV_SERVICE_NAME,
|
|
||||||
URI_OPTIONS_DEPRECATION_MAP,
|
|
||||||
_CaseInsensitiveDictionary,
|
|
||||||
get_validated_options,
|
|
||||||
)
|
|
||||||
from pymongo.errors import ConfigurationError, InvalidURI
|
|
||||||
from pymongo.srv_resolver import _have_dnspython, _SrvResolver
|
|
||||||
from pymongo.typings import _Address
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from pymongo.pyopenssl_context import SSLContext
|
|
||||||
|
|
||||||
SCHEME = "mongodb://"
|
|
||||||
SCHEME_LEN = len(SCHEME)
|
|
||||||
SRV_SCHEME = "mongodb+srv://"
|
|
||||||
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
|
||||||
DEFAULT_PORT = 27017
|
|
||||||
|
|
||||||
|
|
||||||
def _unquoted_percent(s: str) -> bool:
|
|
||||||
"""Check for unescaped percent signs.
|
|
||||||
|
|
||||||
:param s: A string. `s` can have things like '%25', '%2525',
|
|
||||||
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
|
||||||
"""
|
|
||||||
for i in range(len(s)):
|
|
||||||
if s[i] == "%":
|
|
||||||
sub = s[i : i + 3]
|
|
||||||
# If unquoting yields the same string this means there was an
|
|
||||||
# unquoted %.
|
|
||||||
if unquote_plus(sub) == sub:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
|
||||||
"""Validates the format of user information in a MongoDB URI.
|
|
||||||
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
|
||||||
"]", "@") as per RFC 3986 must be escaped.
|
|
||||||
|
|
||||||
Returns a 2-tuple containing the unescaped username followed
|
|
||||||
by the unescaped password.
|
|
||||||
|
|
||||||
:param userinfo: A string of the form <username>:<password>
|
|
||||||
"""
|
|
||||||
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
|
||||||
raise InvalidURI(
|
|
||||||
"Username and password must be escaped according to "
|
|
||||||
"RFC 3986, use urllib.parse.quote_plus"
|
|
||||||
)
|
|
||||||
|
|
||||||
user, _, passwd = userinfo.partition(":")
|
|
||||||
# No password is expected with GSSAPI authentication.
|
|
||||||
if not user:
|
|
||||||
raise InvalidURI("The empty string is not valid username")
|
|
||||||
|
|
||||||
return unquote_plus(user), unquote_plus(passwd)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_ipv6_literal_host(
|
|
||||||
entity: str, default_port: Optional[int]
|
|
||||||
) -> tuple[str, Optional[Union[str, int]]]:
|
|
||||||
"""Validates an IPv6 literal host:port string.
|
|
||||||
|
|
||||||
Returns a 2-tuple of IPv6 literal followed by port where
|
|
||||||
port is default_port if it wasn't specified in entity.
|
|
||||||
|
|
||||||
:param entity: A string that represents an IPv6 literal enclosed
|
|
||||||
in braces (e.g. '[::1]' or '[::1]:27017').
|
|
||||||
:param default_port: The port number to use when one wasn't
|
|
||||||
specified in entity.
|
|
||||||
"""
|
|
||||||
if entity.find("]") == -1:
|
|
||||||
raise ValueError(
|
|
||||||
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
|
||||||
)
|
|
||||||
i = entity.find("]:")
|
|
||||||
if i == -1:
|
|
||||||
return entity[1:-1], default_port
|
|
||||||
return entity[1:i], entity[i + 2 :]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
|
||||||
"""Validates a host string
|
|
||||||
|
|
||||||
Returns a 2-tuple of host followed by port where port is default_port
|
|
||||||
if it wasn't specified in the string.
|
|
||||||
|
|
||||||
:param entity: A host or host:port string where host could be a
|
|
||||||
hostname or IP address.
|
|
||||||
:param default_port: The port number to use when one wasn't
|
|
||||||
specified in entity.
|
|
||||||
"""
|
|
||||||
host = entity
|
|
||||||
port: Optional[Union[str, int]] = default_port
|
|
||||||
if entity[0] == "[":
|
|
||||||
host, port = parse_ipv6_literal_host(entity, default_port)
|
|
||||||
elif entity.endswith(".sock"):
|
|
||||||
return entity, default_port
|
|
||||||
elif entity.find(":") != -1:
|
|
||||||
if entity.count(":") > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Reserved characters such as ':' must be "
|
|
||||||
"escaped according RFC 2396. An IPv6 "
|
|
||||||
"address literal must be enclosed in '[' "
|
|
||||||
"and ']' according to RFC 2732."
|
|
||||||
)
|
|
||||||
host, port = host.split(":", 1)
|
|
||||||
if isinstance(port, str):
|
|
||||||
if not port.isdigit():
|
|
||||||
# Special case check for mistakes like "mongodb://localhost:27017 ".
|
|
||||||
if all(c.isspace() or c.isdigit() for c in port):
|
|
||||||
for c in port:
|
|
||||||
if c.isspace():
|
|
||||||
raise ValueError(f"Port contains whitespace character: {c!r}")
|
|
||||||
|
|
||||||
# A non-digit port indicates that the URI is invalid, likely because the password
|
|
||||||
# or username were not escaped.
|
|
||||||
raise ValueError(
|
|
||||||
"Port contains non-digit characters. Hint: username and password must be escaped according to "
|
|
||||||
"RFC 3986, use urllib.parse.quote_plus"
|
|
||||||
)
|
|
||||||
if int(port) > 65535 or int(port) <= 0:
|
|
||||||
raise ValueError("Port must be an integer between 0 and 65535")
|
|
||||||
port = int(port)
|
|
||||||
|
|
||||||
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
|
||||||
# https://tools.ietf.org/html/rfc4343
|
|
||||||
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
|
||||||
# "FOO.com" is in the hello response.
|
|
||||||
return host.lower(), port
|
|
||||||
|
|
||||||
|
|
||||||
# Options whose values are implicitly determined by tlsInsecure.
|
|
||||||
_IMPLICIT_TLSINSECURE_OPTS = {
|
|
||||||
"tlsallowinvalidcertificates",
|
|
||||||
"tlsallowinvalidhostnames",
|
|
||||||
"tlsdisableocspendpointcheck",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
|
||||||
"""Helper method for split_options which creates the options dict.
|
|
||||||
Also handles the creation of a list for the URI tag_sets/
|
|
||||||
readpreferencetags portion, and the use of a unicode options string.
|
|
||||||
"""
|
|
||||||
options = _CaseInsensitiveDictionary()
|
|
||||||
for uriopt in opts.split(delim):
|
|
||||||
key, value = uriopt.split("=")
|
|
||||||
if key.lower() == "readpreferencetags":
|
|
||||||
options.setdefault(key, []).append(value)
|
|
||||||
else:
|
|
||||||
if key in options:
|
|
||||||
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
|
||||||
if key.lower() == "authmechanismproperties":
|
|
||||||
val = value
|
|
||||||
else:
|
|
||||||
val = unquote_plus(value)
|
|
||||||
options[key] = val
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
|
||||||
"""Raise appropriate errors when conflicting TLS options are present in
|
|
||||||
the options dictionary.
|
|
||||||
|
|
||||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
|
||||||
MongoDB URI options.
|
|
||||||
"""
|
|
||||||
# Implicitly defined options must not be explicitly specified.
|
|
||||||
tlsinsecure = options.get("tlsinsecure")
|
|
||||||
if tlsinsecure is not None:
|
|
||||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
|
||||||
if opt in options:
|
|
||||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
|
||||||
raise InvalidURI(
|
|
||||||
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
|
||||||
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
|
||||||
if tlsallowinvalidcerts is not None:
|
|
||||||
if "tlsdisableocspendpointcheck" in options:
|
|
||||||
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
|
||||||
raise InvalidURI(
|
|
||||||
err_msg
|
|
||||||
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
|
||||||
)
|
|
||||||
if tlsallowinvalidcerts is True:
|
|
||||||
options["tlsdisableocspendpointcheck"] = True
|
|
||||||
|
|
||||||
# Handle co-occurence of CRL and OCSP-related options.
|
|
||||||
tlscrlfile = options.get("tlscrlfile")
|
|
||||||
if tlscrlfile is not None:
|
|
||||||
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
|
||||||
if options.get(opt) is True:
|
|
||||||
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
|
||||||
raise InvalidURI(err_msg % (opt,))
|
|
||||||
|
|
||||||
if "ssl" in options and "tls" in options:
|
|
||||||
|
|
||||||
def truth_value(val: Any) -> Any:
|
|
||||||
if val in ("true", "false"):
|
|
||||||
return val == "true"
|
|
||||||
if isinstance(val, bool):
|
|
||||||
return val
|
|
||||||
return val
|
|
||||||
|
|
||||||
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
|
||||||
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
|
||||||
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
|
||||||
"""Issue appropriate warnings when deprecated options are present in the
|
|
||||||
options dictionary. Removes deprecated option key, value pairs if the
|
|
||||||
options dictionary is found to also have the renamed option.
|
|
||||||
|
|
||||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
|
||||||
MongoDB URI options.
|
|
||||||
"""
|
|
||||||
for optname in list(options):
|
|
||||||
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
|
||||||
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
|
||||||
if mode == "renamed":
|
|
||||||
newoptname = message
|
|
||||||
if newoptname in options:
|
|
||||||
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
|
||||||
warnings.warn(
|
|
||||||
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
options.pop(optname)
|
|
||||||
continue
|
|
||||||
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
|
||||||
warnings.warn(
|
|
||||||
warn_msg % (options.cased_key(optname), newoptname),
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
elif mode == "removed":
|
|
||||||
warn_msg = "Option '%s' is deprecated. %s."
|
|
||||||
warnings.warn(
|
|
||||||
warn_msg % (options.cased_key(optname), message),
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
|
||||||
"""Normalizes option names in the options dictionary by converting them to
|
|
||||||
their internally-used names.
|
|
||||||
|
|
||||||
:param options: Instance of _CaseInsensitiveDictionary containing
|
|
||||||
MongoDB URI options.
|
|
||||||
"""
|
|
||||||
# Expand the tlsInsecure option.
|
|
||||||
tlsinsecure = options.get("tlsinsecure")
|
|
||||||
if tlsinsecure is not None:
|
|
||||||
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
|
||||||
# Implicit options are logically the same as tlsInsecure.
|
|
||||||
options[opt] = tlsinsecure
|
|
||||||
|
|
||||||
for optname in list(options):
|
|
||||||
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
|
||||||
if intname is not None:
|
|
||||||
options[intname] = options.pop(optname)
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
|
||||||
"""Validates and normalizes options passed in a MongoDB URI.
|
|
||||||
|
|
||||||
Returns a new dictionary of validated and normalized options. If warn is
|
|
||||||
False then errors will be thrown for invalid options, otherwise they will
|
|
||||||
be ignored and a warning will be issued.
|
|
||||||
|
|
||||||
:param opts: A dict of MongoDB URI options.
|
|
||||||
:param warn: If ``True`` then warnings will be logged and
|
|
||||||
invalid options will be ignored. Otherwise invalid options will
|
|
||||||
cause errors.
|
|
||||||
"""
|
|
||||||
return get_validated_options(opts, warn)
|
|
||||||
|
|
||||||
|
|
||||||
def split_options(
|
|
||||||
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
|
||||||
) -> MutableMapping[str, Any]:
|
|
||||||
"""Takes the options portion of a MongoDB URI, validates each option
|
|
||||||
and returns the options in a dictionary.
|
|
||||||
|
|
||||||
:param opt: A string representing MongoDB URI options.
|
|
||||||
:param validate: If ``True`` (the default), validate and normalize all
|
|
||||||
options.
|
|
||||||
:param warn: If ``False`` (the default), suppress all warnings raised
|
|
||||||
during validation of options.
|
|
||||||
:param normalize: If ``True`` (the default), renames all options to their
|
|
||||||
internally-used names.
|
|
||||||
"""
|
|
||||||
and_idx = opts.find("&")
|
|
||||||
semi_idx = opts.find(";")
|
|
||||||
try:
|
|
||||||
if and_idx >= 0 and semi_idx >= 0:
|
|
||||||
raise InvalidURI("Can not mix '&' and ';' for option separators")
|
|
||||||
elif and_idx >= 0:
|
|
||||||
options = _parse_options(opts, "&")
|
|
||||||
elif semi_idx >= 0:
|
|
||||||
options = _parse_options(opts, ";")
|
|
||||||
elif opts.find("=") != -1:
|
|
||||||
options = _parse_options(opts, None)
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidURI("MongoDB URI options are key=value pairs") from None
|
|
||||||
|
|
||||||
options = _handle_security_options(options)
|
|
||||||
|
|
||||||
options = _handle_option_deprecations(options)
|
|
||||||
|
|
||||||
if normalize:
|
|
||||||
options = _normalize_options(options)
|
|
||||||
|
|
||||||
if validate:
|
|
||||||
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
|
||||||
if options.get("authsource") == "":
|
|
||||||
raise InvalidURI("the authSource database cannot be an empty string")
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
|
||||||
"""Takes a string of the form host1[:port],host2[:port]... and
|
|
||||||
splits it into (host, port) tuples. If [:port] isn't present the
|
|
||||||
default_port is used.
|
|
||||||
|
|
||||||
Returns a set of 2-tuples containing the host name (or IP) followed by
|
|
||||||
port number.
|
|
||||||
|
|
||||||
:param hosts: A string of the form host1[:port],host2[:port],...
|
|
||||||
:param default_port: The port number to use when one wasn't specified
|
|
||||||
for a host.
|
|
||||||
"""
|
|
||||||
nodes = []
|
|
||||||
for entity in hosts.split(","):
|
|
||||||
if not entity:
|
|
||||||
raise ConfigurationError("Empty host (or extra comma in host list)")
|
|
||||||
port = default_port
|
|
||||||
# Unix socket entities don't have ports
|
|
||||||
if entity.endswith(".sock"):
|
|
||||||
port = None
|
|
||||||
nodes.append(parse_host(entity, port))
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
|
|
||||||
# Prohibited characters in database name. DB names also can't have ".", but for
|
|
||||||
# backward-compat we allow "db.collection" in URI.
|
|
||||||
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
|
||||||
|
|
||||||
_ALLOWED_TXT_OPTS = frozenset(
|
|
||||||
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
|
||||||
# Ensure directConnection was not True if there are multiple seeds.
|
|
||||||
if len(nodes) > 1 and options.get("directconnection"):
|
|
||||||
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
|
||||||
|
|
||||||
if options.get("loadbalanced"):
|
|
||||||
if len(nodes) > 1:
|
|
||||||
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
|
||||||
if options.get("directconnection"):
|
|
||||||
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
|
||||||
if options.get("replicaset"):
|
|
||||||
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_uri(
|
|
||||||
uri: str,
|
|
||||||
default_port: Optional[int] = DEFAULT_PORT,
|
|
||||||
validate: bool = True,
|
|
||||||
warn: bool = False,
|
|
||||||
normalize: bool = True,
|
|
||||||
connect_timeout: Optional[float] = None,
|
|
||||||
srv_service_name: Optional[str] = None,
|
|
||||||
srv_max_hosts: Optional[int] = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Parse and validate a MongoDB URI.
|
|
||||||
|
|
||||||
Returns a dict of the form::
|
|
||||||
|
|
||||||
{
|
|
||||||
'nodelist': <list of (host, port) tuples>,
|
|
||||||
'username': <username> or None,
|
|
||||||
'password': <password> or None,
|
|
||||||
'database': <database name> or None,
|
|
||||||
'collection': <collection name> or None,
|
|
||||||
'options': <dict of MongoDB URI options>,
|
|
||||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
|
||||||
}
|
|
||||||
|
|
||||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
|
||||||
to build nodelist and options.
|
|
||||||
|
|
||||||
:param uri: The MongoDB URI to parse.
|
|
||||||
:param default_port: The port number to use when one wasn't specified
|
|
||||||
for a host in the URI.
|
|
||||||
:param validate: If ``True`` (the default), validate and
|
|
||||||
normalize all options. Default: ``True``.
|
|
||||||
:param warn: When validating, if ``True`` then will warn
|
|
||||||
the user then ignore any invalid options or values. If ``False``,
|
|
||||||
validation will error when options are unsupported or values are
|
|
||||||
invalid. Default: ``False``.
|
|
||||||
:param normalize: If ``True``, convert names of URI options
|
|
||||||
to their internally-used names. Default: ``True``.
|
|
||||||
:param connect_timeout: The maximum time in milliseconds to
|
|
||||||
wait for a response from the DNS server.
|
|
||||||
:param srv_service_name: A custom SRV service name
|
|
||||||
|
|
||||||
.. versionchanged:: 4.6
|
|
||||||
The delimiting slash (``/``) between hosts and connection options is now optional.
|
|
||||||
For example, "mongodb://example.com?tls=true" is now a valid URI.
|
|
||||||
|
|
||||||
.. versionchanged:: 4.0
|
|
||||||
To better follow RFC 3986, unquoted percent signs ("%") are no longer
|
|
||||||
supported.
|
|
||||||
|
|
||||||
.. versionchanged:: 3.9
|
|
||||||
Added the ``normalize`` parameter.
|
|
||||||
|
|
||||||
.. versionchanged:: 3.6
|
|
||||||
Added support for mongodb+srv:// URIs.
|
|
||||||
|
|
||||||
.. versionchanged:: 3.5
|
|
||||||
Return the original value of the ``readPreference`` MongoDB URI option
|
|
||||||
instead of the validated read preference mode.
|
|
||||||
|
|
||||||
.. versionchanged:: 3.1
|
|
||||||
``warn`` added so invalid options can be ignored.
|
|
||||||
"""
|
|
||||||
if uri.startswith(SCHEME):
|
|
||||||
is_srv = False
|
|
||||||
scheme_free = uri[SCHEME_LEN:]
|
|
||||||
elif uri.startswith(SRV_SCHEME):
|
|
||||||
if not _have_dnspython():
|
|
||||||
python_path = sys.executable or "python"
|
|
||||||
raise ConfigurationError(
|
|
||||||
'The "dnspython" module must be '
|
|
||||||
"installed to use mongodb+srv:// URIs. "
|
|
||||||
"To fix this error install pymongo again:\n "
|
|
||||||
"%s -m pip install pymongo>=4.3" % (python_path)
|
|
||||||
)
|
|
||||||
is_srv = True
|
|
||||||
scheme_free = uri[SRV_SCHEME_LEN:]
|
|
||||||
else:
|
|
||||||
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
|
||||||
|
|
||||||
if not scheme_free:
|
|
||||||
raise InvalidURI("Must provide at least one hostname or IP")
|
|
||||||
|
|
||||||
user = None
|
|
||||||
passwd = None
|
|
||||||
dbase = None
|
|
||||||
collection = None
|
|
||||||
options = _CaseInsensitiveDictionary()
|
|
||||||
|
|
||||||
host_plus_db_part, _, opts = scheme_free.partition("?")
|
|
||||||
if "/" in host_plus_db_part:
|
|
||||||
host_part, _, dbase = host_plus_db_part.partition("/")
|
|
||||||
else:
|
|
||||||
host_part = host_plus_db_part
|
|
||||||
|
|
||||||
if dbase:
|
|
||||||
dbase = unquote_plus(dbase)
|
|
||||||
if "." in dbase:
|
|
||||||
dbase, collection = dbase.split(".", 1)
|
|
||||||
if _BAD_DB_CHARS.search(dbase):
|
|
||||||
raise InvalidURI('Bad database name "%s"' % dbase)
|
|
||||||
else:
|
|
||||||
dbase = None
|
|
||||||
|
|
||||||
if opts:
|
|
||||||
options.update(split_options(opts, validate, warn, normalize))
|
|
||||||
if srv_service_name is None:
|
|
||||||
srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME)
|
|
||||||
if "@" in host_part:
|
|
||||||
userinfo, _, hosts = host_part.rpartition("@")
|
|
||||||
user, passwd = parse_userinfo(userinfo)
|
|
||||||
else:
|
|
||||||
hosts = host_part
|
|
||||||
|
|
||||||
if "/" in hosts:
|
|
||||||
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
|
||||||
|
|
||||||
hosts = unquote_plus(hosts)
|
|
||||||
fqdn = None
|
|
||||||
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
|
||||||
if is_srv:
|
|
||||||
if options.get("directConnection"):
|
|
||||||
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
|
||||||
nodes = split_hosts(hosts, default_port=None)
|
|
||||||
if len(nodes) != 1:
|
|
||||||
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
|
||||||
fqdn, port = nodes[0]
|
|
||||||
if port is not None:
|
|
||||||
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
|
||||||
|
|
||||||
# Use the connection timeout. connectTimeoutMS passed as a keyword
|
|
||||||
# argument overrides the same option passed in the connection string.
|
|
||||||
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
|
|
||||||
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
|
|
||||||
nodes = dns_resolver.get_hosts()
|
|
||||||
dns_options = dns_resolver.get_options()
|
|
||||||
if dns_options:
|
|
||||||
parsed_dns_options = split_options(dns_options, validate, warn, normalize)
|
|
||||||
if set(parsed_dns_options) - _ALLOWED_TXT_OPTS:
|
|
||||||
raise ConfigurationError(
|
|
||||||
"Only authSource, replicaSet, and loadBalanced are supported from DNS"
|
|
||||||
)
|
|
||||||
for opt, val in parsed_dns_options.items():
|
|
||||||
if opt not in options:
|
|
||||||
options[opt] = val
|
|
||||||
if options.get("loadBalanced") and srv_max_hosts:
|
|
||||||
raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts")
|
|
||||||
if options.get("replicaSet") and srv_max_hosts:
|
|
||||||
raise InvalidURI("You cannot specify replicaSet with srvMaxHosts")
|
|
||||||
if "tls" not in options and "ssl" not in options:
|
|
||||||
options["tls"] = True if validate else "true"
|
|
||||||
elif not is_srv and options.get("srvServiceName") is not None:
|
|
||||||
raise ConfigurationError(
|
|
||||||
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
|
||||||
)
|
|
||||||
elif not is_srv and srv_max_hosts:
|
|
||||||
raise ConfigurationError(
|
|
||||||
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
nodes = split_hosts(hosts, default_port=default_port)
|
|
||||||
|
|
||||||
_check_options(nodes, options)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"nodelist": nodes,
|
|
||||||
"username": user,
|
|
||||||
"password": passwd,
|
|
||||||
"database": dbase,
|
|
||||||
"collection": collection,
|
|
||||||
"options": options,
|
|
||||||
"fqdn": fqdn,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
|
||||||
"""Parse KMS TLS connection options."""
|
|
||||||
if not kms_tls_options:
|
|
||||||
return {}
|
|
||||||
if not isinstance(kms_tls_options, dict):
|
|
||||||
raise TypeError("kms_tls_options must be a dict")
|
|
||||||
contexts = {}
|
|
||||||
for provider, options in kms_tls_options.items():
|
|
||||||
if not isinstance(options, dict):
|
|
||||||
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
|
||||||
options.setdefault("tls", True)
|
|
||||||
opts = _CaseInsensitiveDictionary(options)
|
|
||||||
opts = _handle_security_options(opts)
|
|
||||||
opts = _normalize_options(opts)
|
|
||||||
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
|
||||||
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
|
||||||
if ssl_context is None:
|
|
||||||
raise ConfigurationError("TLS is required for KMS providers")
|
|
||||||
if allow_invalid_hostnames:
|
|
||||||
raise ConfigurationError("Insecure TLS options prohibited")
|
|
||||||
|
|
||||||
for n in [
|
|
||||||
"tlsInsecure",
|
|
||||||
"tlsAllowInvalidCertificates",
|
|
||||||
"tlsAllowInvalidHostnames",
|
|
||||||
"tlsDisableCertificateRevocationCheck",
|
|
||||||
]:
|
|
||||||
if n in opts:
|
|
||||||
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
|
||||||
contexts[provider] = ssl_context
|
|
||||||
return contexts
|
|
||||||
|
|
||||||
|
from pymongo.errors import InvalidURI
|
||||||
|
from pymongo.synchronous.uri_parser import * # noqa: F403
|
||||||
|
from pymongo.synchronous.uri_parser import __doc__ as original_doc
|
||||||
|
from pymongo.uri_parser_shared import * # noqa: F403
|
||||||
|
|
||||||
|
__doc__ = original_doc
|
||||||
|
__all__ = [ # noqa: F405
|
||||||
|
"parse_userinfo",
|
||||||
|
"parse_ipv6_literal_host",
|
||||||
|
"parse_host",
|
||||||
|
"validate_options",
|
||||||
|
"split_options",
|
||||||
|
"split_hosts",
|
||||||
|
"parse_uri",
|
||||||
|
]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203
|
pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203
|
||||||
except InvalidURI as exc:
|
except InvalidURI as exc:
|
||||||
print(exc) # noqa: T201
|
print(exc) # noqa: T201
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|||||||
549
pymongo/uri_parser_shared.py
Normal file
549
pymongo/uri_parser_shared.py
Normal file
@ -0,0 +1,549 @@
|
|||||||
|
# Copyright 2011-present MongoDB, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||||
|
# may not use this file except in compliance with the License. You
|
||||||
|
# may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied. See the License for the specific language governing
|
||||||
|
# permissions and limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""Tools to parse and validate a MongoDB URI.
|
||||||
|
|
||||||
|
.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Mapping,
|
||||||
|
MutableMapping,
|
||||||
|
Optional,
|
||||||
|
Sized,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
from urllib.parse import unquote_plus
|
||||||
|
|
||||||
|
from pymongo.asynchronous.srv_resolver import _have_dnspython
|
||||||
|
from pymongo.client_options import _parse_ssl_options
|
||||||
|
from pymongo.common import (
|
||||||
|
INTERNAL_URI_OPTION_NAME_MAP,
|
||||||
|
URI_OPTIONS_DEPRECATION_MAP,
|
||||||
|
_CaseInsensitiveDictionary,
|
||||||
|
get_validated_options,
|
||||||
|
)
|
||||||
|
from pymongo.errors import ConfigurationError, InvalidURI
|
||||||
|
from pymongo.typings import _Address
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pymongo.pyopenssl_context import SSLContext
|
||||||
|
|
||||||
|
SCHEME = "mongodb://"
|
||||||
|
SCHEME_LEN = len(SCHEME)
|
||||||
|
SRV_SCHEME = "mongodb+srv://"
|
||||||
|
SRV_SCHEME_LEN = len(SRV_SCHEME)
|
||||||
|
DEFAULT_PORT = 27017
|
||||||
|
|
||||||
|
|
||||||
|
def _unquoted_percent(s: str) -> bool:
|
||||||
|
"""Check for unescaped percent signs.
|
||||||
|
|
||||||
|
:param s: A string. `s` can have things like '%25', '%2525',
|
||||||
|
and '%E2%85%A8' but cannot have unquoted percent like '%foo'.
|
||||||
|
"""
|
||||||
|
for i in range(len(s)):
|
||||||
|
if s[i] == "%":
|
||||||
|
sub = s[i : i + 3]
|
||||||
|
# If unquoting yields the same string this means there was an
|
||||||
|
# unquoted %.
|
||||||
|
if unquote_plus(sub) == sub:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_userinfo(userinfo: str) -> tuple[str, str]:
|
||||||
|
"""Validates the format of user information in a MongoDB URI.
|
||||||
|
Reserved characters that are gen-delimiters (":", "/", "?", "#", "[",
|
||||||
|
"]", "@") as per RFC 3986 must be escaped.
|
||||||
|
|
||||||
|
Returns a 2-tuple containing the unescaped username followed
|
||||||
|
by the unescaped password.
|
||||||
|
|
||||||
|
:param userinfo: A string of the form <username>:<password>
|
||||||
|
"""
|
||||||
|
if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo):
|
||||||
|
raise InvalidURI(
|
||||||
|
"Username and password must be escaped according to "
|
||||||
|
"RFC 3986, use urllib.parse.quote_plus"
|
||||||
|
)
|
||||||
|
|
||||||
|
user, _, passwd = userinfo.partition(":")
|
||||||
|
# No password is expected with GSSAPI authentication.
|
||||||
|
if not user:
|
||||||
|
raise InvalidURI("The empty string is not valid username")
|
||||||
|
|
||||||
|
return unquote_plus(user), unquote_plus(passwd)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ipv6_literal_host(
|
||||||
|
entity: str, default_port: Optional[int]
|
||||||
|
) -> tuple[str, Optional[Union[str, int]]]:
|
||||||
|
"""Validates an IPv6 literal host:port string.
|
||||||
|
|
||||||
|
Returns a 2-tuple of IPv6 literal followed by port where
|
||||||
|
port is default_port if it wasn't specified in entity.
|
||||||
|
|
||||||
|
:param entity: A string that represents an IPv6 literal enclosed
|
||||||
|
in braces (e.g. '[::1]' or '[::1]:27017').
|
||||||
|
:param default_port: The port number to use when one wasn't
|
||||||
|
specified in entity.
|
||||||
|
"""
|
||||||
|
if entity.find("]") == -1:
|
||||||
|
raise ValueError(
|
||||||
|
"an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732."
|
||||||
|
)
|
||||||
|
i = entity.find("]:")
|
||||||
|
if i == -1:
|
||||||
|
return entity[1:-1], default_port
|
||||||
|
return entity[1:i], entity[i + 2 :]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address:
|
||||||
|
"""Validates a host string
|
||||||
|
|
||||||
|
Returns a 2-tuple of host followed by port where port is default_port
|
||||||
|
if it wasn't specified in the string.
|
||||||
|
|
||||||
|
:param entity: A host or host:port string where host could be a
|
||||||
|
hostname or IP address.
|
||||||
|
:param default_port: The port number to use when one wasn't
|
||||||
|
specified in entity.
|
||||||
|
"""
|
||||||
|
host = entity
|
||||||
|
port: Optional[Union[str, int]] = default_port
|
||||||
|
if entity[0] == "[":
|
||||||
|
host, port = parse_ipv6_literal_host(entity, default_port)
|
||||||
|
elif entity.endswith(".sock"):
|
||||||
|
return entity, default_port
|
||||||
|
elif entity.find(":") != -1:
|
||||||
|
if entity.count(":") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Reserved characters such as ':' must be "
|
||||||
|
"escaped according RFC 2396. An IPv6 "
|
||||||
|
"address literal must be enclosed in '[' "
|
||||||
|
"and ']' according to RFC 2732."
|
||||||
|
)
|
||||||
|
host, port = host.split(":", 1)
|
||||||
|
if isinstance(port, str):
|
||||||
|
if not port.isdigit():
|
||||||
|
# Special case check for mistakes like "mongodb://localhost:27017 ".
|
||||||
|
if all(c.isspace() or c.isdigit() for c in port):
|
||||||
|
for c in port:
|
||||||
|
if c.isspace():
|
||||||
|
raise ValueError(f"Port contains whitespace character: {c!r}")
|
||||||
|
|
||||||
|
# A non-digit port indicates that the URI is invalid, likely because the password
|
||||||
|
# or username were not escaped.
|
||||||
|
raise ValueError(
|
||||||
|
"Port contains non-digit characters. Hint: username and password must be escaped according to "
|
||||||
|
"RFC 3986, use urllib.parse.quote_plus"
|
||||||
|
)
|
||||||
|
if int(port) > 65535 or int(port) <= 0:
|
||||||
|
raise ValueError("Port must be an integer between 0 and 65535")
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
# Normalize hostname to lowercase, since DNS is case-insensitive:
|
||||||
|
# https://tools.ietf.org/html/rfc4343
|
||||||
|
# This prevents useless rediscovery if "foo.com" is in the seed list but
|
||||||
|
# "FOO.com" is in the hello response.
|
||||||
|
return host.lower(), port
|
||||||
|
|
||||||
|
|
||||||
|
# Options whose values are implicitly determined by tlsInsecure.
|
||||||
|
_IMPLICIT_TLSINSECURE_OPTS = {
|
||||||
|
"tlsallowinvalidcertificates",
|
||||||
|
"tlsallowinvalidhostnames",
|
||||||
|
"tlsdisableocspendpointcheck",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary:
|
||||||
|
"""Helper method for split_options which creates the options dict.
|
||||||
|
Also handles the creation of a list for the URI tag_sets/
|
||||||
|
readpreferencetags portion, and the use of a unicode options string.
|
||||||
|
"""
|
||||||
|
options = _CaseInsensitiveDictionary()
|
||||||
|
for uriopt in opts.split(delim):
|
||||||
|
key, value = uriopt.split("=")
|
||||||
|
if key.lower() == "readpreferencetags":
|
||||||
|
options.setdefault(key, []).append(value)
|
||||||
|
else:
|
||||||
|
if key in options:
|
||||||
|
warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2)
|
||||||
|
if key.lower() == "authmechanismproperties":
|
||||||
|
val = value
|
||||||
|
else:
|
||||||
|
val = unquote_plus(value)
|
||||||
|
options[key] = val
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||||
|
"""Raise appropriate errors when conflicting TLS options are present in
|
||||||
|
the options dictionary.
|
||||||
|
|
||||||
|
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||||
|
MongoDB URI options.
|
||||||
|
"""
|
||||||
|
# Implicitly defined options must not be explicitly specified.
|
||||||
|
tlsinsecure = options.get("tlsinsecure")
|
||||||
|
if tlsinsecure is not None:
|
||||||
|
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||||
|
if opt in options:
|
||||||
|
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||||
|
raise InvalidURI(
|
||||||
|
err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle co-occurence of OCSP & tlsAllowInvalidCertificates options.
|
||||||
|
tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates")
|
||||||
|
if tlsallowinvalidcerts is not None:
|
||||||
|
if "tlsdisableocspendpointcheck" in options:
|
||||||
|
err_msg = "URI options %s and %s cannot be specified simultaneously."
|
||||||
|
raise InvalidURI(
|
||||||
|
err_msg
|
||||||
|
% ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck"))
|
||||||
|
)
|
||||||
|
if tlsallowinvalidcerts is True:
|
||||||
|
options["tlsdisableocspendpointcheck"] = True
|
||||||
|
|
||||||
|
# Handle co-occurence of CRL and OCSP-related options.
|
||||||
|
tlscrlfile = options.get("tlscrlfile")
|
||||||
|
if tlscrlfile is not None:
|
||||||
|
for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"):
|
||||||
|
if options.get(opt) is True:
|
||||||
|
err_msg = "URI option %s=True cannot be specified when CRL checking is enabled."
|
||||||
|
raise InvalidURI(err_msg % (opt,))
|
||||||
|
|
||||||
|
if "ssl" in options and "tls" in options:
|
||||||
|
|
||||||
|
def truth_value(val: Any) -> Any:
|
||||||
|
if val in ("true", "false"):
|
||||||
|
return val == "true"
|
||||||
|
if isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
return val
|
||||||
|
|
||||||
|
if truth_value(options.get("ssl")) != truth_value(options.get("tls")):
|
||||||
|
err_msg = "Can not specify conflicting values for URI options %s and %s."
|
||||||
|
raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls")))
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||||
|
"""Issue appropriate warnings when deprecated options are present in the
|
||||||
|
options dictionary. Removes deprecated option key, value pairs if the
|
||||||
|
options dictionary is found to also have the renamed option.
|
||||||
|
|
||||||
|
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||||
|
MongoDB URI options.
|
||||||
|
"""
|
||||||
|
for optname in list(options):
|
||||||
|
if optname in URI_OPTIONS_DEPRECATION_MAP:
|
||||||
|
mode, message = URI_OPTIONS_DEPRECATION_MAP[optname]
|
||||||
|
if mode == "renamed":
|
||||||
|
newoptname = message
|
||||||
|
if newoptname in options:
|
||||||
|
warn_msg = "Deprecated option '%s' ignored in favor of '%s'."
|
||||||
|
warnings.warn(
|
||||||
|
warn_msg % (options.cased_key(optname), options.cased_key(newoptname)),
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
options.pop(optname)
|
||||||
|
continue
|
||||||
|
warn_msg = "Option '%s' is deprecated, use '%s' instead."
|
||||||
|
warnings.warn(
|
||||||
|
warn_msg % (options.cased_key(optname), newoptname),
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
elif mode == "removed":
|
||||||
|
warn_msg = "Option '%s' is deprecated. %s."
|
||||||
|
warnings.warn(
|
||||||
|
warn_msg % (options.cased_key(optname), message),
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary:
|
||||||
|
"""Normalizes option names in the options dictionary by converting them to
|
||||||
|
their internally-used names.
|
||||||
|
|
||||||
|
:param options: Instance of _CaseInsensitiveDictionary containing
|
||||||
|
MongoDB URI options.
|
||||||
|
"""
|
||||||
|
# Expand the tlsInsecure option.
|
||||||
|
tlsinsecure = options.get("tlsinsecure")
|
||||||
|
if tlsinsecure is not None:
|
||||||
|
for opt in _IMPLICIT_TLSINSECURE_OPTS:
|
||||||
|
# Implicit options are logically the same as tlsInsecure.
|
||||||
|
options[opt] = tlsinsecure
|
||||||
|
|
||||||
|
for optname in list(options):
|
||||||
|
intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None)
|
||||||
|
if intname is not None:
|
||||||
|
options[intname] = options.pop(optname)
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]:
|
||||||
|
"""Validates and normalizes options passed in a MongoDB URI.
|
||||||
|
|
||||||
|
Returns a new dictionary of validated and normalized options. If warn is
|
||||||
|
False then errors will be thrown for invalid options, otherwise they will
|
||||||
|
be ignored and a warning will be issued.
|
||||||
|
|
||||||
|
:param opts: A dict of MongoDB URI options.
|
||||||
|
:param warn: If ``True`` then warnings will be logged and
|
||||||
|
invalid options will be ignored. Otherwise invalid options will
|
||||||
|
cause errors.
|
||||||
|
"""
|
||||||
|
return get_validated_options(opts, warn)
|
||||||
|
|
||||||
|
|
||||||
|
def split_options(
|
||||||
|
opts: str, validate: bool = True, warn: bool = False, normalize: bool = True
|
||||||
|
) -> MutableMapping[str, Any]:
|
||||||
|
"""Takes the options portion of a MongoDB URI, validates each option
|
||||||
|
and returns the options in a dictionary.
|
||||||
|
|
||||||
|
:param opt: A string representing MongoDB URI options.
|
||||||
|
:param validate: If ``True`` (the default), validate and normalize all
|
||||||
|
options.
|
||||||
|
:param warn: If ``False`` (the default), suppress all warnings raised
|
||||||
|
during validation of options.
|
||||||
|
:param normalize: If ``True`` (the default), renames all options to their
|
||||||
|
internally-used names.
|
||||||
|
"""
|
||||||
|
and_idx = opts.find("&")
|
||||||
|
semi_idx = opts.find(";")
|
||||||
|
try:
|
||||||
|
if and_idx >= 0 and semi_idx >= 0:
|
||||||
|
raise InvalidURI("Can not mix '&' and ';' for option separators")
|
||||||
|
elif and_idx >= 0:
|
||||||
|
options = _parse_options(opts, "&")
|
||||||
|
elif semi_idx >= 0:
|
||||||
|
options = _parse_options(opts, ";")
|
||||||
|
elif opts.find("=") != -1:
|
||||||
|
options = _parse_options(opts, None)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidURI("MongoDB URI options are key=value pairs") from None
|
||||||
|
|
||||||
|
options = _handle_security_options(options)
|
||||||
|
|
||||||
|
options = _handle_option_deprecations(options)
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
options = _normalize_options(options)
|
||||||
|
|
||||||
|
if validate:
|
||||||
|
options = cast(_CaseInsensitiveDictionary, validate_options(options, warn))
|
||||||
|
if options.get("authsource") == "":
|
||||||
|
raise InvalidURI("the authSource database cannot be an empty string")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]:
|
||||||
|
"""Takes a string of the form host1[:port],host2[:port]... and
|
||||||
|
splits it into (host, port) tuples. If [:port] isn't present the
|
||||||
|
default_port is used.
|
||||||
|
|
||||||
|
Returns a set of 2-tuples containing the host name (or IP) followed by
|
||||||
|
port number.
|
||||||
|
|
||||||
|
:param hosts: A string of the form host1[:port],host2[:port],...
|
||||||
|
:param default_port: The port number to use when one wasn't specified
|
||||||
|
for a host.
|
||||||
|
"""
|
||||||
|
nodes = []
|
||||||
|
for entity in hosts.split(","):
|
||||||
|
if not entity:
|
||||||
|
raise ConfigurationError("Empty host (or extra comma in host list)")
|
||||||
|
port = default_port
|
||||||
|
# Unix socket entities don't have ports
|
||||||
|
if entity.endswith(".sock"):
|
||||||
|
port = None
|
||||||
|
nodes.append(parse_host(entity, port))
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
|
||||||
|
# Prohibited characters in database name. DB names also can't have ".", but for
|
||||||
|
# backward-compat we allow "db.collection" in URI.
|
||||||
|
_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]")
|
||||||
|
|
||||||
|
_ALLOWED_TXT_OPTS = frozenset(
|
||||||
|
["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None:
|
||||||
|
# Ensure directConnection was not True if there are multiple seeds.
|
||||||
|
if len(nodes) > 1 and options.get("directconnection"):
|
||||||
|
raise ConfigurationError("Cannot specify multiple hosts with directConnection=true")
|
||||||
|
|
||||||
|
if options.get("loadbalanced"):
|
||||||
|
if len(nodes) > 1:
|
||||||
|
raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true")
|
||||||
|
if options.get("directconnection"):
|
||||||
|
raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true")
|
||||||
|
if options.get("replicaset"):
|
||||||
|
raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]:
|
||||||
|
"""Parse KMS TLS connection options."""
|
||||||
|
if not kms_tls_options:
|
||||||
|
return {}
|
||||||
|
if not isinstance(kms_tls_options, dict):
|
||||||
|
raise TypeError("kms_tls_options must be a dict")
|
||||||
|
contexts = {}
|
||||||
|
for provider, options in kms_tls_options.items():
|
||||||
|
if not isinstance(options, dict):
|
||||||
|
raise TypeError(f'kms_tls_options["{provider}"] must be a dict')
|
||||||
|
options.setdefault("tls", True)
|
||||||
|
opts = _CaseInsensitiveDictionary(options)
|
||||||
|
opts = _handle_security_options(opts)
|
||||||
|
opts = _normalize_options(opts)
|
||||||
|
opts = cast(_CaseInsensitiveDictionary, validate_options(opts))
|
||||||
|
ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts)
|
||||||
|
if ssl_context is None:
|
||||||
|
raise ConfigurationError("TLS is required for KMS providers")
|
||||||
|
if allow_invalid_hostnames:
|
||||||
|
raise ConfigurationError("Insecure TLS options prohibited")
|
||||||
|
|
||||||
|
for n in [
|
||||||
|
"tlsInsecure",
|
||||||
|
"tlsAllowInvalidCertificates",
|
||||||
|
"tlsAllowInvalidHostnames",
|
||||||
|
"tlsDisableCertificateRevocationCheck",
|
||||||
|
]:
|
||||||
|
if n in opts:
|
||||||
|
raise ConfigurationError(f"Insecure TLS options prohibited: {n}")
|
||||||
|
contexts[provider] = ssl_context
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_uri(
|
||||||
|
uri: str,
|
||||||
|
default_port: Optional[int] = DEFAULT_PORT,
|
||||||
|
validate: bool = True,
|
||||||
|
warn: bool = False,
|
||||||
|
normalize: bool = True,
|
||||||
|
srv_max_hosts: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if uri.startswith(SCHEME):
|
||||||
|
is_srv = False
|
||||||
|
scheme_free = uri[SCHEME_LEN:]
|
||||||
|
elif uri.startswith(SRV_SCHEME):
|
||||||
|
if not _have_dnspython():
|
||||||
|
python_path = sys.executable or "python"
|
||||||
|
raise ConfigurationError(
|
||||||
|
'The "dnspython" module must be '
|
||||||
|
"installed to use mongodb+srv:// URIs. "
|
||||||
|
"To fix this error install pymongo again:\n "
|
||||||
|
"%s -m pip install pymongo>=4.3" % (python_path)
|
||||||
|
)
|
||||||
|
is_srv = True
|
||||||
|
scheme_free = uri[SRV_SCHEME_LEN:]
|
||||||
|
else:
|
||||||
|
raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'")
|
||||||
|
|
||||||
|
if not scheme_free:
|
||||||
|
raise InvalidURI("Must provide at least one hostname or IP")
|
||||||
|
|
||||||
|
user = None
|
||||||
|
passwd = None
|
||||||
|
dbase = None
|
||||||
|
collection = None
|
||||||
|
options = _CaseInsensitiveDictionary()
|
||||||
|
|
||||||
|
host_plus_db_part, _, opts = scheme_free.partition("?")
|
||||||
|
if "/" in host_plus_db_part:
|
||||||
|
host_part, _, dbase = host_plus_db_part.partition("/")
|
||||||
|
else:
|
||||||
|
host_part = host_plus_db_part
|
||||||
|
|
||||||
|
if dbase:
|
||||||
|
dbase = unquote_plus(dbase)
|
||||||
|
if "." in dbase:
|
||||||
|
dbase, collection = dbase.split(".", 1)
|
||||||
|
if _BAD_DB_CHARS.search(dbase):
|
||||||
|
raise InvalidURI('Bad database name "%s"' % dbase)
|
||||||
|
else:
|
||||||
|
dbase = None
|
||||||
|
|
||||||
|
if opts:
|
||||||
|
options.update(split_options(opts, validate, warn, normalize))
|
||||||
|
if "@" in host_part:
|
||||||
|
userinfo, _, hosts = host_part.rpartition("@")
|
||||||
|
user, passwd = parse_userinfo(userinfo)
|
||||||
|
else:
|
||||||
|
hosts = host_part
|
||||||
|
|
||||||
|
if "/" in hosts:
|
||||||
|
raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part)
|
||||||
|
|
||||||
|
hosts = unquote_plus(hosts)
|
||||||
|
fqdn = None
|
||||||
|
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
|
||||||
|
if is_srv:
|
||||||
|
if options.get("directConnection"):
|
||||||
|
raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs")
|
||||||
|
nodes = split_hosts(hosts, default_port=None)
|
||||||
|
if len(nodes) != 1:
|
||||||
|
raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname")
|
||||||
|
fqdn, port = nodes[0]
|
||||||
|
if port is not None:
|
||||||
|
raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number")
|
||||||
|
elif not is_srv and options.get("srvServiceName") is not None:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"The srvServiceName option is only allowed with 'mongodb+srv://' URIs"
|
||||||
|
)
|
||||||
|
elif not is_srv and srv_max_hosts:
|
||||||
|
raise ConfigurationError(
|
||||||
|
"The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
nodes = split_hosts(hosts, default_port=default_port)
|
||||||
|
|
||||||
|
_check_options(nodes, options)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodelist": nodes,
|
||||||
|
"username": user,
|
||||||
|
"password": passwd,
|
||||||
|
"database": dbase,
|
||||||
|
"collection": collection,
|
||||||
|
"options": options,
|
||||||
|
"fqdn": fqdn,
|
||||||
|
}
|
||||||
@ -32,7 +32,7 @@ import unittest
|
|||||||
import warnings
|
import warnings
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
|
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
|||||||
@ -32,7 +32,7 @@ import unittest
|
|||||||
import warnings
|
import warnings
|
||||||
from asyncio import iscoroutinefunction
|
from asyncio import iscoroutinefunction
|
||||||
|
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.asynchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ipaddress
|
import ipaddress
|
||||||
@ -1027,7 +1027,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
|
|||||||
auth_mech = kwargs.get("authMechanism", "")
|
auth_mech = kwargs.get("authMechanism", "")
|
||||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||||
# Only add the default username or password if one is not provided.
|
# Only add the default username or password if one is not provided.
|
||||||
res = parse_uri(uri)
|
res = await parse_uri(uri)
|
||||||
if (
|
if (
|
||||||
not res["username"]
|
not res["username"]
|
||||||
and not res["password"]
|
and not res["password"]
|
||||||
@ -1058,7 +1058,7 @@ class AsyncPyMongoTestCase(unittest.TestCase):
|
|||||||
auth_mech = kwargs.get("authMechanism", "")
|
auth_mech = kwargs.get("authMechanism", "")
|
||||||
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC":
|
||||||
# Only add the default username or password if one is not provided.
|
# Only add the default username or password if one is not provided.
|
||||||
res = parse_uri(uri)
|
res = await parse_uri(uri)
|
||||||
if (
|
if (
|
||||||
not res["username"]
|
not res["username"]
|
||||||
and not res["password"]
|
and not res["password"]
|
||||||
|
|||||||
@ -47,7 +47,7 @@ from bson.son import SON
|
|||||||
from pymongo import common, message
|
from pymongo import common, message
|
||||||
from pymongo.read_preferences import ReadPreference
|
from pymongo.read_preferences import ReadPreference
|
||||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
if HAVE_SSL:
|
if HAVE_SSL:
|
||||||
import ssl
|
import ssl
|
||||||
|
|||||||
@ -512,13 +512,13 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
|||||||
|
|
||||||
async def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
async def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
||||||
# Patch the resolver.
|
# Patch the resolver.
|
||||||
from pymongo.srv_resolver import _resolve
|
from pymongo.asynchronous.srv_resolver import _resolve
|
||||||
|
|
||||||
patched_resolver = FunctionCallRecorder(_resolve)
|
patched_resolver = FunctionCallRecorder(_resolve)
|
||||||
pymongo.srv_resolver._resolve = patched_resolver
|
pymongo.asynchronous.srv_resolver._resolve = patched_resolver
|
||||||
|
|
||||||
def reset_resolver():
|
def reset_resolver():
|
||||||
pymongo.srv_resolver._resolve = _resolve
|
pymongo.asynchronous.srv_resolver._resolve = _resolve
|
||||||
|
|
||||||
self.addCleanup(reset_resolver)
|
self.addCleanup(reset_resolver)
|
||||||
|
|
||||||
@ -607,7 +607,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
|||||||
with self.assertRaisesRegex(ConfigurationError, expected):
|
with self.assertRaisesRegex(ConfigurationError, expected):
|
||||||
AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
||||||
|
|
||||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
|
||||||
def test_detected_environment_logging(self, mock_get_hosts):
|
def test_detected_environment_logging(self, mock_get_hosts):
|
||||||
normal_hosts = [
|
normal_hosts = [
|
||||||
"normal.host.com",
|
"normal.host.com",
|
||||||
@ -629,7 +629,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
|||||||
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
||||||
self.assertEqual(len(logs), 7)
|
self.assertEqual(len(logs), 7)
|
||||||
|
|
||||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
@patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts")
|
||||||
async def test_detected_environment_warning(self, mock_get_hosts):
|
async def test_detected_environment_warning(self, mock_get_hosts):
|
||||||
with self._caplog.at_level(logging.WARN):
|
with self._caplog.at_level(logging.WARN):
|
||||||
normal_hosts = [
|
normal_hosts = [
|
||||||
@ -933,6 +933,15 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
async with eval(the_repr) as client_two:
|
async with eval(the_repr) as client_two:
|
||||||
self.assertEqual(client_two, client)
|
self.assertEqual(client_two, client)
|
||||||
|
|
||||||
|
async def test_repr_srv_host(self):
|
||||||
|
client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
|
||||||
|
# before srv resolution
|
||||||
|
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
|
||||||
|
await client.aconnect()
|
||||||
|
# after srv resolution
|
||||||
|
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
|
||||||
|
await client.close()
|
||||||
|
|
||||||
async def test_getters(self):
|
async def test_getters(self):
|
||||||
await async_wait_until(
|
await async_wait_until(
|
||||||
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
lambda: async_client_context.nodes == self.client.nodes, "find all nodes"
|
||||||
@ -1911,28 +1920,37 @@ class TestClient(AsyncIntegrationTest):
|
|||||||
srvServiceName="customname",
|
srvServiceName="customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
await client.aconnect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
await client.close()
|
||||||
client = AsyncMongoClient(
|
client = AsyncMongoClient(
|
||||||
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
||||||
"/?srvServiceName=shouldbeoverriden",
|
"/?srvServiceName=shouldbeoverriden",
|
||||||
srvServiceName="customname",
|
srvServiceName="customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
await client.aconnect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
await client.close()
|
||||||
client = AsyncMongoClient(
|
client = AsyncMongoClient(
|
||||||
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
await client.aconnect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
await client.close()
|
||||||
|
|
||||||
async def test_srv_max_hosts_kwarg(self):
|
async def test_srv_max_hosts_kwarg(self):
|
||||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
||||||
|
await client.aconnect()
|
||||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
||||||
|
await client.aconnect()
|
||||||
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
||||||
client = self.simple_client(
|
client = self.simple_client(
|
||||||
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
||||||
)
|
)
|
||||||
|
await client.aconnect()
|
||||||
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|||||||
@ -54,6 +54,7 @@ from bson import Timestamp, json_util
|
|||||||
from pymongo import common, monitoring
|
from pymongo import common, monitoring
|
||||||
from pymongo.asynchronous.settings import TopologySettings
|
from pymongo.asynchronous.settings import TopologySettings
|
||||||
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
from pymongo.asynchronous.topology import Topology, _ErrorContext
|
||||||
|
from pymongo.asynchronous.uri_parser import parse_uri
|
||||||
from pymongo.errors import (
|
from pymongo.errors import (
|
||||||
AutoReconnect,
|
AutoReconnect,
|
||||||
ConfigurationError,
|
ConfigurationError,
|
||||||
@ -66,7 +67,6 @@ from pymongo.helpers_shared import _check_command_response, _check_write_command
|
|||||||
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
|
from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent
|
||||||
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
||||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||||
from pymongo.uri_parser import parse_uri
|
|
||||||
|
|
||||||
_IS_SYNC = False
|
_IS_SYNC = False
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
async def create_mock_topology(uri, monitor_class=DummyMonitor):
|
async def create_mock_topology(uri, monitor_class=DummyMonitor):
|
||||||
parsed_uri = parse_uri(uri)
|
parsed_uri = await parse_uri(uri)
|
||||||
replica_set_name = None
|
replica_set_name = None
|
||||||
direct_connection = None
|
direct_connection = None
|
||||||
load_balanced = None
|
load_balanced = None
|
||||||
|
|||||||
@ -31,9 +31,10 @@ from test.asynchronous import (
|
|||||||
)
|
)
|
||||||
from test.utils_shared import async_wait_until
|
from test.utils_shared import async_wait_until
|
||||||
|
|
||||||
|
from pymongo.asynchronous.uri_parser import parse_uri
|
||||||
from pymongo.common import validate_read_preference_tags
|
from pymongo.common import validate_read_preference_tags
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.uri_parser import parse_uri, split_hosts
|
from pymongo.uri_parser_shared import split_hosts
|
||||||
|
|
||||||
_IS_SYNC = False
|
_IS_SYNC = False
|
||||||
|
|
||||||
@ -109,7 +110,7 @@ def create_test(test_case):
|
|||||||
hosts = frozenset(split_hosts(",".join(hosts)))
|
hosts = frozenset(split_hosts(",".join(hosts)))
|
||||||
|
|
||||||
if seeds or num_seeds:
|
if seeds or num_seeds:
|
||||||
result = parse_uri(uri, validate=True)
|
result = await parse_uri(uri, validate=True)
|
||||||
if seeds is not None:
|
if seeds is not None:
|
||||||
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
|
self.assertEqual(sorted(result["nodelist"]), sorted(seeds))
|
||||||
if num_seeds is not None:
|
if num_seeds is not None:
|
||||||
@ -161,7 +162,7 @@ def create_test(test_case):
|
|||||||
# and re-run these assertions.
|
# and re-run these assertions.
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
parse_uri(uri)
|
await parse_uri(uri)
|
||||||
except (ConfigurationError, ValueError):
|
except (ConfigurationError, ValueError):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -185,35 +186,24 @@ create_tests(TestDNSSharded)
|
|||||||
|
|
||||||
class TestParsingErrors(AsyncPyMongoTestCase):
|
class TestParsingErrors(AsyncPyMongoTestCase):
|
||||||
async def test_invalid_host(self):
|
async def test_invalid_host(self):
|
||||||
self.assertRaisesRegex(
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||||
ConfigurationError,
|
client = self.simple_client("mongodb+srv://mongodb")
|
||||||
"Invalid URI host: mongodb is not",
|
await client.aconnect()
|
||||||
self.simple_client,
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||||
"mongodb+srv://mongodb",
|
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||||
)
|
await client.aconnect()
|
||||||
self.assertRaisesRegex(
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||||
ConfigurationError,
|
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||||
"Invalid URI host: mongodb.com is not",
|
await client.aconnect()
|
||||||
self.simple_client,
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||||
"mongodb+srv://mongodb.com",
|
client = self.simple_client("mongodb+srv://[::1]")
|
||||||
)
|
await client.aconnect()
|
||||||
self.assertRaisesRegex(
|
|
||||||
ConfigurationError,
|
|
||||||
"Invalid URI host: an IP address is not",
|
|
||||||
self.simple_client,
|
|
||||||
"mongodb+srv://127.0.0.1",
|
|
||||||
)
|
|
||||||
self.assertRaisesRegex(
|
|
||||||
ConfigurationError,
|
|
||||||
"Invalid URI host: an IP address is not",
|
|
||||||
self.simple_client,
|
|
||||||
"mongodb+srv://[::1]",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
|
class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest):
|
||||||
async def test_connect_case_insensitive(self):
|
async def test_connect_case_insensitive(self):
|
||||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||||
|
await client.aconnect()
|
||||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,8 @@ from test.asynchronous.utils import async_wait_until
|
|||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
from pymongo import common
|
from pymongo import common
|
||||||
|
from pymongo.asynchronous.srv_resolver import _have_dnspython
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.srv_resolver import _have_dnspython
|
|
||||||
|
|
||||||
_IS_SYNC = False
|
_IS_SYNC = False
|
||||||
|
|
||||||
@ -54,14 +54,16 @@ class SrvPollingKnobs:
|
|||||||
|
|
||||||
def enable(self):
|
def enable(self):
|
||||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||||
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
self.old_dns_resolver_response = (
|
||||||
|
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||||
|
)
|
||||||
|
|
||||||
if self.min_srv_rescan_interval is not None:
|
if self.min_srv_rescan_interval is not None:
|
||||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||||
|
|
||||||
def mock_get_hosts_and_min_ttl(resolver, *args):
|
async def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||||
assert self.old_dns_resolver_response is not None
|
assert self.old_dns_resolver_response is not None
|
||||||
nodes, ttl = self.old_dns_resolver_response(resolver)
|
nodes, ttl = await self.old_dns_resolver_response(resolver)
|
||||||
if self.nodelist_callback is not None:
|
if self.nodelist_callback is not None:
|
||||||
nodes = self.nodelist_callback()
|
nodes = self.nodelist_callback()
|
||||||
if self.ttl_time is not None:
|
if self.ttl_time is not None:
|
||||||
@ -74,14 +76,14 @@ class SrvPollingKnobs:
|
|||||||
else:
|
else:
|
||||||
patch_func = mock_get_hosts_and_min_ttl
|
patch_func = mock_get_hosts_and_min_ttl
|
||||||
|
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.enable()
|
self.enable()
|
||||||
|
|
||||||
def disable(self):
|
def disable(self):
|
||||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||||
self.old_dns_resolver_response
|
self.old_dns_resolver_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,7 +136,10 @@ class TestSrvPolling(AsyncPyMongoTestCase):
|
|||||||
|
|
||||||
def predicate():
|
def predicate():
|
||||||
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
||||||
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
|
return (
|
||||||
|
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
|
||||||
|
>= 1
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
||||||
@ -144,7 +149,7 @@ class TestSrvPolling(AsyncPyMongoTestCase):
|
|||||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||||
1,
|
1,
|
||||||
"resolver was never called",
|
"resolver was never called",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ except ImportError:
|
|||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.errors import OperationFailure
|
from pymongo.errors import OperationFailure
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
pytestmark = pytest.mark.auth_aws
|
pytestmark = pytest.mark.auth_aws
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from pymongo.synchronous.auth_oidc import (
|
|||||||
OIDCCallbackResult,
|
OIDCCallbackResult,
|
||||||
_get_authenticator,
|
_get_authenticator,
|
||||||
)
|
)
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
ROOT = Path(__file__).parent.parent.resolve()
|
ROOT = Path(__file__).parent.parent.resolve()
|
||||||
TEST_PATH = ROOT / "auth" / "unified"
|
TEST_PATH = ROOT / "auth" / "unified"
|
||||||
|
|||||||
@ -47,7 +47,7 @@ from bson.son import SON
|
|||||||
from pymongo import common, message
|
from pymongo import common, message
|
||||||
from pymongo.read_preferences import ReadPreference
|
from pymongo.read_preferences import ReadPreference
|
||||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
if HAVE_SSL:
|
if HAVE_SSL:
|
||||||
import ssl
|
import ssl
|
||||||
|
|||||||
@ -505,13 +505,13 @@ class ClientUnitTest(UnitTest):
|
|||||||
|
|
||||||
def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
def test_connection_timeout_ms_propagates_to_DNS_resolver(self):
|
||||||
# Patch the resolver.
|
# Patch the resolver.
|
||||||
from pymongo.srv_resolver import _resolve
|
from pymongo.synchronous.srv_resolver import _resolve
|
||||||
|
|
||||||
patched_resolver = FunctionCallRecorder(_resolve)
|
patched_resolver = FunctionCallRecorder(_resolve)
|
||||||
pymongo.srv_resolver._resolve = patched_resolver
|
pymongo.synchronous.srv_resolver._resolve = patched_resolver
|
||||||
|
|
||||||
def reset_resolver():
|
def reset_resolver():
|
||||||
pymongo.srv_resolver._resolve = _resolve
|
pymongo.synchronous.srv_resolver._resolve = _resolve
|
||||||
|
|
||||||
self.addCleanup(reset_resolver)
|
self.addCleanup(reset_resolver)
|
||||||
|
|
||||||
@ -600,7 +600,7 @@ class ClientUnitTest(UnitTest):
|
|||||||
with self.assertRaisesRegex(ConfigurationError, expected):
|
with self.assertRaisesRegex(ConfigurationError, expected):
|
||||||
MongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
MongoClient(**{typo: "standard"}) # type: ignore[arg-type]
|
||||||
|
|
||||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
|
||||||
def test_detected_environment_logging(self, mock_get_hosts):
|
def test_detected_environment_logging(self, mock_get_hosts):
|
||||||
normal_hosts = [
|
normal_hosts = [
|
||||||
"normal.host.com",
|
"normal.host.com",
|
||||||
@ -622,7 +622,7 @@ class ClientUnitTest(UnitTest):
|
|||||||
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"]
|
||||||
self.assertEqual(len(logs), 7)
|
self.assertEqual(len(logs), 7)
|
||||||
|
|
||||||
@patch("pymongo.srv_resolver._SrvResolver.get_hosts")
|
@patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts")
|
||||||
def test_detected_environment_warning(self, mock_get_hosts):
|
def test_detected_environment_warning(self, mock_get_hosts):
|
||||||
with self._caplog.at_level(logging.WARN):
|
with self._caplog.at_level(logging.WARN):
|
||||||
normal_hosts = [
|
normal_hosts = [
|
||||||
@ -908,6 +908,15 @@ class TestClient(IntegrationTest):
|
|||||||
with eval(the_repr) as client_two:
|
with eval(the_repr) as client_two:
|
||||||
self.assertEqual(client_two, client)
|
self.assertEqual(client_two, client)
|
||||||
|
|
||||||
|
def test_repr_srv_host(self):
|
||||||
|
client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False)
|
||||||
|
# before srv resolution
|
||||||
|
self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client))
|
||||||
|
client._connect()
|
||||||
|
# after srv resolution
|
||||||
|
self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client))
|
||||||
|
client.close()
|
||||||
|
|
||||||
def test_getters(self):
|
def test_getters(self):
|
||||||
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes")
|
||||||
|
|
||||||
@ -1868,28 +1877,37 @@ class TestClient(IntegrationTest):
|
|||||||
srvServiceName="customname",
|
srvServiceName="customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
client._connect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
client.close()
|
||||||
client = MongoClient(
|
client = MongoClient(
|
||||||
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
"mongodb+srv://user:password@test22.test.build.10gen.cc"
|
||||||
"/?srvServiceName=shouldbeoverriden",
|
"/?srvServiceName=shouldbeoverriden",
|
||||||
srvServiceName="customname",
|
srvServiceName="customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
client._connect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
client.close()
|
||||||
client = MongoClient(
|
client = MongoClient(
|
||||||
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
"mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname",
|
||||||
connect=False,
|
connect=False,
|
||||||
)
|
)
|
||||||
|
client._connect()
|
||||||
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
self.assertEqual(client._topology_settings.srv_service_name, "customname")
|
||||||
|
client.close()
|
||||||
|
|
||||||
def test_srv_max_hosts_kwarg(self):
|
def test_srv_max_hosts_kwarg(self):
|
||||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/")
|
||||||
|
client._connect()
|
||||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||||
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1)
|
||||||
|
client._connect()
|
||||||
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
self.assertEqual(len(client.topology_description.server_descriptions()), 1)
|
||||||
client = self.simple_client(
|
client = self.simple_client(
|
||||||
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2
|
||||||
)
|
)
|
||||||
|
client._connect()
|
||||||
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
|
|||||||
@ -65,8 +65,8 @@ from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStarte
|
|||||||
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
from pymongo.server_description import SERVER_TYPE, ServerDescription
|
||||||
from pymongo.synchronous.settings import TopologySettings
|
from pymongo.synchronous.settings import TopologySettings
|
||||||
from pymongo.synchronous.topology import Topology, _ErrorContext
|
from pymongo.synchronous.topology import Topology, _ErrorContext
|
||||||
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||||
from pymongo.uri_parser import parse_uri
|
|
||||||
|
|
||||||
_IS_SYNC = True
|
_IS_SYNC = True
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,8 @@ from test.utils_shared import wait_until
|
|||||||
|
|
||||||
from pymongo.common import validate_read_preference_tags
|
from pymongo.common import validate_read_preference_tags
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.uri_parser import parse_uri, split_hosts
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
from pymongo.uri_parser_shared import split_hosts
|
||||||
|
|
||||||
_IS_SYNC = True
|
_IS_SYNC = True
|
||||||
|
|
||||||
@ -183,35 +184,24 @@ create_tests(TestDNSSharded)
|
|||||||
|
|
||||||
class TestParsingErrors(PyMongoTestCase):
|
class TestParsingErrors(PyMongoTestCase):
|
||||||
def test_invalid_host(self):
|
def test_invalid_host(self):
|
||||||
self.assertRaisesRegex(
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"):
|
||||||
ConfigurationError,
|
client = self.simple_client("mongodb+srv://mongodb")
|
||||||
"Invalid URI host: mongodb is not",
|
client._connect()
|
||||||
self.simple_client,
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"):
|
||||||
"mongodb+srv://mongodb",
|
client = self.simple_client("mongodb+srv://mongodb.com")
|
||||||
)
|
client._connect()
|
||||||
self.assertRaisesRegex(
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||||
ConfigurationError,
|
client = self.simple_client("mongodb+srv://127.0.0.1")
|
||||||
"Invalid URI host: mongodb.com is not",
|
client._connect()
|
||||||
self.simple_client,
|
with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"):
|
||||||
"mongodb+srv://mongodb.com",
|
client = self.simple_client("mongodb+srv://[::1]")
|
||||||
)
|
client._connect()
|
||||||
self.assertRaisesRegex(
|
|
||||||
ConfigurationError,
|
|
||||||
"Invalid URI host: an IP address is not",
|
|
||||||
self.simple_client,
|
|
||||||
"mongodb+srv://127.0.0.1",
|
|
||||||
)
|
|
||||||
self.assertRaisesRegex(
|
|
||||||
ConfigurationError,
|
|
||||||
"Invalid URI host: an IP address is not",
|
|
||||||
self.simple_client,
|
|
||||||
"mongodb+srv://[::1]",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCaseInsensitive(IntegrationTest):
|
class TestCaseInsensitive(IntegrationTest):
|
||||||
def test_connect_case_insensitive(self):
|
def test_connect_case_insensitive(self):
|
||||||
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/")
|
||||||
|
client._connect()
|
||||||
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
self.assertGreater(len(client.topology_description.server_descriptions()), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from test.utils import wait_until
|
|||||||
import pymongo
|
import pymongo
|
||||||
from pymongo import common
|
from pymongo import common
|
||||||
from pymongo.errors import ConfigurationError
|
from pymongo.errors import ConfigurationError
|
||||||
from pymongo.srv_resolver import _have_dnspython
|
from pymongo.synchronous.srv_resolver import _have_dnspython
|
||||||
|
|
||||||
_IS_SYNC = True
|
_IS_SYNC = True
|
||||||
|
|
||||||
@ -54,7 +54,9 @@ class SrvPollingKnobs:
|
|||||||
|
|
||||||
def enable(self):
|
def enable(self):
|
||||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||||
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
self.old_dns_resolver_response = (
|
||||||
|
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||||
|
)
|
||||||
|
|
||||||
if self.min_srv_rescan_interval is not None:
|
if self.min_srv_rescan_interval is not None:
|
||||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||||
@ -74,14 +76,14 @@ class SrvPollingKnobs:
|
|||||||
else:
|
else:
|
||||||
patch_func = mock_get_hosts_and_min_ttl
|
patch_func = mock_get_hosts_and_min_ttl
|
||||||
|
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.enable()
|
self.enable()
|
||||||
|
|
||||||
def disable(self):
|
def disable(self):
|
||||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
|
||||||
self.old_dns_resolver_response
|
self.old_dns_resolver_response
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,7 +136,10 @@ class TestSrvPolling(PyMongoTestCase):
|
|||||||
|
|
||||||
def predicate():
|
def predicate():
|
||||||
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
if set(expected_nodelist) == set(self.get_nodelist(client)):
|
||||||
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
|
return (
|
||||||
|
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count
|
||||||
|
>= 1
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
|
||||||
@ -144,7 +149,7 @@ class TestSrvPolling(PyMongoTestCase):
|
|||||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
|
||||||
1,
|
1,
|
||||||
"resolver was never called",
|
"resolver was never called",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,8 +28,8 @@ from test import unittest
|
|||||||
from bson.binary import JAVA_LEGACY
|
from bson.binary import JAVA_LEGACY
|
||||||
from pymongo import ReadPreference
|
from pymongo import ReadPreference
|
||||||
from pymongo.errors import ConfigurationError, InvalidURI
|
from pymongo.errors import ConfigurationError, InvalidURI
|
||||||
from pymongo.uri_parser import (
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
parse_uri,
|
from pymongo.uri_parser_shared import (
|
||||||
parse_userinfo,
|
parse_userinfo,
|
||||||
split_hosts,
|
split_hosts,
|
||||||
split_options,
|
split_options,
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from test.helpers import clear_warning_registry
|
|||||||
|
|
||||||
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
|
from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
|
||||||
from pymongo.compression_support import _have_snappy
|
from pymongo.compression_support import _have_snappy
|
||||||
from pymongo.uri_parser import parse_uri
|
from pymongo.synchronous.uri_parser import parse_uri
|
||||||
|
|
||||||
CONN_STRING_TEST_PATH = os.path.join(
|
CONN_STRING_TEST_PATH = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test")
|
os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test")
|
||||||
|
|||||||
@ -127,6 +127,7 @@ replacements = {
|
|||||||
"async_create_barrier": "create_barrier",
|
"async_create_barrier": "create_barrier",
|
||||||
"async_barrier_wait": "barrier_wait",
|
"async_barrier_wait": "barrier_wait",
|
||||||
"async_joinall": "joinall",
|
"async_joinall": "joinall",
|
||||||
|
"pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts",
|
||||||
}
|
}
|
||||||
|
|
||||||
docstring_replacements: dict[tuple[str, str], str] = {
|
docstring_replacements: dict[tuple[str, str], str] = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user