From 4c77d7c8552a4ae21edb6503f879340691ffed8c Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 21 Jun 2021 18:29:36 -0700 Subject: [PATCH] PYTHON-2677 Better wait queue timeout errors for load balanced clusters (#639) Remove checkout argument in favor of SocketInfo.pin_txn/pin_cursor() --- pymongo/change_stream.py | 2 +- pymongo/client_session.py | 2 +- pymongo/collection.py | 6 +- pymongo/command_cursor.py | 1 + pymongo/database.py | 6 +- pymongo/mongo_client.py | 24 ++- pymongo/pool.py | 59 +++++-- pymongo/server.py | 10 +- .../unified/wait-queue-timeouts.json | 153 ++++++++++++++++++ test/pymongo_mocks.py | 5 +- test/test_cmap.py | 4 +- test/test_pooling.py | 16 +- test/test_read_preferences.py | 4 +- test/utils.py | 2 +- 14 files changed, 234 insertions(+), 60 deletions(-) create mode 100644 test/load_balancer/unified/wait-queue-timeouts.json diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index fcb9d9f52..f742e126c 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -181,7 +181,7 @@ class ChangeStream(object): return self._client._retryable_read( cmd.get_cursor, self._target._read_preference_for(session), - session, pin=self._client._should_pin_cursor(session)) + session) def _create_cursor(self): with self._client._tmp_session(self._session, close=False) as s: diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 9db818482..8e5b2a597 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -314,7 +314,7 @@ class _Transaction(object): self.sharded = True self.pinned_address = server.description.address if server.description.server_type == SERVER_TYPE.LoadBalancer: - sock_info.pinned = True + sock_info.pin_txn() self.sock_mgr = _SocketManager(sock_info, False) def unpin(self): diff --git a/pymongo/collection.py b/pymongo/collection.py index cf8f679be..7943e08aa 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -2093,8 +2093,7 @@ class Collection(common.BaseObject): return cmd_cursor return self.__database.client._retryable_read( - _cmd, read_pref, session, - pin=self.__database.client._should_pin_cursor(session)) + _cmd, read_pref, session) def index_information(self, session=None): """Get information on this collection's indexes. @@ -2175,8 +2174,7 @@ class Collection(common.BaseObject): user_fields={'cursor': {'firstBatch': 1}}) return self.__database.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(session), session, - retryable=not cmd._performs_write, - pin=self.database.client._should_pin_cursor(session)) + retryable=not cmd._performs_write) def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index fa219d038..aabec3999 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -138,6 +138,7 @@ class CommandCursor(object): if not client._should_pin_cursor(self.__session): return if not self.__sock_mgr: + sock_info.pin_cursor() sock_mgr = _SocketManager(sock_info, False) # Ensure the connection gets returned when the entire result is # returned in the first batch. diff --git a/pymongo/database.py b/pymongo/database.py index 8928287ee..0a814fb74 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -381,8 +381,7 @@ class Database(common.BaseObject): user_fields={'cursor': {'firstBatch': 1}}) return self.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(s), s, - retryable=not cmd._performs_write, - pin=self.client._should_pin_cursor(s)) + retryable=not cmd._performs_write) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, @@ -695,8 +694,7 @@ class Database(common.BaseObject): **kwargs) return self.__client._retryable_read( - _cmd, read_pref, session, - pin=self.client._should_pin_cursor(session)) + _cmd, read_pref, session) def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 4eb54e1c5..ad82306b8 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1161,7 +1161,7 @@ class MongoClient(common.BaseObject): return self._topology @contextlib.contextmanager - def _get_socket(self, server, session, pin=False): + def _get_socket(self, server, session): in_txn = session and session.in_transaction with _MongoClientErrorHandler(self, server, session) as err_handler: # Reuse the pinned connection, if it exists. @@ -1169,8 +1169,7 @@ class MongoClient(common.BaseObject): yield session._pinned_connection return with server.get_socket( - self.__all_credentials, checkout=pin, - handler=err_handler) as sock_info: + self.__all_credentials, handler=err_handler) as sock_info: # Pin this session to the selected server or connection. if (in_txn and server.description.server_type in ( SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)): @@ -1221,7 +1220,7 @@ class MongoClient(common.BaseObject): return self._get_socket(server, session) @contextlib.contextmanager - def _slaveok_for_server(self, read_preference, server, session, pin=False): + def _slaveok_for_server(self, read_preference, server, session): assert read_preference is not None, "read_preference must not be None" # Get a socket for a server matching the read preference, and yield # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to @@ -1232,7 +1231,7 @@ class MongoClient(common.BaseObject): topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._get_socket(server, session, pin=pin) as sock_info: + with self._get_socket(server, session) as sock_info: slave_ok = (single and not sock_info.is_mongos) or ( read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok @@ -1259,7 +1258,7 @@ class MongoClient(common.BaseObject): return (self.__options.load_balanced and not (session and session.in_transaction)) - def _run_operation(self, operation, unpack_res, pin=False, address=None): + def _run_operation(self, operation, unpack_res, address=None): """Run a _Query/_GetMore operation and return a Response. :Parameters: @@ -1268,7 +1267,6 @@ class MongoClient(common.BaseObject): - `address` (optional): Optional address when sending a message to a specific server, used for getMore. """ - pin = self._should_pin_cursor(operation.session) or operation.exhaust if operation.sock_mgr: server = self._select_server( operation.read_preference, operation.session, address=address) @@ -1279,17 +1277,16 @@ class MongoClient(common.BaseObject): err_handler.contribute_socket(operation.sock_mgr.sock) return server.run_operation( operation.sock_mgr.sock, operation, True, - self._event_listeners, pin, unpack_res) + self._event_listeners, unpack_res) def _cmd(session, server, sock_info, slave_ok): return server.run_operation( - sock_info, operation, slave_ok, self._event_listeners, pin, + sock_info, operation, slave_ok, self._event_listeners, unpack_res) return self._retryable_read( _cmd, operation.read_preference, operation.session, - address=address, retryable=isinstance(operation, message._Query), - pin=pin) + address=address, retryable=isinstance(operation, message._Query)) def _retry_with_session(self, retryable, func, session, bulk): """Execute an operation with at most one consecutive retries @@ -1361,7 +1358,7 @@ class MongoClient(common.BaseObject): last_error = exc def _retryable_read(self, func, read_pref, session, address=None, - retryable=True, pin=False): + retryable=True): """Execute an operation with at most one consecutive retries Returns func()'s return value on success. On error retries the same @@ -1381,8 +1378,7 @@ class MongoClient(common.BaseObject): read_pref, session, address=address) if not server.description.retryable_reads_supported: retryable = False - with self._slaveok_for_server( - read_pref, server, session, pin=pin) as ( + with self._slaveok_for_server(read_pref, server, session) as ( sock_info, slave_ok): if retrying and not retryable: # A retry is not possible because this server does diff --git a/pymongo/pool.py b/pymongo/pool.py index 15e5b4873..335f8d07f 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -548,10 +548,18 @@ class SocketInfo(object): self.service_id = None # When executing a transaction in load balancing mode, this flag is # set to true to indicate that the session now owns the connection. - self.pinned = False + self.pinned_txn = False + self.pinned_cursor = False + + def pin_txn(self): + self.pinned_txn = True + assert not self.pinned_cursor + + def pin_cursor(self): + self.pinned_cursor = True + assert not self.pinned_txn def unpin(self): - self.pinned = False pool = self.pool_ref() if pool: pool.return_socket(self) @@ -1178,6 +1186,8 @@ class Pool: # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). self.__pinned_sockets = set() + self.ncursors = 0 + self.ntxns = 0 def ready(self): old_state, self.state = self.state, PoolState.READY @@ -1354,7 +1364,7 @@ class Pool: return sock_info @contextlib.contextmanager - def get_socket(self, all_credentials, checkout=False, handler=None): + def get_socket(self, all_credentials, handler=None): """Get a socket from the pool. Use with a "with" statement. Returns a :class:`SocketInfo` object wrapping a connected @@ -1362,7 +1372,7 @@ class Pool: This method should always be used in a with-statement:: - with pool.get_socket(credentials, checkout) as socket_info: + with pool.get_socket(credentials) as socket_info: socket_info.send_message(msg) data = socket_info.receive_message(op_code, request_id) @@ -1374,7 +1384,6 @@ class Pool: :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. - - `checkout` (optional): keep socket checked out. - `handler` (optional): A _MongoClientErrorHandler. """ listeners = self.opts.event_listeners @@ -1382,9 +1391,6 @@ class Pool: listeners.publish_connection_check_out_started(self.address) sock_info = self._get_socket(all_credentials) - if checkout: - self.__pinned_sockets.add(sock_info) - if self.enabled_for_cmap: listeners.publish_connection_checked_out( self.address, sock_info.id) @@ -1395,7 +1401,7 @@ class Pool: # Note that when pinned is True, the session owns the # connection and it is responsible for checking the connection # back into the pool. - pinned = sock_info.pinned + pinned = sock_info.pinned_txn or sock_info.pinned_cursor if handler: # Perform SDAM error handling rules while the connection is # still checked out. @@ -1404,11 +1410,16 @@ class Pool: if not pinned: self.return_socket(sock_info) raise - else: - if sock_info.pinned: + if sock_info.pinned_txn: + with self.lock: self.__pinned_sockets.add(sock_info) - elif not checkout: - self.return_socket(sock_info) + self.ntxns += 1 + elif sock_info.pinned_cursor: + with self.lock: + self.__pinned_sockets.add(sock_info) + self.ncursors += 1 + else: + self.return_socket(sock_info) def _raise_if_not_ready(self, emit_event): if self.state != PoolState.READY: @@ -1528,8 +1539,11 @@ class Pool: :Parameters: - `sock_info`: The socket to check into the pool. """ + txn = sock_info.pinned_txn + cursor = sock_info.pinned_cursor + sock_info.pinned_txn = False + sock_info.pinned_cursor = False self.__pinned_sockets.discard(sock_info) - sock_info.pinned = False listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_checked_in(self.address, sock_info.id) @@ -1559,6 +1573,10 @@ class Pool: self._max_connecting_cond.notify() with self.size_cond: + if txn: + self.ntxns -= 1 + elif cursor: + self.ncursors -= 1 self.requests -= 1 self.active_sockets -= 1 self.operation_count -= 1 @@ -1603,9 +1621,18 @@ class Pool: if self.enabled_for_cmap: listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.TIMEOUT) + if self.opts.load_balanced: + other_ops = self.active_sockets - self.ncursors - self.ntxns + raise ConnectionFailure( + 'Timeout waiting for connection from the connection pool. ' + 'maxPoolSize: %s, connections in use by cursors: %s, ' + 'connections in use by transactions: %s, connections in use ' + 'by other operations: %s, wait_queue_timeout: %s' % ( + self.opts.max_pool_size, self.ncursors, self.ntxns, + other_ops, self.opts.wait_queue_timeout)) raise ConnectionFailure( - 'Timed out while checking out a connection from connection pool ' - 'with max_size %r and wait_queue_timeout %r' % ( + 'Timed out while checking out a connection from connection pool. ' + 'maxPoolSize: %s, wait_queue_timeout: %s' % ( self.opts.max_pool_size, self.opts.wait_queue_timeout)) def __del__(self): diff --git a/pymongo/server.py b/pymongo/server.py index 672a3b1c1..19529237d 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -69,7 +69,7 @@ class Server(object): self._monitor.request_check() def run_operation(self, sock_info, operation, set_slave_okay, listeners, - pin, unpack_res): + unpack_res): """Run a _Query or _GetMore operation and return a Response object. This method is used only to run _Query/_GetMore operations from @@ -81,7 +81,6 @@ class Server(object): - `set_slave_okay`: Pass to operation.get_message. - `all_credentials`: dict, maps auth source to MongoCredential. - `listeners`: Instance of _EventListeners or None. - - `pin`: If True, then this is a pinned cursor operation. - `unpack_res`: A callable that decodes the wire protocol response. """ duration = None @@ -170,7 +169,8 @@ class Server(object): docs = _decode_all_selective( decrypted, operation.codec_options, user_fields) - if pin: + if client._should_pin_cursor(operation.session) or operation.exhaust: + sock_info.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the # more_to_come flag is set. @@ -200,8 +200,8 @@ class Server(object): return response - def get_socket(self, all_credentials, checkout=False, handler=None): - return self.pool.get_socket(all_credentials, checkout, handler) + def get_socket(self, all_credentials, handler=None): + return self.pool.get_socket(all_credentials, handler) @property def description(self): diff --git a/test/load_balancer/unified/wait-queue-timeouts.json b/test/load_balancer/unified/wait-queue-timeouts.json new file mode 100644 index 000000000..61575d670 --- /dev/null +++ b/test/load_balancer/unified/wait-queue-timeouts.json @@ -0,0 +1,153 @@ +{ + "description": "wait queue timeout errors include details about checked out connections", + "schemaVersion": "1.3", + "runOnRequirements": [ + { + "topologies": [ + "load-balanced" + ] + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "useMultipleMongoses": true, + "uriOptions": { + "maxPoolSize": 1, + "waitQueueTimeoutMS": 5 + }, + "observeEvents": [ + "connectionCheckedOutEvent", + "connectionCheckOutFailedEvent" + ] + } + }, + { + "session": { + "id": "session0", + "client": "client0" + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "database0Name" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "coll0" + } + } + ], + "initialData": [ + { + "collectionName": "coll0", + "databaseName": "database0Name", + "documents": [ + { + "_id": 1 + }, + { + "_id": 2 + }, + { + "_id": 3 + } + ] + } + ], + "tests": [ + { + "description": "wait queue timeout errors include cursor statistics", + "operations": [ + { + "name": "createFindCursor", + "object": "collection0", + "arguments": { + "filter": {}, + "batchSize": 2 + }, + "saveResultAsEntity": "cursor0" + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "x": 1 + } + }, + "expectError": { + "isClientError": true, + "errorContains": "maxPoolSize: 1, connections in use by cursors: 1, connections in use by transactions: 0, connections in use by other operations: 0" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "eventType": "cmap", + "events": [ + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckOutFailedEvent": {} + } + ] + } + ] + }, + { + "description": "wait queue timeout errors include transaction statistics", + "operations": [ + { + "name": "startTransaction", + "object": "session0" + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "x": 1 + }, + "session": "session0" + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "x": 1 + } + }, + "expectError": { + "isClientError": true, + "errorContains": "maxPoolSize: 1, connections in use by cursors: 0, connections in use by transactions: 1, connections in use by other operations: 0" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "eventType": "cmap", + "events": [ + { + "connectionCheckedOutEvent": {} + }, + { + "connectionCheckOutFailedEvent": {} + } + ] + } + ] + } + ] +} diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 540dd68e3..af59e0cbe 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -40,7 +40,7 @@ class MockPool(Pool): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def get_socket(self, all_credentials, checkout=False, handler=None): + def get_socket(self, all_credentials, handler=None): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: @@ -51,8 +51,7 @@ class MockPool(Pool): + client.mock_members + client.mock_mongoses), "bad host: %s" % host_and_port - with Pool.get_socket( - self, all_credentials, checkout, handler) as sock_info: + with Pool.get_socket(self, all_credentials, handler) as sock_info: sock_info.mock_host = self.mock_host sock_info.mock_port = self.mock_port yield sock_info diff --git a/test/test_cmap.py b/test/test_cmap.py index 053f27ba7..cfa11e9bc 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -119,7 +119,9 @@ class TestCMAP(IntegrationTest): def check_out(self, op): """Run the 'checkOut' operation.""" label = op['label'] - with self.pool.get_socket({}, checkout=True) as sock_info: + with self.pool.get_socket({}) as sock_info: + # Call 'pin_cursor' so we can hold the socket. + sock_info.pin_cursor() if label: self.labels[label] = sock_info else: diff --git a/test/test_pooling.py b/test/test_pooling.py index 57338a965..11fcb1ce7 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -125,8 +125,9 @@ class SocketGetter(MongoThread): def run_mongo_thread(self): self.state = 'get_socket' - # Pass 'checkout' so we can hold the socket. - with self.pool.get_socket({}, checkout=True) as sock: + # Call 'pin_cursor' so we can hold the socket. + with self.pool.get_socket({}) as sock: + sock.pin_cursor() self.sock = sock self.state = 'sock' @@ -326,11 +327,9 @@ class TestPooling(_TestPoolingBase): # Back to normal, semaphore was correctly released. cx_pool.address = address - with cx_pool.get_socket({}, checkout=True) as sock_info: + with cx_pool.get_socket({}): pass - sock_info.close_socket(None) - def test_wait_queue_timeout(self): wait_queue_timeout = 2 # Seconds pool = self.create_pool( @@ -400,8 +399,9 @@ class TestPooling(_TestPoolingBase): socks = [] for _ in range(2): - # Pass 'checkout' so we can hold the socket. - with pool.get_socket({}, checkout=True) as sock: + # Call 'pin_cursor' so we can hold the socket. + with pool.get_socket({}) as sock: + sock.pin_cursor() socks.append(sock) threads = [] @@ -539,7 +539,7 @@ class TestPoolMaxSize(_TestPoolingBase): # socket from pool" instead of AutoReconnect. for i in range(2): with self.assertRaises(AutoReconnect) as context: - with test_pool.get_socket({}, checkout=True): + with test_pool.get_socket({}): pass # Testing for AutoReconnect instead of ConnectionFailure, above, diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 773aed3f5..d02c1cacc 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -322,9 +322,9 @@ class ReadPrefTester(MongoClient): yield sock_info, slave_ok @contextlib.contextmanager - def _slaveok_for_server(self, read_preference, server, session, pin=False): + def _slaveok_for_server(self, read_preference, server, session): context = super(ReadPrefTester, self)._slaveok_for_server( - read_preference, server, session, pin=pin) + read_preference, server, session) with context as (sock_info, slave_ok): self.record_a_read(sock_info.address) yield sock_info, slave_ok diff --git a/test/utils.py b/test/utils.py index fb29bb37e..02a0a5883 100644 --- a/test/utils.py +++ b/test/utils.py @@ -267,7 +267,7 @@ class MockPool(object): def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) - def get_socket(self, all_credentials, checkout=False, handler=None): + def get_socket(self, all_credentials, handler=None): return MockSocketInfo() def return_socket(self, *args, **kwargs):