From 9ad421a58aee2dac47c728566b2b49b683e8885e Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sat, 19 Jul 2014 12:03:01 -0400 Subject: [PATCH] PYTHON-736 Fix exhaust cursor error-handling. Connection-pool semaphore leak on server error when creating or iterating an exhaust cursor. --- pymongo/cursor.py | 20 ++++- pymongo/mongo_client.py | 34 +++++--- pymongo/mongo_replica_set_client.py | 7 +- test/test_client.py | 6 ++ test/test_replica_set_client.py | 9 +- test/utils.py | 128 +++++++++++++++++++++++++++- 6 files changed, 186 insertions(+), 18 deletions(-) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 908b02c3f..b5059afba 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -56,6 +56,15 @@ class _SocketManager: self.pool.maybe_return_socket(self.sock) self.sock, self.pool = None, None + def error(self): + """Clean up after an error on the managed socket. + """ + if self.sock: + self.sock.close() + + # Return the closed socket to avoid a semaphore leak in the pool. + self.close() + # TODO might be cool to be able to do find().include("foo") or # find().exclude(["bar", "baz"]) or find().slice("a", 1, 2) as an @@ -914,8 +923,14 @@ class Cursor(object): # due to a socket timeout. self.__killed = True raise - else: # exhaust cursor - no getMore message - response = client._exhaust_next(self.__exhaust_mgr.sock) + else: + # Exhaust cursor - no getMore message. + try: + response = client._exhaust_next(self.__exhaust_mgr.sock) + except: + self.__killed = True + self.__exhaust_mgr.error() + raise try: response = helpers._unpack_response(response, self.__id, @@ -938,6 +953,7 @@ class Cursor(object): self.__killed = True client.disconnect() raise + self.__id = response["cursor_id"] # starting from doesn't get set on getmore's for tailable cursors diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 3075ad887..1c4f2e451 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1192,27 +1192,35 @@ class MongoClient(common.BaseObject): sock_info = self.__socket(member) exhaust = kwargs.get('exhaust') try: - try: - if not exhaust and "network_timeout" in kwargs: - sock_info.sock.settimeout(kwargs["network_timeout"]) - response = self.__send_and_receive(message, sock_info) + if not exhaust and "network_timeout" in kwargs: + sock_info.sock.settimeout(kwargs["network_timeout"]) - if not exhaust: - if "network_timeout" in kwargs: - sock_info.sock.settimeout(self.__net_timeout) + response = self.__send_and_receive(message, sock_info) - return (None, (response, sock_info, member.pool)) - except (ConnectionFailure, socket.error), e: - self.disconnect() - raise AutoReconnect(str(e)) - finally: if not exhaust: + if "network_timeout" in kwargs: + sock_info.sock.settimeout(self.__net_timeout) + member.pool.maybe_return_socket(sock_info) + return (None, (response, sock_info, member.pool)) + except (ConnectionFailure, socket.error), e: + self.disconnect() + member.pool.maybe_return_socket(sock_info) + raise AutoReconnect(str(e)) + except: + member.pool.maybe_return_socket(sock_info) + raise + def _exhaust_next(self, sock_info): """Used with exhaust cursors to get the next batch off the socket. + + Can raise AutoReconnect. """ - return self.__receive_message_on_socket(1, None, sock_info) + try: + return self.__receive_message_on_socket(1, None, sock_info) + except socket.error, e: + raise AutoReconnect(str(e)) def start_request(self): """Ensure the current thread or greenlet always uses the same socket diff --git a/pymongo/mongo_replica_set_client.py b/pymongo/mongo_replica_set_client.py index bd7deaa12..40ec23e0a 100644 --- a/pymongo/mongo_replica_set_client.py +++ b/pymongo/mongo_replica_set_client.py @@ -1711,8 +1711,13 @@ class MongoReplicaSetClient(common.BaseObject): def _exhaust_next(self, sock_info): """Used with exhaust cursors to get the next batch off the socket. + + Can raise AutoReconnect. """ - return self.__recv_msg(1, None, sock_info) + try: + return self.__recv_msg(1, None, sock_info) + except socket.error, e: + raise AutoReconnect(str(e)) def start_request(self): """Ensure the current thread or greenlet always uses the same socket diff --git a/test/test_client.py b/test/test_client.py index 6563ed700..01cb00d1c 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -52,6 +52,7 @@ from test.utils import (assertRaisesExactly, server_started_with_auth, TestRequestMixin, _TestLazyConnectMixin, + _TestExhaustCursorMixin, lazy_client_trial, NTHREADS, get_pool, @@ -1147,5 +1148,10 @@ class TestMongoClientFailover(unittest.TestCase): c.db.collection.find_one() +class TestExhaustCursor(_TestExhaustCursorMixin, unittest.TestCase): + def _get_client(self, **kwargs): + return get_client(**kwargs) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index 712f432a6..7f58a3465 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -52,7 +52,7 @@ from test.utils import ( delay, assertReadFrom, assertReadFromAll, read_from_which_host, remove_all_users, assertRaisesExactly, TestRequestMixin, one, server_started_with_auth, pools_from_rs_client, get_pool, - _TestLazyConnectMixin) + _TestLazyConnectMixin, _TestExhaustCursorMixin) class TestReplicaSetClientAgainstStandalone(unittest.TestCase): @@ -1274,5 +1274,12 @@ class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase): self.assertEqual(c.max_write_batch_size, 2) +class TestReplicaSetClientExhaustCursor( + _TestExhaustCursorMixin, + TestReplicaSetClientBase): + + # Base class implements _get_client already. + pass + if __name__ == "__main__": unittest.main() diff --git a/test/utils.py b/test/utils.py index 2ccb8c6a8..c303e5b62 100644 --- a/test/utils.py +++ b/test/utils.py @@ -15,14 +15,16 @@ """Utilities for testing pymongo """ +import gc import os import struct import sys import threading +import time from nose.plugins.skip import SkipTest from pymongo import MongoClient, MongoReplicaSetClient -from pymongo.errors import AutoReconnect +from pymongo.errors import AutoReconnect, ConnectionFailure, OperationFailure from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo from test import host, port, version @@ -586,6 +588,130 @@ class _TestLazyConnectMixin(object): c.max_message_size) +def collect_until(fn): + start = time.time() + while not fn(): + if (time.time() - start) > 5: + raise AssertionError("timed out") + + gc.collect() + + +class _TestExhaustCursorMixin(object): + """Test that clients properly handle errors from exhaust cursors. + + Inherit from this class and from unittest.TestCase, and override + _get_client(self, **kwargs). + """ + def test_exhaust_query_server_error(self): + # When doing an exhaust query, the socket stays checked out on success + # but must be checked in on error to avoid semaphore leaks. + client = self._get_client(max_pool_size=1) + if is_mongos(client): + raise SkipTest("Can't use exhaust cursors with mongos") + + collection = client.pymongo_test.test + pool = get_pool(client) + + sock_info = one(pool.sockets) + cursor = collection.find({'$bad_query_operator': 1}, exhaust=True) + self.assertRaises(OperationFailure, cursor.next) + del cursor + collect_until(lambda: sock_info in pool.sockets) + self.assertFalse(sock_info.closed) + + # The semaphore was decremented despite the error. + self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) + + def test_exhaust_getmore_server_error(self): + # When doing a getmore on an exhaust cursor, the socket stays checked + # out on success but must be checked in on error to avoid semaphore + # leaks. + client = self._get_client(max_pool_size=1) + if is_mongos(client): + raise SkipTest("Can't use exhaust cursors with mongos") + + # A separate client that doesn't affect the test client's pool. + client2 = self._get_client() + + collection = client.pymongo_test.test + collection.remove() + + # Enough data to ensure it streams down for a few milliseconds. + long_str = 'a' * (256 * 1024) + collection.insert([{'a': long_str} for _ in range(1000)]) + + pool = get_pool(client) + pool._check_interval_seconds = None # Never check. + sock_info = one(pool.sockets) + + cursor = collection.find(exhaust=True) + + # Initial query succeeds. + cursor.next() + + # Cause a server error on getmore. + client2.pymongo_test.test.drop() + self.assertRaises(OperationFailure, list, cursor) + del cursor + collect_until(lambda: sock_info.closed) + self.assertFalse(sock_info in pool.sockets) + + # The semaphore was decremented despite the error. + self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) + + def test_exhaust_query_network_error(self): + # When doing an exhaust query, the socket stays checked out on success + # but must be checked in on error to avoid semaphore leaks. + client = self._get_client(max_pool_size=1) + if is_mongos(client): + raise SkipTest("Can't use exhaust cursors with mongos") + + collection = client.pymongo_test.test + pool = get_pool(client) + pool._check_interval_seconds = None # Never check. + + # Cause a network error. + sock_info = one(pool.sockets) + sock_info.sock.close() + cursor = collection.find(exhaust=True) + self.assertRaises(ConnectionFailure, cursor.next) + self.assertTrue(sock_info.closed) + + # The semaphore was decremented despite the error. + self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) + + def test_exhaust_getmore_network_error(self): + # When doing a getmore on an exhaust cursor, the socket stays checked + # out on success but must be checked in on error to avoid semaphore + # leaks. + client = self._get_client(max_pool_size=1) + if is_mongos(client): + raise SkipTest("Can't use exhaust cursors with mongos") + + collection = client.pymongo_test.test + collection.remove() + collection.insert([{} for _ in range(200)]) # More than one batch. + pool = get_pool(client) + pool._check_interval_seconds = None # Never check. + + cursor = collection.find(exhaust=True) + + # Initial query succeeds. + cursor.next() + + # Cause a network error. + sock_info = cursor._Cursor__exhaust_mgr.sock + sock_info.sock.close() + + # A getmore fails. + self.assertRaises(ConnectionFailure, list, cursor) + self.assertTrue(sock_info.closed) + + # The semaphore was decremented despite the error. + self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) + + # Backport of WarningMessage from python 2.6, with fixed syntax for python 2.4. class WarningMessage(object):