PYTHON-436 Change max_pool_size to limit the maximum concurrent connections rather than just the idle connections in the pool. Also add support for waitQueueTimeoutMS and waitQueueMultiple.

This commit is contained in:
Justin Patrin 2013-04-15 17:35:45 -07:00 committed by A. Jesse Jiryu Davis
parent ba8f71cf50
commit 8fe3154138
10 changed files with 607 additions and 145 deletions

View File

@ -11,3 +11,7 @@ services:
- mongodb
script: python setup.py test
install:
#Temporary solution for Travis CI mutiprocessing issue #155
- sudo rm -rf /dev/shm && sudo ln -s /run/shm /dev/shm

View File

@ -62,3 +62,4 @@ The following is a list of people who have contributed to
- Craig Hobbs (craigahobbs)
- Emily Stolfo (estolfo)
- Sam Helman (shelman)
- Justin Patrin (reversefold)

View File

@ -218,6 +218,8 @@ VALIDATORS = {
'journal': validate_boolean,
'connecttimeoutms': validate_timeout_or_none,
'sockettimeoutms': validate_timeout_or_none,
'waitqueuetimeoutms': validate_timeout_or_none,
'waitqueuemultiple': validate_positive_integer_or_none,
'ssl': validate_boolean,
'ssl_keyfile': validate_readable,
'ssl_certfile': validate_readable,

View File

@ -85,8 +85,9 @@ class MongoClient(common.BaseObject):
__max_bson_size = 4 * 1024 * 1024
def __init__(self, host=None, port=None, max_pool_size=10,
document_class=dict, tz_aware=False, _connect=True, **kwargs):
def __init__(self, host=None, port=None, max_pool_size=100,
document_class=dict, tz_aware=False, _connect=True,
**kwargs):
"""Create a new connection to a single MongoDB instance at *host:port*.
The resultant client object has connection-pooling built
@ -120,8 +121,10 @@ class MongoClient(common.BaseObject):
it must be enclosed in '[' and ']' characters following
the RFC2732 URL syntax (e.g. '[::1]' for localhost)
- `port` (optional): port number on which to connect
- `max_pool_size` (optional): The maximum number of idle connections
to keep open in the pool for future use
- `max_pool_size` (optional): The maximum number of connections
that the pool will open simultaneously. If this is set, operations
will block if there are `max_pool_size` outstanding connections
from the pool. Defaults to 100.
- `document_class` (optional): default class to use for
documents returned from queries on this client
- `tz_aware` (optional): if ``True``,
@ -135,6 +138,11 @@ class MongoClient(common.BaseObject):
receive on a socket can take before timing out.
- `connectTimeoutMS`: (integer) How long (in milliseconds) a
connection can take to be opened before timing out.
- `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a
thread will wait for a socket from the pool if the pool has no
free sockets.
- `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give
the number of threads allowed to wait for a socket at one time.
- `auto_start_request`: If ``True``, each thread that accesses
this :class:`MongoClient` has a socket allocated to it for the
thread's lifetime. This ensures consistent reads, even if you
@ -273,6 +281,9 @@ class MongoClient(common.BaseObject):
self.__net_timeout = options.get('sockettimeoutms')
self.__conn_timeout = options.get('connecttimeoutms')
self.__wait_queue_timeout = options.get('waitqueuetimeoutms')
self.__wait_queue_multiple = options.get('waitqueuemultiple')
self.__use_ssl = options.get('ssl', None)
self.__ssl_keyfile = options.get('ssl_keyfile', None)
self.__ssl_certfile = options.get('ssl_certfile', None)
@ -313,7 +324,9 @@ class MongoClient(common.BaseObject):
ssl_keyfile=self.__ssl_keyfile,
ssl_certfile=self.__ssl_certfile,
ssl_cert_reqs=self.__ssl_cert_reqs,
ssl_ca_certs=self.__ssl_ca_certs)
ssl_ca_certs=self.__ssl_ca_certs,
wait_queue_timeout=self.__wait_queue_timeout,
wait_queue_multiple=self.__wait_queue_multiple)
self.__document_class = document_class
self.__tz_aware = common.validate_boolean('tz_aware', tz_aware)
@ -492,14 +505,19 @@ class MongoClient(common.BaseObject):
@property
def max_pool_size(self):
"""The maximum number of idle connections kept open in the pool for
future use.
"""The maximum number of sockets the pool will open concurrently.
.. note:: ``max_pool_size`` does not cap the number of concurrent
connections to the server; there is currently no way to limit the
number of connections. ``max_pool_size`` only limits the number of
**idle** connections kept open when they are returned to the pool.
.. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.5+. Previously, this
parameter would limit only the idle connections the pool would hold
onto, not the number of open sockets. The default has also changed
to 100.
.. note:: ``max_pool_size`` caps the number of concurrent
connections to the server. Connection or query attempts when the pool
has reached `max_pool_size` will block until conn_timeout or a
connection has been returned to the pool.
.. versionchanged:: 2.5+
.. versionadded:: 1.11
"""
return self.__max_pool_size
@ -591,10 +609,12 @@ class MongoClient(common.BaseObject):
# Call 'ismaster' directly so we can get a response time.
sock_info = self.__socket()
response, res_time = self.__simple_command(sock_info,
'admin',
{'ismaster': 1})
self.__pool.maybe_return_socket(sock_info)
try:
response, res_time = self.__simple_command(sock_info,
'admin',
{'ismaster': 1})
finally:
self.__pool.maybe_return_socket(sock_info)
# Are we talking to a mongos?
isdbgrid = response.get('msg', '') == 'isdbgrid'
@ -794,11 +814,15 @@ class MongoClient(common.BaseObject):
# calls select() if the socket hasn't been checked in the last second,
# or it may create a new socket, in which case calling select() is
# redundant.
sock_info = None
try:
sock_info = self.__socket()
return not pool._closed(sock_info.sock)
except (socket.error, ConnectionFailure):
return False
try:
sock_info = self.__socket()
return not pool._closed(sock_info.sock)
except (socket.error, ConnectionFailure):
return False
finally:
self.__pool.maybe_return_socket(sock_info)
def set_cursor_manager(self, manager_class):
"""Set this client's cursor manager.
@ -907,29 +931,30 @@ class MongoClient(common.BaseObject):
sock_info = self.__socket()
try:
(request_id, data) = self.__check_bson_size(message)
sock_info.sock.sendall(data)
# Safe mode. We pack the message together with a lastError
# message and send both. We then get the response (to the
# lastError) and raise OperationFailure if it is an error
# response.
rv = None
if with_last_error:
response = self.__receive_message_on_socket(1, request_id,
sock_info)
rv = self.__check_response_to_last_error(response)
try:
(request_id, data) = self.__check_bson_size(message)
sock_info.sock.sendall(data)
# Safe mode. We pack the message together with a lastError
# message and send both. We then get the response (to the
# lastError) and raise OperationFailure if it is an error
# response.
rv = None
if with_last_error:
response = self.__receive_message_on_socket(1, request_id,
sock_info)
rv = self.__check_response_to_last_error(response)
return rv
except OperationFailure:
raise
except (ConnectionFailure, socket.error), e:
self.disconnect()
raise AutoReconnect(str(e))
except:
sock_info.close()
raise
finally:
self.__pool.maybe_return_socket(sock_info)
return rv
except OperationFailure:
self.__pool.maybe_return_socket(sock_info)
raise
except (ConnectionFailure, socket.error), e:
self.disconnect()
raise AutoReconnect(str(e))
except:
sock_info.close()
raise
def __receive_data_on_socket(self, length, sock_info):
"""Lowest level receive operation.
@ -953,9 +978,9 @@ class MongoClient(common.BaseObject):
"""
header = self.__receive_data_on_socket(16, sock_info)
length = struct.unpack("<i", header[:4])[0]
assert request_id == struct.unpack("<i", header[8:12])[0], \
"ids don't match %r %r" % (request_id,
struct.unpack("<i", header[8:12])[0])
msg_req_id = struct.unpack("<i", header[8:12])[0]
assert request_id == msg_req_id, \
"ids don't match %r %r" % (request_id, msg_req_id)
assert operation == struct.unpack("<i", header[12:])[0]
return self.__receive_data_on_socket(length - 16, sock_info)

View File

@ -302,7 +302,7 @@ class Monitor(object):
try:
try:
self.rsc.refresh()
self.rsc.refresh(force=True)
finally:
self.refreshed.set()
except AutoReconnect:
@ -1026,7 +1026,7 @@ class MongoReplicaSetClient(common.BaseObject):
else:
return threading.local()
def refresh(self):
def refresh(self, force=False):
"""Iterate through the existing host list, or possibly the
seed list, to update the list of hosts and arbiters in this
replica set.
@ -1055,7 +1055,7 @@ class MongoReplicaSetClient(common.BaseObject):
member, sock_info = rs_state.get(node), None
try:
if member:
sock_info = self.__socket(member)
sock_info = self.__socket(member, force=force)
response, ping_time = self.__simple_command(
sock_info, 'admin', {'ismaster': 1})
member.pool.maybe_return_socket(sock_info)
@ -1112,7 +1112,7 @@ class MongoReplicaSetClient(common.BaseObject):
member, sock_info = rs_state.get(host), None
try:
if member:
sock_info = self.__socket(member)
sock_info = self.__socket(member, force=force)
res, ping_time = self.__simple_command(
sock_info, 'admin', {'ismaster': 1})
member.pool.maybe_return_socket(sock_info)
@ -1163,13 +1163,13 @@ class MongoReplicaSetClient(common.BaseObject):
# Couldn't find the primary.
raise AutoReconnect(rs_state.error_message)
def __socket(self, member):
def __socket(self, member, force=False):
"""Get a SocketInfo from the pool.
"""
if self.auto_start_request and not self.in_request():
self.start_request()
sock_info = member.pool.get_socket()
sock_info = member.pool.get_socket(force=force)
try:
self.__check_auth(sock_info)
@ -1236,11 +1236,17 @@ class MongoReplicaSetClient(common.BaseObject):
# calls select() if the socket hasn't been checked in the last second,
# or it may create a new socket, in which case calling select() is
# redundant.
member, sock_info = None, None
try:
sock_info = self.__socket(self.__find_primary())
return not pool._closed(sock_info.sock)
except (socket.error, ConnectionFailure):
return False
try:
member = self.__find_primary()
sock_info = self.__socket(member)
return not pool._closed(sock_info.sock)
except (socket.error, ConnectionFailure):
return False
finally:
if sock_info is not None:
member.pool.maybe_return_socket(sock_info)
def __check_response_to_last_error(self, response):
"""Check a response to a lastError message for errors.
@ -1347,30 +1353,32 @@ class MongoReplicaSetClient(common.BaseObject):
sock_info = None
try:
sock_info = self.__socket(member)
rqst_id, data = self.__check_bson_size(msg, member.max_bson_size)
sock_info.sock.sendall(data)
# Safe mode. We pack the message together with a lastError
# message and send both. We then get the response (to the
# lastError) and raise OperationFailure if it is an error
# response.
rv = None
if with_last_error:
response = self.__recv_msg(1, rqst_id, sock_info)
rv = self.__check_response_to_last_error(response)
member.pool.maybe_return_socket(sock_info)
return rv
except OperationFailure:
member.pool.maybe_return_socket(sock_info)
raise
except(ConnectionFailure, socket.error), why:
member.pool.discard_socket(sock_info)
if _connection_to_use in (None, -1):
self.disconnect()
raise AutoReconnect(str(why))
except:
sock_info.close()
raise
try:
sock_info = self.__socket(member)
rqst_id, data = self.__check_bson_size(msg, member.max_bson_size)
sock_info.sock.sendall(data)
# Safe mode. We pack the message together with a lastError
# message and send both. We then get the response (to the
# lastError) and raise OperationFailure if it is an error
# response.
rv = None
if with_last_error:
response = self.__recv_msg(1, rqst_id, sock_info)
rv = self.__check_response_to_last_error(response)
return rv
except OperationFailure:
raise
except(ConnectionFailure, socket.error), why:
member.pool.discard_socket(sock_info)
if _connection_to_use in (None, -1):
self.disconnect()
raise AutoReconnect(str(why))
except:
sock_info.close()
raise
finally:
if sock_info is not None:
member.pool.maybe_return_socket(sock_info)
def __send_and_receive(self, member, msg, **kwargs):
"""Send a message on the given socket and return the response data.
@ -1379,23 +1387,31 @@ class MongoReplicaSetClient(common.BaseObject):
"""
sock_info = None
try:
sock_info = self.__socket(member)
try:
sock_info = self.__socket(member)
if "network_timeout" in kwargs:
sock_info.sock.settimeout(kwargs['network_timeout'])
if "network_timeout" in kwargs:
sock_info.sock.settimeout(kwargs['network_timeout'])
rqst_id, data = self.__check_bson_size(msg, member.max_bson_size)
sock_info.sock.sendall(data)
response = self.__recv_msg(1, rqst_id, sock_info)
rqst_id, data = self.__check_bson_size(msg, member.max_bson_size)
sock_info.sock.sendall(data)
response = self.__recv_msg(1, rqst_id, sock_info)
if "network_timeout" in kwargs:
sock_info.sock.settimeout(self.__net_timeout)
member.pool.maybe_return_socket(sock_info)
if "network_timeout" in kwargs:
sock_info.sock.settimeout(self.__net_timeout)
return response
except:
member.pool.discard_socket(sock_info)
raise
return response
except (ConnectionFailure, socket.error), why:
host, port = member.pool.pair
member.pool.discard_socket(sock_info)
raise AutoReconnect("%s:%d: %s" % (host, port, str(why)))
except:
if sock_info is not None:
sock_info.close()
raise
finally:
if sock_info is not None:
member.pool.maybe_return_socket(sock_info)
def __try_read(self, member, msg, **kwargs):
"""Attempt a read from a member; on failure mark the member "down" and

View File

@ -21,8 +21,7 @@ import weakref
from pymongo import thread_util
from pymongo.common import HAS_SSL
from pymongo.errors import (CertificateError, ConnectionFailure,
ConfigurationError)
from pymongo.errors import ConnectionFailure, ConfigurationError
try:
from ssl import match_hostname
@ -37,6 +36,11 @@ if sys.platform.startswith('java'):
else:
from select import select
try:
import gevent.coros
except ImportError:
pass
NO_REQUEST = None
NO_SOCKET_YET = -1
@ -62,6 +66,7 @@ class SocketInfo(object):
self.authset = set()
self.closed = False
self.last_checkout = time.time()
self.forced = False
# The pool's pool_id changes with each reset() so we can close sockets
# created before the last reset.
@ -99,11 +104,15 @@ class SocketInfo(object):
class Pool:
def __init__(self, pair, max_size, net_timeout, conn_timeout, use_ssl,
use_greenlets, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs=None, ssl_ca_certs=None):
ssl_cert_reqs=None, ssl_ca_certs=None,
wait_queue_timeout=None, wait_queue_multiple=None):
"""
:Parameters:
- `pair`: a (hostname, port) tuple
- `max_size`: approximate number of idle connections to keep open
- `max_size`: The maximum number of open sockets. Calls to
`get_socket` will block if this is set, this pool has opened
`max_size` sockets, and there are none idle. Set to `None` to
disable.
- `net_timeout`: timeout in seconds for operations on open connection
- `conn_timeout`: timeout in seconds for establishing connection
- `use_ssl`: bool, if True use an encrypted connection
@ -127,6 +136,11 @@ class Pool:
"certification authority" certificates, which are used to validate
certificates passed from the other end of the connection.
Implies ``ssl=True``.
- `wait_queue_timeout`: (integer) How long (in milliseconds) a
thread will wait for a socket from the pool if the pool has no
free sockets.
- `wait_queue_multiple`: (integer) Multiplied by max_pool_size to give
the number of threads allowed to wait for a socket at one time.
"""
if use_greenlets and not thread_util.have_greenlet:
raise ConfigurationError(
@ -134,6 +148,7 @@ class Pool:
"Install the greenlet package from PyPI."
)
self.use_greenlets = use_greenlets
self.sockets = set()
self.lock = threading.Lock()
@ -145,6 +160,8 @@ class Pool:
self.max_size = max_size
self.net_timeout = net_timeout
self.conn_timeout = conn_timeout
self.wait_queue_timeout = wait_queue_timeout
self.wait_queue_multiple = wait_queue_multiple
self.use_ssl = use_ssl
self.ssl_keyfile = ssl_keyfile
self.ssl_certfile = ssl_certfile
@ -154,13 +171,37 @@ class Pool:
if HAS_SSL and use_ssl and not ssl_cert_reqs:
self.ssl_cert_reqs = ssl.CERT_NONE
self._ident = thread_util.create_ident(use_greenlets)
self._ident = thread_util.create_ident(self.use_greenlets)
# Map self._ident.get() -> request socket
self._tid_to_sock = {}
# Count the number of calls to start_request() per thread or greenlet
self._request_counter = thread_util.Counter(use_greenlets)
self._request_counter = thread_util.Counter(self.use_greenlets)
if self.wait_queue_multiple is None:
max_waiters = None
else:
max_waiters = self.max_size * self.wait_queue_multiple
if self.max_size is None:
self._socket_semaphore = thread_util.DummySemaphore()
elif self.use_greenlets:
if max_waiters is None:
self._socket_semaphore = gevent.coros.BoundedSemaphore(
self.max_size)
else:
self._socket_semaphore = (
thread_util.MaxWaitersBoundedSemaphoreGevent(
self.max_size, max_waiters))
else:
if max_waiters is None:
self._socket_semaphore = thread_util.BoundedSemaphore(
self.max_size)
else:
self._socket_semaphore = (
thread_util.MaxWaitersBoundedSemaphoreThread(
self.max_size, max_waiters))
def reset(self):
# Ignore this race condition -- if many threads are resetting at once,
@ -260,7 +301,7 @@ class Pool:
sock.settimeout(self.net_timeout)
return SocketInfo(sock, self.pool_id, hostname)
def get_socket(self, pair=None):
def get_socket(self, pair=None, force=False):
"""Get a socket from the pool.
Returns a :class:`SocketInfo` object wrapping a connected
@ -269,6 +310,8 @@ class Pool:
:Parameters:
- `pair`: optional (hostname, port) tuple
- `force`: optional boolean, forces a connection to be returned
without blocking, even if `max_size` has been reached.
"""
# We use the pid here to avoid issues with fork / multiprocessing.
# See test.test_client:TestClient.test_fork for an example of
@ -280,14 +323,24 @@ class Pool:
req_state = self._get_request_state()
if req_state not in (NO_SOCKET_YET, NO_REQUEST):
# There's a socket for this request, check it and return it
checked_sock = self._check(req_state, pair)
checked_sock = self._check(req_state, pair, acquire_on_connect=True)
if checked_sock != req_state:
self._set_request_state(checked_sock)
checked_sock.last_checkout = time.time()
return checked_sock
forced = False
# We're not in a request, just get any free socket or create one
if force:
# If we're doing an internal operation, attempt to play nicely with
# max_size, but if there is no open "slot" force the connection
# and mark it as forced so we don't release the semaphore without
# having acquired it for this socket.
if not self._socket_semaphore.acquire(False):
forced = True
elif not self._socket_semaphore.acquire(True, self.wait_queue_timeout):
raise socket.timeout()
sock_info, from_pool = None, None
try:
try:
@ -303,6 +356,8 @@ class Pool:
if from_pool:
sock_info = self._check(sock_info, pair)
sock_info.forced = forced
if req_state == NO_SOCKET_YET:
# start_request has been called but we haven't assigned a socket to
# the request yet. Let's use this socket for this request until
@ -350,9 +405,15 @@ class Pool:
"""Return the socket to the pool unless it's the request socket.
"""
if self.pid != os.getpid():
if not sock_info.forced:
self._socket_semaphore.release()
self.reset()
elif sock_info not in (NO_REQUEST, NO_SOCKET_YET):
if sock_info.closed:
if sock_info.forced:
sock_info.forced = False
else:
self._socket_semaphore.release()
return
if sock_info != self._get_request_state():
@ -363,14 +424,21 @@ class Pool:
"""
try:
self.lock.acquire()
if len(self.sockets) < self.max_size:
if (len(self.sockets) < self.max_size
and sock_info.pool_id == self.pool_id
):
self.sockets.add(sock_info)
else:
sock_info.close()
finally:
self.lock.release()
def _check(self, sock_info, pair):
if sock_info.forced:
sock_info.forced = False
else:
self._socket_semaphore.release()
def _check(self, sock_info, pair, acquire_on_connect=False):
"""This side-effecty function checks if this pool has been reset since
the last time this socket was used, or if the socket has been closed by
some external network error, and if so, attempts to create a new socket.
@ -401,24 +469,28 @@ class Pool:
return sock_info
else:
try:
if acquire_on_connect:
if not self._socket_semaphore.acquire(True, self.wait_queue_timeout):
raise socket.timeout()
return self.connect(pair)
except socket.error:
self.reset()
raise
def _set_request_state(self, sock_info):
tid = self._ident.get()
ident = self._ident
tid = ident.get()
if sock_info == NO_REQUEST:
# Ending a request
self._ident.unwatch()
ident.unwatch()
self._tid_to_sock.pop(tid, None)
else:
self._tid_to_sock[tid] = sock_info
if not self._ident.watching():
# Closure over tid and poolref. Don't refer directly to self,
# otherwise there's a cycle.
if not ident.watching():
# Closure over tid, poolref, and ident. Don't refer directly to
# self, otherwise there's a cycle.
# Do not access threadlocals in this function, or any
# function it calls! In the case of the Pool subclass and
@ -430,6 +502,7 @@ class Pool:
poolref = weakref.ref(self)
def on_thread_died(ref):
try:
ident.unwatch(tid)
pool = poolref()
if pool:
# End the request
@ -442,7 +515,7 @@ class Pool:
# Random exceptions on interpreter shutdown.
pass
self._ident.watch(on_thread_died)
ident.watch(on_thread_died)
def _get_request_state(self):
tid = self._ident.get()

View File

@ -17,10 +17,16 @@
import threading
import sys
import weakref
try:
from time import monotonic as _time
except ImportError:
from time import time as _time
have_greenlet = True
try:
import greenlet
import gevent.coros
import gevent.thread
except ImportError:
have_greenlet = False
@ -37,8 +43,10 @@ class Ident(object):
"""Is the current thread or greenlet being watched for death?"""
return self.get() in self._refs
def unwatch(self):
self._refs.pop(self.get(), None)
def unwatch(self, tid=None):
if tid is None:
tid = self.get()
self._refs.pop(tid, None)
def get(self):
"""An id for this thread or greenlet"""
@ -93,6 +101,15 @@ class ThreadIdent(Ident):
vigil = self._make_vigil()
self._refs[id(vigil)] = weakref.ref(vigil, callback)
def watching(self):
"""Is the current thread being watched for death?"""
tid = self.get()
if tid not in self._refs:
return False
# Check that the weakref is active, if not the thread has died
# This fixes the case where a thread id gets reused
return self._refs[tid]()
class GreenletIdent(Ident):
def get(self):
@ -144,3 +161,111 @@ class Counter(object):
def get(self):
return self._counters.get(self.ident.get(), 0)
### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire
class Semaphore:
# After Tim Peters' semaphore class, but not quite the same (no maximum)
def __init__(self, value=1):
if value < 0:
raise ValueError("semaphore initial value must be >= 0")
self._cond = threading.Condition(threading.Lock())
self._value = value
def acquire(self, blocking=True, timeout=None):
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
rc = False
endtime = None
self._cond.acquire()
while self._value == 0:
if not blocking:
break
if timeout is not None:
if endtime is None:
endtime = _time() + timeout
else:
timeout = endtime - _time()
if timeout <= 0:
break
self._cond.wait(timeout)
else:
self._value = self._value - 1
rc = True
self._cond.release()
return rc
__enter__ = acquire
def release(self):
self._cond.acquire()
self._value = self._value + 1
self._cond.notify()
self._cond.release()
def __exit__(self, t, v, tb):
self.release()
@property
def counter(self):
return self._value
class BoundedSemaphore(Semaphore):
"""Semaphore that checks that # releases is <= # acquires"""
def __init__(self, value=1):
Semaphore.__init__(self, value)
self._initial_value = value
def release(self):
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
return Semaphore.release(self)
### End backport from CPython 3.2
class DummySemaphore(object):
def __init__(self, value=None):
pass
def acquire(self, blocking=True, timeout=None):
return True
def release(self):
pass
class ExceededMaxWaiters(Exception):
pass
class MaxWaitersBoundedSemaphore(object):
def __init__(self, semaphore_class, value=1, max_waiters=1):
self.waiter_semaphore = semaphore_class(max_waiters)
self.semaphore = semaphore_class(value)
def acquire(self, blocking=True, timeout=None):
if not self.waiter_semaphore.acquire(False):
raise ExceededMaxWaiters()
try:
return self.semaphore.acquire(blocking, timeout)
finally:
self.waiter_semaphore.release()
def __getattr__(self, name):
return getattr(self.semaphore, name)
class MaxWaitersBoundedSemaphoreThread(MaxWaitersBoundedSemaphore):
def __init__(self, value=1, max_waiters=1):
MaxWaitersBoundedSemaphore.__init__(
self, BoundedSemaphore, value, max_waiters)
if have_greenlet:
class MaxWaitersBoundedSemaphoreGevent(MaxWaitersBoundedSemaphore):
def __init__(self, value=1, max_waiters=1):
MaxWaitersBoundedSemaphore.__init__(
self, gevent.coros.BoundedSemaphore, value, max_waiters)

View File

@ -25,7 +25,8 @@ from nose.plugins.skip import SkipTest
from test import host, port
from test.test_pooling_base import (
_TestPooling, _TestMaxPoolSize, _TestPoolSocketSharing, one)
_TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets,
_TestPoolSocketSharing, _TestWaitQueueMultiple, one)
class TestPoolingThreads(_TestPooling, unittest.TestCase):
@ -172,5 +173,13 @@ class TestPoolSocketSharingThreads(_TestPoolSocketSharing, unittest.TestCase):
use_greenlets = False
class TestMaxOpenSocketsThreads(_TestMaxOpenSockets, unittest.TestCase):
use_greenlets = False
class TestWaitQueueMultipleThreads(_TestWaitQueueMultiple, unittest.TestCase):
use_greenlets = False
if __name__ == "__main__":
unittest.main()

View File

@ -31,6 +31,7 @@ import pymongo.pool
from pymongo.mongo_client import MongoClient
from pymongo.pool import Pool, NO_REQUEST, NO_SOCKET_YET, SocketInfo
from pymongo.errors import ConfigurationError
from pymongo.thread_util import ExceededMaxWaiters
from test import version, host, port
from test.test_client import get_client
from test.utils import delay, is_mongos, one
@ -46,7 +47,7 @@ if sys.version_info[0] >= 3:
try:
import gevent
from gevent import Greenlet, monkey, hub
import gevent.coros, gevent.event
import gevent.event, gevent.thread
has_gevent = True
except ImportError:
has_gevent = False
@ -210,7 +211,7 @@ class CreateAndReleaseSocket(MongoThread):
self.lock = threading.Lock()
self.ready = threading.Event()
def __init__(self, ut, client, start_request, end_request, rendevous):
def __init__(self, ut, client, start_request, end_request, rendevous=None):
super(CreateAndReleaseSocket, self).__init__(ut)
self.client = client
self.start_request = start_request
@ -230,16 +231,17 @@ class CreateAndReleaseSocket(MongoThread):
# Don't finish until all threads reach this point
r = self.rendevous
r.lock.acquire()
r.nthreads_run += 1
if r.nthreads_run == r.nthreads:
# Everyone's here, let them finish
r.ready.set()
r.lock.release()
else:
r.lock.release()
r.ready.wait(timeout=60)
assert r.ready.isSet(), "Rendezvous timed out"
if r is not None:
r.lock.acquire()
r.nthreads_run += 1
if r.nthreads_run == r.nthreads:
# Everyone's here, let them finish
r.ready.set()
r.lock.release()
else:
r.lock.release()
r.ready.wait(2) # Wait two seconds
assert r.ready.isSet(), "Rendezvous timed out"
for i in range(self.end_request):
self.client.end_request()
@ -583,7 +585,6 @@ class _TestPooling(_TestPoolingBase):
# Kill old request socket
sock_info.sock.close()
cx_pool.maybe_return_socket(sock_info)
time.sleep(1.1) # trigger _check_closed
# Dead socket detected and removed
@ -673,7 +674,8 @@ class _TestMaxPoolSize(_TestPoolingBase):
with greenlets.
"""
def _test_max_pool_size(
self, start_request, end_request, max_pool_size=4, nthreads=10):
self, start_request, end_request, max_pool_size=4, nthreads=10,
use_rendezvous=True):
"""Start `nthreads` threads. Each calls start_request `start_request`
times, then find_one and waits at a barrier; once all reach the barrier
each calls end_request `end_request` times. The test asserts that the
@ -693,13 +695,21 @@ class _TestMaxPoolSize(_TestPoolingBase):
c = self.get_client(
max_pool_size=max_pool_size, auto_start_request=False)
rendevous = CreateAndReleaseSocket.Rendezvous(
nthreads, self.use_greenlets)
# If you increase nthreads over about 35, note a
# Gevent 0.13.6 bug on Mac, Greenlet.join() hangs if more than
# about 35 Greenlets share a MongoClient. Apparently fixed in
# recent Gevent development.
if use_rendezvous:
rendezvous = CreateAndReleaseSocket.Rendezvous(
nthreads, self.use_greenlets)
else:
rendezvous = None
threads = []
for i in range(nthreads):
t = CreateAndReleaseSocket(
self, c, start_request, end_request, rendevous)
self, c, start_request, end_request, rendezvous)
threads.append(t)
for t in threads:
@ -732,26 +742,30 @@ class _TestMaxPoolSize(_TestPoolingBase):
# Gevent 0.13 and less
the_hub.shutdown()
if start_request:
# Trigger final cleanup in Python <= 2.7.0.
cx_pool._ident.get()
if use_rendezvous:
if start_request:
# Trigger final cleanup in Python <= 2.7.0.
cx_pool._ident.get()
time.sleep(0.1)
self.assertEqual(10, len(cx_pool.sockets))
expected_idle = min(max_pool_size, nthreads)
message = (
'%d idle sockets (expected %d) and %d request sockets'
' (expected 0)' % (
len(cx_pool.sockets), expected_idle,
len(cx_pool._tid_to_sock)))
expected_idle = min(max_pool_size, nthreads)
message = (
'%d idle sockets (expected %d) and %d request sockets'
' (expected 0)' % (
len(cx_pool.sockets), expected_idle,
len(cx_pool._tid_to_sock)))
self.assertEqual(
expected_idle, len(cx_pool.sockets), message)
else:
# Without calling start_request(), threads can safely share
# sockets; the number running concurrently, and hence the number
# of sockets needed, is between 1 and
# min(max_pool_size, nthreads), depending on thread-scheduling.
self.assertTrue(len(cx_pool.sockets) >= 1)
self.assertEqual(
expected_idle, len(cx_pool.sockets), message)
else:
# Without calling start_request(), threads can safely share
# sockets; the number running concurrently, and hence the number
# of sockets needed, is between 1 and 10, depending on thread-
# scheduling.
self.assertTrue(len(cx_pool.sockets) >= 1)
time.sleep(0.1)
self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter)
self.assertEqual(0, len(cx_pool._tid_to_sock))
def test_max_pool_size(self):
@ -760,22 +774,206 @@ class _TestMaxPoolSize(_TestPoolingBase):
def test_max_pool_size_with_request(self):
self._test_max_pool_size(1, 1)
def test_max_pool_size_with_multiple_request(self):
self._test_max_pool_size(10, 10)
def test_max_pool_size_with_redundant_request(self):
self._test_max_pool_size(2, 1)
def test_max_pool_size_with_redundant_request2(self):
self._test_max_pool_size(20, 1)
def test_max_pool_size_with_redundant_request_no_rendezvous(self):
try:
self._test_max_pool_size(2, 1, False)
self._test_max_pool_size(20, 1, False)
except AssertionError:
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
# Python < 2.7 has a threadlocal bug which sometimes leaks
# sockets due to the threadlocal being accessed too close
# to thread death, causing on_thread_died not to get called.
#
# This is fixable with a monitor thread which checks for stale
# request sockets, but this situation is hopefully unlikely to
# happen in the real world.
raise SkipTest('Python < 2.7 threadlocal race condition usually'
' breaks this test')
else:
raise
def test_max_pool_size_with_leaked_request(self):
# Call start_request() but not end_request() -- when threads die, they
# should return their request sockets to the pool.
self._test_max_pool_size(1, 0)
def test_max_pool_size_with_leaked_request_no_rendezvous(self):
try:
self._test_max_pool_size(1, 0, False)
except AssertionError:
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
# Python < 2.7 has a threadlocal bug which sometimes leaks
# sockets due to the threadlocal being accessed too close
# to thread death, causing on_thread_died not to get called.
#
# This is fixable with a monitor thread which checks for stale
# request sockets, but this situation is hopefully unlikely to
# happen in the real world.
raise SkipTest('Python < 2.7 threadlocal race condition usually'
' breaks this test')
else:
raise
def test_max_pool_size_with_end_request_only(self):
# Call end_request() but not start_request()
self._test_max_pool_size(0, 1)
class _TestMaxOpenSockets(_TestPoolingBase):
"""Test that connection pool doesn't open more than max_size sockets.
To be run both with threads and with greenlets.
"""
def get_pool(self, conn_timeout, net_timeout, wait_queue_timeout):
return pymongo.pool.Pool(('127.0.0.1', 27017),
2, net_timeout, conn_timeout,
False, False,
wait_queue_timeout=wait_queue_timeout)
def test_over_max_times_out(self):
conn_timeout = 2
pool = self.get_pool(conn_timeout, conn_timeout + 5, conn_timeout)
s1 = pool.get_socket()
self.assertTrue(None is not s1)
s2 = pool.get_socket()
self.assertTrue(None is not s2)
self.assertNotEqual(s1, s2)
start = time.time()
self.assertRaises(socket.timeout, pool.get_socket)
end = time.time()
self.assertTrue(end - start > conn_timeout)
self.assertTrue(end - start < conn_timeout + 5)
def test_over_max_no_timeout_blocks(self):
class Thread(threading.Thread):
def __init__(self, pool):
super(Thread, self).__init__()
self.state = 'init'
self.pool = pool
self.sock = None
def run(self):
self.state = 'get_socket'
self.sock = self.pool.get_socket()
self.state = 'sock'
pool = self.get_pool(None, 2, None)
s1 = pool.get_socket()
self.assertTrue(None is not s1)
s2 = pool.get_socket()
self.assertTrue(None is not s2)
self.assertNotEqual(s1, s2)
t = Thread(pool)
t.start()
while t.state != 'get_socket':
time.sleep(0.1)
self.assertEqual(t.state, 'get_socket')
time.sleep(5)
self.assertEqual(t.state, 'get_socket')
pool.maybe_return_socket(s1)
while t.state != 'sock':
time.sleep(0.1)
self.assertEqual(t.state, 'sock')
self.assertEqual(t.sock, s1)
class _TestWaitQueueMultiple(_TestPoolingBase):
"""Test that connection pool doesn't allow more than
waitQueueMultiple * max_size waiters.
To be run both with threads and with greenlets.
"""
def get_pool(self, conn_timeout, net_timeout, wait_queue_timeout,
wait_queue_multiple):
return pymongo.pool.Pool(('127.0.0.1', 27017),
2, net_timeout, conn_timeout,
False, False,
wait_queue_timeout=wait_queue_timeout,
wait_queue_multiple=wait_queue_multiple)
def test_wait_queue_multiple(self):
class Thread(threading.Thread):
def __init__(self, pool):
super(Thread, self).__init__()
self.state = 'init'
self.pool = pool
self.sock = None
def run(self):
self.state = 'get_socket'
self.sock = self.pool.get_socket()
self.state = 'sock'
pool = self.get_pool(None, None, None, 3)
socks = []
for _ in xrange(2):
sock = pool.get_socket()
self.assertTrue(sock is not None)
socks.append(sock)
threads = []
for _ in xrange(6):
thread = Thread(pool)
thread.start()
threads.append(thread)
time.sleep(1)
for thread in threads:
self.assertEqual(thread.state, 'get_socket')
self.assertRaises(ExceededMaxWaiters, pool.get_socket)
while threads:
for sock in socks:
pool.maybe_return_socket(sock)
socks = []
for thread in list(threads):
if thread.sock is not None:
socks.append(thread.sock)
thread.join()
threads.remove(thread)
def test_wait_queue_multiple_unset(self):
class Thread(threading.Thread):
def __init__(self, pool):
super(Thread, self).__init__()
self.state = 'init'
self.pool = pool
self.sock = None
def run(self):
self.state = 'get_socket'
self.sock = self.pool.get_socket()
self.state = 'sock'
pool = self.get_pool(None, None, None, None)
socks = []
for _ in xrange(2):
sock = pool.get_socket()
self.assertTrue(sock is not None)
socks.append(sock)
threads = []
for _ in xrange(30):
thread = Thread(pool)
thread.start()
threads.append(thread)
time.sleep(1)
for thread in threads:
self.assertEqual(thread.state, 'get_socket')
while threads:
for sock in socks:
pool.maybe_return_socket(sock)
socks = []
for thread in list(threads):
if thread.sock is not None:
socks.append(thread.sock)
thread.join()
threads.remove(thread)
class _TestPoolSocketSharing(_TestPoolingBase):
"""Directly test that two simultaneous operations don't share a socket. To
be run both with threads and with greenlets.

View File

@ -22,7 +22,8 @@ from pymongo import pool
from test import host, port
from test.utils import looplet
from test.test_pooling_base import (
_TestPooling, _TestMaxPoolSize, _TestPoolSocketSharing)
_TestPooling, _TestMaxPoolSize, _TestMaxOpenSockets,
_TestPoolSocketSharing, _TestWaitQueueMultiple)
class TestPoolingGevent(_TestPooling, unittest.TestCase):
@ -178,5 +179,13 @@ class TestPoolSocketSharingGevent(_TestPoolSocketSharing, unittest.TestCase):
use_greenlets = True
class TestMaxOpenSocketsGevent(_TestMaxOpenSockets, unittest.TestCase):
use_greenlets = True
class TestWaitQueueMultipleGevent(_TestWaitQueueMultiple, unittest.TestCase):
use_greenlets = True
if __name__ == '__main__':
unittest.main()