PYTHON-1675 SRV polling for mongos discovery

This commit is contained in:
Prashant Mital 2019-05-21 15:37:35 -07:00
parent afbf18b0ad
commit f85a9f9450
No known key found for this signature in database
GPG Key ID: 3D2DAA9E483ABE51
16 changed files with 544 additions and 106 deletions

View File

@ -39,6 +39,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include:
enabled by default. See the :class:`~pymongo.mongo_client.MongoClient`
documentation for details.
- Support zstandard for wire protocol compression.
- Support for periodically polling DNS SRV records to update the mongos proxy
list without having to change client configuration.
Now that supported operations are retried automatically and transparently,
users should consider adjusting any custom retry logic to prevent

View File

@ -73,6 +73,9 @@ SERVER_SELECTION_TIMEOUT = 30
# Spec requires at least 500ms between ismaster calls.
MIN_HEARTBEAT_INTERVAL = 0.5
# Spec requires at least 60s between SRV rescans.
MIN_SRV_RESCAN_INTERVAL = 60
# Default connectTimeout in seconds.
CONNECT_TIMEOUT = 20.0

View File

@ -585,6 +585,7 @@ class MongoClient(common.BaseObject):
password = None
dbase = None
opts = {}
fqdn = None
for entity in host:
if "://" in entity:
res = uri_parser.parse_uri(
@ -594,6 +595,7 @@ class MongoClient(common.BaseObject):
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(uri_parser.split_hosts(entity, port))
if not seeds:
@ -673,7 +675,8 @@ class MongoClient(common.BaseObject):
local_threshold_ms=options.local_threshold_ms,
server_selection_timeout=options.server_selection_timeout,
server_selector=options.server_selector,
heartbeat_frequency=options.heartbeat_frequency)
heartbeat_frequency=options.heartbeat_frequency,
fqdn=fqdn)
self._topology = Topology(self._topology_settings)
if connect:

View File

@ -18,13 +18,42 @@ import weakref
from pymongo import common, periodic_executor
from pymongo.errors import OperationFailure
from pymongo.server_type import SERVER_TYPE
from pymongo.monotonic import time as _time
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.server_type import SERVER_TYPE
from pymongo.srv_resolver import _SrvResolver
class Monitor(object):
class MonitorBase(object):
def __init__(self, *args, **kwargs):
"""Override this method to create an executor."""
raise NotImplementedError
def open(self):
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._executor.open()
def close(self):
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self._executor.close()
def join(self, timeout=None):
"""Wait for the monitor to stop."""
self._executor.join(timeout)
def request_check(self):
"""If the monitor is sleeping, wake it soon."""
self._executor.wake()
class Monitor(MonitorBase):
def __init__(
self,
server_description,
@ -68,31 +97,13 @@ class Monitor(object):
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, executor.close)
def open(self):
"""Start monitoring, or restart after a fork.
Multiple calls have no effect.
"""
self._executor.open()
def close(self):
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self._executor.close()
super(Monitor, self).close()
# Increment the pool_id and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._pool.reset()
def join(self, timeout=None):
self._executor.join(timeout)
def request_check(self):
"""If the monitor is sleeping, wake and check the server soon."""
self._executor.wake()
def _run(self):
try:
self._server_description = self._check_with_retry()
@ -182,3 +193,66 @@ class Monitor(object):
self._topology.receive_cluster_time(
exc.details.get('$clusterTime'))
raise
class SrvMonitor(MonitorBase):
def __init__(self, topology, topology_settings):
"""Class to poll SRV records on a background thread.
Pass a Topology and a TopologySettings.
The Topology is weakly referenced.
"""
self._settings = topology_settings
self._fqdn = self._settings.fqdn
# We strongly reference the executor and it weakly references us via
# this closure. When the monitor is freed, stop the executor soon.
def target():
monitor = self_ref()
if monitor is None:
return False # Stop the executor.
SrvMonitor._run(monitor)
return True
executor = periodic_executor.PeriodicExecutor(
interval=common.MIN_SRV_RESCAN_INTERVAL,
min_interval=self._settings.heartbeat_frequency,
target=target,
name="pymongo_srv_polling_thread")
self._executor = executor
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, executor.close)
def _run(self):
try:
self._seedlist = self._get_seedlist()
self._topology.on_srv_update(self._seedlist)
except ReferenceError:
# Topology was garbage-collected.
self.close()
def _get_seedlist(self):
"""Poll SRV records for a seedlist.
Returns a list of ServerDescriptions.
"""
try:
seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl()
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
# - SRV records must be rescanned every heartbeatFrequencyMS
# - Topology must be left unchanged
self.request_check()
return self._seedlist
else:
self._executor.update_interval(
max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
return seedlist

View File

@ -102,6 +102,9 @@ class PeriodicExecutor(object):
"""Execute the target function soon."""
self._event = True
def update_interval(self, new_interval):
self._interval = new_interval
def __should_stop(self):
with self._lock:
if self._stopped:

View File

@ -20,9 +20,9 @@ from bson.objectid import ObjectId
from pymongo import common, monitor, pool
from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT
from pymongo.errors import ConfigurationError
from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.pool import PoolOptions
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TOPOLOGY_TYPE
class TopologySettings(object):
@ -36,7 +36,8 @@ class TopologySettings(object):
local_threshold_ms=LOCAL_THRESHOLD_MS,
server_selection_timeout=SERVER_SELECTION_TIMEOUT,
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
server_selector=None):
server_selector=None,
fqdn=None):
"""Represent MongoClient's configuration.
Take a list of (host, port) pairs and optional replica set name.
@ -55,6 +56,7 @@ class TopologySettings(object):
self._local_threshold_ms = local_threshold_ms
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._fqdn = fqdn
self._heartbeat_frequency = heartbeat_frequency
self._direct = (len(self._seeds) == 1 and not replica_set_name)
self._topology_id = ObjectId()
@ -100,6 +102,10 @@ class TopologySettings(object):
def heartbeat_frequency(self):
return self._heartbeat_frequency
@property
def fqdn(self):
return self._fqdn
@property
def direct(self):
"""Connect directly to a single server, or use a set of servers?

103
pymongo/srv_resolver.py Normal file
View File

@ -0,0 +1,103 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
try:
from dns import resolver
_HAVE_DNSPYTHON = True
except ImportError:
_HAVE_DNSPYTHON = False
from bson.py3compat import PY3
from pymongo.errors import ConfigurationError
if PY3:
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text):
if isinstance(text, bytes):
return text.decode()
return text
else:
def maybe_decode(text):
return text
class _SrvResolver(object):
def __init__(self, fqdn):
self.__fqdn = fqdn
# Validate the fully qualified domain name.
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError("Invalid URI host")
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError("Invalid URI host")
def get_options(self):
try:
results = resolver.query(self.__fqdn, 'TXT')
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc))
if len(results) > 1:
raise ConfigurationError('Only one TXT record is supported')
return (
b'&'.join([b''.join(res.strings) for res in results])).decode(
'utf-8')
def _resolve_uri(self, encapsulate_errors):
try:
results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV')
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc))
return results
def _get_srv_response_and_hosts(self, encapsulate_errors):
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
for res in results]
# Validate hosts
for node in nodes:
try:
nlist = node[0].split(".")[1:][-self.__slen:]
except Exception:
raise ConfigurationError("Invalid SRV host")
if self.__plist != nlist:
raise ConfigurationError("Invalid SRV host")
return results, nodes
def get_hosts(self):
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self):
results, nodes = self._get_srv_response_and_hosts(False)
return nodes, results.rrset.ttl

View File

@ -30,9 +30,11 @@ from pymongo import common
from pymongo import periodic_executor
from pymongo.pool import PoolOptions
from pymongo.topology_description import (updated_topology_description,
TOPOLOGY_TYPE,
TopologyDescription)
_updated_topology_description_srv_polling,
TopologyDescription,
SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE)
from pymongo.errors import ServerSelectionTimeoutError, ConfigurationError
from pymongo.monitor import SrvMonitor
from pymongo.monotonic import time as _time
from pymongo.server import Server
from pymongo.server_selectors import (any_server_selector,
@ -129,6 +131,10 @@ class Topology(object):
self.__events_executor = executor
executor.open()
self._srv_monitor = None
if self._settings.fqdn is not None:
self._srv_monitor = SrvMonitor(self, self._settings)
def open(self):
"""Start monitoring, or restart after a fork.
@ -272,6 +278,14 @@ class Topology(object):
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id)))
# Shutdown SRV polling for unsupported cluster types.
# This is only applicable if the old topology was Unknown, and the
# new one is something other than Unknown or Sharded.
if self._srv_monitor and (td_old.topology_type == TOPOLOGY_TYPE.Unknown
and self._description.topology_type not in
SRV_POLLING_TOPOLOGIES):
self._srv_monitor.close()
# Wake waiters in select_servers().
self._condition.notify_all()
@ -291,6 +305,28 @@ class Topology(object):
self._description.has_server(server_description.address)):
self._process_change(server_description)
def _process_srv_update(self, seedlist):
"""Process a new seedlist on an opened topology.
Hold the lock when calling this.
"""
td_old = self._description
self._description = _updated_topology_description_srv_polling(
self._description, seedlist)
self._update_servers()
if self._publish_tp:
self._events.put((
self._listeners.publish_topology_description_changed,
(td_old, self._description, self._topology_id)))
def on_srv_update(self, seedlist):
"""Process a new list of nodes obtained from scanning SRV records."""
# We do no I/O holding the lock.
with self._lock:
if self._opened:
self._process_srv_update(seedlist)
def get_server_by_address(self, address):
"""Get a Server or None.
@ -396,6 +432,11 @@ class Topology(object):
# Mark all servers Unknown.
self._description = self._description.reset()
self._update_servers()
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
self._opened = False
# Publish only after releasing the lock.
@ -471,6 +512,11 @@ class Topology(object):
if self._publish_tp or self._publish_server:
self.__events_executor.open()
# Start the SRV polling thread.
if self._srv_monitor and (self.description.topology_type in
SRV_POLLING_TOPOLOGIES):
self._srv_monitor.open()
# Ensure that the monitors are open.
for server in itervalues(self._servers):
server.open()

View File

@ -24,10 +24,14 @@ from pymongo.server_selectors import Selection
from pymongo.server_type import SERVER_TYPE
# Enumeration for various kinds of MongoDB cluster topologies.
TOPOLOGY_TYPE = namedtuple('TopologyType', ['Single', 'ReplicaSetNoPrimary',
'ReplicaSetWithPrimary', 'Sharded',
'Unknown'])(*range(5))
# Topologies compatible with SRV record polling.
SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded)
class TopologyDescription(object):
def __init__(self,
@ -400,6 +404,40 @@ def updated_topology_description(topology_description, server_description):
topology_description._topology_settings)
def _updated_topology_description_srv_polling(topology_description, seedlist):
"""Return an updated copy of a TopologyDescription.
:Parameters:
- `topology_description`: the current TopologyDescription
- `seedlist`: a list of new seeds new ServerDescription that resulted from
an ismaster call
"""
# Create a copy of the server descriptions.
sds = topology_description.server_descriptions()
# If seeds haven't changed, don't do anything.
if set(sds.keys()) == set(seedlist):
return topology_description
# Add SDs corresponding to servers recently added to the SRV record.
for address in seedlist:
if address not in sds:
sds[address] = ServerDescription(address)
# Remove SDs corresponding to servers no longer part of the SRV record.
for address in list(sds.keys()):
if address not in seedlist:
sds.pop(address)
return TopologyDescription(
topology_description.topology_type,
sds,
topology_description.replica_set_name,
topology_description.max_set_version,
topology_description.max_election_id,
topology_description._topology_settings)
def _update_rs_from_primary(
sds,
replica_set_name,

View File

@ -17,12 +17,6 @@
import re
import warnings
try:
from dns import resolver
_HAVE_DNSPYTHON = True
except ImportError:
_HAVE_DNSPYTHON = False
from bson.py3compat import string_type, PY3
if PY3:
@ -34,6 +28,7 @@ from pymongo.common import (
get_validated_options, INTERNAL_URI_OPTION_NAME_MAP,
URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary)
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver
SCHEME = 'mongodb://'
@ -325,46 +320,10 @@ def split_hosts(hosts, default_port=DEFAULT_PORT):
# backward-compat we allow "db.collection" in URI.
_BAD_DB_CHARS = re.compile('[' + re.escape(r'/ "$') + ']')
if PY3:
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text):
if isinstance(text, bytes):
return text.decode()
return text
else:
def maybe_decode(text):
return text
_ALLOWED_TXT_OPTS = frozenset(
['authsource', 'authSource', 'replicaset', 'replicaSet'])
def _get_dns_srv_hosts(hostname):
try:
results = resolver.query('_mongodb._tcp.' + hostname, 'SRV')
except Exception as exc:
raise ConfigurationError(str(exc))
return [(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
for res in results]
def _get_dns_txt_options(hostname):
try:
results = resolver.query(hostname, 'TXT')
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc))
if len(results) > 1:
raise ConfigurationError('Only one TXT record is supported')
return (
b'&'.join([b''.join(res.strings) for res in results])).decode('utf-8')
def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
normalize=True):
"""Parse and validate a MongoDB URI.
@ -377,7 +336,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
'password': <password> or None,
'database': <database name> or None,
'collection': <collection name> or None,
'options': <dict of MongoDB URI options>
'options': <dict of MongoDB URI options>,
'fqdn': <fqdn of the MongoDB+SRV URI> or None
}
If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done
@ -451,6 +411,7 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
" percent-encoded: %s" % host_part)
hosts = unquote_plus(hosts)
fqdn = None
if is_srv:
nodes = split_hosts(hosts, default_port=None)
@ -462,24 +423,10 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
if port is not None:
raise InvalidURI(
"%s URIs must not include a port number" % (SRV_SCHEME,))
nodes = _get_dns_srv_hosts(fqdn)
try:
plist = fqdn.split(".")[1:]
except Exception:
raise ConfigurationError("Invalid URI host")
slen = len(plist)
if slen < 2:
raise ConfigurationError("Invalid URI host")
for node in nodes:
try:
nlist = node[0].split(".")[1:][-slen:]
except Exception:
raise ConfigurationError("Invalid SRV host")
if plist != nlist:
raise ConfigurationError("Invalid SRV host")
dns_options = _get_dns_txt_options(fqdn)
dns_resolver = _SrvResolver(fqdn)
nodes = dns_resolver.get_hosts()
dns_options = dns_resolver.get_options()
if dns_options:
options = split_options(dns_options, validate, warn, normalize)
if set(options) - _ALLOWED_TXT_OPTS:
@ -514,7 +461,8 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False,
'password': passwd,
'database': dbase,
'collection': collection,
'options': options
'options': options,
'fqdn': fqdn
}

View File

@ -22,9 +22,10 @@ import sys
sys.path[0:0] = [""]
from pymongo.common import validate_read_preference_tags
from pymongo.srv_resolver import _HAVE_DNSPYTHON
from pymongo.errors import ConfigurationError
from pymongo.mongo_client import MongoClient
from pymongo.uri_parser import parse_uri, split_hosts, _HAVE_DNSPYTHON
from pymongo.uri_parser import parse_uri, split_hosts
from test import client_context, unittest
from test.utils import wait_until

View File

@ -29,7 +29,7 @@ from pymongo.monitor import Monitor
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.server_type import SERVER_TYPE
from pymongo.topology import TOPOLOGY_TYPE
from pymongo.topology_description import TOPOLOGY_TYPE
from test import unittest, client_context, client_knobs
from test.utils import (ServerAndTopologyEventListener,
single_client,

View File

@ -27,7 +27,8 @@ from pymongo.topology import Topology
sys.path[0:0] = [""]
from test import client_context, unittest, IntegrationTest
from test.utils import rs_or_single_client, wait_until, EventListener
from test.utils import (rs_or_single_client, wait_until, EventListener,
FunctionCallCounter)
from test.utils_selection_tests import (
create_selection_tests, get_addresses, get_topology_settings_dict,
make_server_description)
@ -39,16 +40,6 @@ _TEST_PATH = os.path.join(
os.path.join('server_selection', 'server_selection'))
class CallCountSelector(object):
"""No-op selector that keeps track of how many times it is called."""
def __init__(self):
self.call_count = 0
def __call__(self, servers):
self.call_count += 1
return servers
class SelectionStoreSelector(object):
"""No-op selector that keeps track of what was passed to it."""
def __init__(self):
@ -114,7 +105,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
@client_context.require_replica_set
def test_selector_called(self):
selector = CallCountSelector()
selector = FunctionCallCounter(lambda x: x)
# Client setup.
mongo_client = rs_or_single_client(server_selector=selector)
@ -178,7 +169,7 @@ class TestCustomServerSelectorFunction(IntegrationTest):
@client_context.require_replica_set
def test_server_selector_bypassed(self):
selector = CallCountSelector()
selector = FunctionCallCounter(lambda x: x)
scenario_def = {
'topology_description': {

201
test/test_srv_polling.py Normal file
View File

@ -0,0 +1,201 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the SRV support tests."""
import sys
from time import sleep
sys.path[0:0] = [""]
import pymongo
from pymongo import common
from pymongo.srv_resolver import _HAVE_DNSPYTHON
from pymongo.mongo_client import MongoClient
from test import client_knobs, unittest
from test.utils import wait_until, FunctionCallCounter
WAIT_TIME = 0.1
class SRVPollingKnobs(object):
def __init__(self, ttl_time=None, min_srv_rescan_interval=None,
dns_resolver_nodelist_response=None,
count_resolver_calls=False):
self.ttl_time = ttl_time
self.min_srv_rescan_interval = min_srv_rescan_interval
self.dns_resolver_nodelist_response = dns_resolver_nodelist_response
self.count_resolver_calls = count_resolver_calls
self.old_min_srv_rescan_interval = None
self.old_dns_resolver_response = None
def enable(self):
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
self.old_dns_resolver_response = \
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
if self.min_srv_rescan_interval is not None:
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
def mock_get_hosts_and_min_ttl(resolver, *args):
nodes, ttl = self.old_dns_resolver_response(resolver)
if self.dns_resolver_nodelist_response is not None:
nodes = self.dns_resolver_nodelist_response()
if self.ttl_time is not None:
ttl = self.ttl_time
return nodes, ttl
if self.count_resolver_calls:
patch_func = FunctionCallCounter(mock_get_hosts_and_min_ttl)
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
self.old_dns_resolver_response
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
class TestSRVPolling(unittest.TestCase):
BASE_SRV_RESPONSE = [
("localhost.test.build.10gen.cc", 27017),
("localhost.test.build.10gen.cc", 27018)]
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
def setUp(self):
if not _HAVE_DNSPYTHON:
raise unittest.SkipTest("SRV polling tests require the dnspython "
"module")
def get_nodelist(self, client):
return client._topology.description.server_descriptions().keys()
def assert_nodelist_change(self, expected_nodelist, client):
"""Check if the client._topology eventually sees all nodes in the
expected_nodelist.
"""
def predicate():
nodelist = self.get_nodelist(client)
if set(expected_nodelist) == set(nodelist):
return True
return False
wait_until(predicate, "see expected nodelist", timeout=10*WAIT_TIME)
def assert_nodelist_nochange(self, expected_nodelist, client):
"""Check if the client._topology ever deviates from seeing all nodes
in the expected_nodelist. Consistency is checked after sleeping for
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
least once.
"""
sleep(WAIT_TIME*10)
nodelist = self.get_nodelist(client)
if set(expected_nodelist) != set(nodelist):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
1, "resolver was never called")
return True
def _run_scenario(self, dns_response, expect_change):
if callable(dns_response):
dns_resolver_response = dns_response
else:
def dns_resolver_response():
return dns_response
if expect_change:
assertion_method = self.assert_nodelist_change
count_resolver_calls = False
expected_response = dns_response
else:
assertion_method = self.assert_nodelist_nochange
count_resolver_calls = True
expected_response = self.BASE_SRV_RESPONSE
# Patch timeouts to ensure short test running times.
with SRVPollingKnobs(
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
mc = MongoClient(self.CONNECTION_STRING)
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, mc)
# Patch list of hosts returned by DNS query.
with SRVPollingKnobs(
dns_resolver_nodelist_response=dns_resolver_response,
count_resolver_calls=count_resolver_calls):
assertion_method(expected_response, mc)
def run_scenario(self, dns_response, expect_change):
# Patch timeouts to ensure short rescan SRV interval.
with client_knobs(heartbeat_frequency=WAIT_TIME,
min_heartbeat_interval=WAIT_TIME,
events_queue_frequency=WAIT_TIME):
self._run_scenario(dns_response, expect_change)
def test_addition(self):
response = self.BASE_SRV_RESPONSE[:]
response.append(
("localhost.test.build.10gen.cc", 27019))
self.run_scenario(response, True)
def test_removal(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(
("localhost.test.build.10gen.cc", 27018))
self.run_scenario(response, True)
def test_replace_one(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(
("localhost.test.build.10gen.cc", 27018))
response.append(
("localhost.test.build.10gen.cc", 27019))
self.run_scenario(response, True)
def test_replace_both_with_one(self):
response = [("localhost.test.build.10gen.cc", 27019)]
self.run_scenario(response, True)
def test_replace_both_with_two(self):
response = [("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27020)]
self.run_scenario(response, True)
def test_dns_failures(self):
from dns import exception
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
def response_callback(*args):
raise exc("DNS Failure!")
self.run_scenario(response_callback, False)
def test_dns_record_lookup_empty(self):
response = []
self.run_scenario(response, False)
if __name__ == '__main__':
unittest.main()

View File

@ -193,7 +193,8 @@ class TestURI(unittest.TestCase):
'password': None,
'database': None,
'collection': None,
'options': {}
'options': {},
'fqdn': None
}
res = copy.deepcopy(orig)
@ -442,7 +443,8 @@ class TestURI(unittest.TestCase):
'nodelist': [('/MongoDB.sock', None)],
'options': {'ssl_certfile': '/a/b'},
'password': 'foo/bar',
'username': 'jesse'},
'username': 'jesse',
'fqdn': None},
parse_uri(
'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=/a/b',
validate=False))
@ -453,7 +455,8 @@ class TestURI(unittest.TestCase):
'nodelist': [('/MongoDB.sock', None)],
'options': {'ssl_certfile': 'a/b'},
'password': 'foo/bar',
'username': 'jesse'},
'username': 'jesse',
'fqdn': None},
parse_uri(
'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=a/b',
validate=False))

View File

@ -170,6 +170,21 @@ class CompareType(object):
return not self.__eq__(other)
class FunctionCallCounter(object):
"""Class that wraps a function and keeps count of invocations."""
def __init__(self, function):
self._function = function
self._call_count = 0
def __call__(self, *args, **kwargs):
self._call_count += 1
return self._function(*args, **kwargs)
@property
def call_count(self):
return self._call_count
class TestCreator(object):
"""Class to create test cases from specifications."""
def __init__(self, create_test, test_class, test_path):
@ -502,6 +517,7 @@ def wait_until(predicate, success_description, timeout=10):
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout)/100, 0.1)
while True:
retval = predicate()
if retval:
@ -510,7 +526,7 @@ def wait_until(predicate, success_description, timeout=10):
if time.time() - start > timeout:
raise AssertionError("Didn't ever %s" % success_description)
time.sleep(0.1)
time.sleep(interval)
def is_mongos(client):