From c8f32a7a376e027cbe30dcb8df9bf05778e61248 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 24 May 2021 17:49:44 -0700 Subject: [PATCH] PYTHON-2673 Connection pinning behavior for load balanced clusters (#630) Tweak spec test because pymongo unpins cursors eagerly after errors. Tweak spec test for PoolClearedEvent ordering when MongoDB handshake fails (see DRIVERS-1785). Only skip killCursors for some error codes. Rely on SDAM error handling to close the connection after a state change error. Add service_id to various events. Retain reference to pinned sockets to prevent premptive closure by CPython's cyclic GC. --- pymongo/aggregation.py | 4 +- pymongo/change_stream.py | 2 +- pymongo/client_session.py | 44 ++++- pymongo/collection.py | 28 ++- pymongo/command_cursor.py | 77 ++++---- pymongo/cursor.py | 113 ++++++----- pymongo/database.py | 12 +- pymongo/message.py | 59 +++--- pymongo/mongo_client.py | 186 ++++++++++-------- pymongo/monitoring.py | 21 +- pymongo/network.py | 9 +- pymongo/pool.py | 60 +++++- pymongo/response.py | 22 +-- pymongo/server.py | 41 ++-- test/load_balancer/test_load_balancer.py | 22 ++- test/load_balancer/unified/cursors.json | 8 +- .../unified/sdam-error-handling.json | 6 +- test/pymongo_mocks.py | 5 +- test/test_client.py | 6 +- test/test_monitoring.py | 6 +- test/test_read_preferences.py | 5 +- test/unified_format.py | 23 ++- test/utils.py | 2 +- 23 files changed, 465 insertions(+), 296 deletions(-) diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index 438a3421b..ae1b9d9eb 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -161,11 +161,13 @@ class _AggregationCommand(object): } # Create and return cursor instance. - return self._cursor_class( + cmd_cursor = self._cursor_class( self._cursor_collection(cursor), cursor, sock_info.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, explicit_session=self._explicit_session) + cmd_cursor._maybe_pin_connection(sock_info) + return cmd_cursor class _CollectionAggregationCommand(_AggregationCommand): diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index f742e126c..fcb9d9f52 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -181,7 +181,7 @@ class ChangeStream(object): return self._client._retryable_read( cmd.get_cursor, self._target._read_preference_for(session), - session) + session, pin=self._client._should_pin_cursor(session)) def _create_cursor(self): with self._client._tmp_session(self._session, close=False) as s: diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 5b6ff7524..9db818482 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -108,6 +108,7 @@ from bson.int64 import Int64 from bson.son import SON from bson.timestamp import Timestamp +from pymongo.cursor import _SocketManager from pymongo.errors import (ConfigurationError, ConnectionFailure, InvalidOperation, @@ -117,6 +118,7 @@ from pymongo.errors import (ConfigurationError, from pymongo.helpers import _RETRYABLE_ERROR_CODES from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern @@ -292,6 +294,7 @@ class _Transaction(object): self.state = _TxnState.NONE self.sharded = False self.pinned_address = None + self.sock_mgr = None self.recovery_token = None self.attempt = 0 @@ -301,10 +304,29 @@ class _Transaction(object): def starting(self): return self.state == _TxnState.STARTING + @property + def pinned_conn(self): + if self.active() and self.sock_mgr: + return self.sock_mgr.sock + return None + + def pin(self, server, sock_info): + self.sharded = True + self.pinned_address = server.description.address + if server.description.server_type == SERVER_TYPE.LoadBalancer: + sock_info.pinned = True + self.sock_mgr = _SocketManager(sock_info, False) + + def unpin(self): + self.pinned_address = None + if self.sock_mgr: + self.sock_mgr.close() + self.sock_mgr = None + def reset(self): + self.unpin() self.state = _TxnState.NONE self.sharded = False - self.pinned_address = None self.recovery_token = None self.attempt = 0 @@ -374,6 +396,9 @@ class ClientSession(object): try: if self.in_transaction: self.abort_transaction() + # It's possible we're still pinned here when the transaction + # is in the committed state when the session is discarded. + self._unpin() finally: self._client._return_server_session(self._server_session, lock) self._server_session = None @@ -779,14 +804,18 @@ class ClientSession(object): return self._transaction.pinned_address return None - def _pin(self, server): - """Pin this session to the given Server.""" - self._transaction.sharded = True - self._transaction.pinned_address = server.description.address + @property + def _pinned_connection(self): + """The connection this transaction was started on.""" + return self._transaction.pinned_conn + + def _pin(self, server, sock_info): + """Pin this session to the given Server or to the given connection.""" + self._transaction.pin(server, sock_info) def _unpin(self): """Unpin this session from any pinned Server.""" - self._transaction.pinned_address = None + self._transaction.unpin() def _txn_read_preference(self): """Return read preference of this transaction or None.""" @@ -800,9 +829,6 @@ class ClientSession(object): self._server_session.last_use = time.monotonic() command['lsid'] = self._server_session.session_id - if not self.in_transaction: - self._transaction.reset() - if is_retryable: command['txnNumber'] = self._server_session.transaction_id return diff --git a/pymongo/collection.py b/pymongo/collection.py index 203ea32ae..cf8f679be 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -520,7 +520,8 @@ class Collection(common.BaseObject): if publish: duration = datetime.datetime.now() - start listeners.publish_command_start( - cmd, self.__database.name, rqst_id, sock_info.address, op_id) + cmd, self.__database.name, rqst_id, sock_info.address, op_id, + sock_info.service_id) start = datetime.datetime.now() try: result = sock_info.legacy_write(rqst_id, msg, max_size, False) @@ -534,12 +535,14 @@ class Collection(common.BaseObject): reply = message._convert_write_result( name, cmd, details) listeners.publish_command_success( - dur, reply, name, rqst_id, sock_info.address, op_id) + dur, reply, name, rqst_id, sock_info.address, + op_id, sock_info.service_id) raise else: details = message._convert_exception(exc) listeners.publish_command_failure( - dur, details, name, rqst_id, sock_info.address, op_id) + dur, details, name, rqst_id, sock_info.address, op_id, + sock_info.service_id) raise if publish: if result is not None: @@ -549,7 +552,8 @@ class Collection(common.BaseObject): reply = {'ok': 1} duration = (datetime.datetime.now() - start) + duration listeners.publish_command_success( - duration, reply, name, rqst_id, sock_info.address, op_id) + duration, reply, name, rqst_id, sock_info.address, op_id, + sock_info.service_id) return result def _insert_one( @@ -2072,9 +2076,9 @@ class Collection(common.BaseObject): if exc.code != 26: raise cursor = {'id': 0, 'firstBatch': []} - return CommandCursor(coll, cursor, sock_info.address, - session=s, - explicit_session=session is not None) + cmd_cursor = CommandCursor( + coll, cursor, sock_info.address, session=s, + explicit_session=session is not None) else: res = message._first_batch( sock_info, self.__database.name, "system.indexes", @@ -2084,10 +2088,13 @@ class Collection(common.BaseObject): cursor = res["cursor"] # Note that a collection can only have 64 indexes, so there # will never be a getMore call. - return CommandCursor(coll, cursor, sock_info.address) + cmd_cursor = CommandCursor(coll, cursor, sock_info.address) + cmd_cursor._maybe_pin_connection(sock_info) + return cmd_cursor return self.__database.client._retryable_read( - _cmd, read_pref, session) + _cmd, read_pref, session, + pin=self.__database.client._should_pin_cursor(session)) def index_information(self, session=None): """Get information on this collection's indexes. @@ -2168,7 +2175,8 @@ class Collection(common.BaseObject): user_fields={'cursor': {'firstBatch': 1}}) return self.__database.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(session), session, - retryable=not cmd._performs_write) + retryable=not cmd._performs_write, + pin=self.database.client._should_pin_cursor(session)) def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index f90a9267a..fa219d038 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -17,13 +17,14 @@ from collections import deque from bson import _convert_raw_document_lists_to_streams +from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS from pymongo.errors import (ConnectionFailure, InvalidOperation, - NotMasterError, OperationFailure) from pymongo.message import (_CursorAddress, _GetMore, _RawBatchGetMore) +from pymongo.response import PinnedResponse class CommandCursor(object): @@ -37,6 +38,7 @@ class CommandCursor(object): The parameter 'retrieved' is unused. """ + self.__sock_mgr = None self.__collection = collection self.__id = cursor_info['id'] self.__data = deque(cursor_info['firstBatch']) @@ -75,11 +77,15 @@ class CommandCursor(object): self.__address, self.__collection.full_name) if synchronous: self.__collection.database.client._close_cursor_now( - self.__id, address, session=self.__session) + self.__id, address, session=self.__session, + sock_mgr=self.__sock_mgr) else: # The cursor will be closed later in a different session. self.__collection.database.client._close_cursor( self.__id, address) + if self.__sock_mgr: + self.__sock_mgr.close() + self.__sock_mgr = None self.__end_session(synchronous) def __end_session(self, synchronous): @@ -127,52 +133,58 @@ class CommandCursor(object): changeStream aggregate or getMore.""" return self.__postbatchresumetoken + def _maybe_pin_connection(self, sock_info): + client = self.__collection.database.client + if not client._should_pin_cursor(self.__session): + return + if not self.__sock_mgr: + sock_mgr = _SocketManager(sock_info, False) + # Ensure the connection gets returned when the entire result is + # returned in the first batch. + if self.__id == 0: + sock_mgr.close() + else: + self.__sock_mgr = sock_mgr + def __send_message(self, operation): """Send a getmore message and handle the response. """ - def kill(): - self.__killed = True - self.__end_session(True) - client = self.__collection.database.client try: - response = client._run_operation_with_response( + response = client._run_operation( operation, self._unpack_response, address=self.__address) - except OperationFailure: - kill() - raise - except NotMasterError: - # Don't send kill cursors to another server after a "not master" - # error. It's completely pointless. - kill() + except OperationFailure as exc: + if exc.code in _CURSOR_CLOSED_ERRORS: + # Don't send killCursors because the cursor is already closed. + self.__killed = True + # Return the session and pinned connection, if necessary. + self.close() raise except ConnectionFailure: - # Don't try to send kill cursors on another socket - # or to another server. It can cause a _pinValue - # assertion on some server releases if we get here - # due to a socket timeout. - kill() + # Don't send killCursors because the cursor is already closed. + self.__killed = True + # Return the session and pinned connection, if necessary. + self.close() raise except Exception: - # Close the cursor - self.__die() + self.close() raise - from_command = response.from_command - reply = response.data - docs = response.docs - - if from_command: - cursor = docs[0]['cursor'] + if isinstance(response, PinnedResponse): + if not self.__sock_mgr: + self.__sock_mgr = _SocketManager(response.socket_info, + response.more_to_come) + if response.from_command: + cursor = response.docs[0]['cursor'] documents = cursor['nextBatch'] self.__postbatchresumetoken = cursor.get('postBatchResumeToken') self.__id = cursor['id'] else: - documents = docs - self.__id = reply.cursor_id + documents = response.docs + self.__id = response.data.cursor_id if self.__id == 0: - kill() + self.__die(True) self.__data = deque(documents) def _unpack_response(self, response, cursor_id, codec_options, @@ -203,10 +215,9 @@ class CommandCursor(object): self.__session, self.__collection.database.client, self.__max_await_time_ms, - False)) + self.__sock_mgr, False)) else: # Cursor id is zero nothing else to return - self.__killed = True - self.__end_session(True) + self.__die(True) return len(self.__data) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index c6a0246bf..22b0de20d 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -15,6 +15,7 @@ """Cursor class to iterate over Mongo query results.""" import copy +import threading import warnings from collections import deque @@ -27,7 +28,6 @@ from pymongo.common import validate_boolean, validate_is_mapping from pymongo.collation import validate_collation_or_none from pymongo.errors import (ConnectionFailure, InvalidOperation, - NotMasterError, OperationFailure) from pymongo.message import (_CursorAddress, _GetMore, @@ -35,7 +35,37 @@ from pymongo.message import (_CursorAddress, _Query, _RawBatchQuery) from pymongo.monitoring import ConnectionClosedReason +from pymongo.response import PinnedResponse +# These errors mean that the server has already killed the cursor so there is +# no need to send killCursors. +_CURSOR_CLOSED_ERRORS = frozenset([ + 43, # CursorNotFound + 50, # MaxTimeMSExpired + 175, # QueryPlanKilled + 237, # CursorKilled + + # On a tailable cursor, the following errors mean the capped collection + # rolled over. + # MongoDB 2.6: + # {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0} + 28617, + # MongoDB 3.0: + # {'$err': 'getMore executor error: UnknownError no details available', + # 'code': 17406, 'ok': 0} + 17406, + # MongoDB 3.2 + 3.4: + # {'ok': 0.0, 'errmsg': 'GetMore command executor error: + # CappedPositionLost: CollectionScan died due to failure to restore + # tailable cursor position. Last seen record id: RecordId(3)', + # 'code': 96} + 96, + # MongoDB 3.6+: + # {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to + # restore tailable cursor position. Last seen record id: RecordId(3)"', + # 'code': 136, 'codeName': 'CappedPositionLost'} + 136, +]) _QUERY_OPTIONS = { "tailable_cursor": 2, @@ -78,14 +108,14 @@ class CursorType(object): # This has to be an old style class due to # http://bugs.jython.org/issue1057 -class _ExhaustManager: +class _SocketManager: """Used with exhaust cursors to ensure the socket is returned. """ - def __init__(self, sock, pool, more_to_come): + def __init__(self, sock, more_to_come): self.sock = sock - self.pool = pool self.more_to_come = more_to_come self.__closed = False + self.lock = threading.Lock() def __del__(self): self.close() @@ -98,8 +128,8 @@ class _ExhaustManager: """ if not self.__closed: self.__closed = True - self.pool.return_socket(self.sock) - self.sock, self.pool = None, None + self.sock.unpin() + self.sock = None class Cursor(object): @@ -128,7 +158,7 @@ class Cursor(object): # an error to avoid attribute errors during garbage collection. self.__id = None self.__exhaust = False - self.__exhaust_mgr = None + self.__sock_mgr = None self.__killed = False if session: @@ -319,24 +349,26 @@ class Cursor(object): self.__killed = True if self.__id and not already_killed: - if self.__exhaust and self.__exhaust_mgr: + if self.__exhaust and self.__sock_mgr: # If this is an exhaust cursor and we haven't completely # exhausted the result set we *must* close the socket # to stop the server from sending more data. - self.__exhaust_mgr.sock.close_socket( + self.__sock_mgr.sock.close_socket( ConnectionClosedReason.ERROR) else: address = _CursorAddress( self.__address, self.__collection.full_name) if synchronous: self.__collection.database.client._close_cursor_now( - self.__id, address, session=self.__session) + self.__id, address, session=self.__session, + sock_mgr=self.__sock_mgr) else: # The cursor will be closed later in a different session. self.__collection.database.client._close_cursor( self.__id, address) - if self.__exhaust and self.__exhaust_mgr: - self.__exhaust_mgr.close() + if self.__sock_mgr: + self.__sock_mgr.close() + self.__sock_mgr = None if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None @@ -1004,53 +1036,35 @@ class Cursor(object): "exhaust cursors do not support auto encryption") try: - response = client._run_operation_with_response( - operation, self._unpack_response, exhaust=self.__exhaust, - address=self.__address) - except OperationFailure: - self.__killed = True - - # Make sure exhaust socket is returned immediately, if necessary. - self.__die() - + response = client._run_operation( + operation, self._unpack_response, address=self.__address) + except OperationFailure as exc: + if exc.code in _CURSOR_CLOSED_ERRORS or self.__exhaust: + # Don't send killCursors because the cursor is already closed. + self.__killed = True + self.close() # If this is a tailable cursor the error is likely # due to capped collection roll over. Setting # self.__killed to True ensures Cursor.alive will be # False. No need to re-raise. - if self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]: + if (exc.code in _CURSOR_CLOSED_ERRORS and + self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]): return raise - except NotMasterError: - # Don't send kill cursors to another server after a "not master" - # error. It's completely pointless. - self.__killed = True - - # Make sure exhaust socket is returned immediately, if necessary. - self.__die() - - raise except ConnectionFailure: - # Don't try to send kill cursors on another socket - # or to another server. It can cause a _pinValue - # assertion on some server releases if we get here - # due to a socket timeout. + # Don't send killCursors because the cursor is already closed. self.__killed = True - self.__die() + self.close() raise except Exception: - # Close the cursor - self.__die() + self.close() raise self.__address = response.address - if self.__exhaust: - # 'response' is an ExhaustResponse. - if not self.__exhaust_mgr: - self.__exhaust_mgr = _ExhaustManager(response.socket_info, - response.pool, - response.more_to_come) - else: - self.__exhaust_mgr.update_exhaust(response.more_to_come) + if isinstance(response, PinnedResponse): + if not self.__sock_mgr: + self.__sock_mgr = _SocketManager(response.socket_info, + response.more_to_come) cmd_name = operation.name docs = response.docs @@ -1078,7 +1092,6 @@ class Cursor(object): self.__retrieved += response.data.number_returned if self.__id == 0: - self.__killed = True # Don't wait for garbage collection to call __del__, return the # socket and the session to the pool now. self.__die() @@ -1132,7 +1145,8 @@ class Cursor(object): self.__collation, self.__session, self.__collection.database.client, - self.__allow_disk_use) + self.__allow_disk_use, + self.__exhaust) self.__send_message(q) elif self.__id: # Get More if self.__limit: @@ -1151,7 +1165,8 @@ class Cursor(object): self.__session, self.__collection.database.client, self.__max_await_time_ms, - self.__exhaust_mgr) + self.__sock_mgr, + self.__exhaust) self.__send_message(g) return len(self.__data) diff --git a/pymongo/database.py b/pymongo/database.py index 8de73161a..9e2b221fe 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -377,7 +377,8 @@ class Database(common.BaseObject): user_fields={'cursor': {'firstBatch': 1}}) return self.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(s), s, - retryable=not cmd._performs_write) + retryable=not cmd._performs_write, + pin=self.client._should_pin_cursor(s)) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, @@ -636,7 +637,7 @@ class Database(common.BaseObject): sock_info, cmd, slave_okay, read_preference=read_preference, session=tmp_session)["cursor"] - return CommandCursor( + cmd_cursor = CommandCursor( coll, cursor, sock_info.address, @@ -656,7 +657,9 @@ class Database(common.BaseObject): ("pipeline", pipeline), ("cursor", kwargs.get("cursor", {}))]) cursor = self._command(sock_info, cmd, slave_okay)["cursor"] - return CommandCursor(coll, cursor, sock_info.address) + cmd_cursor = CommandCursor(coll, cursor, sock_info.address) + cmd_cursor._maybe_pin_connection(sock_info) + return cmd_cursor def list_collections(self, session=None, filter=None, **kwargs): """Get a cursor over the collectons of this database. @@ -688,7 +691,8 @@ class Database(common.BaseObject): **kwargs) return self.__client._retryable_read( - _cmd, read_pref, session) + _cmd, read_pref, session, + pin=self.client._should_pin_cursor(session)) def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. diff --git a/pymongo/message.py b/pymongo/message.py index 307260263..4edf8abce 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -234,16 +234,17 @@ class _Query(object): __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec', 'fields', 'codec_options', 'read_preference', 'limit', 'batch_size', 'name', 'read_concern', 'collation', - 'session', 'client', 'allow_disk_use', '_as_command') + 'session', 'client', 'allow_disk_use', '_as_command', + 'exhaust') # For compatibility with the _GetMore class. - exhaust_mgr = None + sock_mgr = None cursor_id = None def __init__(self, flags, db, coll, ntoskip, spec, fields, codec_options, read_preference, limit, batch_size, read_concern, collation, session, client, - allow_disk_use): + allow_disk_use, exhaust): self.flags = flags self.db = db self.coll = coll @@ -261,13 +262,14 @@ class _Query(object): self.allow_disk_use = allow_disk_use self.name = 'find' self._as_command = None + self.exhaust = exhaust def namespace(self): return "%s.%s" % (self.db, self.coll) - def use_command(self, sock_info, exhaust): + def use_command(self, sock_info): use_find_cmd = False - if sock_info.max_wire_version >= 4 and not exhaust: + if sock_info.max_wire_version >= 4 and not self.exhaust: use_find_cmd = True elif sock_info.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ @@ -375,13 +377,13 @@ class _GetMore(object): __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms', 'codec_options', 'read_preference', 'session', 'client', - 'exhaust_mgr', '_as_command') + 'sock_mgr', '_as_command', 'exhaust') name = 'getMore' def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, read_preference, session, client, max_await_time_ms, - exhaust_mgr): + sock_mgr, exhaust): self.db = db self.coll = coll self.ntoreturn = ntoreturn @@ -391,15 +393,16 @@ class _GetMore(object): self.session = session self.client = client self.max_await_time_ms = max_await_time_ms - self.exhaust_mgr = exhaust_mgr + self.sock_mgr = sock_mgr self._as_command = None + self.exhaust = exhaust def namespace(self): return "%s.%s" % (self.db, self.coll) - def use_command(self, sock_info, exhaust): + def use_command(self, sock_info): use_cmd = False - if sock_info.max_wire_version >= 4 and not exhaust: + if sock_info.max_wire_version >= 4 and not self.exhaust: use_cmd = True elif sock_info.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ @@ -440,7 +443,7 @@ class _GetMore(object): if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: - if self.exhaust_mgr: + if self.sock_mgr: flags = _OpMsg.EXHAUST_ALLOWED else: flags = 0 @@ -456,23 +459,23 @@ class _GetMore(object): class _RawBatchQuery(_Query): - def use_command(self, socket_info, exhaust): + def use_command(self, sock_info): # Compatibility checks. - super(_RawBatchQuery, self).use_command(socket_info, exhaust) - if socket_info.max_wire_version >= 8: + super(_RawBatchQuery, self).use_command(sock_info) + if sock_info.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True - elif socket_info.op_msg_enabled and not exhaust: + elif sock_info.op_msg_enabled and not self.exhaust: return True return False class _RawBatchGetMore(_GetMore): - def use_command(self, socket_info, exhaust): - if socket_info.max_wire_version >= 8: + def use_command(self, sock_info): + if sock_info.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True - elif socket_info.op_msg_enabled and not exhaust: + elif sock_info.op_msg_enabled and not self.exhaust: return True return False @@ -1033,20 +1036,23 @@ class _BulkWriteContext(object): cmd[self.field] = docs self.listeners.publish_command_start( cmd, self.db_name, - request_id, self.sock_info.address, self.op_id) + request_id, self.sock_info.address, self.op_id, + self.sock_info.service_id) return cmd def _succeed(self, request_id, reply, duration): """Publish a CommandSucceededEvent.""" self.listeners.publish_command_success( duration, reply, self.name, - request_id, self.sock_info.address, self.op_id) + request_id, self.sock_info.address, self.op_id, + self.sock_info.service_id) def _fail(self, request_id, failure, duration): """Publish a CommandFailedEvent.""" self.listeners.publish_command_failure( duration, failure, self.name, - request_id, self.sock_info.address, self.op_id) + request_id, self.sock_info.address, self.op_id, + self.sock_info.service_id) # From the Client Side Encryption spec: @@ -1686,7 +1692,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn, query = _Query( 0, db, coll, 0, query, None, codec_options, read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None, - None, None) + None, None, False) name = next(iter(cmd)) publish = listeners.enabled_for_commands @@ -1698,7 +1704,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn, if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start( - cmd, db, request_id, sock_info.address) + cmd, db, request_id, sock_info.address, + service_id=sock_info.service_id) start = datetime.datetime.now() sock_info.send_message(msg, max_doc_size) @@ -1713,7 +1720,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn, else: failure = _convert_exception(exc) listeners.publish_command_failure( - duration, failure, name, request_id, sock_info.address) + duration, failure, name, request_id, sock_info.address, + service_id=sock_info.service_id) raise # listIndexes if 'cursor' in cmd: @@ -1732,6 +1740,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn, if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( - duration, result, name, request_id, sock_info.address) + duration, result, name, request_id, sock_info.address, + service_id=sock_info.service_id) return result diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index dde4ef242..4eb54e1c5 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -60,6 +60,7 @@ from pymongo.errors import (AutoReconnect, OperationFailure, PyMongoError, ServerSelectionTimeoutError) +from pymongo.pool import ConnectionClosedReason from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import (writable_preferred_server_selector, writable_server_selector) @@ -1160,10 +1161,20 @@ class MongoClient(common.BaseObject): return self._topology @contextlib.contextmanager - def _get_socket(self, server, session, exhaust=False): + def _get_socket(self, server, session, pin=False): + in_txn = session and session.in_transaction with _MongoClientErrorHandler(self, server, session) as err_handler: + # Reuse the pinned connection, if it exists. + if in_txn and session._pinned_connection: + yield session._pinned_connection + return with server.get_socket( - self.__all_credentials, checkout=exhaust) as sock_info: + self.__all_credentials, checkout=pin, + handler=err_handler) as sock_info: + # Pin this session to the selected server or connection. + if (in_txn and server.description.server_type in ( + SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)): + session._pin(server, sock_info) err_handler.contribute_socket(sock_info) if (self._encrypter and not self._encrypter._bypass_auto_encryption and @@ -1186,6 +1197,8 @@ class MongoClient(common.BaseObject): """ try: topology = self._get_topology() + if session and not session.in_transaction: + session._transaction.reset() address = address or (session and session._pinned_address) if address: # We're running a getMore or this session is pinned to a mongos. @@ -1195,12 +1208,6 @@ class MongoClient(common.BaseObject): % address) else: server = topology.select_server(server_selector) - # Pin this session to the selected server if it's performing a - # sharded transaction. - if (server.description.server_type in ( - SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer) - and session and session.in_transaction): - session._pin(server) return server except PyMongoError as exc: # Server selection errors in a transaction are transient. @@ -1214,8 +1221,7 @@ class MongoClient(common.BaseObject): return self._get_socket(server, session) @contextlib.contextmanager - def _slaveok_for_server(self, read_preference, server, session, - exhaust=False): + def _slaveok_for_server(self, read_preference, server, session, pin=False): assert read_preference is not None, "read_preference must not be None" # Get a socket for a server matching the read preference, and yield # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to @@ -1226,7 +1232,7 @@ class MongoClient(common.BaseObject): topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._get_socket(server, session, exhaust=exhaust) as sock_info: + with self._get_socket(server, session, pin=pin) as sock_info: slave_ok = (single and not sock_info.is_mongos) or ( read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok @@ -1249,47 +1255,41 @@ class MongoClient(common.BaseObject): read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok - def _run_operation_with_response(self, operation, unpack_res, - exhaust=False, address=None): + def _should_pin_cursor(self, session): + return (self.__options.load_balanced and + not (session and session.in_transaction)) + + def _run_operation(self, operation, unpack_res, pin=False, address=None): """Run a _Query/_GetMore operation and return a Response. :Parameters: - `operation`: a _Query or _GetMore object. - `unpack_res`: A callable that decodes the wire protocol response. - - `exhaust` (optional): If True, the socket used stays checked out. - It is returned along with its Pool in the Response. - `address` (optional): Optional address when sending a message to a specific server, used for getMore. """ - if operation.exhaust_mgr: + pin = self._should_pin_cursor(operation.session) or operation.exhaust + if operation.sock_mgr: server = self._select_server( operation.read_preference, operation.session, address=address) - with _MongoClientErrorHandler( - self, server, operation.session) as err_handler: - err_handler.contribute_socket(operation.exhaust_mgr.sock) - return server.run_operation_with_response( - operation.exhaust_mgr.sock, - operation, - True, - self._event_listeners, - exhaust, - unpack_res) + with operation.sock_mgr.lock: + with _MongoClientErrorHandler( + self, server, operation.session) as err_handler: + err_handler.contribute_socket(operation.sock_mgr.sock) + return server.run_operation( + operation.sock_mgr.sock, operation, True, + self._event_listeners, pin, unpack_res) def _cmd(session, server, sock_info, slave_ok): - return server.run_operation_with_response( - sock_info, - operation, - slave_ok, - self._event_listeners, - exhaust, + return server.run_operation( + sock_info, operation, slave_ok, self._event_listeners, pin, unpack_res) return self._retryable_read( _cmd, operation.read_preference, operation.session, - address=address, - retryable=isinstance(operation, message._Query), - exhaust=exhaust) + address=address, retryable=isinstance(operation, message._Query), + pin=pin) def _retry_with_session(self, retryable, func, session, bulk): """Execute an operation with at most one consecutive retries @@ -1361,7 +1361,7 @@ class MongoClient(common.BaseObject): last_error = exc def _retryable_read(self, func, read_pref, session, address=None, - retryable=True, exhaust=False): + retryable=True, pin=False): """Execute an operation with at most one consecutive retries Returns func()'s return value on success. On error retries the same @@ -1381,9 +1381,9 @@ class MongoClient(common.BaseObject): read_pref, session, address=address) if not server.description.retryable_reads_supported: retryable = False - with self._slaveok_for_server(read_pref, server, session, - exhaust=exhaust) as (sock_info, - slave_ok): + with self._slaveok_for_server( + read_pref, server, session, pin=pin) as ( + sock_info, slave_ok): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. @@ -1496,7 +1496,8 @@ class MongoClient(common.BaseObject): """ self.__kill_cursors_queue.append((address, [cursor_id])) - def _close_cursor_now(self, cursor_id, address=None, session=None): + def _close_cursor_now(self, cursor_id, address=None, session=None, + sock_mgr=None): """Send a kill cursors message with the given id. The cursor is closed synchronously on the current thread. @@ -1505,16 +1506,20 @@ class MongoClient(common.BaseObject): raise TypeError("cursor_id must be an instance of int") try: - self._kill_cursors( - [cursor_id], address, self._get_topology(), session) + if sock_mgr: + with sock_mgr.lock: + # Cursor is pinned to LB outside of a transaction. + self._kill_cursor_impl( + [cursor_id], address, session, sock_mgr.sock) + else: + self._kill_cursors( + [cursor_id], address, self._get_topology(), session) except PyMongoError: # Make another attempt to kill the cursor later. self.__kill_cursors_queue.append((address, [cursor_id])) def _kill_cursors(self, cursor_ids, address, topology, session): """Send a kill cursors message with the given ids.""" - listeners = self._event_listeners - publish = listeners.enabled_for_commands if address: # address could be a tuple or _CursorAddress, but # select_server_by_address needs (host, port). @@ -1523,49 +1528,55 @@ class MongoClient(common.BaseObject): # Application called close_cursor() with no address. server = topology.select_server(writable_server_selector) + with self._get_socket(server, session) as sock_info: + self._kill_cursor_impl(cursor_ids, address, session, sock_info) + + def _kill_cursor_impl(self, cursor_ids, address, session, sock_info): + listeners = self._event_listeners + publish = listeners.enabled_for_commands + try: namespace = address.namespace db, coll = namespace.split('.', 1) except AttributeError: namespace = None db = coll = "OP_KILL_CURSORS" - spec = SON([('killCursors', coll), ('cursors', cursor_ids)]) - with server.get_socket(self.__all_credentials) as sock_info: - if sock_info.max_wire_version >= 4 and namespace is not None: - sock_info.command(db, spec, session=session, client=self) - else: - if publish: - start = datetime.datetime.now() - request_id, msg = message.kill_cursors(cursor_ids) - if publish: - duration = datetime.datetime.now() - start - # Here and below, address could be a tuple or - # _CursorAddress. We always want to publish a - # tuple to match the rest of the monitoring - # API. - listeners.publish_command_start( - spec, db, request_id, tuple(address)) - start = datetime.datetime.now() - - try: - sock_info.send_message(msg, 0) - except Exception as exc: - if publish: - dur = ((datetime.datetime.now() - start) + duration) - listeners.publish_command_failure( - dur, message._convert_exception(exc), - 'killCursors', request_id, - tuple(address)) - raise + if sock_info.max_wire_version >= 4 and namespace is not None: + sock_info.command(db, spec, session=session, client=self) + else: + if publish: + start = datetime.datetime.now() + request_id, msg = message.kill_cursors(cursor_ids) + if publish: + duration = datetime.datetime.now() - start + # Here and below, address could be a tuple or + # _CursorAddress. We always want to publish a + # tuple to match the rest of the monitoring + # API. + listeners.publish_command_start( + spec, db, request_id, tuple(address), + service_id=sock_info.service_id) + start = datetime.datetime.now() + try: + sock_info.send_message(msg, 0) + except Exception as exc: if publish: - duration = ((datetime.datetime.now() - start) + duration) - # OP_KILL_CURSORS returns no reply, fake one. - reply = {'cursorsUnknown': cursor_ids, 'ok': 1} - listeners.publish_command_success( - duration, reply, 'killCursors', request_id, - tuple(address)) + dur = ((datetime.datetime.now() - start) + duration) + listeners.publish_command_failure( + dur, message._convert_exception(exc), + 'killCursors', request_id, + tuple(address), service_id=sock_info.service_id) + raise + + if publish: + duration = ((datetime.datetime.now() - start) + duration) + # OP_KILL_CURSORS returns no reply, fake one. + reply = {'cursorsUnknown': cursor_ids, 'ok': 1} + listeners.publish_command_success( + duration, reply, 'killCursors', request_id, + tuple(address), service_id=sock_info.service_id) def _process_kill_cursors(self): """Process any pending kill cursors requests.""" @@ -1966,7 +1977,8 @@ def _add_retryable_write_error(exc, max_wire_version): class _MongoClientErrorHandler(object): """Handle errors raised when executing an operation.""" __slots__ = ('client', 'server_address', 'session', 'max_wire_version', - 'sock_generation', 'completed_handshake', 'service_id') + 'sock_generation', 'completed_handshake', 'service_id', + 'handled') def __init__(self, client, server, session): self.client = client @@ -1980,6 +1992,7 @@ class _MongoClientErrorHandler(object): self.sock_generation = server.pool.gen.get_overall() self.completed_handshake = False self.service_id = None + self.handled = False def contribute_socket(self, sock_info): """Provide socket information to the error handler.""" @@ -1988,13 +2001,10 @@ class _MongoClientErrorHandler(object): self.service_id = sock_info.service_id self.completed_handshake = True - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: + def handle(self, exc_type, exc_val): + if self.handled or exc_type is None: return - + self.handled = True if self.session: if issubclass(exc_type, ConnectionFailure): if self.session.in_transaction: @@ -2010,3 +2020,9 @@ class _MongoClientErrorHandler(object): exc_val, self.max_wire_version, self.sock_generation, self.completed_handshake, self.service_id) self.client._topology.handle_error(self.server_address, err_ctx) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return self.handle(exc_type, exc_val) diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index b53629d12..5e86e93e0 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -1310,7 +1310,8 @@ class _EventListeners(object): self.__topology_listeners[:]) def publish_command_start(self, command, database_name, - request_id, connection_id, op_id=None): + request_id, connection_id, op_id=None, + service_id=None): """Publish a CommandStartedEvent to all command listeners. :Parameters: @@ -1321,11 +1322,13 @@ class _EventListeners(object): - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. + - `service_id`: The service_id this command was sent to, or ``None``. """ if op_id is None: op_id = request_id event = CommandStartedEvent( - command, database_name, request_id, connection_id, op_id) + command, database_name, request_id, connection_id, op_id, + service_id=service_id) for subscriber in self.__command_listeners: try: subscriber.started(event) @@ -1333,7 +1336,8 @@ class _EventListeners(object): _handle_exception() def publish_command_success(self, duration, reply, command_name, - request_id, connection_id, op_id=None): + request_id, connection_id, op_id=None, + service_id=None): """Publish a CommandSucceededEvent to all command listeners. :Parameters: @@ -1344,11 +1348,13 @@ class _EventListeners(object): - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. + - `service_id`: The service_id this command was sent to, or ``None``. """ if op_id is None: op_id = request_id event = CommandSucceededEvent( - duration, reply, command_name, request_id, connection_id, op_id) + duration, reply, command_name, request_id, connection_id, op_id, + service_id) for subscriber in self.__command_listeners: try: subscriber.succeeded(event) @@ -1356,7 +1362,8 @@ class _EventListeners(object): _handle_exception() def publish_command_failure(self, duration, failure, command_name, - request_id, connection_id, op_id=None): + request_id, connection_id, op_id=None, + service_id=None): """Publish a CommandFailedEvent to all command listeners. :Parameters: @@ -1368,11 +1375,13 @@ class _EventListeners(object): - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. + - `service_id`: The service_id this command was sent to, or ``None``. """ if op_id is None: op_id = request_id event = CommandFailedEvent( - duration, failure, command_name, request_id, connection_id, op_id) + duration, failure, command_name, request_id, connection_id, op_id, + service_id=service_id) for subscriber in self.__command_listeners: try: subscriber.failed(event) diff --git a/pymongo/network.py b/pymongo/network.py index 39f650448..813a7cc38 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -134,7 +134,8 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos, if publish: encoding_duration = datetime.datetime.now() - start - listeners.publish_command_start(orig, dbname, request_id, address) + listeners.publish_command_start(orig, dbname, request_id, address, + service_id=sock_info.service_id) start = datetime.datetime.now() try: @@ -164,12 +165,14 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos, else: failure = message._convert_exception(exc) listeners.publish_command_failure( - duration, failure, name, request_id, address) + duration, failure, name, request_id, address, + service_id=sock_info.service_id) raise if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( - duration, response_doc, name, request_id, address) + duration, response_doc, name, request_id, address, + service_id=sock_info.service_id) if client and client._encrypter and reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) diff --git a/pymongo/pool.py b/pymongo/pool.py index 53766f65f..5d62e4bee 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -23,6 +23,7 @@ import socket import sys import threading import time +import weakref from bson import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -503,6 +504,7 @@ class SocketInfo(object): - `id`: the id of this socket in it's pool """ def __init__(self, sock, pool, address, id): + self.pool_ref = weakref.ref(pool) self.sock = sock self.address = address self.id = id @@ -541,6 +543,17 @@ class SocketInfo(object): self.more_to_come = False # For load balancer support. self.service_id = None + # When executing a transaction in load balancing mode, this flag is + # set to true to indicate that the session now owns the connection. + self.pinned = False + + def unpin(self): + self.pinned = False + pool = self.pool_ref() + if pool: + pool.return_socket(self) + else: + self.close_socket(ConnectionClosedReason.STALE) def hello_cmd(self): if self.opts.server_api: @@ -712,7 +725,7 @@ class SocketInfo(object): unacknowledged=unacknowledged, user_fields=user_fields, exhaust_allowed=exhaust_allowed) - except OperationFailure: + except (OperationFailure, NotMasterError): raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: @@ -898,7 +911,13 @@ class SocketInfo(object): # ...) is called in Python code, which experiences the signal as a # KeyboardInterrupt from the start, rather than as an initial # socket.error, so we catch that, close the socket, and reraise it. - self.close_socket(ConnectionClosedReason.ERROR) + # + # The connection closed event will be emitted later in return_socket. + if self.ready: + reason = None + else: + reason = ConnectionClosedReason.ERROR + self.close_socket(reason) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, _SSLError)): _raise_connection_failure(self.address, error) @@ -1152,6 +1171,10 @@ class Pool: self.address, self.opts.non_default_options) # Similar to active_sockets but includes threads in the wait queue. self.operation_count = 0 + # Retain references to pinned connections to prevent the CPython GC + # from thinking that a cursor's pinned connection can be GC'd when the + # cursor is GC'd (see PYTHON-2751). + self.__pinned_sockets = set() def ready(self): old_state, self.state = self.state, PoolState.READY @@ -1328,7 +1351,7 @@ class Pool: return sock_info @contextlib.contextmanager - def get_socket(self, all_credentials, checkout=False): + def get_socket(self, all_credentials, checkout=False, handler=None): """Get a socket from the pool. Use with a "with" statement. Returns a :class:`SocketInfo` object wrapping a connected @@ -1349,12 +1372,15 @@ class Pool: :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. - `checkout` (optional): keep socket checked out. + - `handler` (optional): A _MongoClientErrorHandler. """ listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) sock_info = self._get_socket(all_credentials) + if checkout: + self.__pinned_sockets.add(sock_info) if self.enabled_for_cmap: listeners.publish_connection_checked_out( @@ -1362,11 +1388,23 @@ class Pool: try: yield sock_info except: - # Exception in caller. Decrement semaphore. - self.return_socket(sock_info) + # Exception in caller. Ensure the connection gets returned. + # Note that when pinned is True, the session owns the + # connection and it is responsible for checking the connection + # back into the pool. + pinned = sock_info.pinned + if handler: + # Perform SDAM error handling rules while the connection is + # still checked out. + exc_type, exc_val, _ = sys.exc_info() + handler.handle(exc_type, exc_val) + if not pinned: + self.return_socket(sock_info) raise else: - if not checkout: + if sock_info.pinned: + self.__pinned_sockets.add(sock_info) + elif not checkout: self.return_socket(sock_info) def _raise_if_not_ready(self, emit_event): @@ -1487,6 +1525,8 @@ class Pool: :Parameters: - `sock_info`: The socket to check into the pool. """ + self.__pinned_sockets.discard(sock_info) + sock_info.pinned = False listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_checked_in(self.address, sock_info.id) @@ -1495,7 +1535,13 @@ class Pool: else: if self.closed: sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) - elif not sock_info.closed: + elif sock_info.closed: + # CMAP requires the closed event be emitted after the check in. + if self.enabled_for_cmap: + listeners.publish_connection_closed( + self.address, sock_info.id, + ConnectionClosedReason.ERROR) + else: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). diff --git a/pymongo/response.py b/pymongo/response.py index 474e2c4d3..3094399da 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -68,10 +68,10 @@ class Response(object): return self._docs -class ExhaustResponse(Response): - __slots__ = ('_socket_info', '_pool', '_more_to_come') +class PinnedResponse(Response): + __slots__ = ('_socket_info', '_more_to_come') - def __init__(self, data, address, socket_info, pool, request_id, duration, + def __init__(self, data, address, socket_info, request_id, duration, from_command, docs, more_to_come): """Represent a response to an exhaust cursor's initial query. @@ -87,13 +87,12 @@ class ExhaustResponse(Response): - `more_to_come`: Bool indicating whether cursor is ready to be exhausted. """ - super(ExhaustResponse, self).__init__(data, - address, - request_id, - duration, - from_command, docs) + super(PinnedResponse, self).__init__(data, + address, + request_id, + duration, + from_command, docs) self._socket_info = socket_info - self._pool = pool self._more_to_come = more_to_come @property @@ -106,11 +105,6 @@ class ExhaustResponse(Response): """ return self._socket_info - @property - def pool(self): - """The Pool from which the SocketInfo came.""" - return self._pool - @property def more_to_come(self): """If true, server is ready to send batches on the socket until the diff --git a/pymongo/server.py b/pymongo/server.py index 389f8e729..672a3b1c1 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -21,7 +21,7 @@ from bson import _decode_all_selective from pymongo.errors import NotMasterError, OperationFailure from pymongo.helpers import _check_command_response from pymongo.message import _convert_exception, _OpMsg -from pymongo.response import Response, ExhaustResponse +from pymongo.response import Response, PinnedResponse from pymongo.server_type import SERVER_TYPE _CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}} @@ -68,14 +68,8 @@ class Server(object): """Check the server's state soon.""" self._monitor.request_check() - def run_operation_with_response( - self, - sock_info, - operation, - set_slave_okay, - listeners, - exhaust, - unpack_res): + def run_operation(self, sock_info, operation, set_slave_okay, listeners, + pin, unpack_res): """Run a _Query or _GetMore operation and return a Response object. This method is used only to run _Query/_GetMore operations from @@ -87,7 +81,7 @@ class Server(object): - `set_slave_okay`: Pass to operation.get_message. - `all_credentials`: dict, maps auth source to MongoCredential. - `listeners`: Instance of _EventListeners or None. - - `exhaust`: If True, then this is an exhaust cursor operation. + - `pin`: If True, then this is a pinned cursor operation. - `unpack_res`: A callable that decodes the wire protocol response. """ duration = None @@ -95,9 +89,9 @@ class Server(object): if publish: start = datetime.now() - use_cmd = operation.use_command(sock_info, exhaust) - more_to_come = (operation.exhaust_mgr - and operation.exhaust_mgr.more_to_come) + use_cmd = operation.use_command(sock_info) + more_to_come = (operation.sock_mgr + and operation.sock_mgr.more_to_come) if more_to_come: request_id = 0 else: @@ -108,7 +102,8 @@ class Server(object): if publish: cmd, dbn = operation.as_command(sock_info) listeners.publish_command_start( - cmd, dbn, request_id, sock_info.address) + cmd, dbn, request_id, sock_info.address, + service_id=sock_info.service_id) start = datetime.now() try: @@ -142,7 +137,8 @@ class Server(object): failure = _convert_exception(exc) listeners.publish_command_failure( duration, failure, operation.name, - request_id, sock_info.address) + request_id, sock_info.address, + service_id=sock_info.service_id) raise if publish: @@ -163,7 +159,7 @@ class Server(object): res["cursor"]["nextBatch"] = docs listeners.publish_command_success( duration, res, operation.name, request_id, - sock_info.address) + sock_info.address, service_id=sock_info.service_id) # Decrypt response. client = operation.client @@ -174,19 +170,20 @@ class Server(object): docs = _decode_all_selective( decrypted, operation.codec_options, user_fields) - if exhaust: + if pin: if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the # more_to_come flag is set. more_to_come = reply.more_to_come else: # In OP_REPLY, the server keeps sending until cursor_id is 0. - more_to_come = bool(reply.cursor_id) - response = ExhaustResponse( + more_to_come = bool(operation.exhaust and reply.cursor_id) + if operation.sock_mgr: + operation.sock_mgr.update_exhaust(more_to_come) + response = PinnedResponse( data=reply, address=self._description.address, socket_info=sock_info, - pool=self._pool, duration=duration, request_id=request_id, from_command=use_cmd, @@ -203,8 +200,8 @@ class Server(object): return response - def get_socket(self, all_credentials, checkout=False): - return self.pool.get_socket(all_credentials, checkout) + def get_socket(self, all_credentials, checkout=False, handler=None): + return self.pool.get_socket(all_credentials, checkout, handler) @property def description(self): diff --git a/test/load_balancer/test_load_balancer.py b/test/load_balancer/test_load_balancer.py index c31ff58ef..99f8855ca 100644 --- a/test/load_balancer/test_load_balancer.py +++ b/test/load_balancer/test_load_balancer.py @@ -19,8 +19,8 @@ import sys sys.path[0:0] = [""] -from test import unittest - +from test import unittest, IntegrationTest, client_context +from test.utils import get_pool from test.unified_format import generate_test_classes # Location of JSON test specifications. @@ -30,5 +30,23 @@ TEST_PATH = os.path.join( # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) + +class TestLB(IntegrationTest): + @client_context.require_load_balancer + def test_unpin_committed_transaction(self): + pool = get_pool(self.client) + with self.client.start_session() as session: + with session.start_transaction(): + self.assertEqual(pool.active_sockets, 0) + self.db.test.insert_one({}, session=session) + self.assertEqual(pool.active_sockets, 1) # Pinned. + self.assertEqual(pool.active_sockets, 1) # Still pinned. + self.assertEqual(pool.active_sockets, 0) # Unpinned. + + def test_client_can_be_reopened(self): + self.client.close() + self.db.test.find_one({}) + + if __name__ == "__main__": unittest.main() diff --git a/test/load_balancer/unified/cursors.json b/test/load_balancer/unified/cursors.json index 43e4fbb4f..4e2a55fd4 100644 --- a/test/load_balancer/unified/cursors.json +++ b/test/load_balancer/unified/cursors.json @@ -376,7 +376,7 @@ ] }, { - "description": "pinned connections are not returned after an network error during getMore", + "description": "pinned connections are returned after an network error during getMore", "operations": [ { "name": "failPoint", @@ -440,7 +440,7 @@ "object": "testRunner", "arguments": { "client": "client0", - "connections": 1 + "connections": 0 } }, { @@ -659,7 +659,7 @@ ] }, { - "description": "pinned connections are not returned to the pool after a non-network error on getMore", + "description": "pinned connections are returned to the pool after a non-network error on getMore", "operations": [ { "name": "failPoint", @@ -715,7 +715,7 @@ "object": "testRunner", "arguments": { "client": "client0", - "connections": 1 + "connections": 0 } }, { diff --git a/test/load_balancer/unified/sdam-error-handling.json b/test/load_balancer/unified/sdam-error-handling.json index 63aabc04d..462fa0aac 100644 --- a/test/load_balancer/unified/sdam-error-handling.json +++ b/test/load_balancer/unified/sdam-error-handling.json @@ -366,9 +366,6 @@ { "connectionCreatedEvent": {} }, - { - "poolClearedEvent": {} - }, { "connectionClosedEvent": { "reason": "error" @@ -378,6 +375,9 @@ "connectionCheckOutFailedEvent": { "reason": "connectionError" } + }, + { + "poolClearedEvent": {} } ] } diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 8a28284bf..540dd68e3 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -40,7 +40,7 @@ class MockPool(Pool): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def get_socket(self, all_credentials, checkout=False): + def get_socket(self, all_credentials, checkout=False, handler=None): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: @@ -51,7 +51,8 @@ class MockPool(Pool): + client.mock_members + client.mock_mongoses), "bad host: %s" % host_and_port - with Pool.get_socket(self, all_credentials, checkout) as sock_info: + with Pool.get_socket( + self, all_credentials, checkout, handler) as sock_info: sock_info.mock_host = self.mock_host sock_info.mock_port = self.mock_port yield sock_info diff --git a/test/test_client.py b/test/test_client.py index 88adaac6a..f97691176 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1323,11 +1323,11 @@ class TestClient(IntegrationTest): with self.assertRaises(AutoReconnect): client = rs_client(connect=False, serverSelectionTimeoutMS=100) - client._run_operation_with_response( + client._run_operation( operation=message._GetMore('pymongo_test', 'collection', 101, 1234, client.codec_options, ReadPreference.PRIMARY, - None, client, None, None), + None, client, None, None, False), unpack_res=Cursor( client.pymongo_test.collection)._unpack_response, address=('not-a-member', 27017)) @@ -1708,7 +1708,7 @@ class TestExhaustCursor(IntegrationTest): cursor.next() # Cause a network error. - sock_info = cursor._Cursor__exhaust_mgr.sock + sock_info = cursor._Cursor__sock_mgr.sock sock_info.sock.close() # A getmore fails. diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 16c0166b5..bdf24e240 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -20,6 +20,7 @@ import warnings sys.path[0:0] = [""] +from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON from pymongo import CursorType, monitoring, InsertOne, UpdateOne, DeleteOne @@ -397,7 +398,8 @@ class TestCommandMonitoring(PyMongoTestCase): def test_get_more_failure(self): address = self.client.address coll = self.client.pymongo_test.test - cursor_doc = {"id": 12345, "firstBatch": [], "ns": coll.full_name} + cursor_id = Int64(12345) + cursor_doc = {"id": cursor_id, "firstBatch": [], "ns": coll.full_name} cursor = CommandCursor(coll, cursor_doc, address) try: next(cursor) @@ -410,7 +412,7 @@ class TestCommandMonitoring(PyMongoTestCase): self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('getMore', 12345), + SON([('getMore', cursor_id), ('collection', 'test')]), started.command) self.assertEqual('getMore', started.command_name) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 93a282a6d..773aed3f5 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -322,10 +322,9 @@ class ReadPrefTester(MongoClient): yield sock_info, slave_ok @contextlib.contextmanager - def _slaveok_for_server(self, read_preference, server, session, - exhaust=False): + def _slaveok_for_server(self, read_preference, server, session, pin=False): context = super(ReadPrefTester, self)._slaveok_for_server( - read_preference, server, session, exhaust=exhaust) + read_preference, server, session, pin=pin) with context as (sock_info, slave_ok): self.record_a_read(sock_info.address) yield sock_info, slave_ok diff --git a/test/unified_format.py b/test/unified_format.py index 91e02e9e2..d2433e5b4 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -519,6 +519,14 @@ class MatchEvaluatorUtil(object): self.test.assertIsInstance(actual, type(expectation)) self.test.assertEqual(expectation, actual) + def assertHasServiceId(self, spec, actual): + if 'hasServiceId' in spec: + if spec.get('hasServiceId'): + self.test.assertIsNotNone(actual.service_id) + self.test.assertIsInstance(actual.service_id, ObjectId) + else: + self.test.assertIsNone(actual.service_id) + def match_event(self, event_type, expectation, actual): name, spec = next(iter(expectation.items())) @@ -543,24 +551,23 @@ class MatchEvaluatorUtil(object): if database_name: self.test.assertEqual( database_name, actual.database_name) + self.assertHasServiceId(spec, actual) elif name == 'commandSucceededEvent': self.test.assertIsInstance(actual, CommandSucceededEvent) reply = spec.get('reply') if reply: self.match_result(reply, actual.reply) + self.assertHasServiceId(spec, actual) elif name == 'commandFailedEvent': self.test.assertIsInstance(actual, CommandFailedEvent) + self.assertHasServiceId(spec, actual) elif name == 'poolCreatedEvent': self.test.assertIsInstance(actual, PoolCreatedEvent) elif name == 'poolReadyEvent': self.test.assertIsInstance(actual, PoolReadyEvent) elif name == 'poolClearedEvent': self.test.assertIsInstance(actual, PoolClearedEvent) - if spec.get('hasServiceId'): - self.test.assertIsNotNone(actual.service_id) - self.test.assertIsInstance(actual.service_id, ObjectId) - else: - self.test.assertIsNone(actual.service_id) + self.assertHasServiceId(spec, actual) elif name == 'poolClosedEvent': self.test.assertIsInstance(actual, PoolClosedEvent) elif name == 'connectionCreatedEvent': @@ -569,12 +576,14 @@ class MatchEvaluatorUtil(object): self.test.assertIsInstance(actual, ConnectionReadyEvent) elif name == 'connectionClosedEvent': self.test.assertIsInstance(actual, ConnectionClosedEvent) - self.test.assertEqual(actual.reason, spec['reason']) + if 'reason' in spec: + self.test.assertEqual(actual.reason, spec['reason']) elif name == 'connectionCheckOutStartedEvent': self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) elif name == 'connectionCheckOutFailedEvent': self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) - self.test.assertEqual(actual.reason, spec['reason']) + if 'reason' in spec: + self.test.assertEqual(actual.reason, spec['reason']) elif name == 'connectionCheckedOutEvent': self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) elif name == 'connectionCheckedInEvent': diff --git a/test/utils.py b/test/utils.py index fa2865c83..c8336440e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -267,7 +267,7 @@ class MockPool(object): def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) - def get_socket(self, all_credentials, checkout=False): + def get_socket(self, all_credentials, checkout=False, handler=None): return MockSocketInfo() def return_socket(self, *args, **kwargs):