MongoClient.close() stops monitors.

Subsequent operations restart them.
This commit is contained in:
A. Jesse Jiryu Davis 2014-10-31 21:24:41 -04:00
parent 7226dab318
commit 38848d47f3
6 changed files with 103 additions and 44 deletions

View File

@ -56,6 +56,7 @@ class Monitor(object):
Multiple calls have no effect.
"""
self._stopped = False
started = False
try:
started = self._thread and self._thread.is_alive()
@ -73,12 +74,12 @@ class Monitor(object):
def close(self):
"""Disconnect and stop monitoring.
The Monitor cannot be used after closing.
open() restarts the monitor after closing.
"""
self._stopped = True
self._pool.reset()
# Awake the thread so it notices that _stopped is True.
# Wake the thread so it notices that _stopped is True.
self.request_check()
def join(self, timeout=None):

View File

@ -42,8 +42,9 @@ class Server(object):
self._pool.reset()
def reset(self):
"""Clear the connection pool."""
"""Clear the connection pool and stop the monitor."""
self._pool.reset()
self._monitor.close()
def request_check(self):
"""Check the server's state soon."""

View File

@ -207,7 +207,7 @@ class Topology(object):
"""
with self._lock:
for server in self._servers.values():
server.pool.reset()
server.reset()
# Mark all servers Unknown.
self._description = self._description.reset()

View File

@ -17,38 +17,43 @@
import gc
import sys
import time
from functools import partial
sys.path[0:0] = [""]
from pymongo.monitor import MONITORS
from test import unittest, port, host, IntegrationTest
from test.utils import single_client, wait_until, one
from test.utils import single_client, one, connected, wait_until
def find_monitor_ref(monitor):
for ref in MONITORS.copy():
if ref() is monitor:
return ref
return None
def unregistered(ref):
gc.collect()
return ref not in MONITORS
class TestMonitor(IntegrationTest):
def test_atexit_hook(self):
# Weakrefs to currently running Monitor instances.
prior_monitors = MONITORS.copy()
client = single_client(host, port)
wait_until(lambda: MONITORS - prior_monitors,
'register new monitor')
monitor = one(client._topology._servers.values())._monitor
connected(client)
# Just one new monitor should have been registered.
new_monitor_refs = MONITORS - prior_monitors
self.assertEqual(1, len(new_monitor_refs))
monitor_ref = one(new_monitor_refs)
# The client registers a weakref to the monitor.
ref = wait_until(partial(find_monitor_ref, monitor),
'register monitor')
client.close()
del monitor
del client
start = time.time()
while time.time() - start < 30:
gc.collect()
if monitor_ref not in MONITORS:
# New monitor was unregistered.
break
time.sleep(0.1)
else:
self.fail("Didn't ever unregister monitor")
wait_until(partial(unregistered, ref), 'unregister monitor')
if __name__ == "__main__":

View File

@ -18,7 +18,8 @@ import threading
import traceback
from test import unittest, client_context, IntegrationTest
from test.utils import (joinall,
from test.utils import (frequent_thread_switches,
joinall,
remove_all_users,
RendezvousThread,
rs_or_single_client)
@ -57,12 +58,15 @@ class SaveAndFind(threading.Thread):
threading.Thread.__init__(self)
self.collection = collection
self.setDaemon(True)
self.passed = False
def run(self):
sum = 0
for document in self.collection.find():
sum += document["x"]
assert sum == 499500, "sum was %d not 499500" % sum
self.passed = True
class Insert(threading.Thread):
@ -114,6 +118,21 @@ class Update(threading.Thread):
assert error
class Disconnect(threading.Thread):
def __init__(self, client, n):
threading.Thread.__init__(self)
self.client = client
self.n = n
self.passed = False
def run(self):
for _ in range(self.n):
self.client.disconnect()
self.passed = True
class IgnoreAutoReconnect(threading.Thread):
def __init__(self, collection, n):
@ -285,6 +304,26 @@ class TestThreads(IntegrationTest):
for t in threads:
self.assertTrue(t.passed, "%s threw exception" % t)
def test_client_disconnect(self):
self.db.drop_collection("test")
for i in range(1000):
self.db.test.save({"x": i})
# Start 10 threads that execute a query, and 10 threads that call
# client.disconnect() 10 times in a row.
threads = [SaveAndFind(self.db.test) for _ in range(10)]
threads.extend(Disconnect(self.db.connection, 10) for _ in range(10))
with frequent_thread_switches():
for t in threads:
t.start()
for t in threads:
t.join(30)
for t in threads:
self.assertTrue(t.passed)
class TestThreadsAuth(IntegrationTest):
@classmethod

View File

@ -210,7 +210,7 @@ def connected(client):
return client
def wait_until(predicate, success_description, timeout=10):
"""Wait up to 10 seconds (by default) for predicate to be True.
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
@ -219,12 +219,20 @@ def wait_until(predicate, success_description, timeout=10):
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
while not predicate():
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise AssertionError("Didn't ever %s" % success_description)
time.sleep(0.1)
def is_mongos(client):
res = client.admin.command('ismaster')
return res.get('msg', '') == 'isdbgrid'
@ -505,6 +513,28 @@ def run_threads(collection, target):
assert not t.isAlive()
@contextlib.contextmanager
def frequent_thread_switches():
"""Make concurrency bugs more likely to manifest."""
interval = None
if not sys.platform.startswith('java'):
if hasattr(sys, 'getswitchinterval'):
interval = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
else:
interval = sys.getcheckinterval()
sys.setcheckinterval(1)
try:
yield
finally:
if not sys.platform.startswith('java'):
if hasattr(sys, 'setswitchinterval'):
sys.setswitchinterval(interval)
else:
sys.setcheckinterval(interval)
def lazy_client_trial(reset, target, test, get_client):
"""Test concurrent operations on a lazily-connecting client.
@ -518,27 +548,10 @@ def lazy_client_trial(reset, target, test, get_client):
"""
collection = client_context.client.pymongo_test.test
# Make concurrency bugs more likely to manifest.
interval = None
if not sys.platform.startswith('java'):
if hasattr(sys, 'getswitchinterval'):
interval = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
else:
interval = sys.getcheckinterval()
sys.setcheckinterval(1)
try:
with frequent_thread_switches():
for i in range(NTRIALS):
reset(collection)
lazy_client = get_client()
lazy_collection = lazy_client.pymongo_test.test
run_threads(lazy_collection, target)
test(lazy_collection)
finally:
if not sys.platform.startswith('java'):
if hasattr(sys, 'setswitchinterval'):
sys.setswitchinterval(interval)
else:
sys.setcheckinterval(interval)