PYTHON-1272 Fix deadlock when garbage collecting pinned cursors and sessions (#642)

It's not safe to return the pinned connection to the pool from within
Cursor.del because the Pool's lock may be held by a python thread
while the cyclic garbage collector runs. Instead we send the cursor
cleanup request to the client's background thread. The thread will
send killCursors on the pinned socket and then return the socket to
the pool.
Also fixed a similar bug when garbage collecting a pinned session.
This commit is contained in:
Shane Harvey 2021-06-22 17:29:26 -07:00 committed by GitHub
parent 3ef01179a2
commit 6bc5e088af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 219 additions and 67 deletions

View File

@ -289,7 +289,7 @@ class _TxnState(object):
class _Transaction(object):
"""Internal class to hold transaction information in a ClientSession."""
def __init__(self, opts):
def __init__(self, opts, client):
self.opts = opts
self.state = _TxnState.NONE
self.sharded = False
@ -297,6 +297,7 @@ class _Transaction(object):
self.sock_mgr = None
self.recovery_token = None
self.attempt = 0
self.client = client
def active(self):
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@ -330,6 +331,13 @@ class _Transaction(object):
self.recovery_token = None
self.attempt = 0
def __del__(self):
if self.sock_mgr:
# Reuse the cursor closing machinery to return the socket to the
# pool soon.
self.client._close_cursor_soon(0, None, self.sock_mgr)
self.sock_mgr = None
def _reraise_with_unknown_commit(exc):
"""Re-raise an exception with the UnknownTransactionCommitResult label."""
@ -382,7 +390,7 @@ class ClientSession(object):
self._operation_time = None
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None)
self._transaction = _Transaction(None, client)
def end_session(self):
"""Finish this session. If a transaction has started, abort it.

View File

@ -64,8 +64,7 @@ class CommandCursor(object):
raise TypeError("max_await_time_ms must be an integer or None")
def __del__(self):
if self.__id and not self.__killed:
self.__die()
self.__die()
def __die(self, synchronous=False):
"""Closes this cursor.
@ -73,20 +72,23 @@ class CommandCursor(object):
already_killed = self.__killed
self.__killed = True
if self.__id and not already_killed:
cursor_id = self.__id
address = _CursorAddress(
self.__address, self.__collection.full_name)
if synchronous:
self.__collection.database.client._close_cursor_now(
self.__id, address, session=self.__session,
sock_mgr=self.__sock_mgr)
else:
# The cursor will be closed later in a different session.
self.__collection.database.client._close_cursor(
self.__id, address)
if self.__sock_mgr:
self.__sock_mgr.close()
self.__sock_mgr = None
self.__end_session(synchronous)
else:
# Skip killCursors.
cursor_id = 0
address = None
self.__collection.database.client._cleanup_cursor(
synchronous,
cursor_id,
address,
self.__sock_mgr,
self.__session,
self.__explicit_session)
if not self.__explicit_session:
self.__session = None
self.__sock_mgr = None
def __end_session(self, synchronous):
if self.__session and not self.__explicit_session:
@ -185,7 +187,7 @@ class CommandCursor(object):
self.__id = response.data.cursor_id
if self.__id == 0:
self.__die(True)
self.close()
self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options,

View File

@ -34,7 +34,6 @@ from pymongo.message import (_CursorAddress,
_RawBatchGetMore,
_Query,
_RawBatchQuery)
from pymongo.monitoring import ConnectionClosedReason
from pymongo.response import PinnedResponse
# These errors mean that the server has already killed the cursor so there is
@ -106,28 +105,23 @@ class CursorType(object):
"""
# This has to be an old style class due to
# http://bugs.jython.org/issue1057
class _SocketManager:
class _SocketManager(object):
"""Used with exhaust cursors to ensure the socket is returned.
"""
def __init__(self, sock, more_to_come):
self.sock = sock
self.more_to_come = more_to_come
self.__closed = False
self.closed = False
self.lock = threading.Lock()
def __del__(self):
self.close()
def update_exhaust(self, more_to_come):
self.more_to_come = more_to_come
def close(self):
"""Return this instance's socket to the connection pool.
"""
if not self.__closed:
self.__closed = True
if not self.closed:
self.closed = True
self.sock.unpin()
self.sock = None
@ -156,6 +150,7 @@ class Cursor(object):
"""
# Initialize all attributes used in __del__ before possibly raising
# an error to avoid attribute errors during garbage collection.
self.__collection = collection
self.__id = None
self.__exhaust = False
self.__sock_mgr = None
@ -208,7 +203,6 @@ class Cursor(object):
projection = {"_id": 1}
projection = helpers._fields_list_to_dict(projection, "projection")
self.__collection = collection
self.__spec = spec
self.__projection = projection
self.__skip = skip
@ -293,6 +287,7 @@ class Cursor(object):
be sent to the server, even if the resultant data has already been
retrieved by this cursor.
"""
self.close()
self.__data = deque()
self.__id = None
self.__address = None
@ -349,29 +344,23 @@ class Cursor(object):
self.__killed = True
if self.__id and not already_killed:
if self.__exhaust and self.__sock_mgr:
# If this is an exhaust cursor and we haven't completely
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
self.__sock_mgr.sock.close_socket(
ConnectionClosedReason.ERROR)
else:
address = _CursorAddress(
self.__address, self.__collection.full_name)
if synchronous:
self.__collection.database.client._close_cursor_now(
self.__id, address, session=self.__session,
sock_mgr=self.__sock_mgr)
else:
# The cursor will be closed later in a different session.
self.__collection.database.client._close_cursor(
self.__id, address)
if self.__sock_mgr:
self.__sock_mgr.close()
self.__sock_mgr = None
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
cursor_id = self.__id
address = _CursorAddress(
self.__address, self.__collection.full_name)
else:
# Skip killCursors.
cursor_id = 0
address = None
self.__collection.database.client._cleanup_cursor(
synchronous,
cursor_id,
address,
self.__sock_mgr,
self.__session,
self.__explicit_session)
if not self.__explicit_session:
self.__session = None
self.__sock_mgr = None
def close(self):
"""Explicitly close / kill this cursor.
@ -1094,10 +1083,10 @@ class Cursor(object):
if self.__id == 0:
# Don't wait for garbage collection to call __del__, return the
# socket and the session to the pool now.
self.__die()
self.close()
if self.__limit and self.__id and self.__limit <= self.__retrieved:
self.__die()
self.close()
def _unpack_response(self, response, cursor_id, codec_options,
user_fields=None, legacy_response=False):

View File

@ -1483,14 +1483,46 @@ class MongoClient(common.BaseObject):
"""
return database.Database(self, name)
def _close_cursor(self, cursor_id, address):
"""Send a kill cursors message with the given id.
def _cleanup_cursor(self, locks_allowed, cursor_id, address, sock_mgr,
session, explicit_session):
"""Cleanup a cursor from cursor.close() or __del__.
What closing the cursor actually means depends on this client's
cursor manager. If there is none, the cursor is closed asynchronously
on a background thread.
This method handles cleanup for Cursors/CommandCursors including any
pinned connection or implicit session attached at the time the cursor
was closed or garbage collected.
:Parameters:
- `locks_allowed`: True if we are allowed to acquire locks.
- `cursor_id`: The cursor id which may be 0.
- `address`: The _CursorAddress.
- `sock_mgr`: The _SocketManager for the pinned connection or None.
- `session`: The cursor's session.
- `explicit_session`: True if the session was passed explicitly.
"""
self.__kill_cursors_queue.append((address, [cursor_id]))
if locks_allowed:
if cursor_id:
if sock_mgr and sock_mgr.more_to_come:
# If this is an exhaust cursor and we haven't completely
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
sock_mgr.sock.close_socket(
ConnectionClosedReason.ERROR)
else:
self._close_cursor_now(
cursor_id, address, session=session,
sock_mgr=sock_mgr)
if sock_mgr:
sock_mgr.close()
else:
# The cursor will be closed later in a different session.
if cursor_id or sock_mgr:
self._close_cursor_soon(cursor_id, address, sock_mgr)
if session and not explicit_session:
session._end_session(lock=locks_allowed)
def _close_cursor_soon(self, cursor_id, address, sock_mgr=None):
"""Request that a cursor and/or connection be cleaned up soon."""
self.__kill_cursors_queue.append((address, cursor_id, sock_mgr))
def _close_cursor_now(self, cursor_id, address=None, session=None,
sock_mgr=None):
@ -1512,7 +1544,7 @@ class MongoClient(common.BaseObject):
[cursor_id], address, self._get_topology(), session)
except PyMongoError:
# Make another attempt to kill the cursor later.
self.__kill_cursors_queue.append((address, [cursor_id]))
self._close_cursor_soon(cursor_id, address)
def _kill_cursors(self, cursor_ids, address, topology, session):
"""Send a kill cursors message with the given ids."""
@ -1577,15 +1609,26 @@ class MongoClient(common.BaseObject):
def _process_kill_cursors(self):
"""Process any pending kill cursors requests."""
address_to_cursor_ids = defaultdict(list)
pinned_cursors = []
# Other threads or the GC may append to the queue concurrently.
while True:
try:
address, cursor_ids = self.__kill_cursors_queue.pop()
address, cursor_id, sock_mgr = self.__kill_cursors_queue.pop()
except IndexError:
break
address_to_cursor_ids[address].extend(cursor_ids)
if sock_mgr:
pinned_cursors.append((address, cursor_id, sock_mgr))
else:
address_to_cursor_ids[address].append(cursor_id)
for address, cursor_id, sock_mgr in pinned_cursors:
try:
self._cleanup_cursor(True, cursor_id, address, sock_mgr,
None, False)
except Exception:
helpers._handle_exception()
# Don't re-open topology if it's closed and there's no pending cursors.
if address_to_cursor_ids:

View File

@ -185,8 +185,8 @@ class client_knobs(object):
def __del__(self):
if self._enabled:
print(
'\nERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, '
msg = (
'ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, '
'MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, '
'EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s' % (
common.HEARTBEAT_FREQUENCY,
@ -195,6 +195,7 @@ class client_knobs(object):
common.EVENTS_QUEUE_FREQUENCY,
self._stack))
self.disable()
raise Exception(msg)
def _all_users(db):

View File

@ -34,6 +34,7 @@ from pymongo.errors import (ConfigurationError,
from pymongo.read_preferences import ReadPreference
import gridfs
from gridfs.errors import CorruptGridFile, FileExists, NoFile
from gridfs.grid_file import GridOutCursor
from test import (client_context,
unittest,
IntegrationTest)
@ -445,6 +446,14 @@ class TestGridfs(IntegrationTest):
cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
def test_delete_not_initialized(self):
# Creating a cursor with invalid arguments will not run __init__
# but will still call __del__.
cursor = GridOutCursor.__new__(GridOutCursor) # Skip calling __init__
with self.assertRaises(TypeError):
cursor.__init__(self.db.fs.files, {}, {"_id": True})
cursor.__del__() # no error
def test_gridfs_find_one(self):
self.assertEqual(None, self.fs.find_one())

View File

@ -14,13 +14,15 @@
"""Test the Load Balancer unified spec tests."""
import gc
import os
import sys
import threading
sys.path[0:0] = [""]
from test import unittest, IntegrationTest, client_context
from test.utils import get_pool
from test.utils import get_pool, wait_until, ExceptionCatchingThread
from test.unified_format import generate_test_classes
# Location of JSON test specifications.
@ -56,6 +58,106 @@ class TestLB(IntegrationTest):
self.client.close()
self.db.test.find_one({})
@client_context.require_failCommand_fail_point
def test_cursor_gc(self):
def create_resource(coll):
cursor = coll.find({}, batch_size=3)
next(cursor)
return cursor
self._test_no_gc_deadlock(create_resource)
@client_context.require_failCommand_fail_point
def test_command_cursor_gc(self):
def create_resource(coll):
cursor = coll.aggregate([], batchSize=3)
next(cursor)
return cursor
self._test_no_gc_deadlock(create_resource)
def _test_no_gc_deadlock(self, create_resource):
pool = get_pool(self.client)
self.assertEqual(pool.active_sockets, 0)
self.db.test.insert_many([{} for _ in range(10)])
# Cause the initial find attempt to fail to induce a reference cycle.
args = {
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find", "aggregate"
],
"errorCode": 91,
"closeConnection": True,
}
}
with self.fail_point(args):
resource = create_resource(self.db.test)
if client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
thread = PoolLocker(pool)
thread.start()
self.assertTrue(thread.locked.wait(5), 'timed out')
# Garbage collect the resource while the pool is locked to ensure we
# don't deadlock.
del resource
gc.collect()
thread.unlock.set()
thread.join(5)
self.assertFalse(thread.is_alive())
self.assertIsNone(thread.exc)
wait_until(lambda: pool.active_sockets == 0, 'return socket')
# Run another operation to ensure the socket still works.
self.db.test.delete_many({})
@client_context.require_transactions
def test_session_gc(self):
pool = get_pool(self.client)
self.assertEqual(pool.active_sockets, 0)
session = self.client.start_session()
session.start_transaction()
self.client.test_session_gc.test.find_one({}, session=session)
if client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.
thread = PoolLocker(pool)
thread.start()
self.assertTrue(thread.locked.wait(5), 'timed out')
# Garbage collect the session while the pool is locked to ensure we
# don't deadlock.
del session
# On PyPy it can take a few rounds to collect the session.
for _ in range(3):
gc.collect()
thread.unlock.set()
thread.join(5)
self.assertFalse(thread.is_alive())
self.assertIsNone(thread.exc)
wait_until(lambda: pool.active_sockets == 0, 'return socket')
# Run another operation to ensure the socket still works.
self.db.test.delete_many({})
class PoolLocker(ExceptionCatchingThread):
def __init__(self, pool):
super(PoolLocker, self).__init__(target=self.lock_pool)
self.pool = pool
self.daemon = True
self.locked = threading.Event()
self.unlock = threading.Event()
def lock_pool(self):
with self.pool.lock:
self.locked.set()
# Wait for the unlock flag.
unlock_pool = self.unlock.wait(10)
if not unlock_pool:
raise Exception('timed out waiting for unlock signal:'
' deadlock?')
if __name__ == "__main__":
unittest.main()

View File

@ -53,13 +53,11 @@ class TestClientOptions(PyMongoTestCase):
class TestSpec(SpecRunner):
@classmethod
@client_context.require_version_min(4, 0)
@client_context.require_failCommand_fail_point
# TODO: remove this once PYTHON-1948 is done.
@client_context.require_no_mmap
def setUpClass(cls):
super(TestSpec, cls).setUpClass()
if client_context.is_mongos and client_context.version[:2] <= (4, 0):
raise unittest.SkipTest("4.0 mongos does not support failCommand")
def maybe_skip_scenario(self, test):
super(TestSpec, self).maybe_skip_scenario(test)