From 8fe3154138a1a81963286af27a823a7060ca47af Mon Sep 17 00:00:00 2001 From: Justin Patrin Date: Mon, 15 Apr 2013 17:35:45 -0700 Subject: [PATCH] PYTHON-436 Change max_pool_size to limit the maximum concurrent connections rather than just the idle connections in the pool. Also add support for waitQueueTimeoutMS and waitQueueMultiple. --- .travis.yml | 4 + doc/contributors.rst | 1 + pymongo/common.py | 2 + pymongo/mongo_client.py | 111 +++++++----- pymongo/mongo_replica_set_client.py | 110 +++++++----- pymongo/pool.py | 105 +++++++++-- pymongo/thread_util.py | 129 ++++++++++++- test/test_pooling.py | 11 +- test/test_pooling_base.py | 268 ++++++++++++++++++++++++---- test/test_pooling_gevent.py | 11 +- 10 files changed, 607 insertions(+), 145 deletions(-) diff --git a/.travis.yml b/.travis.yml index e77713247..60528be6d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,3 +11,7 @@ services: - mongodb script: python setup.py test + +install: + #Temporary solution for Travis CI mutiprocessing issue #155 + - sudo rm -rf /dev/shm && sudo ln -s /run/shm /dev/shm diff --git a/doc/contributors.rst b/doc/contributors.rst index c04b352a6..eb4bba8e3 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -62,3 +62,4 @@ The following is a list of people who have contributed to - Craig Hobbs (craigahobbs) - Emily Stolfo (estolfo) - Sam Helman (shelman) +- Justin Patrin (reversefold) diff --git a/pymongo/common.py b/pymongo/common.py index 9332ff2f4..a0ee7413a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -218,6 +218,8 @@ VALIDATORS = { 'journal': validate_boolean, 'connecttimeoutms': validate_timeout_or_none, 'sockettimeoutms': validate_timeout_or_none, + 'waitqueuetimeoutms': validate_timeout_or_none, + 'waitqueuemultiple': validate_positive_integer_or_none, 'ssl': validate_boolean, 'ssl_keyfile': validate_readable, 'ssl_certfile': validate_readable, diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 88a52ee95..e3f1a3697 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -85,8 +85,9 @@ class MongoClient(common.BaseObject): __max_bson_size = 4 * 1024 * 1024 - def __init__(self, host=None, port=None, max_pool_size=10, - document_class=dict, tz_aware=False, _connect=True, **kwargs): + def __init__(self, host=None, port=None, max_pool_size=100, + document_class=dict, tz_aware=False, _connect=True, + **kwargs): """Create a new connection to a single MongoDB instance at *host:port*. The resultant client object has connection-pooling built @@ -120,8 +121,10 @@ class MongoClient(common.BaseObject): it must be enclosed in '[' and ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for localhost) - `port` (optional): port number on which to connect - - `max_pool_size` (optional): The maximum number of idle connections - to keep open in the pool for future use + - `max_pool_size` (optional): The maximum number of connections + that the pool will open simultaneously. If this is set, operations + will block if there are `max_pool_size` outstanding connections + from the pool. Defaults to 100. - `document_class` (optional): default class to use for documents returned from queries on this client - `tz_aware` (optional): if ``True``, @@ -135,6 +138,11 @@ class MongoClient(common.BaseObject): receive on a socket can take before timing out. - `connectTimeoutMS`: (integer) How long (in milliseconds) a connection can take to be opened before timing out. + - `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a + thread will wait for a socket from the pool if the pool has no + free sockets. + - `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give + the number of threads allowed to wait for a socket at one time. - `auto_start_request`: If ``True``, each thread that accesses this :class:`MongoClient` has a socket allocated to it for the thread's lifetime. This ensures consistent reads, even if you @@ -273,6 +281,9 @@ class MongoClient(common.BaseObject): self.__net_timeout = options.get('sockettimeoutms') self.__conn_timeout = options.get('connecttimeoutms') + self.__wait_queue_timeout = options.get('waitqueuetimeoutms') + self.__wait_queue_multiple = options.get('waitqueuemultiple') + self.__use_ssl = options.get('ssl', None) self.__ssl_keyfile = options.get('ssl_keyfile', None) self.__ssl_certfile = options.get('ssl_certfile', None) @@ -313,7 +324,9 @@ class MongoClient(common.BaseObject): ssl_keyfile=self.__ssl_keyfile, ssl_certfile=self.__ssl_certfile, ssl_cert_reqs=self.__ssl_cert_reqs, - ssl_ca_certs=self.__ssl_ca_certs) + ssl_ca_certs=self.__ssl_ca_certs, + wait_queue_timeout=self.__wait_queue_timeout, + wait_queue_multiple=self.__wait_queue_multiple) self.__document_class = document_class self.__tz_aware = common.validate_boolean('tz_aware', tz_aware) @@ -492,14 +505,19 @@ class MongoClient(common.BaseObject): @property def max_pool_size(self): - """The maximum number of idle connections kept open in the pool for - future use. + """The maximum number of sockets the pool will open concurrently. - .. note:: ``max_pool_size`` does not cap the number of concurrent - connections to the server; there is currently no way to limit the - number of connections. ``max_pool_size`` only limits the number of - **idle** connections kept open when they are returned to the pool. + .. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.5+. Previously, this + parameter would limit only the idle connections the pool would hold + onto, not the number of open sockets. The default has also changed + to 100. + .. note:: ``max_pool_size`` caps the number of concurrent + connections to the server. Connection or query attempts when the pool + has reached `max_pool_size` will block until conn_timeout or a + connection has been returned to the pool. + + .. versionchanged:: 2.5+ .. versionadded:: 1.11 """ return self.__max_pool_size @@ -591,10 +609,12 @@ class MongoClient(common.BaseObject): # Call 'ismaster' directly so we can get a response time. sock_info = self.__socket() - response, res_time = self.__simple_command(sock_info, - 'admin', - {'ismaster': 1}) - self.__pool.maybe_return_socket(sock_info) + try: + response, res_time = self.__simple_command(sock_info, + 'admin', + {'ismaster': 1}) + finally: + self.__pool.maybe_return_socket(sock_info) # Are we talking to a mongos? isdbgrid = response.get('msg', '') == 'isdbgrid' @@ -794,11 +814,15 @@ class MongoClient(common.BaseObject): # calls select() if the socket hasn't been checked in the last second, # or it may create a new socket, in which case calling select() is # redundant. + sock_info = None try: - sock_info = self.__socket() - return not pool._closed(sock_info.sock) - except (socket.error, ConnectionFailure): - return False + try: + sock_info = self.__socket() + return not pool._closed(sock_info.sock) + except (socket.error, ConnectionFailure): + return False + finally: + self.__pool.maybe_return_socket(sock_info) def set_cursor_manager(self, manager_class): """Set this client's cursor manager. @@ -907,29 +931,30 @@ class MongoClient(common.BaseObject): sock_info = self.__socket() try: - (request_id, data) = self.__check_bson_size(message) - sock_info.sock.sendall(data) - # Safe mode. We pack the message together with a lastError - # message and send both. We then get the response (to the - # lastError) and raise OperationFailure if it is an error - # response. - rv = None - if with_last_error: - response = self.__receive_message_on_socket(1, request_id, - sock_info) - rv = self.__check_response_to_last_error(response) + try: + (request_id, data) = self.__check_bson_size(message) + sock_info.sock.sendall(data) + # Safe mode. We pack the message together with a lastError + # message and send both. We then get the response (to the + # lastError) and raise OperationFailure if it is an error + # response. + rv = None + if with_last_error: + response = self.__receive_message_on_socket(1, request_id, + sock_info) + rv = self.__check_response_to_last_error(response) + return rv + except OperationFailure: + raise + except (ConnectionFailure, socket.error), e: + self.disconnect() + raise AutoReconnect(str(e)) + except: + sock_info.close() + raise + finally: self.__pool.maybe_return_socket(sock_info) - return rv - except OperationFailure: - self.__pool.maybe_return_socket(sock_info) - raise - except (ConnectionFailure, socket.error), e: - self.disconnect() - raise AutoReconnect(str(e)) - except: - sock_info.close() - raise def __receive_data_on_socket(self, length, sock_info): """Lowest level receive operation. @@ -953,9 +978,9 @@ class MongoClient(common.BaseObject): """ header = self.__receive_data_on_socket(16, sock_info) length = struct.unpack(" request socket self._tid_to_sock = {} # Count the number of calls to start_request() per thread or greenlet - self._request_counter = thread_util.Counter(use_greenlets) + self._request_counter = thread_util.Counter(self.use_greenlets) + + if self.wait_queue_multiple is None: + max_waiters = None + else: + max_waiters = self.max_size * self.wait_queue_multiple + + if self.max_size is None: + self._socket_semaphore = thread_util.DummySemaphore() + elif self.use_greenlets: + if max_waiters is None: + self._socket_semaphore = gevent.coros.BoundedSemaphore( + self.max_size) + else: + self._socket_semaphore = ( + thread_util.MaxWaitersBoundedSemaphoreGevent( + self.max_size, max_waiters)) + else: + if max_waiters is None: + self._socket_semaphore = thread_util.BoundedSemaphore( + self.max_size) + else: + self._socket_semaphore = ( + thread_util.MaxWaitersBoundedSemaphoreThread( + self.max_size, max_waiters)) def reset(self): # Ignore this race condition -- if many threads are resetting at once, @@ -260,7 +301,7 @@ class Pool: sock.settimeout(self.net_timeout) return SocketInfo(sock, self.pool_id, hostname) - def get_socket(self, pair=None): + def get_socket(self, pair=None, force=False): """Get a socket from the pool. Returns a :class:`SocketInfo` object wrapping a connected @@ -269,6 +310,8 @@ class Pool: :Parameters: - `pair`: optional (hostname, port) tuple + - `force`: optional boolean, forces a connection to be returned + without blocking, even if `max_size` has been reached. """ # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of @@ -280,14 +323,24 @@ class Pool: req_state = self._get_request_state() if req_state not in (NO_SOCKET_YET, NO_REQUEST): # There's a socket for this request, check it and return it - checked_sock = self._check(req_state, pair) + checked_sock = self._check(req_state, pair, acquire_on_connect=True) if checked_sock != req_state: self._set_request_state(checked_sock) checked_sock.last_checkout = time.time() return checked_sock + forced = False # We're not in a request, just get any free socket or create one + if force: + # If we're doing an internal operation, attempt to play nicely with + # max_size, but if there is no open "slot" force the connection + # and mark it as forced so we don't release the semaphore without + # having acquired it for this socket. + if not self._socket_semaphore.acquire(False): + forced = True + elif not self._socket_semaphore.acquire(True, self.wait_queue_timeout): + raise socket.timeout() sock_info, from_pool = None, None try: try: @@ -303,6 +356,8 @@ class Pool: if from_pool: sock_info = self._check(sock_info, pair) + sock_info.forced = forced + if req_state == NO_SOCKET_YET: # start_request has been called but we haven't assigned a socket to # the request yet. Let's use this socket for this request until @@ -350,9 +405,15 @@ class Pool: """Return the socket to the pool unless it's the request socket. """ if self.pid != os.getpid(): + if not sock_info.forced: + self._socket_semaphore.release() self.reset() elif sock_info not in (NO_REQUEST, NO_SOCKET_YET): if sock_info.closed: + if sock_info.forced: + sock_info.forced = False + else: + self._socket_semaphore.release() return if sock_info != self._get_request_state(): @@ -363,14 +424,21 @@ class Pool: """ try: self.lock.acquire() - if len(self.sockets) < self.max_size: + if (len(self.sockets) < self.max_size + and sock_info.pool_id == self.pool_id + ): self.sockets.add(sock_info) else: sock_info.close() finally: self.lock.release() - def _check(self, sock_info, pair): + if sock_info.forced: + sock_info.forced = False + else: + self._socket_semaphore.release() + + def _check(self, sock_info, pair, acquire_on_connect=False): """This side-effecty function checks if this pool has been reset since the last time this socket was used, or if the socket has been closed by some external network error, and if so, attempts to create a new socket. @@ -401,24 +469,28 @@ class Pool: return sock_info else: try: + if acquire_on_connect: + if not self._socket_semaphore.acquire(True, self.wait_queue_timeout): + raise socket.timeout() return self.connect(pair) except socket.error: self.reset() raise def _set_request_state(self, sock_info): - tid = self._ident.get() + ident = self._ident + tid = ident.get() if sock_info == NO_REQUEST: # Ending a request - self._ident.unwatch() + ident.unwatch() self._tid_to_sock.pop(tid, None) else: self._tid_to_sock[tid] = sock_info - if not self._ident.watching(): - # Closure over tid and poolref. Don't refer directly to self, - # otherwise there's a cycle. + if not ident.watching(): + # Closure over tid, poolref, and ident. Don't refer directly to + # self, otherwise there's a cycle. # Do not access threadlocals in this function, or any # function it calls! In the case of the Pool subclass and @@ -430,6 +502,7 @@ class Pool: poolref = weakref.ref(self) def on_thread_died(ref): try: + ident.unwatch(tid) pool = poolref() if pool: # End the request @@ -442,7 +515,7 @@ class Pool: # Random exceptions on interpreter shutdown. pass - self._ident.watch(on_thread_died) + ident.watch(on_thread_died) def _get_request_state(self): tid = self._ident.get() diff --git a/pymongo/thread_util.py b/pymongo/thread_util.py index 4aa20d1c1..dd50adf49 100644 --- a/pymongo/thread_util.py +++ b/pymongo/thread_util.py @@ -17,10 +17,16 @@ import threading import sys import weakref +try: + from time import monotonic as _time +except ImportError: + from time import time as _time have_greenlet = True try: import greenlet + import gevent.coros + import gevent.thread except ImportError: have_greenlet = False @@ -37,8 +43,10 @@ class Ident(object): """Is the current thread or greenlet being watched for death?""" return self.get() in self._refs - def unwatch(self): - self._refs.pop(self.get(), None) + def unwatch(self, tid=None): + if tid is None: + tid = self.get() + self._refs.pop(tid, None) def get(self): """An id for this thread or greenlet""" @@ -93,6 +101,15 @@ class ThreadIdent(Ident): vigil = self._make_vigil() self._refs[id(vigil)] = weakref.ref(vigil, callback) + def watching(self): + """Is the current thread being watched for death?""" + tid = self.get() + if tid not in self._refs: + return False + # Check that the weakref is active, if not the thread has died + # This fixes the case where a thread id gets reused + return self._refs[tid]() + class GreenletIdent(Ident): def get(self): @@ -144,3 +161,111 @@ class Counter(object): def get(self): return self._counters.get(self.ident.get(), 0) + + +### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire +class Semaphore: + + # After Tim Peters' semaphore class, but not quite the same (no maximum) + + def __init__(self, value=1): + if value < 0: + raise ValueError("semaphore initial value must be >= 0") + self._cond = threading.Condition(threading.Lock()) + self._value = value + + def acquire(self, blocking=True, timeout=None): + if not blocking and timeout is not None: + raise ValueError("can't specify timeout for non-blocking acquire") + rc = False + endtime = None + self._cond.acquire() + while self._value == 0: + if not blocking: + break + if timeout is not None: + if endtime is None: + endtime = _time() + timeout + else: + timeout = endtime - _time() + if timeout <= 0: + break + self._cond.wait(timeout) + else: + self._value = self._value - 1 + rc = True + self._cond.release() + return rc + + __enter__ = acquire + + def release(self): + self._cond.acquire() + self._value = self._value + 1 + self._cond.notify() + self._cond.release() + + def __exit__(self, t, v, tb): + self.release() + + @property + def counter(self): + return self._value + + +class BoundedSemaphore(Semaphore): + """Semaphore that checks that # releases is <= # acquires""" + def __init__(self, value=1): + Semaphore.__init__(self, value) + self._initial_value = value + + def release(self): + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + return Semaphore.release(self) +### End backport from CPython 3.2 + + +class DummySemaphore(object): + def __init__(self, value=None): + pass + + def acquire(self, blocking=True, timeout=None): + return True + + def release(self): + pass + + +class ExceededMaxWaiters(Exception): + pass + + +class MaxWaitersBoundedSemaphore(object): + def __init__(self, semaphore_class, value=1, max_waiters=1): + self.waiter_semaphore = semaphore_class(max_waiters) + self.semaphore = semaphore_class(value) + + def acquire(self, blocking=True, timeout=None): + if not self.waiter_semaphore.acquire(False): + raise ExceededMaxWaiters() + try: + return self.semaphore.acquire(blocking, timeout) + finally: + self.waiter_semaphore.release() + + def __getattr__(self, name): + return getattr(self.semaphore, name) + + +class MaxWaitersBoundedSemaphoreThread(MaxWaitersBoundedSemaphore): + def __init__(self, value=1, max_waiters=1): + MaxWaitersBoundedSemaphore.__init__( + self, BoundedSemaphore, value, max_waiters) + + +if have_greenlet: + class MaxWaitersBoundedSemaphoreGevent(MaxWaitersBoundedSemaphore): + def __init__(self, value=1, max_waiters=1): + MaxWaitersBoundedSemaphore.__init__( + self, gevent.coros.BoundedSemaphore, value, max_waiters) diff --git a/test/test_pooling.py b/test/test_pooling.py index 5786baafd..5acea6a85 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -25,7 +25,8 @@ from nose.plugins.skip import SkipTest from test import host, port from test.test_pooling_base import ( - _TestPooling, _TestMaxPoolSize, _TestPoolSocketSharing, one) + _TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets, + _TestPoolSocketSharing, _TestWaitQueueMultiple, one) class TestPoolingThreads(_TestPooling, unittest.TestCase): @@ -172,5 +173,13 @@ class TestPoolSocketSharingThreads(_TestPoolSocketSharing, unittest.TestCase): use_greenlets = False +class TestMaxOpenSocketsThreads(_TestMaxOpenSockets, unittest.TestCase): + use_greenlets = False + + +class TestWaitQueueMultipleThreads(_TestWaitQueueMultiple, unittest.TestCase): + use_greenlets = False + + if __name__ == "__main__": unittest.main() diff --git a/test/test_pooling_base.py b/test/test_pooling_base.py index e40bc6fae..cf46a8c70 100644 --- a/test/test_pooling_base.py +++ b/test/test_pooling_base.py @@ -31,6 +31,7 @@ import pymongo.pool from pymongo.mongo_client import MongoClient from pymongo.pool import Pool, NO_REQUEST, NO_SOCKET_YET, SocketInfo from pymongo.errors import ConfigurationError +from pymongo.thread_util import ExceededMaxWaiters from test import version, host, port from test.test_client import get_client from test.utils import delay, is_mongos, one @@ -46,7 +47,7 @@ if sys.version_info[0] >= 3: try: import gevent from gevent import Greenlet, monkey, hub - import gevent.coros, gevent.event + import gevent.event, gevent.thread has_gevent = True except ImportError: has_gevent = False @@ -210,7 +211,7 @@ class CreateAndReleaseSocket(MongoThread): self.lock = threading.Lock() self.ready = threading.Event() - def __init__(self, ut, client, start_request, end_request, rendevous): + def __init__(self, ut, client, start_request, end_request, rendevous=None): super(CreateAndReleaseSocket, self).__init__(ut) self.client = client self.start_request = start_request @@ -230,16 +231,17 @@ class CreateAndReleaseSocket(MongoThread): # Don't finish until all threads reach this point r = self.rendevous - r.lock.acquire() - r.nthreads_run += 1 - if r.nthreads_run == r.nthreads: - # Everyone's here, let them finish - r.ready.set() - r.lock.release() - else: - r.lock.release() - r.ready.wait(timeout=60) - assert r.ready.isSet(), "Rendezvous timed out" + if r is not None: + r.lock.acquire() + r.nthreads_run += 1 + if r.nthreads_run == r.nthreads: + # Everyone's here, let them finish + r.ready.set() + r.lock.release() + else: + r.lock.release() + r.ready.wait(2) # Wait two seconds + assert r.ready.isSet(), "Rendezvous timed out" for i in range(self.end_request): self.client.end_request() @@ -583,7 +585,6 @@ class _TestPooling(_TestPoolingBase): # Kill old request socket sock_info.sock.close() - cx_pool.maybe_return_socket(sock_info) time.sleep(1.1) # trigger _check_closed # Dead socket detected and removed @@ -673,7 +674,8 @@ class _TestMaxPoolSize(_TestPoolingBase): with greenlets. """ def _test_max_pool_size( - self, start_request, end_request, max_pool_size=4, nthreads=10): + self, start_request, end_request, max_pool_size=4, nthreads=10, + use_rendezvous=True): """Start `nthreads` threads. Each calls start_request `start_request` times, then find_one and waits at a barrier; once all reach the barrier each calls end_request `end_request` times. The test asserts that the @@ -693,13 +695,21 @@ class _TestMaxPoolSize(_TestPoolingBase): c = self.get_client( max_pool_size=max_pool_size, auto_start_request=False) - rendevous = CreateAndReleaseSocket.Rendezvous( - nthreads, self.use_greenlets) + # If you increase nthreads over about 35, note a + # Gevent 0.13.6 bug on Mac, Greenlet.join() hangs if more than + # about 35 Greenlets share a MongoClient. Apparently fixed in + # recent Gevent development. + + if use_rendezvous: + rendezvous = CreateAndReleaseSocket.Rendezvous( + nthreads, self.use_greenlets) + else: + rendezvous = None threads = [] for i in range(nthreads): t = CreateAndReleaseSocket( - self, c, start_request, end_request, rendevous) + self, c, start_request, end_request, rendezvous) threads.append(t) for t in threads: @@ -732,26 +742,30 @@ class _TestMaxPoolSize(_TestPoolingBase): # Gevent 0.13 and less the_hub.shutdown() - if start_request: - # Trigger final cleanup in Python <= 2.7.0. - cx_pool._ident.get() + if use_rendezvous: + if start_request: + # Trigger final cleanup in Python <= 2.7.0. + cx_pool._ident.get() + time.sleep(0.1) + self.assertEqual(10, len(cx_pool.sockets)) + expected_idle = min(max_pool_size, nthreads) + message = ( + '%d idle sockets (expected %d) and %d request sockets' + ' (expected 0)' % ( + len(cx_pool.sockets), expected_idle, + len(cx_pool._tid_to_sock))) - expected_idle = min(max_pool_size, nthreads) - message = ( - '%d idle sockets (expected %d) and %d request sockets' - ' (expected 0)' % ( - len(cx_pool.sockets), expected_idle, - len(cx_pool._tid_to_sock))) - - self.assertEqual( - expected_idle, len(cx_pool.sockets), message) - else: - # Without calling start_request(), threads can safely share - # sockets; the number running concurrently, and hence the number - # of sockets needed, is between 1 and - # min(max_pool_size, nthreads), depending on thread-scheduling. - self.assertTrue(len(cx_pool.sockets) >= 1) + self.assertEqual( + expected_idle, len(cx_pool.sockets), message) + else: + # Without calling start_request(), threads can safely share + # sockets; the number running concurrently, and hence the number + # of sockets needed, is between 1 and 10, depending on thread- + # scheduling. + self.assertTrue(len(cx_pool.sockets) >= 1) + time.sleep(0.1) + self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter) self.assertEqual(0, len(cx_pool._tid_to_sock)) def test_max_pool_size(self): @@ -760,22 +774,206 @@ class _TestMaxPoolSize(_TestPoolingBase): def test_max_pool_size_with_request(self): self._test_max_pool_size(1, 1) + def test_max_pool_size_with_multiple_request(self): + self._test_max_pool_size(10, 10) + def test_max_pool_size_with_redundant_request(self): self._test_max_pool_size(2, 1) def test_max_pool_size_with_redundant_request2(self): self._test_max_pool_size(20, 1) + def test_max_pool_size_with_redundant_request_no_rendezvous(self): + try: + self._test_max_pool_size(2, 1, False) + self._test_max_pool_size(20, 1, False) + except AssertionError: + if sys.version_info[0] == 2 and sys.version_info[1] < 7: + # Python < 2.7 has a threadlocal bug which sometimes leaks + # sockets due to the threadlocal being accessed too close + # to thread death, causing on_thread_died not to get called. + # + # This is fixable with a monitor thread which checks for stale + # request sockets, but this situation is hopefully unlikely to + # happen in the real world. + raise SkipTest('Python < 2.7 threadlocal race condition usually' + ' breaks this test') + else: + raise + def test_max_pool_size_with_leaked_request(self): # Call start_request() but not end_request() -- when threads die, they # should return their request sockets to the pool. self._test_max_pool_size(1, 0) + def test_max_pool_size_with_leaked_request_no_rendezvous(self): + try: + self._test_max_pool_size(1, 0, False) + except AssertionError: + if sys.version_info[0] == 2 and sys.version_info[1] < 7: + # Python < 2.7 has a threadlocal bug which sometimes leaks + # sockets due to the threadlocal being accessed too close + # to thread death, causing on_thread_died not to get called. + # + # This is fixable with a monitor thread which checks for stale + # request sockets, but this situation is hopefully unlikely to + # happen in the real world. + raise SkipTest('Python < 2.7 threadlocal race condition usually' + ' breaks this test') + else: + raise + def test_max_pool_size_with_end_request_only(self): # Call end_request() but not start_request() self._test_max_pool_size(0, 1) +class _TestMaxOpenSockets(_TestPoolingBase): + """Test that connection pool doesn't open more than max_size sockets. + To be run both with threads and with greenlets. + """ + def get_pool(self, conn_timeout, net_timeout, wait_queue_timeout): + return pymongo.pool.Pool(('127.0.0.1', 27017), + 2, net_timeout, conn_timeout, + False, False, + wait_queue_timeout=wait_queue_timeout) + + def test_over_max_times_out(self): + conn_timeout = 2 + pool = self.get_pool(conn_timeout, conn_timeout + 5, conn_timeout) + s1 = pool.get_socket() + self.assertTrue(None is not s1) + s2 = pool.get_socket() + self.assertTrue(None is not s2) + self.assertNotEqual(s1, s2) + start = time.time() + self.assertRaises(socket.timeout, pool.get_socket) + end = time.time() + self.assertTrue(end - start > conn_timeout) + self.assertTrue(end - start < conn_timeout + 5) + + def test_over_max_no_timeout_blocks(self): + class Thread(threading.Thread): + def __init__(self, pool): + super(Thread, self).__init__() + self.state = 'init' + self.pool = pool + self.sock = None + + def run(self): + self.state = 'get_socket' + self.sock = self.pool.get_socket() + self.state = 'sock' + + pool = self.get_pool(None, 2, None) + s1 = pool.get_socket() + self.assertTrue(None is not s1) + s2 = pool.get_socket() + self.assertTrue(None is not s2) + self.assertNotEqual(s1, s2) + t = Thread(pool) + t.start() + while t.state != 'get_socket': + time.sleep(0.1) + self.assertEqual(t.state, 'get_socket') + time.sleep(5) + self.assertEqual(t.state, 'get_socket') + pool.maybe_return_socket(s1) + while t.state != 'sock': + time.sleep(0.1) + self.assertEqual(t.state, 'sock') + self.assertEqual(t.sock, s1) + + +class _TestWaitQueueMultiple(_TestPoolingBase): + """Test that connection pool doesn't allow more than + waitQueueMultiple * max_size waiters. + To be run both with threads and with greenlets. + """ + def get_pool(self, conn_timeout, net_timeout, wait_queue_timeout, + wait_queue_multiple): + return pymongo.pool.Pool(('127.0.0.1', 27017), + 2, net_timeout, conn_timeout, + False, False, + wait_queue_timeout=wait_queue_timeout, + wait_queue_multiple=wait_queue_multiple) + + def test_wait_queue_multiple(self): + class Thread(threading.Thread): + def __init__(self, pool): + super(Thread, self).__init__() + self.state = 'init' + self.pool = pool + self.sock = None + + def run(self): + self.state = 'get_socket' + self.sock = self.pool.get_socket() + self.state = 'sock' + + pool = self.get_pool(None, None, None, 3) + socks = [] + for _ in xrange(2): + sock = pool.get_socket() + self.assertTrue(sock is not None) + socks.append(sock) + threads = [] + for _ in xrange(6): + thread = Thread(pool) + thread.start() + threads.append(thread) + time.sleep(1) + for thread in threads: + self.assertEqual(thread.state, 'get_socket') + self.assertRaises(ExceededMaxWaiters, pool.get_socket) + while threads: + for sock in socks: + pool.maybe_return_socket(sock) + socks = [] + for thread in list(threads): + if thread.sock is not None: + socks.append(thread.sock) + thread.join() + threads.remove(thread) + + def test_wait_queue_multiple_unset(self): + class Thread(threading.Thread): + def __init__(self, pool): + super(Thread, self).__init__() + self.state = 'init' + self.pool = pool + self.sock = None + + def run(self): + self.state = 'get_socket' + self.sock = self.pool.get_socket() + self.state = 'sock' + + pool = self.get_pool(None, None, None, None) + socks = [] + for _ in xrange(2): + sock = pool.get_socket() + self.assertTrue(sock is not None) + socks.append(sock) + threads = [] + for _ in xrange(30): + thread = Thread(pool) + thread.start() + threads.append(thread) + time.sleep(1) + for thread in threads: + self.assertEqual(thread.state, 'get_socket') + while threads: + for sock in socks: + pool.maybe_return_socket(sock) + socks = [] + for thread in list(threads): + if thread.sock is not None: + socks.append(thread.sock) + thread.join() + threads.remove(thread) + + class _TestPoolSocketSharing(_TestPoolingBase): """Directly test that two simultaneous operations don't share a socket. To be run both with threads and with greenlets. diff --git a/test/test_pooling_gevent.py b/test/test_pooling_gevent.py index 2ec85bbfb..f01c7ac0e 100644 --- a/test/test_pooling_gevent.py +++ b/test/test_pooling_gevent.py @@ -22,7 +22,8 @@ from pymongo import pool from test import host, port from test.utils import looplet from test.test_pooling_base import ( - _TestPooling, _TestMaxPoolSize, _TestPoolSocketSharing) + _TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets, + _TestPoolSocketSharing, _TestWaitQueueMultiple) class TestPoolingGevent(_TestPooling, unittest.TestCase): @@ -178,5 +179,13 @@ class TestPoolSocketSharingGevent(_TestPoolSocketSharing, unittest.TestCase): use_greenlets = True +class TestMaxOpenSocketsGevent(_TestMaxOpenSockets, unittest.TestCase): + use_greenlets = True + + +class TestWaitQueueMultipleGevent(_TestWaitQueueMultiple, unittest.TestCase): + use_greenlets = True + + if __name__ == '__main__': unittest.main()