From b834e312a3232e971bc622111acee5ff002a836c Mon Sep 17 00:00:00 2001 From: Prashant Mital Date: Fri, 14 Jun 2019 13:24:55 -0700 Subject: [PATCH] PYTHON-1872 Fix SrvMonitor related test failures --- pymongo/monitor.py | 17 ++++++++++------- test/test_srv_polling.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/pymongo/monitor.py b/pymongo/monitor.py index d77d4e450..23af967ff 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -204,6 +204,7 @@ class SrvMonitor(MonitorBase): The Topology is weakly referenced. """ self._settings = topology_settings + self._seedlist = self._settings._seeds self._fqdn = self._settings.fqdn # We strongly reference the executor and it weakly references us via @@ -228,12 +229,14 @@ class SrvMonitor(MonitorBase): self._topology = weakref.proxy(topology, executor.close) def _run(self): - try: - self._seedlist = self._get_seedlist() - self._topology.on_srv_update(self._seedlist) - except ReferenceError: - # Topology was garbage-collected. - self.close() + seedlist = self._get_seedlist() + if seedlist: + self._seedlist = seedlist + try: + self._topology.on_srv_update(self._seedlist) + except ReferenceError: + # Topology was garbage-collected. + self.close() def _get_seedlist(self): """Poll SRV records for a seedlist. @@ -251,7 +254,7 @@ class SrvMonitor(MonitorBase): # - SRV records must be rescanned every heartbeatFrequencyMS # - Topology must be left unchanged self.request_check() - return self._seedlist + return None else: self._executor.update_interval( max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 6794e8e54..e7d62ca00 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -23,6 +23,7 @@ sys.path[0:0] = [""] import pymongo from pymongo import common +from pymongo.errors import ConfigurationError from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.mongo_client import MongoClient from test import client_knobs, unittest @@ -91,6 +92,14 @@ class TestSRVPolling(unittest.TestCase): if not _HAVE_DNSPYTHON: raise unittest.SkipTest("SRV polling tests require the dnspython " "module") + # Patch timeouts to ensure short rescan SRV interval. + self.client_knobs = client_knobs( + heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME, + events_queue_frequency=WAIT_TIME) + self.client_knobs.enable() + + def tearDown(self): + self.client_knobs.disable() def get_nodelist(self, client): return client._topology.description.server_descriptions().keys() @@ -122,7 +131,7 @@ class TestSRVPolling(unittest.TestCase): 1, "resolver was never called") return True - def _run_scenario(self, dns_response, expect_change): + def run_scenario(self, dns_response, expect_change): if callable(dns_response): dns_resolver_response = dns_response else: @@ -149,13 +158,6 @@ class TestSRVPolling(unittest.TestCase): count_resolver_calls=count_resolver_calls): assertion_method(expected_response, mc) - def run_scenario(self, dns_response, expect_change): - # Patch timeouts to ensure short rescan SRV interval. - with client_knobs(heartbeat_frequency=WAIT_TIME, - min_heartbeat_interval=WAIT_TIME, - events_queue_frequency=WAIT_TIME): - self._run_scenario(dns_response, expect_change) - def test_addition(self): response = self.BASE_SRV_RESPONSE[:] response.append( @@ -196,6 +198,24 @@ class TestSRVPolling(unittest.TestCase): response = [] self.run_scenario(response, False) + def _test_recover_from_initial(self, response_callback): + with SRVPollingKnobs( + ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, + dns_resolver_nodelist_response=response_callback, + count_resolver_calls=True): + mc = MongoClient(self.CONNECTION_STRING) + self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, mc) + + def test_recover_from_initially_empty_seedlist(self): + def empty_seedlist(): + return [] + self._test_recover_from_initial(empty_seedlist) + + def test_recover_from_initially_erroring_seedlist(self): + def erroring_seedlist(): + raise ConfigurationError + self._test_recover_from_initial(erroring_seedlist) + if __name__ == '__main__': unittest.main()