Fix connection leak in mod_wsgi 2.x and Python <= 2.6: Instead of storing SocketInfos in threadlocals, watch for thread death with a weakref callback to a generic threadlocal PYTHON-353

This commit is contained in:
A. Jesse Jiryu Davis 2012-05-14 11:32:51 -04:00
parent 58b3d646bd
commit eeef9179ec
10 changed files with 507 additions and 213 deletions

View File

@ -741,7 +741,7 @@ class Connection(common.BaseObject):
sock_info)
rv = self.__check_response_to_last_error(response)
self.__pool.return_socket(sock_info)
self.__pool.maybe_return_socket(sock_info)
return rv
except (ConnectionFailure, socket.error), e:
self.disconnect()
@ -758,8 +758,7 @@ class Connection(common.BaseObject):
try:
chunk = sock_info.sock.recv(length)
except:
# If recv was interrupted, discard the socket
# and re-raise the exception.
# recv was interrupted
self.__pool.discard_socket(sock_info)
raise
if chunk == EMPTY:
@ -816,12 +815,12 @@ class Connection(common.BaseObject):
# Restore the socket's original timeout and return it to
# the pool
sock_info.sock.settimeout(self.__net_timeout)
self.__pool.return_socket(sock_info)
self.__pool.maybe_return_socket(sock_info)
except socket.error:
# There was an exception and we've closed the socket
pass
else:
self.__pool.return_socket(sock_info)
self.__pool.maybe_return_socket(sock_info)
def start_request(self):
"""Ensure the current thread or greenlet always uses the same socket

View File

@ -15,6 +15,7 @@
import os
import socket
import sys
import thread
import time
import threading
import weakref
@ -63,17 +64,15 @@ def _closed(sock):
class SocketInfo(object):
"""Store a socket with some metadata
"""
def __init__(self, sock, poolref):
def __init__(self, sock, pool_id):
self.sock = sock
# We can't strongly reference the Pool, because the Pool
# references this SocketInfo as long as it's in pool
self.poolref = poolref
self.authset = set()
self.closed = False
self.last_checkout = time.time()
self.pool_id = poolref().pool_id
# The pool's pool_id changes with each reset() so we can close sockets
# created before the last reset.
self.pool_id = pool_id
def close(self):
self.closed = True
@ -83,32 +82,17 @@ class SocketInfo(object):
except:
pass
def __del__(self):
if not self.closed:
# This socket was given out, but not explicitly returned. Perhaps
# the socket was assigned to a thread local for a request, but the
# request wasn't ended before the thread died. Reclaim the socket
# for the pool.
pool = self.poolref()
if pool:
# Return a copy of self rather than self -- the Python docs
# discourage postponing deletion by adding a reference to self.
copy = SocketInfo(self.sock, self.poolref)
copy.authset = self.authset
pool.return_socket(copy)
else:
# Close socket now rather than awaiting garbage collector
self.close()
def __eq__(self, other):
# Need to check if other is NO_REQUEST or NO_SOCKET_YET, and then check
# if its sock is the same as ours
return hasattr(other, 'sock') and self.sock == other.sock
def __hash__(self):
return hash(self.sock)
def __repr__(self):
return "SocketInfo(%s, %s)%s at %s" % (
repr(self.sock), repr(self.poolref()),
return "SocketInfo(%s)%s at %s" % (
repr(self.sock),
self.closed and " CLOSED" or "",
id(self)
)
@ -136,21 +120,22 @@ class BasePool(object):
self.net_timeout = net_timeout
self.conn_timeout = conn_timeout
self.use_ssl = use_ssl
# Map self._get_thread_ident() -> request socket
self._tid_to_sock = {}
# Weakrefs used by subclasses to watch for dead threads or greenlets.
# We must keep a reference to the weakref to keep it alive for at least
# as long as what it references, otherwise its delete-callback won't
# fire.
self._refs = {}
def reset(self):
# Ignore this race condition -- if many threads are resetting at once,
# the pool_id will definitely change, which is all we care about.
self.pool_id += 1
request_state = self._get_request_state()
self.pid = os.getpid()
# Close sockets before deleting them, otherwise they'll come
# running back.
if request_state not in (NO_REQUEST, NO_SOCKET_YET):
# request_state is a SocketInfo for this request
request_state.close()
sockets = None
try:
# Swapping variables is not atomic. We need to ensure no other
@ -163,15 +148,6 @@ class BasePool(object):
for sock_info in sockets: sock_info.close()
# Reset subclass's data structures
self._reset()
# If we were in a request before the reset, then delete the request
# socket, but resume the request with a new socket the next time
# get_socket() is called.
if request_state != NO_REQUEST:
self._set_request_state(NO_SOCKET_YET)
def create_connection(self, pair):
"""Connect to *pair* and return the socket object.
@ -227,7 +203,7 @@ class BasePool(object):
"not be configured with SSL support.")
sock.settimeout(self.net_timeout)
return SocketInfo(sock, weakref.ref(self))
return SocketInfo(sock, self.pool_id)
def get_socket(self, pair=None):
"""Get a socket from the pool.
@ -260,7 +236,7 @@ class BasePool(object):
sock_info, from_pool = None, None
try:
try:
# set.pop() isn't atomic in Jython, see
# set.pop() isn't atomic in Jython less than 2.7, see
# http://bugs.jython.org/issue1854
self.lock.acquire()
sock_info, from_pool = self.sockets.pop(), True
@ -293,20 +269,22 @@ class BasePool(object):
def end_request(self):
sock_info = self._get_request_state()
self._set_request_state(NO_REQUEST)
self.return_socket(sock_info)
if sock_info not in (NO_REQUEST, NO_SOCKET_YET):
self._return_socket(sock_info)
def discard_socket(self, sock_info):
"""Close and discard the active socket.
"""
if sock_info:
if sock_info not in (NO_REQUEST, NO_SOCKET_YET):
sock_info.close()
if sock_info == self._get_request_state():
# Discarding request socket; prepare to use a new request
# socket on next get_socket().
self._set_request_state(NO_SOCKET_YET)
def return_socket(self, sock_info):
"""Return the socket currently in use to the pool. If the
pool is full the socket will be discarded.
def maybe_return_socket(self, sock_info):
"""Return the socket to the pool unless it's the request socket.
"""
if self.pid != os.getpid():
self.reset()
@ -315,24 +293,26 @@ class BasePool(object):
return
if sock_info != self._get_request_state():
added = False
try:
self.lock.acquire()
if len(self.sockets) < self.max_size:
self.sockets.add(sock_info)
added = True
finally:
self.lock.release()
self._return_socket(sock_info)
if not added:
self.discard_socket(sock_info)
def _return_socket(self, sock_info):
"""Return socket to the pool. If pool is full the socket is discarded.
"""
try:
self.lock.acquire()
if len(self.sockets) < self.max_size:
self.sockets.add(sock_info)
else:
sock_info.close()
finally:
self.lock.release()
def _check(self, sock_info, pair):
"""This side-effecty function checks if this pool has been reset since
the last time this socket was used, or if the socket has been closed by
some external network error if it's been > 1 second since the last time
we used it, and if so, attempts to create a new socket. If this
connection attempt fails we reset the pool and reraise the error.
some external network error, and if so, attempts to create a new socket.
If this connection attempt fails we reset the pool and reraise the
error.
Checking sockets lets us avoid seeing *some*
:class:`~pymongo.errors.AutoReconnect` exceptions on server
@ -342,13 +322,16 @@ class BasePool(object):
"""
error = False
if self.pool_id != sock_info.pool_id:
self.discard_socket(sock_info)
if sock_info.closed:
error = True
elif self.pool_id != sock_info.pool_id:
sock_info.close()
error = True
elif time.time() - sock_info.last_checkout > 1:
if _closed(sock_info.sock):
self.discard_socket(sock_info)
sock_info.close()
error = True
if not error:
@ -360,25 +343,64 @@ class BasePool(object):
self.reset()
raise
# Overridable methods for Pools. These methods must simply set and get an
# arbitrary value associated with the execution context (thread, greenlet,
# Tornado StackContext, ...) in which we want to use a single socket.
def _set_request_state(self, sock_info):
raise NotImplementedError
tid = self._get_thread_ident()
if sock_info == NO_REQUEST:
# Ending a request
self._refs.pop(tid, None)
self._tid_to_sock.pop(tid, None)
else:
self._tid_to_sock[tid] = sock_info
if tid not in self._refs:
# Closure over tid and poolref. Don't refer directly to self,
# otherwise there's a cycle.
# Do not access threadlocals in this function, or any
# function it calls! In the case of the Pool subclass and
# mod_wsgi 2.x, on_thread_died() is triggered when mod_wsgi
# calls PyThreadState_Clear(), which deferences the
# ThreadVigil and triggers the weakref callback. Accessing
# thread locals in this function, while PyThreadState_Clear()
# is in progress can cause leaks, see PYTHON-353.
poolref = weakref.ref(self)
def on_thread_died(ref):
try:
pool = poolref()
if pool:
# End the request
pool._refs.pop(tid, None)
request_sock = pool._tid_to_sock.pop(tid, None)
# Was thread ever assigned a socket before it died?
if request_sock not in (NO_REQUEST, NO_SOCKET_YET):
pool._return_socket(request_sock)
except:
# Random exceptions on interpreter shutdown.
pass
self._watch_current_thread(on_thread_died)
def _get_request_state(self):
tid = self._get_thread_ident()
return self._tid_to_sock.get(tid, NO_REQUEST)
# Overridable methods for pools.
def _get_thread_ident(self):
raise NotImplementedError
def _reset(self):
pass
def _watch_current_thread(self, callback):
raise NotImplementedError
def __del__(self):
# Avoid ResourceWarnings in Python 3
for sock_info in self.sockets:
sock_info.close()
# This thread-local will hold a Pool's per-thread request state. sock_info
# defaults to NO_REQUEST each time it's accessed from a new thread. It's
# much simpler to make a separate thread-local class rather than having Pool
# inherit both from BasePool and threading.local.
class _Local(threading.local):
sock_info = NO_REQUEST
for request_sock in self._tid_to_sock.values():
if request_sock not in (NO_REQUEST, NO_SOCKET_YET):
request_sock.close()
class Pool(BasePool):
@ -388,17 +410,45 @@ class Pool(BasePool):
to the pool when the thread calls end_request() or dies.
"""
def __init__(self, *args, **kwargs):
self.local = _Local()
super(Pool, self).__init__(*args, **kwargs)
self._local = threading.local()
def _set_request_state(self, sock_info):
self.local.sock_info = sock_info
# Overrides
def _get_request_state(self):
return self.local.sock_info
# In Python <= 2.6, a dead thread's locals aren't cleaned up until the
# next access. That can lead to a nasty race where a new thread with
# the same ident as a previous one does _get_request_state() and thinks
# it's still in the previous thread's request. Only when some thread
# next accesses self._local.vigil does the dead thread's vigil get
# destroyed, triggered on_thread_died and returning the request socket
# to self.sockets. At that point a different thread can acquire that
# socket, and with two threads using the same socket they'll read
# each other's data. A symptom is an AssertionError in
# Connection.__receive_message_on_socket().
def _reset(self):
self.local.sock_info = NO_REQUEST
# Accessing the thread local here guarantees that a previous thread's
# locals are cleaned up before we check request state, and so even if
# this thread has the same ident as a previous one, we don't think we're
# in the same request.
getattr(self._local, 'vigil', None)
return super(Pool, self)._get_request_state()
def _get_thread_ident(self):
return thread.get_ident()
# After a thread calls start_request() and we assign it a socket, we must
# watch the thread to know if it dies without calling end_request so we can
# return its socket to the idle pool, self.sockets. We watch for
# thread-death using a weakref callback to a thread local. The weakref is
# permitted on subclasses of object but not object() itself, so we make
# this class.
class ThreadVigil(object):
pass
def _watch_current_thread(self, callback):
tid = self._get_thread_ident()
self._local.vigil = vigil = Pool.ThreadVigil()
self._refs[tid] = weakref.ref(vigil, callback)
class GreenletPool(BasePool):
@ -407,52 +457,24 @@ class GreenletPool(BasePool):
Calling start_request() acquires a greenlet-local socket, which is returned
to the pool when the greenlet calls end_request() or dies.
"""
def __init__(self, *args, **kwargs):
self._gr_id_to_sock = {}
# Weakrefs to non-Gevent greenlets
self._refs = {}
super(GreenletPool, self).__init__(*args, **kwargs)
# Overrides
def _set_request_state(self, sock_info):
def _get_thread_ident(self):
return id(greenlet.getcurrent())
def _watch_current_thread(self, callback):
current = greenlet.getcurrent()
gr_id = id(current)
tid = self._get_thread_ident()
if sock_info == NO_REQUEST:
self._refs.pop(gr_id, None)
self._gr_id_to_sock.pop(gr_id, None)
if hasattr(current, 'link'):
# This is a Gevent Greenlet (capital G), which inherits from
# greenlet and provides a 'link' method to detect when the
# Greenlet exits.
current.link(callback)
self._refs[tid] = None
else:
self._gr_id_to_sock[gr_id] = sock_info
def delete_callback(dummy):
# End the request
self._refs.pop(gr_id, None)
request_sock = self._gr_id_to_sock.pop(gr_id, None)
self.return_socket(request_sock)
if gr_id not in self._refs:
if hasattr(current, 'link'):
# This is a Gevent Greenlet (capital G), which inherits from
# greenlet and provides a 'link' method to detect when the
# Greenlet exits
current.link(delete_callback)
self._refs[gr_id] = None
else:
# This is a non-Gevent greenlet (small g), or it's the main
# greenlet. Since there's no link() method, we use a weakref
# to detect when the greenlet is garbage-collected. Garbage-
# collection is a later-firing and less reliable event than
# Greenlet.link() so we prefer link() if available.
self._refs[gr_id] = weakref.ref(current, delete_callback)
def _get_request_state(self):
gr_id = id(greenlet.getcurrent())
return self._gr_id_to_sock.get(gr_id, NO_REQUEST)
def _reset(self):
self._gr_id_to_sock.clear()
self._refs.clear()
# This is a non-Gevent greenlet (small g), or it's the main
# greenlet.
self._refs[tid] = weakref.ref(current, callback)
class Request(object):

View File

@ -533,7 +533,7 @@ class ReplicaSetConnection(common.BaseObject):
sock_info, 'admin', {'ismaster': 1}
)
pool.return_socket(sock_info)
pool.maybe_return_socket(sock_info)
return response, pool
def __update_pools(self):
@ -547,7 +547,7 @@ class ReplicaSetConnection(common.BaseObject):
mongo = self.__pools[host]
sock_info = self.__socket(mongo)
res = self.__simple_command(sock_info, 'admin', {'ismaster': 1})
mongo['pool'].return_socket(sock_info)
mongo['pool'].maybe_return_socket(sock_info)
else:
res, conn = self.__is_master(host)
bson_max = res.get('maxBsonObjectSize', MAX_BSON_SIZE)
@ -555,7 +555,7 @@ class ReplicaSetConnection(common.BaseObject):
'last_checkout': time.time(),
'max_bson_size': bson_max}
except (ConnectionFailure, socket.error):
if mongo and sock_info:
if mongo:
mongo['pool'].discard_socket(sock_info)
continue
# Only use hosts that are currently in 'secondary' state
@ -583,7 +583,7 @@ class ReplicaSetConnection(common.BaseObject):
sock_info = self.__socket(mongo)
response = self.__simple_command(sock_info, 'admin',
{'ismaster': 1})
mongo['pool'].return_socket(sock_info)
mongo['pool'].maybe_return_socket(sock_info)
else:
response, conn = self.__is_master(node)
@ -607,7 +607,7 @@ class ReplicaSetConnection(common.BaseObject):
hosts.update([_partition_node(h)
for h in response["passives"]])
except (ConnectionFailure, socket.error), why:
if mongo and sock_info:
if mongo:
mongo['pool'].discard_socket(sock_info)
errors.append("%s:%d: %s" % (node[0], node[1], str(why)))
if hosts:
@ -638,12 +638,12 @@ class ReplicaSetConnection(common.BaseObject):
'last_checkout': time.time(),
'max_bson_size': bson_max}
except (ConnectionFailure, socket.error), why:
if mongo and sock_info:
if mongo:
mongo['pool'].discard_socket(sock_info)
raise ConnectionFailure("%s:%d: %s" % (host[0], host[1], str(why)))
if mongo and sock_info:
mongo['pool'].return_socket(sock_info)
mongo['pool'].maybe_return_socket(sock_info)
if res["ismaster"]:
return host
@ -814,7 +814,7 @@ class ReplicaSetConnection(common.BaseObject):
if safe:
response = self.__recv_msg(1, rqst_id, sock_info)
rv = self.__check_response_to_last_error(response)
mongo['pool'].return_socket(sock_info)
mongo['pool'].maybe_return_socket(sock_info)
return rv
except(ConnectionFailure, socket.error), why:
mongo['pool'].discard_socket(sock_info)
@ -842,7 +842,7 @@ class ReplicaSetConnection(common.BaseObject):
if "network_timeout" in kwargs:
sock_info.sock.settimeout(self.__net_timeout)
mongo['pool'].return_socket(sock_info)
mongo['pool'].maybe_return_socket(sock_info)
return response
except (ConnectionFailure, socket.error), why:

View File

@ -0,0 +1,27 @@
# Copyright 2012 10gen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Minimal test of PyMongo in a WSGI application, see bug PYTHON-353
<VirtualHost *>
ServerName localhost
WSGIDaemonProcess mod_wsgi_test processes=1 threads=15 display-name=mod_wsgi_test
WSGIProcessGroup mod_wsgi_test
# For the convienience of unittests, rather than hard-code the location of
# mod_wsgi_test.wsgi, include it in the URL, so
# http://localhost/location-of-pymongo-checkout will work.
WSGIScriptAliasMatch ^(.+) $1/test/mod_wsgi_test/mod_wsgi_test.wsgi
</VirtualHost>

View File

@ -0,0 +1,52 @@
# Copyright 2012 10gen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Minimal test of PyMongo in a WSGI application, see bug PYTHON-353
"""
import os
import sys
this_path = os.path.dirname(os.path.join(os.getcwd(), __file__))
# Location of PyMongo checkout
repository_path = os.path.normpath(os.path.join(this_path, '..', '..'))
sys.path.insert(0, repository_path)
import pymongo
from pymongo.connection import Connection
connection = Connection()
collection = connection.test.test
ndocs = 20
collection.drop()
collection.insert([{'i': i} for i in range(ndocs)], safe=True)
connection.disconnect() # discard main thread's request socket
try:
from mod_wsgi import version as mod_wsgi_version
except:
mod_wsgi_version = None
def application(environ, start_response):
results = list(collection.find().batch_size(10))
assert len(results) == ndocs
output = 'python %s, mod_wsgi %s, pymongo %s' % (
sys.version, mod_wsgi_version, pymongo.version)
response_headers = [('Content-Length', str(len(output)))]
start_response('200 OK', response_headers)
return [output]

View File

@ -0,0 +1,165 @@
# Copyright 2012 10gen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test client for mod_wsgi application, see bug PYTHON-353.
"""
import sys
import urllib2
import thread
import threading
import time
from optparse import OptionParser
def parse_args():
parser = OptionParser("""usage: %prog [options] mode url
mode:\tparallel or serial""")
# Should be enough that any collection leak will exhaust available file
# descriptors.
parser.add_option("-n", "--nrequests", type="int",
dest="nrequests", default=50 * 1000,
help="Number of times to GET the URL, in total")
parser.add_option("-t", "--nthreads", type="int",
dest="nthreads", default=100,
help="Number of threads with mode 'parallel'")
parser.add_option("-q", "--quiet",
action="store_false", dest="verbose", default=True,
help="Don't print status messages to stdout")
parser.add_option("-c", "--continue",
action="store_true", dest="continue_", default=False,
help="Continue after HTTP errors")
try:
options, (mode, url) = parser.parse_args()
except ValueError:
parser.print_usage()
sys.exit(1)
if mode not in ('parallel', 'serial'):
parser.print_usage()
sys.exit(1)
return options, mode, url
def get(url):
urllib2.urlopen(url).read().strip()
class URLGetterThread(threading.Thread):
# class variables
counter_lock = threading.Lock()
counter = 0
def __init__(self, options, url, nrequests_per_thread):
super(URLGetterThread, self).__init__()
self.options = options
self.url = url
self.nrequests_per_thread = nrequests_per_thread
self.errors = 0
def run(self):
for i in range(self.nrequests_per_thread):
try:
get(url)
except Exception, e:
print e
if not options.continue_:
thread.interrupt_main()
thread.exit()
self.errors += 1
URLGetterThread.counter_lock.acquire()
URLGetterThread.counter += 1
counter = URLGetterThread.counter
URLGetterThread.counter_lock.release()
should_print = options.verbose and not counter % 1000
if should_print:
print counter
def main(options, mode, url):
start_time = time.time()
errors = 0
if mode == 'parallel':
nrequests_per_thread = options.nrequests / options.nthreads
if options.verbose:
print ('Getting %s %s times total in %s threads, '
'%s times per thread' % (
url, nrequests_per_thread * options.nthreads, options.nthreads,
nrequests_per_thread
))
threads = [
URLGetterThread(options, url, nrequests_per_thread)
for _ in range(options.nthreads)
]
for t in threads:
t.start()
for t in threads:
t.join()
errors = sum([t.errors for t in threads])
nthreads_with_errors = len([t for t in threads if t.errors])
if nthreads_with_errors:
print '%d threads had errors! %d errors in total' % (
nthreads_with_errors, errors)
else:
assert mode == 'serial'
if options.verbose:
print 'Getting %s %s times in one thread' % (
url, options.nrequests
)
for i in range(1, options.nrequests + 1):
try:
get(url)
except Exception, e:
print e
if not options.continue_:
sys.exit(1)
errors += 1
if options.verbose and not i % 1000:
print i
if errors:
print '%d errors!' % errors
if options.verbose:
print 'Completed in %.2f seconds' % (time.time() - start_time)
if errors:
# Failure
sys.exit(1)
if __name__ == '__main__':
options, mode, url = parse_args()
main(options, mode, url)

View File

@ -502,6 +502,8 @@ with get_connection() as connection:
sock_info0 = self.get_sock(pool)
sock_info1 = self.get_sock(pool)
self.assertEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)
def assertDifferentSock(self, pool):
# We have to hold both SocketInfos at the same time, otherwise the
@ -511,6 +513,8 @@ with get_connection() as connection:
sock_info0 = self.get_sock(pool)
sock_info1 = self.get_sock(pool)
self.assertNotEqual(sock_info0, sock_info1)
pool.maybe_return_socket(sock_info0)
pool.maybe_return_socket(sock_info1)
def assertNoRequest(self, pool):
self.assertEqual(NO_REQUEST, pool._get_request_state())

View File

@ -76,9 +76,15 @@ class TestPoolingThreads(_TestPooling, unittest.TestCase):
lock = thread.allocate_lock()
lock.acquire()
sock_ids = []
def run_in_request():
p.start_request()
p.get_socket()
sock0 = p.get_socket()
sock1 = p.get_socket()
sock_ids.extend([id(sock0), id(sock1)])
p.maybe_return_socket(sock0)
p.maybe_return_socket(sock1)
p.end_request()
lock.release()
@ -93,6 +99,7 @@ class TestPoolingThreads(_TestPooling, unittest.TestCase):
break
self.assertTrue(acquired, "Thread is hung")
self.assertEqual(sock_ids[0], sock_ids[1])
def test_pool_with_fork(self):
# Test that separate Connections have separate Pools, and that the
@ -145,8 +152,11 @@ class TestPoolingThreads(_TestPooling, unittest.TestCase):
self.assertTrue(a_sock.sock.getsockname() != b_sock)
self.assertTrue(a_sock.sock.getsockname() != c_sock)
self.assertTrue(b_sock != c_sock)
self.assertEqual(a_sock,
a._Connection__pool.get_socket((a.host, a.port)))
# a_sock, created by parent process, is still in the pool
d_sock = a._Connection__pool.get_socket((a.host, a.port))
self.assertEqual(a_sock, d_sock)
d_sock.close()
class TestMaxPoolSizeThreads(_TestMaxPoolSize, unittest.TestCase):

View File

@ -58,38 +58,6 @@ def one(s):
return iter(s).next()
def force_reclaim_sockets(cx_pool, n_expected):
# When a thread dies without ending its request, the SocketInfo it was
# using is deleted, and in its __del__ it returns the socket to the
# pool. However, when exactly that happens is unpredictable. Try
# various ways of forcing the issue.
if sys.platform.startswith('java'):
raise SkipTest("Jython can't reclaim sockets")
if 'PyPy' in sys.version:
raise SkipTest("Socket reclamation happens at unpredictable time in PyPy")
# Bizarre behavior in CPython 2.4, and possibly other CPython versions
# less than 2.7: the last dead thread's locals aren't cleaned up until
# the local attribute with the same name is accessed from a different
# thread. This assert checks that the thread-local is indeed local, and
# also triggers the cleanup so the socket is reclaimed.
if isinstance(cx_pool, Pool):
assert cx_pool.local.sock_info is None
# Try for a while to make garbage-collection call SocketInfo.__del__
start = time.time()
while len(cx_pool.sockets) < n_expected and time.time() - start < 5:
try:
gc.collect(2)
except TypeError:
# collect() didn't support 'generation' arg until 2.5
gc.collect()
time.sleep(0.5)
class MongoThread(object):
"""A thread, or a greenlet, that uses a Connection"""
def __init__(self, test_case):
@ -379,10 +347,14 @@ class _TestPooling(_TestPoolingBase):
run_cases(self, [SaveAndFind, Disconnect, Unique])
def test_independent_pools(self):
# Test for regression of very early PyMongo bug: separate pools shared
# state.
p = self.get_pool((host, port), 10, None, None, False)
self.c.start_request()
self.c.pymongo_test.test.find_one()
self.assertEqual(set(), p.sockets)
self.c.end_request()
self.assert_pool_size(1)
self.assertEqual(set(), p.sockets)
def test_dependent_pools(self):
@ -437,6 +409,9 @@ class _TestPooling(_TestPoolingBase):
self.assertEqual(a_sock,
a._Connection__pool.get_socket((a.host, a.port)))
a_sock.close()
b_sock.close()
def test_request(self):
# Check that Pool gives two different sockets in two calls to
# get_socket() -- doesn't automatically put us in a request any more
@ -476,6 +451,9 @@ class _TestPooling(_TestPoolingBase):
# end_request() returned sock2 to pool
self.assertEqual(sock4, sock2)
for s in [sock0, sock1, sock2, sock3, sock4, sock5]:
s.close()
def test_reset_and_request(self):
# reset() is called after a fork, or after a socket error. Ensure that
# a new request is begun if a request was in progress when the reset()
@ -495,17 +473,14 @@ class _TestPooling(_TestPoolingBase):
# Test Pool's _check_closed() method doesn't close a healthy socket
cx_pool = self.get_pool((host,port), 10, None, None, False)
sock_info = cx_pool.get_socket()
cx_pool.return_socket(sock_info)
cx_pool.maybe_return_socket(sock_info)
# trigger _check_closed, which only runs on sockets that haven't been
# used in a second
time.sleep(1.1)
new_sock_info = cx_pool.get_socket()
self.assertEqual(sock_info, new_sock_info)
del sock_info, new_sock_info
# Assert sock_info was returned to the pool *once*
force_reclaim_sockets(cx_pool, 1)
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
def test_pool_removes_dead_socket(self):
@ -517,15 +492,38 @@ class _TestPooling(_TestPoolingBase):
# Simulate a closed socket without telling the SocketInfo it's closed
sock_info.sock.close()
self.assertTrue(pymongo.pool._closed(sock_info.sock))
cx_pool.return_socket(sock_info)
cx_pool.maybe_return_socket(sock_info)
time.sleep(1.1) # trigger _check_closed
new_sock_info = cx_pool.get_socket()
self.assertEqual(0, len(cx_pool.sockets))
self.assertNotEqual(sock_info, new_sock_info)
del sock_info, new_sock_info
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
# new_sock_info returned to the pool, but not the closed sock_info
force_reclaim_sockets(cx_pool, 1)
def test_pool_removes_dead_request_socket_after_1_sec(self):
# Test that Pool keeps request going even if a socket dies in request
cx_pool = self.get_pool((host,port), 10, None, None, False)
cx_pool.start_request()
# Get the request socket
sock_info = cx_pool.get_socket()
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
sock_info.sock.close()
cx_pool.maybe_return_socket(sock_info)
time.sleep(1.1) # trigger _check_closed
# Although the request socket died, we're still in a request with a
# new socket
new_sock_info = cx_pool.get_socket()
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(0, len(cx_pool.sockets))
cx_pool.end_request()
self.assertEqual(1, len(cx_pool.sockets))
def test_pool_removes_dead_request_socket(self):
@ -537,16 +535,19 @@ class _TestPooling(_TestPoolingBase):
sock_info = cx_pool.get_socket()
self.assertEqual(0, len(cx_pool.sockets))
self.assertEqual(sock_info, cx_pool._get_request_state())
sock_info.sock.close()
cx_pool.return_socket(sock_info)
time.sleep(1.1) # trigger _check_closed
# Unlike in test_pool_removes_dead_request_socket_after_1_sec, we
# set sock_info.closed and *don't* wait 1 second
sock_info.close()
cx_pool.maybe_return_socket(sock_info)
# Although the request socket died, we're still in a request with a
# new socket
new_sock_info = cx_pool.get_socket()
self.assertTrue(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
cx_pool.return_socket(new_sock_info)
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(new_sock_info, cx_pool._get_request_state())
self.assertEqual(0, len(cx_pool.sockets))
@ -569,17 +570,22 @@ class _TestPooling(_TestPoolingBase):
# Kill old request socket
sock_info.sock.close()
old_sock_info_id = id(sock_info)
del sock_info
cx_pool.maybe_return_socket(sock_info)
time.sleep(1.1) # trigger _check_closed
# Dead socket detected and removed
new_sock_info = cx_pool.get_socket()
self.assertNotEqual(id(new_sock_info), old_sock_info_id)
self.assertFalse(cx_pool.in_request())
self.assertNotEqual(sock_info, new_sock_info)
self.assertEqual(0, len(cx_pool.sockets))
self.assertFalse(pymongo.pool._closed(new_sock_info.sock))
cx_pool.maybe_return_socket(new_sock_info)
self.assertEqual(1, len(cx_pool.sockets))
def test_socket_reclamation(self):
if sys.platform.startswith('java'):
raise SkipTest("Jython can't do socket reclamation")
# Check that if a thread starts a request and dies without ending
# the request, that the socket is reclaimed into the pool.
cx_pool = self.get_pool(
@ -602,6 +608,7 @@ class _TestPooling(_TestPoolingBase):
sock_info = cx_pool.get_socket()
self.assertEqual(sock_info, cx_pool._get_request_state())
the_sock[0] = id(sock_info.sock)
cx_pool.maybe_return_socket(sock_info)
if not self.use_greenlets:
lock.release()
@ -623,7 +630,23 @@ class _TestPooling(_TestPoolingBase):
acquired = lock.acquire()
self.assertTrue(acquired, "Thread is hung")
force_reclaim_sockets(cx_pool, 1)
# Make sure thread is really gone
time.sleep(1)
if 'PyPy' in sys.version:
gc.collect()
# Access the thread local from the main thread to trigger the
# ThreadVigil's delete callback, returning the request socket to
# the pool.
# In Python 2.6 and lesser, a dead thread's locals are deleted
# and those locals' weakref callbacks are fired only when another
# thread accesses the locals and finds the thread state is stale.
# This is more or less a bug in Python <= 2.6. Accessing the thread
# local from the main thread is a necessary part of this test, and
# realistic: in a multithreaded web server a new thread will access
# Pool._local soon after an old thread has died.
getattr(cx_pool._local, 'whatever', None)
# Pool reclaimed the socket
self.assertEqual(1, len(cx_pool.sockets))
@ -637,17 +660,12 @@ class _TestMaxPoolSize(_TestPoolingBase):
"""
def _test_max_pool_size(self, start_request, end_request):
c = self.get_connection(max_pool_size=4, auto_start_request=False)
# If you increase nthreads over about 35, note a
# Gevent 0.13.6 bug on Mac, Greenlet.join() hangs if more than
# about 35 Greenlets share a Connection. Apparently fixed in
# recent Gevent development.
nthreads = 10
if (
self.use_greenlets and sys.platform == 'darwin'
and gevent.version_info[0] < 1
):
# Gevent 0.13.6 bug on Mac, Greenlet.join() hangs if more than
# about 35 Greenlets share a Connection. Apparently fixed in
# recent Gevent development.
nthreads = 30
threads = []
for i in range(nthreads):
t = CreateAndReleaseSocket(self, c, start_request, end_request)
@ -662,19 +680,16 @@ class _TestMaxPoolSize(_TestPoolingBase):
for t in threads:
self.assertTrue(t.passed)
# Critical: release refs to threads, so SocketInfo.__del__() executes
# and reclaims sockets.
del threads
t = None
cx_pool = c._Connection__pool
force_reclaim_sockets(cx_pool, 4)
nsock = len(cx_pool.sockets)
# Socket-reclamation depends on timely garbage-collection, so be lenient
self.assertTrue(2 <= nsock <= 4,
msg="Expected between 2 and 4 sockets in the pool, got %d" % nsock)
# Socket-reclamation depends on timely garbage-collection
if 'PyPy' in sys.version:
import gc
gc.collect()
if not sys.platform.startswith('java'):
self.assertEqual(4, len(cx_pool.sockets))
def test_max_pool_size(self):
self._test_max_pool_size(0, 0)

View File

@ -116,7 +116,7 @@ class TestPoolingGeventSpecial(unittest.TestCase):
for _ in range(2):
sock = cx_pool.get_socket()
cx_pool.return_socket(sock)
cx_pool.maybe_return_socket(sock)
greenlet2socks.setdefault(
greenlet.getcurrent(), []
).append(id(sock))