From 38848d47f32ea425720580f3147d5f83996380e4 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Fri, 31 Oct 2014 21:24:41 -0400 Subject: [PATCH] MongoClient.close() stops monitors. Subsequent operations restart them. --- pymongo/monitor.py | 5 +++-- pymongo/server.py | 3 ++- pymongo/topology.py | 2 +- test/test_monitor.py | 43 +++++++++++++++++++---------------- test/test_threads.py | 41 +++++++++++++++++++++++++++++++++- test/utils.py | 53 +++++++++++++++++++++++++++----------------- 6 files changed, 103 insertions(+), 44 deletions(-) diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 78132eb68..616e7345a 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -56,6 +56,7 @@ class Monitor(object): Multiple calls have no effect. """ + self._stopped = False started = False try: started = self._thread and self._thread.is_alive() @@ -73,12 +74,12 @@ class Monitor(object): def close(self): """Disconnect and stop monitoring. - The Monitor cannot be used after closing. + open() restarts the monitor after closing. """ self._stopped = True self._pool.reset() - # Awake the thread so it notices that _stopped is True. + # Wake the thread so it notices that _stopped is True. self.request_check() def join(self, timeout=None): diff --git a/pymongo/server.py b/pymongo/server.py index 4d73eafe2..f56dcb335 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -42,8 +42,9 @@ class Server(object): self._pool.reset() def reset(self): - """Clear the connection pool.""" + """Clear the connection pool and stop the monitor.""" self._pool.reset() + self._monitor.close() def request_check(self): """Check the server's state soon.""" diff --git a/pymongo/topology.py b/pymongo/topology.py index 99bb7debc..9026020dc 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -207,7 +207,7 @@ class Topology(object): """ with self._lock: for server in self._servers.values(): - server.pool.reset() + server.reset() # Mark all servers Unknown. self._description = self._description.reset() diff --git a/test/test_monitor.py b/test/test_monitor.py index eaa38b418..cb426775c 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -17,38 +17,43 @@ import gc import sys import time +from functools import partial sys.path[0:0] = [""] from pymongo.monitor import MONITORS from test import unittest, port, host, IntegrationTest -from test.utils import single_client, wait_until, one +from test.utils import single_client, one, connected, wait_until + + +def find_monitor_ref(monitor): + for ref in MONITORS.copy(): + if ref() is monitor: + return ref + + return None + + +def unregistered(ref): + gc.collect() + return ref not in MONITORS class TestMonitor(IntegrationTest): def test_atexit_hook(self): - # Weakrefs to currently running Monitor instances. - prior_monitors = MONITORS.copy() client = single_client(host, port) - wait_until(lambda: MONITORS - prior_monitors, - 'register new monitor') + monitor = one(client._topology._servers.values())._monitor + connected(client) - # Just one new monitor should have been registered. - new_monitor_refs = MONITORS - prior_monitors - self.assertEqual(1, len(new_monitor_refs)) - monitor_ref = one(new_monitor_refs) + # The client registers a weakref to the monitor. + ref = wait_until(partial(find_monitor_ref, monitor), + 'register monitor') + + client.close() + del monitor del client - start = time.time() - while time.time() - start < 30: - gc.collect() - if monitor_ref not in MONITORS: - # New monitor was unregistered. - break - - time.sleep(0.1) - else: - self.fail("Didn't ever unregister monitor") + wait_until(partial(unregistered, ref), 'unregister monitor') if __name__ == "__main__": diff --git a/test/test_threads.py b/test/test_threads.py index cf68f719b..d8bfb3ec3 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -18,7 +18,8 @@ import threading import traceback from test import unittest, client_context, IntegrationTest -from test.utils import (joinall, +from test.utils import (frequent_thread_switches, + joinall, remove_all_users, RendezvousThread, rs_or_single_client) @@ -57,12 +58,15 @@ class SaveAndFind(threading.Thread): threading.Thread.__init__(self) self.collection = collection self.setDaemon(True) + self.passed = False def run(self): sum = 0 for document in self.collection.find(): sum += document["x"] + assert sum == 499500, "sum was %d not 499500" % sum + self.passed = True class Insert(threading.Thread): @@ -114,6 +118,21 @@ class Update(threading.Thread): assert error +class Disconnect(threading.Thread): + + def __init__(self, client, n): + threading.Thread.__init__(self) + self.client = client + self.n = n + self.passed = False + + def run(self): + for _ in range(self.n): + self.client.disconnect() + + self.passed = True + + class IgnoreAutoReconnect(threading.Thread): def __init__(self, collection, n): @@ -285,6 +304,26 @@ class TestThreads(IntegrationTest): for t in threads: self.assertTrue(t.passed, "%s threw exception" % t) + def test_client_disconnect(self): + self.db.drop_collection("test") + for i in range(1000): + self.db.test.save({"x": i}) + + # Start 10 threads that execute a query, and 10 threads that call + # client.disconnect() 10 times in a row. + threads = [SaveAndFind(self.db.test) for _ in range(10)] + threads.extend(Disconnect(self.db.connection, 10) for _ in range(10)) + + with frequent_thread_switches(): + for t in threads: + t.start() + + for t in threads: + t.join(30) + + for t in threads: + self.assertTrue(t.passed) + class TestThreadsAuth(IntegrationTest): @classmethod diff --git a/test/utils.py b/test/utils.py index 45ca9ae90..ce6f8920b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -210,7 +210,7 @@ def connected(client): return client def wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be True. + """Wait up to 10 seconds (by default) for predicate to be true. E.g.: @@ -219,12 +219,20 @@ def wait_until(predicate, success_description, timeout=10): If the lambda-expression isn't true after 10 seconds, we raise AssertionError("Didn't ever connect to the primary"). + + Returns the predicate's first true value. """ start = time.time() - while not predicate(): + while True: + retval = predicate() + if retval: + return retval + if time.time() - start > timeout: raise AssertionError("Didn't ever %s" % success_description) + time.sleep(0.1) + def is_mongos(client): res = client.admin.command('ismaster') return res.get('msg', '') == 'isdbgrid' @@ -505,6 +513,28 @@ def run_threads(collection, target): assert not t.isAlive() +@contextlib.contextmanager +def frequent_thread_switches(): + """Make concurrency bugs more likely to manifest.""" + interval = None + if not sys.platform.startswith('java'): + if hasattr(sys, 'getswitchinterval'): + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + else: + interval = sys.getcheckinterval() + sys.setcheckinterval(1) + + try: + yield + finally: + if not sys.platform.startswith('java'): + if hasattr(sys, 'setswitchinterval'): + sys.setswitchinterval(interval) + else: + sys.setcheckinterval(interval) + + def lazy_client_trial(reset, target, test, get_client): """Test concurrent operations on a lazily-connecting client. @@ -518,27 +548,10 @@ def lazy_client_trial(reset, target, test, get_client): """ collection = client_context.client.pymongo_test.test - # Make concurrency bugs more likely to manifest. - interval = None - if not sys.platform.startswith('java'): - if hasattr(sys, 'getswitchinterval'): - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - else: - interval = sys.getcheckinterval() - sys.setcheckinterval(1) - - try: + with frequent_thread_switches(): for i in range(NTRIALS): reset(collection) lazy_client = get_client() lazy_collection = lazy_client.pymongo_test.test run_threads(lazy_collection, target) test(lazy_collection) - - finally: - if not sys.platform.startswith('java'): - if hasattr(sys, 'setswitchinterval'): - sys.setswitchinterval(interval) - else: - sys.setcheckinterval(interval)