From f85a9f9450a9784c048a20f6d39948b3e6326e71 Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Tue, 21 May 2019 15:37:35 -0700 Subject: [PATCH] PYTHON-1675 SRV polling for mongos discovery --- doc/changelog.rst | 2 + pymongo/common.py | 3 + pymongo/mongo_client.py | 5 +- pymongo/monitor.py | 116 +++++++++++++---- pymongo/periodic_executor.py | 3 + pymongo/settings.py | 10 +- pymongo/srv_resolver.py | 103 +++++++++++++++ pymongo/topology.py | 50 +++++++- pymongo/topology_description.py | 38 ++++++ pymongo/uri_parser.py | 70 ++--------- test/test_dns.py | 3 +- test/test_sdam_monitoring_spec.py | 2 +- test/test_server_selection.py | 17 +-- test/test_srv_polling.py | 201 ++++++++++++++++++++++++++++++ test/test_uri_parser.py | 9 +- test/utils.py | 18 ++- 16 files changed, 544 insertions(+), 106 deletions(-) create mode 100644 pymongo/srv_resolver.py create mode 100644 test/test_srv_polling.py diff --git a/doc/changelog.rst b/doc/changelog.rst index cd9615e11..dc9895364 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -39,6 +39,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include: enabled by default. See the :class:`~pymongo.mongo_client.MongoClient` documentation for details. - Support zstandard for wire protocol compression. +- Support for periodically polling DNS SRV records to update the mongos proxy + list without having to change client configuration. Now that supported operations are retried automatically and transparently, users should consider adjusting any custom retry logic to prevent diff --git a/pymongo/common.py b/pymongo/common.py index 7f4334264..62ef6761c 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -73,6 +73,9 @@ SERVER_SELECTION_TIMEOUT = 30 # Spec requires at least 500ms between ismaster calls. MIN_HEARTBEAT_INTERVAL = 0.5 +# Spec requires at least 60s between SRV rescans. +MIN_SRV_RESCAN_INTERVAL = 60 + # Default connectTimeout in seconds. CONNECT_TIMEOUT = 20.0 diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 034aa8a54..62e2f6ff7 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -585,6 +585,7 @@ class MongoClient(common.BaseObject): password = None dbase = None opts = {} + fqdn = None for entity in host: if "://" in entity: res = uri_parser.parse_uri( @@ -594,6 +595,7 @@ class MongoClient(common.BaseObject): password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] + fqdn = res["fqdn"] else: seeds.update(uri_parser.split_hosts(entity, port)) if not seeds: @@ -673,7 +675,8 @@ class MongoClient(common.BaseObject): local_threshold_ms=options.local_threshold_ms, server_selection_timeout=options.server_selection_timeout, server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency) + heartbeat_frequency=options.heartbeat_frequency, + fqdn=fqdn) self._topology = Topology(self._topology_settings) if connect: diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 4c0f3182e..d77d4e450 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -18,13 +18,42 @@ import weakref from pymongo import common, periodic_executor from pymongo.errors import OperationFailure -from pymongo.server_type import SERVER_TYPE from pymongo.monotonic import time as _time from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription +from pymongo.server_type import SERVER_TYPE +from pymongo.srv_resolver import _SrvResolver -class Monitor(object): +class MonitorBase(object): + def __init__(self, *args, **kwargs): + """Override this method to create an executor.""" + raise NotImplementedError + + def open(self): + """Start monitoring, or restart after a fork. + + Multiple calls have no effect. + """ + self._executor.open() + + def close(self): + """Close and stop monitoring. + + open() restarts the monitor after closing. + """ + self._executor.close() + + def join(self, timeout=None): + """Wait for the monitor to stop.""" + self._executor.join(timeout) + + def request_check(self): + """If the monitor is sleeping, wake it soon.""" + self._executor.wake() + + +class Monitor(MonitorBase): def __init__( self, server_description, @@ -68,31 +97,13 @@ class Monitor(object): self_ref = weakref.ref(self, executor.close) self._topology = weakref.proxy(topology, executor.close) - def open(self): - """Start monitoring, or restart after a fork. - - Multiple calls have no effect. - """ - self._executor.open() - def close(self): - """Close and stop monitoring. - - open() restarts the monitor after closing. - """ - self._executor.close() + super(Monitor, self).close() # Increment the pool_id and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. self._pool.reset() - def join(self, timeout=None): - self._executor.join(timeout) - - def request_check(self): - """If the monitor is sleeping, wake and check the server soon.""" - self._executor.wake() - def _run(self): try: self._server_description = self._check_with_retry() @@ -182,3 +193,66 @@ class Monitor(object): self._topology.receive_cluster_time( exc.details.get('$clusterTime')) raise + + +class SrvMonitor(MonitorBase): + def __init__(self, topology, topology_settings): + """Class to poll SRV records on a background thread. + + Pass a Topology and a TopologySettings. + + The Topology is weakly referenced. + """ + self._settings = topology_settings + self._fqdn = self._settings.fqdn + + # We strongly reference the executor and it weakly references us via + # this closure. When the monitor is freed, stop the executor soon. + def target(): + monitor = self_ref() + if monitor is None: + return False # Stop the executor. + SrvMonitor._run(monitor) + return True + + executor = periodic_executor.PeriodicExecutor( + interval=common.MIN_SRV_RESCAN_INTERVAL, + min_interval=self._settings.heartbeat_frequency, + target=target, + name="pymongo_srv_polling_thread") + + self._executor = executor + + # Avoid cycles. When self or topology is freed, stop executor soon. + self_ref = weakref.ref(self, executor.close) + self._topology = weakref.proxy(topology, executor.close) + + def _run(self): + try: + self._seedlist = self._get_seedlist() + self._topology.on_srv_update(self._seedlist) + except ReferenceError: + # Topology was garbage-collected. + self.close() + + def _get_seedlist(self): + """Poll SRV records for a seedlist. + + Returns a list of ServerDescriptions. + """ + try: + seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl() + if len(seedlist) == 0: + # As per the spec: this should be treated as a failure. + raise Exception + except Exception: + # As per the spec, upon encountering an error: + # - An error must not be raised + # - SRV records must be rescanned every heartbeatFrequencyMS + # - Topology must be left unchanged + self.request_check() + return self._seedlist + else: + self._executor.update_interval( + max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) + return seedlist diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 2a325446e..ba9664fa7 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -102,6 +102,9 @@ class PeriodicExecutor(object): """Execute the target function soon.""" self._event = True + def update_interval(self, new_interval): + self._interval = new_interval + def __should_stop(self): with self._lock: if self._stopped: diff --git a/pymongo/settings.py b/pymongo/settings.py index a03bbc98a..2a02f05d5 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -20,9 +20,9 @@ from bson.objectid import ObjectId from pymongo import common, monitor, pool from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError -from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.pool import PoolOptions from pymongo.server_description import ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE class TopologySettings(object): @@ -36,7 +36,8 @@ class TopologySettings(object): local_threshold_ms=LOCAL_THRESHOLD_MS, server_selection_timeout=SERVER_SELECTION_TIMEOUT, heartbeat_frequency=common.HEARTBEAT_FREQUENCY, - server_selector=None): + server_selector=None, + fqdn=None): """Represent MongoClient's configuration. Take a list of (host, port) pairs and optional replica set name. @@ -55,6 +56,7 @@ class TopologySettings(object): self._local_threshold_ms = local_threshold_ms self._server_selection_timeout = server_selection_timeout self._server_selector = server_selector + self._fqdn = fqdn self._heartbeat_frequency = heartbeat_frequency self._direct = (len(self._seeds) == 1 and not replica_set_name) self._topology_id = ObjectId() @@ -100,6 +102,10 @@ class TopologySettings(object): def heartbeat_frequency(self): return self._heartbeat_frequency + @property + def fqdn(self): + return self._fqdn + @property def direct(self): """Connect directly to a single server, or use a set of servers? diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py new file mode 100644 index 000000000..a5e4b9bb4 --- /dev/null +++ b/pymongo/srv_resolver.py @@ -0,0 +1,103 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for resolving hosts and options from mongodb+srv:// URIs.""" + +try: + from dns import resolver + _HAVE_DNSPYTHON = True +except ImportError: + _HAVE_DNSPYTHON = False + +from bson.py3compat import PY3 + +from pymongo.errors import ConfigurationError + + +if PY3: + # dnspython can return bytes or str from various parts + # of its API depending on version. We always want str. + def maybe_decode(text): + if isinstance(text, bytes): + return text.decode() + return text +else: + def maybe_decode(text): + return text + + +class _SrvResolver(object): + def __init__(self, fqdn): + self.__fqdn = fqdn + + # Validate the fully qualified domain name. + try: + self.__plist = self.__fqdn.split(".")[1:] + except Exception: + raise ConfigurationError("Invalid URI host") + self.__slen = len(self.__plist) + if self.__slen < 2: + raise ConfigurationError("Invalid URI host") + + def get_options(self): + try: + results = resolver.query(self.__fqdn, 'TXT') + except (resolver.NoAnswer, resolver.NXDOMAIN): + # No TXT records + return None + except Exception as exc: + raise ConfigurationError(str(exc)) + if len(results) > 1: + raise ConfigurationError('Only one TXT record is supported') + return ( + b'&'.join([b''.join(res.strings) for res in results])).decode( + 'utf-8') + + def _resolve_uri(self, encapsulate_errors): + try: + results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV') + except Exception as exc: + if not encapsulate_errors: + # Raise the original error. + raise + # Else, raise all errors as ConfigurationError. + raise ConfigurationError(str(exc)) + return results + + def _get_srv_response_and_hosts(self, encapsulate_errors): + results = self._resolve_uri(encapsulate_errors) + + # Construct address tuples + nodes = [ + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) + for res in results] + + # Validate hosts + for node in nodes: + try: + nlist = node[0].split(".")[1:][-self.__slen:] + except Exception: + raise ConfigurationError("Invalid SRV host") + if self.__plist != nlist: + raise ConfigurationError("Invalid SRV host") + + return results, nodes + + def get_hosts(self): + _, nodes = self._get_srv_response_and_hosts(True) + return nodes + + def get_hosts_and_min_ttl(self): + results, nodes = self._get_srv_response_and_hosts(False) + return nodes, results.rrset.ttl diff --git a/pymongo/topology.py b/pymongo/topology.py index c69c127ff..eecc4a7cf 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -30,9 +30,11 @@ from pymongo import common from pymongo import periodic_executor from pymongo.pool import PoolOptions from pymongo.topology_description import (updated_topology_description, - TOPOLOGY_TYPE, - TopologyDescription) + _updated_topology_description_srv_polling, + TopologyDescription, + SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE) from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError +from pymongo.monitor import SrvMonitor from pymongo.monotonic import time as _time from pymongo.server import Server from pymongo.server_selectors import (any_server_selector, @@ -129,6 +131,10 @@ class Topology(object): self.__events_executor = executor executor.open() + self._srv_monitor = None + if self._settings.fqdn is not None: + self._srv_monitor = SrvMonitor(self, self._settings) + def open(self): """Start monitoring, or restart after a fork. @@ -272,6 +278,14 @@ class Topology(object): self._listeners.publish_topology_description_changed, (td_old, self._description, self._topology_id))) + # Shutdown SRV polling for unsupported cluster types. + # This is only applicable if the old topology was Unknown, and the + # new one is something other than Unknown or Sharded. + if self._srv_monitor and (td_old.topology_type == TOPOLOGY_TYPE.Unknown + and self._description.topology_type not in + SRV_POLLING_TOPOLOGIES): + self._srv_monitor.close() + # Wake waiters in select_servers(). self._condition.notify_all() @@ -291,6 +305,28 @@ class Topology(object): self._description.has_server(server_description.address)): self._process_change(server_description) + def _process_srv_update(self, seedlist): + """Process a new seedlist on an opened topology. + Hold the lock when calling this. + """ + td_old = self._description + self._description = _updated_topology_description_srv_polling( + self._description, seedlist) + + self._update_servers() + + if self._publish_tp: + self._events.put(( + self._listeners.publish_topology_description_changed, + (td_old, self._description, self._topology_id))) + + def on_srv_update(self, seedlist): + """Process a new list of nodes obtained from scanning SRV records.""" + # We do no I/O holding the lock. + with self._lock: + if self._opened: + self._process_srv_update(seedlist) + def get_server_by_address(self, address): """Get a Server or None. @@ -396,6 +432,11 @@ class Topology(object): # Mark all servers Unknown. self._description = self._description.reset() self._update_servers() + + # Stop SRV polling thread. + if self._srv_monitor: + self._srv_monitor.close() + self._opened = False # Publish only after releasing the lock. @@ -471,6 +512,11 @@ class Topology(object): if self._publish_tp or self._publish_server: self.__events_executor.open() + # Start the SRV polling thread. + if self._srv_monitor and (self.description.topology_type in + SRV_POLLING_TOPOLOGIES): + self._srv_monitor.open() + # Ensure that the monitors are open. for server in itervalues(self._servers): server.open() diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index eca761520..b3b912f8d 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -24,10 +24,14 @@ from pymongo.server_selectors import Selection from pymongo.server_type import SERVER_TYPE +# Enumeration for various kinds of MongoDB cluster topologies. TOPOLOGY_TYPE = namedtuple('TopologyType', ['Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded', 'Unknown'])(*range(5)) +# Topologies compatible with SRV record polling. +SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) + class TopologyDescription(object): def __init__(self, @@ -400,6 +404,40 @@ def updated_topology_description(topology_description, server_description): topology_description._topology_settings) +def _updated_topology_description_srv_polling(topology_description, seedlist): + """Return an updated copy of a TopologyDescription. + + :Parameters: + - `topology_description`: the current TopologyDescription + - `seedlist`: a list of new seeds new ServerDescription that resulted from + an ismaster call + """ + # Create a copy of the server descriptions. + sds = topology_description.server_descriptions() + + # If seeds haven't changed, don't do anything. + if set(sds.keys()) == set(seedlist): + return topology_description + + # Add SDs corresponding to servers recently added to the SRV record. + for address in seedlist: + if address not in sds: + sds[address] = ServerDescription(address) + + # Remove SDs corresponding to servers no longer part of the SRV record. + for address in list(sds.keys()): + if address not in seedlist: + sds.pop(address) + + return TopologyDescription( + topology_description.topology_type, + sds, + topology_description.replica_set_name, + topology_description.max_set_version, + topology_description.max_election_id, + topology_description._topology_settings) + + def _update_rs_from_primary( sds, replica_set_name, diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index e44a86023..0834c3bf8 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -17,12 +17,6 @@ import re import warnings -try: - from dns import resolver - _HAVE_DNSPYTHON = True -except ImportError: - _HAVE_DNSPYTHON = False - from bson.py3compat import string_type, PY3 if PY3: @@ -34,6 +28,7 @@ from pymongo.common import ( get_validated_options, INTERNAL_URI_OPTION_NAME_MAP, URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary) from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver SCHEME = 'mongodb://' @@ -325,46 +320,10 @@ def split_hosts(hosts, default_port=DEFAULT_PORT): # backward-compat we allow "db.collection" in URI. _BAD_DB_CHARS = re.compile('[' + re.escape(r'/ "$') + ']') - -if PY3: - # dnspython can return bytes or str from various parts - # of its API depending on version. We always want str. - def maybe_decode(text): - if isinstance(text, bytes): - return text.decode() - return text -else: - def maybe_decode(text): - return text - - _ALLOWED_TXT_OPTS = frozenset( ['authsource', 'authSource', 'replicaset', 'replicaSet']) -def _get_dns_srv_hosts(hostname): - try: - results = resolver.query('_mongodb._tcp.' + hostname, 'SRV') - except Exception as exc: - raise ConfigurationError(str(exc)) - return [(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) - for res in results] - - -def _get_dns_txt_options(hostname): - try: - results = resolver.query(hostname, 'TXT') - except (resolver.NoAnswer, resolver.NXDOMAIN): - # No TXT records - return None - except Exception as exc: - raise ConfigurationError(str(exc)) - if len(results) > 1: - raise ConfigurationError('Only one TXT record is supported') - return ( - b'&'.join([b''.join(res.strings) for res in results])).decode('utf-8') - - def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, normalize=True): """Parse and validate a MongoDB URI. @@ -377,7 +336,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, 'password': or None, 'database': or None, 'collection': or None, - 'options': + 'options': , + 'fqdn': or None } If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done @@ -451,6 +411,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, " percent-encoded: %s" % host_part) hosts = unquote_plus(hosts) + fqdn = None if is_srv: nodes = split_hosts(hosts, default_port=None) @@ -462,24 +423,10 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, if port is not None: raise InvalidURI( "%s URIs must not include a port number" % (SRV_SCHEME,)) - nodes = _get_dns_srv_hosts(fqdn) - try: - plist = fqdn.split(".")[1:] - except Exception: - raise ConfigurationError("Invalid URI host") - slen = len(plist) - if slen < 2: - raise ConfigurationError("Invalid URI host") - for node in nodes: - try: - nlist = node[0].split(".")[1:][-slen:] - except Exception: - raise ConfigurationError("Invalid SRV host") - if plist != nlist: - raise ConfigurationError("Invalid SRV host") - - dns_options = _get_dns_txt_options(fqdn) + dns_resolver = _SrvResolver(fqdn) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() if dns_options: options = split_options(dns_options, validate, warn, normalize) if set(options) - _ALLOWED_TXT_OPTS: @@ -514,7 +461,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, 'password': passwd, 'database': dbase, 'collection': collection, - 'options': options + 'options': options, + 'fqdn': fqdn } diff --git a/test/test_dns.py b/test/test_dns.py index dc659be0a..a80664442 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -22,9 +22,10 @@ import sys sys.path[0:0] = [""] from pymongo.common import validate_read_preference_tags +from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.errors import ConfigurationError from pymongo.mongo_client import MongoClient -from pymongo.uri_parser import parse_uri, split_hosts, _HAVE_DNSPYTHON +from pymongo.uri_parser import parse_uri, split_hosts from test import client_context, unittest from test.utils import wait_until diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index a26880220..a84f2f7f7 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -29,7 +29,7 @@ from pymongo.monitor import Monitor from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription from pymongo.server_type import SERVER_TYPE -from pymongo.topology import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE from test import unittest, client_context, client_knobs from test.utils import (ServerAndTopologyEventListener, single_client, diff --git a/test/test_server_selection.py b/test/test_server_selection.py index ff2729dc7..a86239791 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -27,7 +27,8 @@ from pymongo.topology import Topology sys.path[0:0] = [""] from test import client_context, unittest, IntegrationTest -from test.utils import rs_or_single_client, wait_until, EventListener +from test.utils import (rs_or_single_client, wait_until, EventListener, + FunctionCallCounter) from test.utils_selection_tests import ( create_selection_tests, get_addresses, get_topology_settings_dict, make_server_description) @@ -39,16 +40,6 @@ _TEST_PATH = os.path.join( os.path.join('server_selection', 'server_selection')) -class CallCountSelector(object): - """No-op selector that keeps track of how many times it is called.""" - def __init__(self): - self.call_count = 0 - - def __call__(self, servers): - self.call_count += 1 - return servers - - class SelectionStoreSelector(object): """No-op selector that keeps track of what was passed to it.""" def __init__(self): @@ -114,7 +105,7 @@ class TestCustomServerSelectorFunction(IntegrationTest): @client_context.require_replica_set def test_selector_called(self): - selector = CallCountSelector() + selector = FunctionCallCounter(lambda x: x) # Client setup. mongo_client = rs_or_single_client(server_selector=selector) @@ -178,7 +169,7 @@ class TestCustomServerSelectorFunction(IntegrationTest): @client_context.require_replica_set def test_server_selector_bypassed(self): - selector = CallCountSelector() + selector = FunctionCallCounter(lambda x: x) scenario_def = { 'topology_description': { diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py new file mode 100644 index 000000000..6794e8e54 --- /dev/null +++ b/test/test_srv_polling.py @@ -0,0 +1,201 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the SRV support tests.""" + +import sys + +from time import sleep + +sys.path[0:0] = [""] + +import pymongo + +from pymongo import common +from pymongo.srv_resolver import _HAVE_DNSPYTHON +from pymongo.mongo_client import MongoClient +from test import client_knobs, unittest +from test.utils import wait_until, FunctionCallCounter + + +WAIT_TIME = 0.1 + + +class SRVPollingKnobs(object): + def __init__(self, ttl_time=None, min_srv_rescan_interval=None, + dns_resolver_nodelist_response=None, + count_resolver_calls=False): + self.ttl_time = ttl_time + self.min_srv_rescan_interval = min_srv_rescan_interval + self.dns_resolver_nodelist_response = dns_resolver_nodelist_response + self.count_resolver_calls = count_resolver_calls + + self.old_min_srv_rescan_interval = None + self.old_dns_resolver_response = None + + def enable(self): + self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL + self.old_dns_resolver_response = \ + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + + if self.min_srv_rescan_interval is not None: + common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval + + def mock_get_hosts_and_min_ttl(resolver, *args): + nodes, ttl = self.old_dns_resolver_response(resolver) + if self.dns_resolver_nodelist_response is not None: + nodes = self.dns_resolver_nodelist_response() + if self.ttl_time is not None: + ttl = self.ttl_time + return nodes, ttl + + if self.count_resolver_calls: + patch_func = FunctionCallCounter(mock_get_hosts_and_min_ttl) + else: + patch_func = mock_get_hosts_and_min_ttl + + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func + + def __enter__(self): + self.enable() + + def disable(self): + common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \ + self.old_dns_resolver_response + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + +class TestSRVPolling(unittest.TestCase): + + BASE_SRV_RESPONSE = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27018)] + + CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" + + def setUp(self): + if not _HAVE_DNSPYTHON: + raise unittest.SkipTest("SRV polling tests require the dnspython " + "module") + + def get_nodelist(self, client): + return client._topology.description.server_descriptions().keys() + + def assert_nodelist_change(self, expected_nodelist, client): + """Check if the client._topology eventually sees all nodes in the + expected_nodelist. + """ + def predicate(): + nodelist = self.get_nodelist(client) + if set(expected_nodelist) == set(nodelist): + return True + return False + wait_until(predicate, "see expected nodelist", timeout=10*WAIT_TIME) + + def assert_nodelist_nochange(self, expected_nodelist, client): + """Check if the client._topology ever deviates from seeing all nodes + in the expected_nodelist. Consistency is checked after sleeping for + (WAIT_TIME * 10) seconds. Also check that the resolver is called at + least once. + """ + sleep(WAIT_TIME*10) + nodelist = self.get_nodelist(client) + if set(expected_nodelist) != set(nodelist): + msg = "Client nodelist %s changed unexpectedly (expected %s)" + raise self.fail(msg % (nodelist, expected_nodelist)) + self.assertGreaterEqual( + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, + 1, "resolver was never called") + return True + + def _run_scenario(self, dns_response, expect_change): + if callable(dns_response): + dns_resolver_response = dns_response + else: + def dns_resolver_response(): + return dns_response + + if expect_change: + assertion_method = self.assert_nodelist_change + count_resolver_calls = False + expected_response = dns_response + else: + assertion_method = self.assert_nodelist_nochange + count_resolver_calls = True + expected_response = self.BASE_SRV_RESPONSE + + # Patch timeouts to ensure short test running times. + with SRVPollingKnobs( + ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + mc = MongoClient(self.CONNECTION_STRING) + self.assert_nodelist_change(self.BASE_SRV_RESPONSE, mc) + # Patch list of hosts returned by DNS query. + with SRVPollingKnobs( + dns_resolver_nodelist_response=dns_resolver_response, + count_resolver_calls=count_resolver_calls): + assertion_method(expected_response, mc) + + def run_scenario(self, dns_response, expect_change): + # Patch timeouts to ensure short rescan SRV interval. + with client_knobs(heartbeat_frequency=WAIT_TIME, + min_heartbeat_interval=WAIT_TIME, + events_queue_frequency=WAIT_TIME): + self._run_scenario(dns_response, expect_change) + + def test_addition(self): + response = self.BASE_SRV_RESPONSE[:] + response.append( + ("localhost.test.build.10gen.cc", 27019)) + self.run_scenario(response, True) + + def test_removal(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove( + ("localhost.test.build.10gen.cc", 27018)) + self.run_scenario(response, True) + + def test_replace_one(self): + response = self.BASE_SRV_RESPONSE[:] + response.remove( + ("localhost.test.build.10gen.cc", 27018)) + response.append( + ("localhost.test.build.10gen.cc", 27019)) + self.run_scenario(response, True) + + def test_replace_both_with_one(self): + response = [("localhost.test.build.10gen.cc", 27019)] + self.run_scenario(response, True) + + def test_replace_both_with_two(self): + response = [("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020)] + self.run_scenario(response, True) + + def test_dns_failures(self): + from dns import exception + for exc in (exception.FormError, exception.TooBig, exception.Timeout): + def response_callback(*args): + raise exc("DNS Failure!") + self.run_scenario(response_callback, False) + + def test_dns_record_lookup_empty(self): + response = [] + self.run_scenario(response, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 870953595..8f626633f 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -193,7 +193,8 @@ class TestURI(unittest.TestCase): 'password': None, 'database': None, 'collection': None, - 'options': {} + 'options': {}, + 'fqdn': None } res = copy.deepcopy(orig) @@ -442,7 +443,8 @@ class TestURI(unittest.TestCase): 'nodelist': [('/MongoDB.sock', None)], 'options': {'ssl_certfile': '/a/b'}, 'password': 'foo/bar', - 'username': 'jesse'}, + 'username': 'jesse', + 'fqdn': None}, parse_uri( 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=/a/b', validate=False)) @@ -453,7 +455,8 @@ class TestURI(unittest.TestCase): 'nodelist': [('/MongoDB.sock', None)], 'options': {'ssl_certfile': 'a/b'}, 'password': 'foo/bar', - 'username': 'jesse'}, + 'username': 'jesse', + 'fqdn': None}, parse_uri( 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=a/b', validate=False)) diff --git a/test/utils.py b/test/utils.py index a197364fc..fc8042a4e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -170,6 +170,21 @@ class CompareType(object): return not self.__eq__(other) +class FunctionCallCounter(object): + """Class that wraps a function and keeps count of invocations.""" + def __init__(self, function): + self._function = function + self._call_count = 0 + + def __call__(self, *args, **kwargs): + self._call_count += 1 + return self._function(*args, **kwargs) + + @property + def call_count(self): + return self._call_count + + class TestCreator(object): """Class to create test cases from specifications.""" def __init__(self, create_test, test_class, test_path): @@ -502,6 +517,7 @@ def wait_until(predicate, success_description, timeout=10): Returns the predicate's first true value. """ start = time.time() + interval = min(float(timeout)/100, 0.1) while True: retval = predicate() if retval: @@ -510,7 +526,7 @@ def wait_until(predicate, success_description, timeout=10): if time.time() - start > timeout: raise AssertionError("Didn't ever %s" % success_description) - time.sleep(0.1) + time.sleep(interval) def is_mongos(client):