mongo-python-driver/test/test_srv_polling.py
2019-11-18 17:41:42 -08:00

236 lines
8.7 KiB
Python

# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the SRV support tests."""
import sys
from time import sleep
sys.path[0:0] = [""]
import pymongo
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.srv_resolver import _HAVE_DNSPYTHON
from pymongo.mongo_client import MongoClient
from test import client_knobs, unittest
from test.utils import wait_until, FunctionCallRecorder
WAIT_TIME = 0.1
class SrvPollingKnobs(object):
def __init__(self, ttl_time=None, min_srv_rescan_interval=None,
dns_resolver_nodelist_response=None,
count_resolver_calls=False):
self.ttl_time = ttl_time
self.min_srv_rescan_interval = min_srv_rescan_interval
self.dns_resolver_nodelist_response = dns_resolver_nodelist_response
self.count_resolver_calls = count_resolver_calls
self.old_min_srv_rescan_interval = None
self.old_dns_resolver_response = None
def enable(self):
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
self.old_dns_resolver_response = \
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
if self.min_srv_rescan_interval is not None:
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
def mock_get_hosts_and_min_ttl(resolver, *args):
nodes, ttl = self.old_dns_resolver_response(resolver)
if self.dns_resolver_nodelist_response is not None:
nodes = self.dns_resolver_nodelist_response()
if self.ttl_time is not None:
ttl = self.ttl_time
return nodes, ttl
if self.count_resolver_calls:
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
else:
patch_func = mock_get_hosts_and_min_ttl
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func
def __enter__(self):
self.enable()
def disable(self):
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \
self.old_dns_resolver_response
def __exit__(self, exc_type, exc_val, exc_tb):
self.disable()
class TestSrvPolling(unittest.TestCase):
BASE_SRV_RESPONSE = [
("localhost.test.build.10gen.cc", 27017),
("localhost.test.build.10gen.cc", 27018)]
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
def setUp(self):
if not _HAVE_DNSPYTHON:
raise unittest.SkipTest("SRV polling tests require the dnspython "
"module")
# Patch timeouts to ensure short rescan SRV interval.
self.client_knobs = client_knobs(
heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME,
events_queue_frequency=WAIT_TIME)
self.client_knobs.enable()
def tearDown(self):
self.client_knobs.disable()
def get_nodelist(self, client):
return client._topology.description.server_descriptions().keys()
def assert_nodelist_change(self, expected_nodelist, client):
"""Check if the client._topology eventually sees all nodes in the
expected_nodelist.
"""
def predicate():
nodelist = self.get_nodelist(client)
if set(expected_nodelist) == set(nodelist):
return True
return False
wait_until(predicate, "see expected nodelist", timeout=100*WAIT_TIME)
def assert_nodelist_nochange(self, expected_nodelist, client):
"""Check if the client._topology ever deviates from seeing all nodes
in the expected_nodelist. Consistency is checked after sleeping for
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
least once.
"""
sleep(WAIT_TIME*10)
nodelist = self.get_nodelist(client)
if set(expected_nodelist) != set(nodelist):
msg = "Client nodelist %s changed unexpectedly (expected %s)"
raise self.fail(msg % (nodelist, expected_nodelist))
self.assertGreaterEqual(
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count,
1, "resolver was never called")
return True
def run_scenario(self, dns_response, expect_change):
if callable(dns_response):
dns_resolver_response = dns_response
else:
def dns_resolver_response():
return dns_response
if expect_change:
assertion_method = self.assert_nodelist_change
count_resolver_calls = False
expected_response = dns_response
else:
assertion_method = self.assert_nodelist_nochange
count_resolver_calls = True
expected_response = self.BASE_SRV_RESPONSE
# Patch timeouts to ensure short test running times.
with SrvPollingKnobs(
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
client = MongoClient(self.CONNECTION_STRING)
self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
# Patch list of hosts returned by DNS query.
with SrvPollingKnobs(
dns_resolver_nodelist_response=dns_resolver_response,
count_resolver_calls=count_resolver_calls):
assertion_method(expected_response, client)
def test_addition(self):
response = self.BASE_SRV_RESPONSE[:]
response.append(
("localhost.test.build.10gen.cc", 27019))
self.run_scenario(response, True)
def test_removal(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(
("localhost.test.build.10gen.cc", 27018))
self.run_scenario(response, True)
def test_replace_one(self):
response = self.BASE_SRV_RESPONSE[:]
response.remove(
("localhost.test.build.10gen.cc", 27018))
response.append(
("localhost.test.build.10gen.cc", 27019))
self.run_scenario(response, True)
def test_replace_both_with_one(self):
response = [("localhost.test.build.10gen.cc", 27019)]
self.run_scenario(response, True)
def test_replace_both_with_two(self):
response = [("localhost.test.build.10gen.cc", 27019),
("localhost.test.build.10gen.cc", 27020)]
self.run_scenario(response, True)
def test_dns_failures(self):
from dns import exception
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
def response_callback(*args):
raise exc("DNS Failure!")
self.run_scenario(response_callback, False)
def test_dns_record_lookup_empty(self):
response = []
self.run_scenario(response, False)
def _test_recover_from_initial(self, initial_callback):
# Construct a valid final response callback distinct from base.
response_final = self.BASE_SRV_RESPONSE[:]
response_final.pop()
def final_callback():
return response_final
with SrvPollingKnobs(
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME,
dns_resolver_nodelist_response=initial_callback,
count_resolver_calls=True):
# Client uses unpatched method to get initial nodelist
client = MongoClient(self.CONNECTION_STRING)
# Invalid DNS resolver response should not change nodelist.
self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
with SrvPollingKnobs(
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME,
dns_resolver_nodelist_response=final_callback):
# Nodelist should reflect new valid DNS resolver response.
self.assert_nodelist_change(response_final, client)
def test_recover_from_initially_empty_seedlist(self):
def empty_seedlist():
return []
self._test_recover_from_initial(empty_seedlist)
def test_recover_from_initially_erroring_seedlist(self):
def erroring_seedlist():
raise ConfigurationError
self._test_recover_from_initial(erroring_seedlist)
if __name__ == '__main__':
unittest.main()