PYTHON-1675 SRV polling for mongos discovery
This commit is contained in:
parent
afbf18b0ad
commit
f85a9f9450
@ -39,6 +39,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include:
|
||||
enabled by default. See the :class:`~pymongo.mongo_client.MongoClient`
|
||||
documentation for details.
|
||||
- Support zstandard for wire protocol compression.
|
||||
- Support for periodically polling DNS SRV records to update the mongos proxy
|
||||
list without having to change client configuration.
|
||||
|
||||
Now that supported operations are retried automatically and transparently,
|
||||
users should consider adjusting any custom retry logic to prevent
|
||||
|
||||
@ -73,6 +73,9 @@ SERVER_SELECTION_TIMEOUT = 30
|
||||
# Spec requires at least 500ms between ismaster calls.
|
||||
MIN_HEARTBEAT_INTERVAL = 0.5
|
||||
|
||||
# Spec requires at least 60s between SRV rescans.
|
||||
MIN_SRV_RESCAN_INTERVAL = 60
|
||||
|
||||
# Default connectTimeout in seconds.
|
||||
CONNECT_TIMEOUT = 20.0
|
||||
|
||||
|
||||
@ -585,6 +585,7 @@ class MongoClient(common.BaseObject):
|
||||
password = None
|
||||
dbase = None
|
||||
opts = {}
|
||||
fqdn = None
|
||||
for entity in host:
|
||||
if "://" in entity:
|
||||
res = uri_parser.parse_uri(
|
||||
@ -594,6 +595,7 @@ class MongoClient(common.BaseObject):
|
||||
password = res["password"] or password
|
||||
dbase = res["database"] or dbase
|
||||
opts = res["options"]
|
||||
fqdn = res["fqdn"]
|
||||
else:
|
||||
seeds.update(uri_parser.split_hosts(entity, port))
|
||||
if not seeds:
|
||||
@ -673,7 +675,8 @@ class MongoClient(common.BaseObject):
|
||||
local_threshold_ms=options.local_threshold_ms,
|
||||
server_selection_timeout=options.server_selection_timeout,
|
||||
server_selector=options.server_selector,
|
||||
heartbeat_frequency=options.heartbeat_frequency)
|
||||
heartbeat_frequency=options.heartbeat_frequency,
|
||||
fqdn=fqdn)
|
||||
|
||||
self._topology = Topology(self._topology_settings)
|
||||
if connect:
|
||||
|
||||
@ -18,13 +18,42 @@ import weakref
|
||||
|
||||
from pymongo import common, periodic_executor
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.monotonic import time as _time
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.srv_resolver import _SrvResolver
|
||||
|
||||
|
||||
class Monitor(object):
|
||||
class MonitorBase(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Override this method to create an executor."""
|
||||
raise NotImplementedError
|
||||
|
||||
def open(self):
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._executor.open()
|
||||
|
||||
def close(self):
|
||||
"""Close and stop monitoring.
|
||||
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self._executor.close()
|
||||
|
||||
def join(self, timeout=None):
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
|
||||
def request_check(self):
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
self._executor.wake()
|
||||
|
||||
|
||||
class Monitor(MonitorBase):
|
||||
def __init__(
|
||||
self,
|
||||
server_description,
|
||||
@ -68,31 +97,13 @@ class Monitor(object):
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self._topology = weakref.proxy(topology, executor.close)
|
||||
|
||||
def open(self):
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._executor.open()
|
||||
|
||||
def close(self):
|
||||
"""Close and stop monitoring.
|
||||
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self._executor.close()
|
||||
super(Monitor, self).close()
|
||||
|
||||
# Increment the pool_id and maybe close the socket. If the executor
|
||||
# thread has the socket checked out, it will be closed when checked in.
|
||||
self._pool.reset()
|
||||
|
||||
def join(self, timeout=None):
|
||||
self._executor.join(timeout)
|
||||
|
||||
def request_check(self):
|
||||
"""If the monitor is sleeping, wake and check the server soon."""
|
||||
self._executor.wake()
|
||||
|
||||
def _run(self):
|
||||
try:
|
||||
self._server_description = self._check_with_retry()
|
||||
@ -182,3 +193,66 @@ class Monitor(object):
|
||||
self._topology.receive_cluster_time(
|
||||
exc.details.get('$clusterTime'))
|
||||
raise
|
||||
|
||||
|
||||
class SrvMonitor(MonitorBase):
|
||||
def __init__(self, topology, topology_settings):
|
||||
"""Class to poll SRV records on a background thread.
|
||||
|
||||
Pass a Topology and a TopologySettings.
|
||||
|
||||
The Topology is weakly referenced.
|
||||
"""
|
||||
self._settings = topology_settings
|
||||
self._fqdn = self._settings.fqdn
|
||||
|
||||
# We strongly reference the executor and it weakly references us via
|
||||
# this closure. When the monitor is freed, stop the executor soon.
|
||||
def target():
|
||||
monitor = self_ref()
|
||||
if monitor is None:
|
||||
return False # Stop the executor.
|
||||
SrvMonitor._run(monitor)
|
||||
return True
|
||||
|
||||
executor = periodic_executor.PeriodicExecutor(
|
||||
interval=common.MIN_SRV_RESCAN_INTERVAL,
|
||||
min_interval=self._settings.heartbeat_frequency,
|
||||
target=target,
|
||||
name="pymongo_srv_polling_thread")
|
||||
|
||||
self._executor = executor
|
||||
|
||||
# Avoid cycles. When self or topology is freed, stop executor soon.
|
||||
self_ref = weakref.ref(self, executor.close)
|
||||
self._topology = weakref.proxy(topology, executor.close)
|
||||
|
||||
def _run(self):
|
||||
try:
|
||||
self._seedlist = self._get_seedlist()
|
||||
self._topology.on_srv_update(self._seedlist)
|
||||
except ReferenceError:
|
||||
# Topology was garbage-collected.
|
||||
self.close()
|
||||
|
||||
def _get_seedlist(self):
|
||||
"""Poll SRV records for a seedlist.
|
||||
|
||||
Returns a list of ServerDescriptions.
|
||||
"""
|
||||
try:
|
||||
seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl()
|
||||
if len(seedlist) == 0:
|
||||
# As per the spec: this should be treated as a failure.
|
||||
raise Exception
|
||||
except Exception:
|
||||
# As per the spec, upon encountering an error:
|
||||
# - An error must not be raised
|
||||
# - SRV records must be rescanned every heartbeatFrequencyMS
|
||||
# - Topology must be left unchanged
|
||||
self.request_check()
|
||||
return self._seedlist
|
||||
else:
|
||||
self._executor.update_interval(
|
||||
max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
|
||||
return seedlist
|
||||
|
||||
@ -102,6 +102,9 @@ class PeriodicExecutor(object):
|
||||
"""Execute the target function soon."""
|
||||
self._event = True
|
||||
|
||||
def update_interval(self, new_interval):
|
||||
self._interval = new_interval
|
||||
|
||||
def __should_stop(self):
|
||||
with self._lock:
|
||||
if self._stopped:
|
||||
|
||||
@ -20,9 +20,9 @@ from bson.objectid import ObjectId
|
||||
from pymongo import common, monitor, pool
|
||||
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
|
||||
|
||||
class TopologySettings(object):
|
||||
@ -36,7 +36,8 @@ class TopologySettings(object):
|
||||
local_threshold_ms=LOCAL_THRESHOLD_MS,
|
||||
server_selection_timeout=SERVER_SELECTION_TIMEOUT,
|
||||
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
|
||||
server_selector=None):
|
||||
server_selector=None,
|
||||
fqdn=None):
|
||||
"""Represent MongoClient's configuration.
|
||||
|
||||
Take a list of (host, port) pairs and optional replica set name.
|
||||
@ -55,6 +56,7 @@ class TopologySettings(object):
|
||||
self._local_threshold_ms = local_threshold_ms
|
||||
self._server_selection_timeout = server_selection_timeout
|
||||
self._server_selector = server_selector
|
||||
self._fqdn = fqdn
|
||||
self._heartbeat_frequency = heartbeat_frequency
|
||||
self._direct = (len(self._seeds) == 1 and not replica_set_name)
|
||||
self._topology_id = ObjectId()
|
||||
@ -100,6 +102,10 @@ class TopologySettings(object):
|
||||
def heartbeat_frequency(self):
|
||||
return self._heartbeat_frequency
|
||||
|
||||
@property
|
||||
def fqdn(self):
|
||||
return self._fqdn
|
||||
|
||||
@property
|
||||
def direct(self):
|
||||
"""Connect directly to a single server, or use a set of servers?
|
||||
|
||||
103
pymongo/srv_resolver.py
Normal file
103
pymongo/srv_resolver.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
|
||||
|
||||
try:
|
||||
from dns import resolver
|
||||
_HAVE_DNSPYTHON = True
|
||||
except ImportError:
|
||||
_HAVE_DNSPYTHON = False
|
||||
|
||||
from bson.py3compat import PY3
|
||||
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
|
||||
if PY3:
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text):
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
else:
|
||||
def maybe_decode(text):
|
||||
return text
|
||||
|
||||
|
||||
class _SrvResolver(object):
|
||||
def __init__(self, fqdn):
|
||||
self.__fqdn = fqdn
|
||||
|
||||
# Validate the fully qualified domain name.
|
||||
try:
|
||||
self.__plist = self.__fqdn.split(".")[1:]
|
||||
except Exception:
|
||||
raise ConfigurationError("Invalid URI host")
|
||||
self.__slen = len(self.__plist)
|
||||
if self.__slen < 2:
|
||||
raise ConfigurationError("Invalid URI host")
|
||||
|
||||
def get_options(self):
|
||||
try:
|
||||
results = resolver.query(self.__fqdn, 'TXT')
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc))
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError('Only one TXT record is supported')
|
||||
return (
|
||||
b'&'.join([b''.join(res.strings) for res in results])).decode(
|
||||
'utf-8')
|
||||
|
||||
def _resolve_uri(self, encapsulate_errors):
|
||||
try:
|
||||
results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV')
|
||||
except Exception as exc:
|
||||
if not encapsulate_errors:
|
||||
# Raise the original error.
|
||||
raise
|
||||
# Else, raise all errors as ConfigurationError.
|
||||
raise ConfigurationError(str(exc))
|
||||
return results
|
||||
|
||||
def _get_srv_response_and_hosts(self, encapsulate_errors):
|
||||
results = self._resolve_uri(encapsulate_errors)
|
||||
|
||||
# Construct address tuples
|
||||
nodes = [
|
||||
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
|
||||
for res in results]
|
||||
|
||||
# Validate hosts
|
||||
for node in nodes:
|
||||
try:
|
||||
nlist = node[0].split(".")[1:][-self.__slen:]
|
||||
except Exception:
|
||||
raise ConfigurationError("Invalid SRV host")
|
||||
if self.__plist != nlist:
|
||||
raise ConfigurationError("Invalid SRV host")
|
||||
|
||||
return results, nodes
|
||||
|
||||
def get_hosts(self):
|
||||
_, nodes = self._get_srv_response_and_hosts(True)
|
||||
return nodes
|
||||
|
||||
def get_hosts_and_min_ttl(self):
|
||||
results, nodes = self._get_srv_response_and_hosts(False)
|
||||
return nodes, results.rrset.ttl
|
||||
@ -30,9 +30,11 @@ from pymongo import common
|
||||
from pymongo import periodic_executor
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.topology_description import (updated_topology_description,
|
||||
TOPOLOGY_TYPE,
|
||||
TopologyDescription)
|
||||
_updated_topology_description_srv_polling,
|
||||
TopologyDescription,
|
||||
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE)
|
||||
from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError
|
||||
from pymongo.monitor import SrvMonitor
|
||||
from pymongo.monotonic import time as _time
|
||||
from pymongo.server import Server
|
||||
from pymongo.server_selectors import (any_server_selector,
|
||||
@ -129,6 +131,10 @@ class Topology(object):
|
||||
self.__events_executor = executor
|
||||
executor.open()
|
||||
|
||||
self._srv_monitor = None
|
||||
if self._settings.fqdn is not None:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
def open(self):
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
@ -272,6 +278,14 @@ class Topology(object):
|
||||
self._listeners.publish_topology_description_changed,
|
||||
(td_old, self._description, self._topology_id)))
|
||||
|
||||
# Shutdown SRV polling for unsupported cluster types.
|
||||
# This is only applicable if the old topology was Unknown, and the
|
||||
# new one is something other than Unknown or Sharded.
|
||||
if self._srv_monitor and (td_old.topology_type == TOPOLOGY_TYPE.Unknown
|
||||
and self._description.topology_type not in
|
||||
SRV_POLLING_TOPOLOGIES):
|
||||
self._srv_monitor.close()
|
||||
|
||||
# Wake waiters in select_servers().
|
||||
self._condition.notify_all()
|
||||
|
||||
@ -291,6 +305,28 @@ class Topology(object):
|
||||
self._description.has_server(server_description.address)):
|
||||
self._process_change(server_description)
|
||||
|
||||
def _process_srv_update(self, seedlist):
|
||||
"""Process a new seedlist on an opened topology.
|
||||
Hold the lock when calling this.
|
||||
"""
|
||||
td_old = self._description
|
||||
self._description = _updated_topology_description_srv_polling(
|
||||
self._description, seedlist)
|
||||
|
||||
self._update_servers()
|
||||
|
||||
if self._publish_tp:
|
||||
self._events.put((
|
||||
self._listeners.publish_topology_description_changed,
|
||||
(td_old, self._description, self._topology_id)))
|
||||
|
||||
def on_srv_update(self, seedlist):
|
||||
"""Process a new list of nodes obtained from scanning SRV records."""
|
||||
# We do no I/O holding the lock.
|
||||
with self._lock:
|
||||
if self._opened:
|
||||
self._process_srv_update(seedlist)
|
||||
|
||||
def get_server_by_address(self, address):
|
||||
"""Get a Server or None.
|
||||
|
||||
@ -396,6 +432,11 @@ class Topology(object):
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
self._update_servers()
|
||||
|
||||
# Stop SRV polling thread.
|
||||
if self._srv_monitor:
|
||||
self._srv_monitor.close()
|
||||
|
||||
self._opened = False
|
||||
|
||||
# Publish only after releasing the lock.
|
||||
@ -471,6 +512,11 @@ class Topology(object):
|
||||
if self._publish_tp or self._publish_server:
|
||||
self.__events_executor.open()
|
||||
|
||||
# Start the SRV polling thread.
|
||||
if self._srv_monitor and (self.description.topology_type in
|
||||
SRV_POLLING_TOPOLOGIES):
|
||||
self._srv_monitor.open()
|
||||
|
||||
# Ensure that the monitors are open.
|
||||
for server in itervalues(self._servers):
|
||||
server.open()
|
||||
|
||||
@ -24,10 +24,14 @@ from pymongo.server_selectors import Selection
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
|
||||
# Enumeration for various kinds of MongoDB cluster topologies.
|
||||
TOPOLOGY_TYPE = namedtuple('TopologyType', ['Single', 'ReplicaSetNoPrimary',
|
||||
'ReplicaSetWithPrimary', 'Sharded',
|
||||
'Unknown'])(*range(5))
|
||||
|
||||
# Topologies compatible with SRV record polling.
|
||||
SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
|
||||
|
||||
|
||||
class TopologyDescription(object):
|
||||
def __init__(self,
|
||||
@ -400,6 +404,40 @@ def updated_topology_description(topology_description, server_description):
|
||||
topology_description._topology_settings)
|
||||
|
||||
|
||||
def _updated_topology_description_srv_polling(topology_description, seedlist):
|
||||
"""Return an updated copy of a TopologyDescription.
|
||||
|
||||
:Parameters:
|
||||
- `topology_description`: the current TopologyDescription
|
||||
- `seedlist`: a list of new seeds new ServerDescription that resulted from
|
||||
an ismaster call
|
||||
"""
|
||||
# Create a copy of the server descriptions.
|
||||
sds = topology_description.server_descriptions()
|
||||
|
||||
# If seeds haven't changed, don't do anything.
|
||||
if set(sds.keys()) == set(seedlist):
|
||||
return topology_description
|
||||
|
||||
# Add SDs corresponding to servers recently added to the SRV record.
|
||||
for address in seedlist:
|
||||
if address not in sds:
|
||||
sds[address] = ServerDescription(address)
|
||||
|
||||
# Remove SDs corresponding to servers no longer part of the SRV record.
|
||||
for address in list(sds.keys()):
|
||||
if address not in seedlist:
|
||||
sds.pop(address)
|
||||
|
||||
return TopologyDescription(
|
||||
topology_description.topology_type,
|
||||
sds,
|
||||
topology_description.replica_set_name,
|
||||
topology_description.max_set_version,
|
||||
topology_description.max_election_id,
|
||||
topology_description._topology_settings)
|
||||
|
||||
|
||||
def _update_rs_from_primary(
|
||||
sds,
|
||||
replica_set_name,
|
||||
|
||||
@ -17,12 +17,6 @@
|
||||
import re
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from dns import resolver
|
||||
_HAVE_DNSPYTHON = True
|
||||
except ImportError:
|
||||
_HAVE_DNSPYTHON = False
|
||||
|
||||
from bson.py3compat import string_type, PY3
|
||||
|
||||
if PY3:
|
||||
@ -34,6 +28,7 @@ from pymongo.common import (
|
||||
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
|
||||
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
|
||||
from pymongo.errors import ConfigurationError, InvalidURI
|
||||
from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver
|
||||
|
||||
|
||||
SCHEME = 'mongodb://'
|
||||
@ -325,46 +320,10 @@ def split_hosts(hosts, default_port=DEFAULT_PORT):
|
||||
# backward-compat we allow "db.collection" in URI.
|
||||
_BAD_DB_CHARS = re.compile('[' + re.escape(r'/ "$') + ']')
|
||||
|
||||
|
||||
if PY3:
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
def maybe_decode(text):
|
||||
if isinstance(text, bytes):
|
||||
return text.decode()
|
||||
return text
|
||||
else:
|
||||
def maybe_decode(text):
|
||||
return text
|
||||
|
||||
|
||||
_ALLOWED_TXT_OPTS = frozenset(
|
||||
['authsource', 'authSource', 'replicaset', 'replicaSet'])
|
||||
|
||||
|
||||
def _get_dns_srv_hosts(hostname):
|
||||
try:
|
||||
results = resolver.query('_mongodb._tcp.' + hostname, 'SRV')
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc))
|
||||
return [(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
|
||||
for res in results]
|
||||
|
||||
|
||||
def _get_dns_txt_options(hostname):
|
||||
try:
|
||||
results = resolver.query(hostname, 'TXT')
|
||||
except (resolver.NoAnswer, resolver.NXDOMAIN):
|
||||
# No TXT records
|
||||
return None
|
||||
except Exception as exc:
|
||||
raise ConfigurationError(str(exc))
|
||||
if len(results) > 1:
|
||||
raise ConfigurationError('Only one TXT record is supported')
|
||||
return (
|
||||
b'&'.join([b''.join(res.strings) for res in results])).decode('utf-8')
|
||||
|
||||
|
||||
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
normalize=True):
|
||||
"""Parse and validate a MongoDB URI.
|
||||
@ -377,7 +336,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
'password': <password> or None,
|
||||
'database': <database name> or None,
|
||||
'collection': <collection name> or None,
|
||||
'options': <dict of MongoDB URI options>
|
||||
'options': <dict of MongoDB URI options>,
|
||||
'fqdn': <fqdn of the MongoDB+SRV URI> or None
|
||||
}
|
||||
|
||||
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
|
||||
@ -451,6 +411,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
" percent-encoded: %s" % host_part)
|
||||
|
||||
hosts = unquote_plus(hosts)
|
||||
fqdn = None
|
||||
|
||||
if is_srv:
|
||||
nodes = split_hosts(hosts, default_port=None)
|
||||
@ -462,24 +423,10 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
if port is not None:
|
||||
raise InvalidURI(
|
||||
"%s URIs must not include a port number" % (SRV_SCHEME,))
|
||||
nodes = _get_dns_srv_hosts(fqdn)
|
||||
|
||||
try:
|
||||
plist = fqdn.split(".")[1:]
|
||||
except Exception:
|
||||
raise ConfigurationError("Invalid URI host")
|
||||
slen = len(plist)
|
||||
if slen < 2:
|
||||
raise ConfigurationError("Invalid URI host")
|
||||
for node in nodes:
|
||||
try:
|
||||
nlist = node[0].split(".")[1:][-slen:]
|
||||
except Exception:
|
||||
raise ConfigurationError("Invalid SRV host")
|
||||
if plist != nlist:
|
||||
raise ConfigurationError("Invalid SRV host")
|
||||
|
||||
dns_options = _get_dns_txt_options(fqdn)
|
||||
dns_resolver = _SrvResolver(fqdn)
|
||||
nodes = dns_resolver.get_hosts()
|
||||
dns_options = dns_resolver.get_options()
|
||||
if dns_options:
|
||||
options = split_options(dns_options, validate, warn, normalize)
|
||||
if set(options) - _ALLOWED_TXT_OPTS:
|
||||
@ -514,7 +461,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
|
||||
'password': passwd,
|
||||
'database': dbase,
|
||||
'collection': collection,
|
||||
'options': options
|
||||
'options': options,
|
||||
'fqdn': fqdn
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -22,9 +22,10 @@ import sys
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.common import validate_read_preference_tags
|
||||
from pymongo.srv_resolver import _HAVE_DNSPYTHON
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from pymongo.uri_parser import parse_uri, split_hosts, _HAVE_DNSPYTHON
|
||||
from pymongo.uri_parser import parse_uri, split_hosts
|
||||
from test import client_context, unittest
|
||||
from test.utils import wait_until
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from pymongo.monitor import Monitor
|
||||
from pymongo.read_preferences import MovingAverage
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.topology import TOPOLOGY_TYPE
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
from test import unittest, client_context, client_knobs
|
||||
from test.utils import (ServerAndTopologyEventListener,
|
||||
single_client,
|
||||
|
||||
@ -27,7 +27,8 @@ from pymongo.topology import Topology
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import client_context, unittest, IntegrationTest
|
||||
from test.utils import rs_or_single_client, wait_until, EventListener
|
||||
from test.utils import (rs_or_single_client, wait_until, EventListener,
|
||||
FunctionCallCounter)
|
||||
from test.utils_selection_tests import (
|
||||
create_selection_tests, get_addresses, get_topology_settings_dict,
|
||||
make_server_description)
|
||||
@ -39,16 +40,6 @@ _TEST_PATH = os.path.join(
|
||||
os.path.join('server_selection', 'server_selection'))
|
||||
|
||||
|
||||
class CallCountSelector(object):
|
||||
"""No-op selector that keeps track of how many times it is called."""
|
||||
def __init__(self):
|
||||
self.call_count = 0
|
||||
|
||||
def __call__(self, servers):
|
||||
self.call_count += 1
|
||||
return servers
|
||||
|
||||
|
||||
class SelectionStoreSelector(object):
|
||||
"""No-op selector that keeps track of what was passed to it."""
|
||||
def __init__(self):
|
||||
@ -114,7 +105,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
|
||||
@client_context.require_replica_set
|
||||
def test_selector_called(self):
|
||||
selector = CallCountSelector()
|
||||
selector = FunctionCallCounter(lambda x: x)
|
||||
|
||||
# Client setup.
|
||||
mongo_client = rs_or_single_client(server_selector=selector)
|
||||
@ -178,7 +169,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
|
||||
|
||||
@client_context.require_replica_set
|
||||
def test_server_selector_bypassed(self):
|
||||
selector = CallCountSelector()
|
||||
selector = FunctionCallCounter(lambda x: x)
|
||||
|
||||
scenario_def = {
|
||||
'topology_description': {
|
||||
|
||||
201
test/test_srv_polling.py
Normal file
201
test/test_srv_polling.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2019-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the SRV support tests."""
|
||||
|
||||
import sys
|
||||
|
||||
from time import sleep
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
import pymongo
|
||||
|
||||
from pymongo import common
|
||||
from pymongo.srv_resolver import _HAVE_DNSPYTHON
|
||||
from pymongo.mongo_client import MongoClient
|
||||
from test import client_knobs, unittest
|
||||
from test.utils import wait_until, FunctionCallCounter
|
||||
|
||||
|
||||
WAIT_TIME = 0.1
|
||||
|
||||
|
||||
class SRVPollingKnobs(object):
|
||||
def __init__(self, ttl_time=None, min_srv_rescan_interval=None,
|
||||
dns_resolver_nodelist_response=None,
|
||||
count_resolver_calls=False):
|
||||
self.ttl_time = ttl_time
|
||||
self.min_srv_rescan_interval = min_srv_rescan_interval
|
||||
self.dns_resolver_nodelist_response = dns_resolver_nodelist_response
|
||||
self.count_resolver_calls = count_resolver_calls
|
||||
|
||||
self.old_min_srv_rescan_interval = None
|
||||
self.old_dns_resolver_response = None
|
||||
|
||||
def enable(self):
|
||||
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
|
||||
self.old_dns_resolver_response = \
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
|
||||
|
||||
if self.min_srv_rescan_interval is not None:
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
|
||||
|
||||
def mock_get_hosts_and_min_ttl(resolver, *args):
|
||||
nodes, ttl = self.old_dns_resolver_response(resolver)
|
||||
if self.dns_resolver_nodelist_response is not None:
|
||||
nodes = self.dns_resolver_nodelist_response()
|
||||
if self.ttl_time is not None:
|
||||
ttl = self.ttl_time
|
||||
return nodes, ttl
|
||||
|
||||
if self.count_resolver_calls:
|
||||
patch_func = FunctionCallCounter(mock_get_hosts_and_min_ttl)
|
||||
else:
|
||||
patch_func = mock_get_hosts_and_min_ttl
|
||||
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
|
||||
|
||||
def __enter__(self):
|
||||
self.enable()
|
||||
|
||||
def disable(self):
|
||||
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
|
||||
self.old_dns_resolver_response
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.disable()
|
||||
|
||||
|
||||
class TestSRVPolling(unittest.TestCase):
|
||||
|
||||
BASE_SRV_RESPONSE = [
|
||||
("localhost.test.build.10gen.cc", 27017),
|
||||
("localhost.test.build.10gen.cc", 27018)]
|
||||
|
||||
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
|
||||
|
||||
def setUp(self):
|
||||
if not _HAVE_DNSPYTHON:
|
||||
raise unittest.SkipTest("SRV polling tests require the dnspython "
|
||||
"module")
|
||||
|
||||
def get_nodelist(self, client):
|
||||
return client._topology.description.server_descriptions().keys()
|
||||
|
||||
def assert_nodelist_change(self, expected_nodelist, client):
|
||||
"""Check if the client._topology eventually sees all nodes in the
|
||||
expected_nodelist.
|
||||
"""
|
||||
def predicate():
|
||||
nodelist = self.get_nodelist(client)
|
||||
if set(expected_nodelist) == set(nodelist):
|
||||
return True
|
||||
return False
|
||||
wait_until(predicate, "see expected nodelist", timeout=10*WAIT_TIME)
|
||||
|
||||
def assert_nodelist_nochange(self, expected_nodelist, client):
|
||||
"""Check if the client._topology ever deviates from seeing all nodes
|
||||
in the expected_nodelist. Consistency is checked after sleeping for
|
||||
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
|
||||
least once.
|
||||
"""
|
||||
sleep(WAIT_TIME*10)
|
||||
nodelist = self.get_nodelist(client)
|
||||
if set(expected_nodelist) != set(nodelist):
|
||||
msg = "Client nodelist %s changed unexpectedly (expected %s)"
|
||||
raise self.fail(msg % (nodelist, expected_nodelist))
|
||||
self.assertGreaterEqual(
|
||||
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
|
||||
1, "resolver was never called")
|
||||
return True
|
||||
|
||||
def _run_scenario(self, dns_response, expect_change):
|
||||
if callable(dns_response):
|
||||
dns_resolver_response = dns_response
|
||||
else:
|
||||
def dns_resolver_response():
|
||||
return dns_response
|
||||
|
||||
if expect_change:
|
||||
assertion_method = self.assert_nodelist_change
|
||||
count_resolver_calls = False
|
||||
expected_response = dns_response
|
||||
else:
|
||||
assertion_method = self.assert_nodelist_nochange
|
||||
count_resolver_calls = True
|
||||
expected_response = self.BASE_SRV_RESPONSE
|
||||
|
||||
# Patch timeouts to ensure short test running times.
|
||||
with SRVPollingKnobs(
|
||||
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
|
||||
mc = MongoClient(self.CONNECTION_STRING)
|
||||
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, mc)
|
||||
# Patch list of hosts returned by DNS query.
|
||||
with SRVPollingKnobs(
|
||||
dns_resolver_nodelist_response=dns_resolver_response,
|
||||
count_resolver_calls=count_resolver_calls):
|
||||
assertion_method(expected_response, mc)
|
||||
|
||||
def run_scenario(self, dns_response, expect_change):
|
||||
# Patch timeouts to ensure short rescan SRV interval.
|
||||
with client_knobs(heartbeat_frequency=WAIT_TIME,
|
||||
min_heartbeat_interval=WAIT_TIME,
|
||||
events_queue_frequency=WAIT_TIME):
|
||||
self._run_scenario(dns_response, expect_change)
|
||||
|
||||
def test_addition(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.append(
|
||||
("localhost.test.build.10gen.cc", 27019))
|
||||
self.run_scenario(response, True)
|
||||
|
||||
def test_removal(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.remove(
|
||||
("localhost.test.build.10gen.cc", 27018))
|
||||
self.run_scenario(response, True)
|
||||
|
||||
def test_replace_one(self):
|
||||
response = self.BASE_SRV_RESPONSE[:]
|
||||
response.remove(
|
||||
("localhost.test.build.10gen.cc", 27018))
|
||||
response.append(
|
||||
("localhost.test.build.10gen.cc", 27019))
|
||||
self.run_scenario(response, True)
|
||||
|
||||
def test_replace_both_with_one(self):
|
||||
response = [("localhost.test.build.10gen.cc", 27019)]
|
||||
self.run_scenario(response, True)
|
||||
|
||||
def test_replace_both_with_two(self):
|
||||
response = [("localhost.test.build.10gen.cc", 27019),
|
||||
("localhost.test.build.10gen.cc", 27020)]
|
||||
self.run_scenario(response, True)
|
||||
|
||||
def test_dns_failures(self):
|
||||
from dns import exception
|
||||
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
|
||||
def response_callback(*args):
|
||||
raise exc("DNS Failure!")
|
||||
self.run_scenario(response_callback, False)
|
||||
|
||||
def test_dns_record_lookup_empty(self):
|
||||
response = []
|
||||
self.run_scenario(response, False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -193,7 +193,8 @@ class TestURI(unittest.TestCase):
|
||||
'password': None,
|
||||
'database': None,
|
||||
'collection': None,
|
||||
'options': {}
|
||||
'options': {},
|
||||
'fqdn': None
|
||||
}
|
||||
|
||||
res = copy.deepcopy(orig)
|
||||
@ -442,7 +443,8 @@ class TestURI(unittest.TestCase):
|
||||
'nodelist': [('/MongoDB.sock', None)],
|
||||
'options': {'ssl_certfile': '/a/b'},
|
||||
'password': 'foo/bar',
|
||||
'username': 'jesse'},
|
||||
'username': 'jesse',
|
||||
'fqdn': None},
|
||||
parse_uri(
|
||||
'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=/a/b',
|
||||
validate=False))
|
||||
@ -453,7 +455,8 @@ class TestURI(unittest.TestCase):
|
||||
'nodelist': [('/MongoDB.sock', None)],
|
||||
'options': {'ssl_certfile': 'a/b'},
|
||||
'password': 'foo/bar',
|
||||
'username': 'jesse'},
|
||||
'username': 'jesse',
|
||||
'fqdn': None},
|
||||
parse_uri(
|
||||
'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=a/b',
|
||||
validate=False))
|
||||
|
||||
@ -170,6 +170,21 @@ class CompareType(object):
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class FunctionCallCounter(object):
|
||||
"""Class that wraps a function and keeps count of invocations."""
|
||||
def __init__(self, function):
|
||||
self._function = function
|
||||
self._call_count = 0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._call_count += 1
|
||||
return self._function(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def call_count(self):
|
||||
return self._call_count
|
||||
|
||||
|
||||
class TestCreator(object):
|
||||
"""Class to create test cases from specifications."""
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
@ -502,6 +517,7 @@ def wait_until(predicate, success_description, timeout=10):
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
interval = min(float(timeout)/100, 0.1)
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
@ -510,7 +526,7 @@ def wait_until(predicate, success_description, timeout=10):
|
||||
if time.time() - start > timeout:
|
||||
raise AssertionError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(0.1)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def is_mongos(client):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user