MongoClient.close() stops monitors.
Subsequent operations restart them.
This commit is contained in:
parent
7226dab318
commit
38848d47f3
@ -56,6 +56,7 @@ class Monitor(object):
|
||||
|
||||
Multiple calls have no effect.
|
||||
"""
|
||||
self._stopped = False
|
||||
started = False
|
||||
try:
|
||||
started = self._thread and self._thread.is_alive()
|
||||
@ -73,12 +74,12 @@ class Monitor(object):
|
||||
def close(self):
|
||||
"""Disconnect and stop monitoring.
|
||||
|
||||
The Monitor cannot be used after closing.
|
||||
open() restarts the monitor after closing.
|
||||
"""
|
||||
self._stopped = True
|
||||
self._pool.reset()
|
||||
|
||||
# Awake the thread so it notices that _stopped is True.
|
||||
# Wake the thread so it notices that _stopped is True.
|
||||
self.request_check()
|
||||
|
||||
def join(self, timeout=None):
|
||||
|
||||
@ -42,8 +42,9 @@ class Server(object):
|
||||
self._pool.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Clear the connection pool."""
|
||||
"""Clear the connection pool and stop the monitor."""
|
||||
self._pool.reset()
|
||||
self._monitor.close()
|
||||
|
||||
def request_check(self):
|
||||
"""Check the server's state soon."""
|
||||
|
||||
@ -207,7 +207,7 @@ class Topology(object):
|
||||
"""
|
||||
with self._lock:
|
||||
for server in self._servers.values():
|
||||
server.pool.reset()
|
||||
server.reset()
|
||||
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
|
||||
@ -17,38 +17,43 @@
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from pymongo.monitor import MONITORS
|
||||
from test import unittest, port, host, IntegrationTest
|
||||
from test.utils import single_client, wait_until, one
|
||||
from test.utils import single_client, one, connected, wait_until
|
||||
|
||||
|
||||
def find_monitor_ref(monitor):
|
||||
for ref in MONITORS.copy():
|
||||
if ref() is monitor:
|
||||
return ref
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def unregistered(ref):
|
||||
gc.collect()
|
||||
return ref not in MONITORS
|
||||
|
||||
|
||||
class TestMonitor(IntegrationTest):
|
||||
def test_atexit_hook(self):
|
||||
# Weakrefs to currently running Monitor instances.
|
||||
prior_monitors = MONITORS.copy()
|
||||
client = single_client(host, port)
|
||||
wait_until(lambda: MONITORS - prior_monitors,
|
||||
'register new monitor')
|
||||
monitor = one(client._topology._servers.values())._monitor
|
||||
connected(client)
|
||||
|
||||
# Just one new monitor should have been registered.
|
||||
new_monitor_refs = MONITORS - prior_monitors
|
||||
self.assertEqual(1, len(new_monitor_refs))
|
||||
monitor_ref = one(new_monitor_refs)
|
||||
# The client registers a weakref to the monitor.
|
||||
ref = wait_until(partial(find_monitor_ref, monitor),
|
||||
'register monitor')
|
||||
|
||||
client.close()
|
||||
del monitor
|
||||
del client
|
||||
|
||||
start = time.time()
|
||||
while time.time() - start < 30:
|
||||
gc.collect()
|
||||
if monitor_ref not in MONITORS:
|
||||
# New monitor was unregistered.
|
||||
break
|
||||
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
self.fail("Didn't ever unregister monitor")
|
||||
wait_until(partial(unregistered, ref), 'unregister monitor')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -18,7 +18,8 @@ import threading
|
||||
import traceback
|
||||
|
||||
from test import unittest, client_context, IntegrationTest
|
||||
from test.utils import (joinall,
|
||||
from test.utils import (frequent_thread_switches,
|
||||
joinall,
|
||||
remove_all_users,
|
||||
RendezvousThread,
|
||||
rs_or_single_client)
|
||||
@ -57,12 +58,15 @@ class SaveAndFind(threading.Thread):
|
||||
threading.Thread.__init__(self)
|
||||
self.collection = collection
|
||||
self.setDaemon(True)
|
||||
self.passed = False
|
||||
|
||||
def run(self):
|
||||
sum = 0
|
||||
for document in self.collection.find():
|
||||
sum += document["x"]
|
||||
|
||||
assert sum == 499500, "sum was %d not 499500" % sum
|
||||
self.passed = True
|
||||
|
||||
|
||||
class Insert(threading.Thread):
|
||||
@ -114,6 +118,21 @@ class Update(threading.Thread):
|
||||
assert error
|
||||
|
||||
|
||||
class Disconnect(threading.Thread):
|
||||
|
||||
def __init__(self, client, n):
|
||||
threading.Thread.__init__(self)
|
||||
self.client = client
|
||||
self.n = n
|
||||
self.passed = False
|
||||
|
||||
def run(self):
|
||||
for _ in range(self.n):
|
||||
self.client.disconnect()
|
||||
|
||||
self.passed = True
|
||||
|
||||
|
||||
class IgnoreAutoReconnect(threading.Thread):
|
||||
|
||||
def __init__(self, collection, n):
|
||||
@ -285,6 +304,26 @@ class TestThreads(IntegrationTest):
|
||||
for t in threads:
|
||||
self.assertTrue(t.passed, "%s threw exception" % t)
|
||||
|
||||
def test_client_disconnect(self):
|
||||
self.db.drop_collection("test")
|
||||
for i in range(1000):
|
||||
self.db.test.save({"x": i})
|
||||
|
||||
# Start 10 threads that execute a query, and 10 threads that call
|
||||
# client.disconnect() 10 times in a row.
|
||||
threads = [SaveAndFind(self.db.test) for _ in range(10)]
|
||||
threads.extend(Disconnect(self.db.connection, 10) for _ in range(10))
|
||||
|
||||
with frequent_thread_switches():
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join(30)
|
||||
|
||||
for t in threads:
|
||||
self.assertTrue(t.passed)
|
||||
|
||||
|
||||
class TestThreadsAuth(IntegrationTest):
|
||||
@classmethod
|
||||
|
||||
@ -210,7 +210,7 @@ def connected(client):
|
||||
return client
|
||||
|
||||
def wait_until(predicate, success_description, timeout=10):
|
||||
"""Wait up to 10 seconds (by default) for predicate to be True.
|
||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
||||
|
||||
E.g.:
|
||||
|
||||
@ -219,12 +219,20 @@ def wait_until(predicate, success_description, timeout=10):
|
||||
|
||||
If the lambda-expression isn't true after 10 seconds, we raise
|
||||
AssertionError("Didn't ever connect to the primary").
|
||||
|
||||
Returns the predicate's first true value.
|
||||
"""
|
||||
start = time.time()
|
||||
while not predicate():
|
||||
while True:
|
||||
retval = predicate()
|
||||
if retval:
|
||||
return retval
|
||||
|
||||
if time.time() - start > timeout:
|
||||
raise AssertionError("Didn't ever %s" % success_description)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
def is_mongos(client):
|
||||
res = client.admin.command('ismaster')
|
||||
return res.get('msg', '') == 'isdbgrid'
|
||||
@ -505,6 +513,28 @@ def run_threads(collection, target):
|
||||
assert not t.isAlive()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def frequent_thread_switches():
|
||||
"""Make concurrency bugs more likely to manifest."""
|
||||
interval = None
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'getswitchinterval'):
|
||||
interval = sys.getswitchinterval()
|
||||
sys.setswitchinterval(1e-6)
|
||||
else:
|
||||
interval = sys.getcheckinterval()
|
||||
sys.setcheckinterval(1)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'setswitchinterval'):
|
||||
sys.setswitchinterval(interval)
|
||||
else:
|
||||
sys.setcheckinterval(interval)
|
||||
|
||||
|
||||
def lazy_client_trial(reset, target, test, get_client):
|
||||
"""Test concurrent operations on a lazily-connecting client.
|
||||
|
||||
@ -518,27 +548,10 @@ def lazy_client_trial(reset, target, test, get_client):
|
||||
"""
|
||||
collection = client_context.client.pymongo_test.test
|
||||
|
||||
# Make concurrency bugs more likely to manifest.
|
||||
interval = None
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'getswitchinterval'):
|
||||
interval = sys.getswitchinterval()
|
||||
sys.setswitchinterval(1e-6)
|
||||
else:
|
||||
interval = sys.getcheckinterval()
|
||||
sys.setcheckinterval(1)
|
||||
|
||||
try:
|
||||
with frequent_thread_switches():
|
||||
for i in range(NTRIALS):
|
||||
reset(collection)
|
||||
lazy_client = get_client()
|
||||
lazy_collection = lazy_client.pymongo_test.test
|
||||
run_threads(lazy_collection, target)
|
||||
test(lazy_collection)
|
||||
|
||||
finally:
|
||||
if not sys.platform.startswith('java'):
|
||||
if hasattr(sys, 'setswitchinterval'):
|
||||
sys.setswitchinterval(interval)
|
||||
else:
|
||||
sys.setcheckinterval(interval)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user