PYTHON-2673 Connection pinning behavior for load balanced clusters (#630)

Tweak spec test because pymongo unpins cursors eagerly after errors.
Tweak spec test for PoolClearedEvent ordering when MongoDB handshake fails (see DRIVERS-1785).
Only skip killCursors for some error codes.
Rely on SDAM error handling to close the connection after a state change error.
Add service_id to various events.
Retain reference to pinned sockets to prevent premptive closure by CPython's cyclic GC.
This commit is contained in:
Shane Harvey 2021-05-24 17:49:44 -07:00
parent 7a48831124
commit c8f32a7a37
23 changed files with 465 additions and 296 deletions

View File

@ -161,11 +161,13 @@ class _AggregationCommand(object):
}
# Create and return cursor instance.
return self._cursor_class(
cmd_cursor = self._cursor_class(
self._cursor_collection(cursor), cursor, sock_info.address,
batch_size=self._batch_size or 0,
max_await_time_ms=self._max_await_time_ms,
session=session, explicit_session=self._explicit_session)
cmd_cursor._maybe_pin_connection(sock_info)
return cmd_cursor
class _CollectionAggregationCommand(_AggregationCommand):

View File

@ -181,7 +181,7 @@ class ChangeStream(object):
return self._client._retryable_read(
cmd.get_cursor, self._target._read_preference_for(session),
session)
session, pin=self._client._should_pin_cursor(session))
def _create_cursor(self):
with self._client._tmp_session(self._session, close=False) as s:

View File

@ -108,6 +108,7 @@ from bson.int64 import Int64
from bson.son import SON
from bson.timestamp import Timestamp
from pymongo.cursor import _SocketManager
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
InvalidOperation,
@ -117,6 +118,7 @@ from pymongo.errors import (ConfigurationError,
from pymongo.helpers import _RETRYABLE_ERROR_CODES
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.server_type import SERVER_TYPE
from pymongo.write_concern import WriteConcern
@ -292,6 +294,7 @@ class _Transaction(object):
self.state = _TxnState.NONE
self.sharded = False
self.pinned_address = None
self.sock_mgr = None
self.recovery_token = None
self.attempt = 0
@ -301,10 +304,29 @@ class _Transaction(object):
def starting(self):
return self.state == _TxnState.STARTING
@property
def pinned_conn(self):
if self.active() and self.sock_mgr:
return self.sock_mgr.sock
return None
def pin(self, server, sock_info):
self.sharded = True
self.pinned_address = server.description.address
if server.description.server_type == SERVER_TYPE.LoadBalancer:
sock_info.pinned = True
self.sock_mgr = _SocketManager(sock_info, False)
def unpin(self):
self.pinned_address = None
if self.sock_mgr:
self.sock_mgr.close()
self.sock_mgr = None
def reset(self):
self.unpin()
self.state = _TxnState.NONE
self.sharded = False
self.pinned_address = None
self.recovery_token = None
self.attempt = 0
@ -374,6 +396,9 @@ class ClientSession(object):
try:
if self.in_transaction:
self.abort_transaction()
# It's possible we're still pinned here when the transaction
# is in the committed state when the session is discarded.
self._unpin()
finally:
self._client._return_server_session(self._server_session, lock)
self._server_session = None
@ -779,14 +804,18 @@ class ClientSession(object):
return self._transaction.pinned_address
return None
def _pin(self, server):
"""Pin this session to the given Server."""
self._transaction.sharded = True
self._transaction.pinned_address = server.description.address
@property
def _pinned_connection(self):
"""The connection this transaction was started on."""
return self._transaction.pinned_conn
def _pin(self, server, sock_info):
"""Pin this session to the given Server or to the given connection."""
self._transaction.pin(server, sock_info)
def _unpin(self):
"""Unpin this session from any pinned Server."""
self._transaction.pinned_address = None
self._transaction.unpin()
def _txn_read_preference(self):
"""Return read preference of this transaction or None."""
@ -800,9 +829,6 @@ class ClientSession(object):
self._server_session.last_use = time.monotonic()
command['lsid'] = self._server_session.session_id
if not self.in_transaction:
self._transaction.reset()
if is_retryable:
command['txnNumber'] = self._server_session.transaction_id
return

View File

@ -520,7 +520,8 @@ class Collection(common.BaseObject):
if publish:
duration = datetime.datetime.now() - start
listeners.publish_command_start(
cmd, self.__database.name, rqst_id, sock_info.address, op_id)
cmd, self.__database.name, rqst_id, sock_info.address, op_id,
sock_info.service_id)
start = datetime.datetime.now()
try:
result = sock_info.legacy_write(rqst_id, msg, max_size, False)
@ -534,12 +535,14 @@ class Collection(common.BaseObject):
reply = message._convert_write_result(
name, cmd, details)
listeners.publish_command_success(
dur, reply, name, rqst_id, sock_info.address, op_id)
dur, reply, name, rqst_id, sock_info.address,
op_id, sock_info.service_id)
raise
else:
details = message._convert_exception(exc)
listeners.publish_command_failure(
dur, details, name, rqst_id, sock_info.address, op_id)
dur, details, name, rqst_id, sock_info.address, op_id,
sock_info.service_id)
raise
if publish:
if result is not None:
@ -549,7 +552,8 @@ class Collection(common.BaseObject):
reply = {'ok': 1}
duration = (datetime.datetime.now() - start) + duration
listeners.publish_command_success(
duration, reply, name, rqst_id, sock_info.address, op_id)
duration, reply, name, rqst_id, sock_info.address, op_id,
sock_info.service_id)
return result
def _insert_one(
@ -2072,9 +2076,9 @@ class Collection(common.BaseObject):
if exc.code != 26:
raise
cursor = {'id': 0, 'firstBatch': []}
return CommandCursor(coll, cursor, sock_info.address,
session=s,
explicit_session=session is not None)
cmd_cursor = CommandCursor(
coll, cursor, sock_info.address, session=s,
explicit_session=session is not None)
else:
res = message._first_batch(
sock_info, self.__database.name, "system.indexes",
@ -2084,10 +2088,13 @@ class Collection(common.BaseObject):
cursor = res["cursor"]
# Note that a collection can only have 64 indexes, so there
# will never be a getMore call.
return CommandCursor(coll, cursor, sock_info.address)
cmd_cursor = CommandCursor(coll, cursor, sock_info.address)
cmd_cursor._maybe_pin_connection(sock_info)
return cmd_cursor
return self.__database.client._retryable_read(
_cmd, read_pref, session)
_cmd, read_pref, session,
pin=self.__database.client._should_pin_cursor(session))
def index_information(self, session=None):
"""Get information on this collection's indexes.
@ -2168,7 +2175,8 @@ class Collection(common.BaseObject):
user_fields={'cursor': {'firstBatch': 1}})
return self.__database.client._retryable_read(
cmd.get_cursor, cmd.get_read_preference(session), session,
retryable=not cmd._performs_write)
retryable=not cmd._performs_write,
pin=self.database.client._should_pin_cursor(session))
def aggregate(self, pipeline, session=None, **kwargs):
"""Perform an aggregation using the aggregation framework on this

View File

@ -17,13 +17,14 @@
from collections import deque
from bson import _convert_raw_document_lists_to_streams
from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
NotMasterError,
OperationFailure)
from pymongo.message import (_CursorAddress,
_GetMore,
_RawBatchGetMore)
from pymongo.response import PinnedResponse
class CommandCursor(object):
@ -37,6 +38,7 @@ class CommandCursor(object):
The parameter 'retrieved' is unused.
"""
self.__sock_mgr = None
self.__collection = collection
self.__id = cursor_info['id']
self.__data = deque(cursor_info['firstBatch'])
@ -75,11 +77,15 @@ class CommandCursor(object):
self.__address, self.__collection.full_name)
if synchronous:
self.__collection.database.client._close_cursor_now(
self.__id, address, session=self.__session)
self.__id, address, session=self.__session,
sock_mgr=self.__sock_mgr)
else:
# The cursor will be closed later in a different session.
self.__collection.database.client._close_cursor(
self.__id, address)
if self.__sock_mgr:
self.__sock_mgr.close()
self.__sock_mgr = None
self.__end_session(synchronous)
def __end_session(self, synchronous):
@ -127,52 +133,58 @@ class CommandCursor(object):
changeStream aggregate or getMore."""
return self.__postbatchresumetoken
def _maybe_pin_connection(self, sock_info):
client = self.__collection.database.client
if not client._should_pin_cursor(self.__session):
return
if not self.__sock_mgr:
sock_mgr = _SocketManager(sock_info, False)
# Ensure the connection gets returned when the entire result is
# returned in the first batch.
if self.__id == 0:
sock_mgr.close()
else:
self.__sock_mgr = sock_mgr
def __send_message(self, operation):
"""Send a getmore message and handle the response.
"""
def kill():
self.__killed = True
self.__end_session(True)
client = self.__collection.database.client
try:
response = client._run_operation_with_response(
response = client._run_operation(
operation, self._unpack_response, address=self.__address)
except OperationFailure:
kill()
raise
except NotMasterError:
# Don't send kill cursors to another server after a "not master"
# error. It's completely pointless.
kill()
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS:
# Don't send killCursors because the cursor is already closed.
self.__killed = True
# Return the session and pinned connection, if necessary.
self.close()
raise
except ConnectionFailure:
# Don't try to send kill cursors on another socket
# or to another server. It can cause a _pinValue
# assertion on some server releases if we get here
# due to a socket timeout.
kill()
# Don't send killCursors because the cursor is already closed.
self.__killed = True
# Return the session and pinned connection, if necessary.
self.close()
raise
except Exception:
# Close the cursor
self.__die()
self.close()
raise
from_command = response.from_command
reply = response.data
docs = response.docs
if from_command:
cursor = docs[0]['cursor']
if isinstance(response, PinnedResponse):
if not self.__sock_mgr:
self.__sock_mgr = _SocketManager(response.socket_info,
response.more_to_come)
if response.from_command:
cursor = response.docs[0]['cursor']
documents = cursor['nextBatch']
self.__postbatchresumetoken = cursor.get('postBatchResumeToken')
self.__id = cursor['id']
else:
documents = docs
self.__id = reply.cursor_id
documents = response.docs
self.__id = response.data.cursor_id
if self.__id == 0:
kill()
self.__die(True)
self.__data = deque(documents)
def _unpack_response(self, response, cursor_id, codec_options,
@ -203,10 +215,9 @@ class CommandCursor(object):
self.__session,
self.__collection.database.client,
self.__max_await_time_ms,
False))
self.__sock_mgr, False))
else: # Cursor id is zero nothing else to return
self.__killed = True
self.__end_session(True)
self.__die(True)
return len(self.__data)

View File

@ -15,6 +15,7 @@
"""Cursor class to iterate over Mongo query results."""
import copy
import threading
import warnings
from collections import deque
@ -27,7 +28,6 @@ from pymongo.common import validate_boolean, validate_is_mapping
from pymongo.collation import validate_collation_or_none
from pymongo.errors import (ConnectionFailure,
InvalidOperation,
NotMasterError,
OperationFailure)
from pymongo.message import (_CursorAddress,
_GetMore,
@ -35,7 +35,37 @@ from pymongo.message import (_CursorAddress,
_Query,
_RawBatchQuery)
from pymongo.monitoring import ConnectionClosedReason
from pymongo.response import PinnedResponse
# These errors mean that the server has already killed the cursor so there is
# no need to send killCursors.
_CURSOR_CLOSED_ERRORS = frozenset([
43, # CursorNotFound
50, # MaxTimeMSExpired
175, # QueryPlanKilled
237, # CursorKilled
# On a tailable cursor, the following errors mean the capped collection
# rolled over.
# MongoDB 2.6:
# {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0}
28617,
# MongoDB 3.0:
# {'$err': 'getMore executor error: UnknownError no details available',
# 'code': 17406, 'ok': 0}
17406,
# MongoDB 3.2 + 3.4:
# {'ok': 0.0, 'errmsg': 'GetMore command executor error:
# CappedPositionLost: CollectionScan died due to failure to restore
# tailable cursor position. Last seen record id: RecordId(3)',
# 'code': 96}
96,
# MongoDB 3.6+:
# {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to
# restore tailable cursor position. Last seen record id: RecordId(3)"',
# 'code': 136, 'codeName': 'CappedPositionLost'}
136,
])
_QUERY_OPTIONS = {
"tailable_cursor": 2,
@ -78,14 +108,14 @@ class CursorType(object):
# This has to be an old style class due to
# http://bugs.jython.org/issue1057
class _ExhaustManager:
class _SocketManager:
"""Used with exhaust cursors to ensure the socket is returned.
"""
def __init__(self, sock, pool, more_to_come):
def __init__(self, sock, more_to_come):
self.sock = sock
self.pool = pool
self.more_to_come = more_to_come
self.__closed = False
self.lock = threading.Lock()
def __del__(self):
self.close()
@ -98,8 +128,8 @@ class _ExhaustManager:
"""
if not self.__closed:
self.__closed = True
self.pool.return_socket(self.sock)
self.sock, self.pool = None, None
self.sock.unpin()
self.sock = None
class Cursor(object):
@ -128,7 +158,7 @@ class Cursor(object):
# an error to avoid attribute errors during garbage collection.
self.__id = None
self.__exhaust = False
self.__exhaust_mgr = None
self.__sock_mgr = None
self.__killed = False
if session:
@ -319,24 +349,26 @@ class Cursor(object):
self.__killed = True
if self.__id and not already_killed:
if self.__exhaust and self.__exhaust_mgr:
if self.__exhaust and self.__sock_mgr:
# If this is an exhaust cursor and we haven't completely
# exhausted the result set we *must* close the socket
# to stop the server from sending more data.
self.__exhaust_mgr.sock.close_socket(
self.__sock_mgr.sock.close_socket(
ConnectionClosedReason.ERROR)
else:
address = _CursorAddress(
self.__address, self.__collection.full_name)
if synchronous:
self.__collection.database.client._close_cursor_now(
self.__id, address, session=self.__session)
self.__id, address, session=self.__session,
sock_mgr=self.__sock_mgr)
else:
# The cursor will be closed later in a different session.
self.__collection.database.client._close_cursor(
self.__id, address)
if self.__exhaust and self.__exhaust_mgr:
self.__exhaust_mgr.close()
if self.__sock_mgr:
self.__sock_mgr.close()
self.__sock_mgr = None
if self.__session and not self.__explicit_session:
self.__session._end_session(lock=synchronous)
self.__session = None
@ -1004,53 +1036,35 @@ class Cursor(object):
"exhaust cursors do not support auto encryption")
try:
response = client._run_operation_with_response(
operation, self._unpack_response, exhaust=self.__exhaust,
address=self.__address)
except OperationFailure:
self.__killed = True
# Make sure exhaust socket is returned immediately, if necessary.
self.__die()
response = client._run_operation(
operation, self._unpack_response, address=self.__address)
except OperationFailure as exc:
if exc.code in _CURSOR_CLOSED_ERRORS or self.__exhaust:
# Don't send killCursors because the cursor is already closed.
self.__killed = True
self.close()
# If this is a tailable cursor the error is likely
# due to capped collection roll over. Setting
# self.__killed to True ensures Cursor.alive will be
# False. No need to re-raise.
if self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]:
if (exc.code in _CURSOR_CLOSED_ERRORS and
self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]):
return
raise
except NotMasterError:
# Don't send kill cursors to another server after a "not master"
# error. It's completely pointless.
self.__killed = True
# Make sure exhaust socket is returned immediately, if necessary.
self.__die()
raise
except ConnectionFailure:
# Don't try to send kill cursors on another socket
# or to another server. It can cause a _pinValue
# assertion on some server releases if we get here
# due to a socket timeout.
# Don't send killCursors because the cursor is already closed.
self.__killed = True
self.__die()
self.close()
raise
except Exception:
# Close the cursor
self.__die()
self.close()
raise
self.__address = response.address
if self.__exhaust:
# 'response' is an ExhaustResponse.
if not self.__exhaust_mgr:
self.__exhaust_mgr = _ExhaustManager(response.socket_info,
response.pool,
response.more_to_come)
else:
self.__exhaust_mgr.update_exhaust(response.more_to_come)
if isinstance(response, PinnedResponse):
if not self.__sock_mgr:
self.__sock_mgr = _SocketManager(response.socket_info,
response.more_to_come)
cmd_name = operation.name
docs = response.docs
@ -1078,7 +1092,6 @@ class Cursor(object):
self.__retrieved += response.data.number_returned
if self.__id == 0:
self.__killed = True
# Don't wait for garbage collection to call __del__, return the
# socket and the session to the pool now.
self.__die()
@ -1132,7 +1145,8 @@ class Cursor(object):
self.__collation,
self.__session,
self.__collection.database.client,
self.__allow_disk_use)
self.__allow_disk_use,
self.__exhaust)
self.__send_message(q)
elif self.__id: # Get More
if self.__limit:
@ -1151,7 +1165,8 @@ class Cursor(object):
self.__session,
self.__collection.database.client,
self.__max_await_time_ms,
self.__exhaust_mgr)
self.__sock_mgr,
self.__exhaust)
self.__send_message(g)
return len(self.__data)

View File

@ -377,7 +377,8 @@ class Database(common.BaseObject):
user_fields={'cursor': {'firstBatch': 1}})
return self.client._retryable_read(
cmd.get_cursor, cmd.get_read_preference(s), s,
retryable=not cmd._performs_write)
retryable=not cmd._performs_write,
pin=self.client._should_pin_cursor(s))
def watch(self, pipeline=None, full_document=None, resume_after=None,
max_await_time_ms=None, batch_size=None, collation=None,
@ -636,7 +637,7 @@ class Database(common.BaseObject):
sock_info, cmd, slave_okay,
read_preference=read_preference,
session=tmp_session)["cursor"]
return CommandCursor(
cmd_cursor = CommandCursor(
coll,
cursor,
sock_info.address,
@ -656,7 +657,9 @@ class Database(common.BaseObject):
("pipeline", pipeline),
("cursor", kwargs.get("cursor", {}))])
cursor = self._command(sock_info, cmd, slave_okay)["cursor"]
return CommandCursor(coll, cursor, sock_info.address)
cmd_cursor = CommandCursor(coll, cursor, sock_info.address)
cmd_cursor._maybe_pin_connection(sock_info)
return cmd_cursor
def list_collections(self, session=None, filter=None, **kwargs):
"""Get a cursor over the collectons of this database.
@ -688,7 +691,8 @@ class Database(common.BaseObject):
**kwargs)
return self.__client._retryable_read(
_cmd, read_pref, session)
_cmd, read_pref, session,
pin=self.client._should_pin_cursor(session))
def list_collection_names(self, session=None, filter=None, **kwargs):
"""Get a list of all the collection names in this database.

View File

@ -234,16 +234,17 @@ class _Query(object):
__slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec',
'fields', 'codec_options', 'read_preference', 'limit',
'batch_size', 'name', 'read_concern', 'collation',
'session', 'client', 'allow_disk_use', '_as_command')
'session', 'client', 'allow_disk_use', '_as_command',
'exhaust')
# For compatibility with the _GetMore class.
exhaust_mgr = None
sock_mgr = None
cursor_id = None
def __init__(self, flags, db, coll, ntoskip, spec, fields,
codec_options, read_preference, limit,
batch_size, read_concern, collation, session, client,
allow_disk_use):
allow_disk_use, exhaust):
self.flags = flags
self.db = db
self.coll = coll
@ -261,13 +262,14 @@ class _Query(object):
self.allow_disk_use = allow_disk_use
self.name = 'find'
self._as_command = None
self.exhaust = exhaust
def namespace(self):
return "%s.%s" % (self.db, self.coll)
def use_command(self, sock_info, exhaust):
def use_command(self, sock_info):
use_find_cmd = False
if sock_info.max_wire_version >= 4 and not exhaust:
if sock_info.max_wire_version >= 4 and not self.exhaust:
use_find_cmd = True
elif sock_info.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
@ -375,13 +377,13 @@ class _GetMore(object):
__slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms',
'codec_options', 'read_preference', 'session', 'client',
'exhaust_mgr', '_as_command')
'sock_mgr', '_as_command', 'exhaust')
name = 'getMore'
def __init__(self, db, coll, ntoreturn, cursor_id, codec_options,
read_preference, session, client, max_await_time_ms,
exhaust_mgr):
sock_mgr, exhaust):
self.db = db
self.coll = coll
self.ntoreturn = ntoreturn
@ -391,15 +393,16 @@ class _GetMore(object):
self.session = session
self.client = client
self.max_await_time_ms = max_await_time_ms
self.exhaust_mgr = exhaust_mgr
self.sock_mgr = sock_mgr
self._as_command = None
self.exhaust = exhaust
def namespace(self):
return "%s.%s" % (self.db, self.coll)
def use_command(self, sock_info, exhaust):
def use_command(self, sock_info):
use_cmd = False
if sock_info.max_wire_version >= 4 and not exhaust:
if sock_info.max_wire_version >= 4 and not self.exhaust:
use_cmd = True
elif sock_info.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
@ -440,7 +443,7 @@ class _GetMore(object):
if use_cmd:
spec = self.as_command(sock_info)[0]
if sock_info.op_msg_enabled:
if self.exhaust_mgr:
if self.sock_mgr:
flags = _OpMsg.EXHAUST_ALLOWED
else:
flags = 0
@ -456,23 +459,23 @@ class _GetMore(object):
class _RawBatchQuery(_Query):
def use_command(self, socket_info, exhaust):
def use_command(self, sock_info):
# Compatibility checks.
super(_RawBatchQuery, self).use_command(socket_info, exhaust)
if socket_info.max_wire_version >= 8:
super(_RawBatchQuery, self).use_command(sock_info)
if sock_info.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif socket_info.op_msg_enabled and not exhaust:
elif sock_info.op_msg_enabled and not self.exhaust:
return True
return False
class _RawBatchGetMore(_GetMore):
def use_command(self, socket_info, exhaust):
if socket_info.max_wire_version >= 8:
def use_command(self, sock_info):
if sock_info.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif socket_info.op_msg_enabled and not exhaust:
elif sock_info.op_msg_enabled and not self.exhaust:
return True
return False
@ -1033,20 +1036,23 @@ class _BulkWriteContext(object):
cmd[self.field] = docs
self.listeners.publish_command_start(
cmd, self.db_name,
request_id, self.sock_info.address, self.op_id)
request_id, self.sock_info.address, self.op_id,
self.sock_info.service_id)
return cmd
def _succeed(self, request_id, reply, duration):
"""Publish a CommandSucceededEvent."""
self.listeners.publish_command_success(
duration, reply, self.name,
request_id, self.sock_info.address, self.op_id)
request_id, self.sock_info.address, self.op_id,
self.sock_info.service_id)
def _fail(self, request_id, failure, duration):
"""Publish a CommandFailedEvent."""
self.listeners.publish_command_failure(
duration, failure, self.name,
request_id, self.sock_info.address, self.op_id)
request_id, self.sock_info.address, self.op_id,
self.sock_info.service_id)
# From the Client Side Encryption spec:
@ -1686,7 +1692,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
query = _Query(
0, db, coll, 0, query, None, codec_options,
read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None,
None, None)
None, None, False)
name = next(iter(cmd))
publish = listeners.enabled_for_commands
@ -1698,7 +1704,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
if publish:
encoding_duration = datetime.datetime.now() - start
listeners.publish_command_start(
cmd, db, request_id, sock_info.address)
cmd, db, request_id, sock_info.address,
service_id=sock_info.service_id)
start = datetime.datetime.now()
sock_info.send_message(msg, max_doc_size)
@ -1713,7 +1720,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
else:
failure = _convert_exception(exc)
listeners.publish_command_failure(
duration, failure, name, request_id, sock_info.address)
duration, failure, name, request_id, sock_info.address,
service_id=sock_info.service_id)
raise
# listIndexes
if 'cursor' in cmd:
@ -1732,6 +1740,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
listeners.publish_command_success(
duration, result, name, request_id, sock_info.address)
duration, result, name, request_id, sock_info.address,
service_id=sock_info.service_id)
return result

View File

@ -60,6 +60,7 @@ from pymongo.errors import (AutoReconnect,
OperationFailure,
PyMongoError,
ServerSelectionTimeoutError)
from pymongo.pool import ConnectionClosedReason
from pymongo.read_preferences import ReadPreference
from pymongo.server_selectors import (writable_preferred_server_selector,
writable_server_selector)
@ -1160,10 +1161,20 @@ class MongoClient(common.BaseObject):
return self._topology
@contextlib.contextmanager
def _get_socket(self, server, session, exhaust=False):
def _get_socket(self, server, session, pin=False):
in_txn = session and session.in_transaction
with _MongoClientErrorHandler(self, server, session) as err_handler:
# Reuse the pinned connection, if it exists.
if in_txn and session._pinned_connection:
yield session._pinned_connection
return
with server.get_socket(
self.__all_credentials, checkout=exhaust) as sock_info:
self.__all_credentials, checkout=pin,
handler=err_handler) as sock_info:
# Pin this session to the selected server or connection.
if (in_txn and server.description.server_type in (
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)):
session._pin(server, sock_info)
err_handler.contribute_socket(sock_info)
if (self._encrypter and
not self._encrypter._bypass_auto_encryption and
@ -1186,6 +1197,8 @@ class MongoClient(common.BaseObject):
"""
try:
topology = self._get_topology()
if session and not session.in_transaction:
session._transaction.reset()
address = address or (session and session._pinned_address)
if address:
# We're running a getMore or this session is pinned to a mongos.
@ -1195,12 +1208,6 @@ class MongoClient(common.BaseObject):
% address)
else:
server = topology.select_server(server_selector)
# Pin this session to the selected server if it's performing a
# sharded transaction.
if (server.description.server_type in (
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)
and session and session.in_transaction):
session._pin(server)
return server
except PyMongoError as exc:
# Server selection errors in a transaction are transient.
@ -1214,8 +1221,7 @@ class MongoClient(common.BaseObject):
return self._get_socket(server, session)
@contextlib.contextmanager
def _slaveok_for_server(self, read_preference, server, session,
exhaust=False):
def _slaveok_for_server(self, read_preference, server, session, pin=False):
assert read_preference is not None, "read_preference must not be None"
# Get a socket for a server matching the read preference, and yield
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
@ -1226,7 +1232,7 @@ class MongoClient(common.BaseObject):
topology = self._get_topology()
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
with self._get_socket(server, session, exhaust=exhaust) as sock_info:
with self._get_socket(server, session, pin=pin) as sock_info:
slave_ok = (single and not sock_info.is_mongos) or (
read_preference != ReadPreference.PRIMARY)
yield sock_info, slave_ok
@ -1249,47 +1255,41 @@ class MongoClient(common.BaseObject):
read_preference != ReadPreference.PRIMARY)
yield sock_info, slave_ok
def _run_operation_with_response(self, operation, unpack_res,
exhaust=False, address=None):
def _should_pin_cursor(self, session):
return (self.__options.load_balanced and
not (session and session.in_transaction))
def _run_operation(self, operation, unpack_res, pin=False, address=None):
"""Run a _Query/_GetMore operation and return a Response.
:Parameters:
- `operation`: a _Query or _GetMore object.
- `unpack_res`: A callable that decodes the wire protocol response.
- `exhaust` (optional): If True, the socket used stays checked out.
It is returned along with its Pool in the Response.
- `address` (optional): Optional address when sending a message
to a specific server, used for getMore.
"""
if operation.exhaust_mgr:
pin = self._should_pin_cursor(operation.session) or operation.exhaust
if operation.sock_mgr:
server = self._select_server(
operation.read_preference, operation.session, address=address)
with _MongoClientErrorHandler(
self, server, operation.session) as err_handler:
err_handler.contribute_socket(operation.exhaust_mgr.sock)
return server.run_operation_with_response(
operation.exhaust_mgr.sock,
operation,
True,
self._event_listeners,
exhaust,
unpack_res)
with operation.sock_mgr.lock:
with _MongoClientErrorHandler(
self, server, operation.session) as err_handler:
err_handler.contribute_socket(operation.sock_mgr.sock)
return server.run_operation(
operation.sock_mgr.sock, operation, True,
self._event_listeners, pin, unpack_res)
def _cmd(session, server, sock_info, slave_ok):
return server.run_operation_with_response(
sock_info,
operation,
slave_ok,
self._event_listeners,
exhaust,
return server.run_operation(
sock_info, operation, slave_ok, self._event_listeners, pin,
unpack_res)
return self._retryable_read(
_cmd, operation.read_preference, operation.session,
address=address,
retryable=isinstance(operation, message._Query),
exhaust=exhaust)
address=address, retryable=isinstance(operation, message._Query),
pin=pin)
def _retry_with_session(self, retryable, func, session, bulk):
"""Execute an operation with at most one consecutive retries
@ -1361,7 +1361,7 @@ class MongoClient(common.BaseObject):
last_error = exc
def _retryable_read(self, func, read_pref, session, address=None,
retryable=True, exhaust=False):
retryable=True, pin=False):
"""Execute an operation with at most one consecutive retries
Returns func()'s return value on success. On error retries the same
@ -1381,9 +1381,9 @@ class MongoClient(common.BaseObject):
read_pref, session, address=address)
if not server.description.retryable_reads_supported:
retryable = False
with self._slaveok_for_server(read_pref, server, session,
exhaust=exhaust) as (sock_info,
slave_ok):
with self._slaveok_for_server(
read_pref, server, session, pin=pin) as (
sock_info, slave_ok):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
@ -1496,7 +1496,8 @@ class MongoClient(common.BaseObject):
"""
self.__kill_cursors_queue.append((address, [cursor_id]))
def _close_cursor_now(self, cursor_id, address=None, session=None):
def _close_cursor_now(self, cursor_id, address=None, session=None,
sock_mgr=None):
"""Send a kill cursors message with the given id.
The cursor is closed synchronously on the current thread.
@ -1505,16 +1506,20 @@ class MongoClient(common.BaseObject):
raise TypeError("cursor_id must be an instance of int")
try:
self._kill_cursors(
[cursor_id], address, self._get_topology(), session)
if sock_mgr:
with sock_mgr.lock:
# Cursor is pinned to LB outside of a transaction.
self._kill_cursor_impl(
[cursor_id], address, session, sock_mgr.sock)
else:
self._kill_cursors(
[cursor_id], address, self._get_topology(), session)
except PyMongoError:
# Make another attempt to kill the cursor later.
self.__kill_cursors_queue.append((address, [cursor_id]))
def _kill_cursors(self, cursor_ids, address, topology, session):
"""Send a kill cursors message with the given ids."""
listeners = self._event_listeners
publish = listeners.enabled_for_commands
if address:
# address could be a tuple or _CursorAddress, but
# select_server_by_address needs (host, port).
@ -1523,49 +1528,55 @@ class MongoClient(common.BaseObject):
# Application called close_cursor() with no address.
server = topology.select_server(writable_server_selector)
with self._get_socket(server, session) as sock_info:
self._kill_cursor_impl(cursor_ids, address, session, sock_info)
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info):
listeners = self._event_listeners
publish = listeners.enabled_for_commands
try:
namespace = address.namespace
db, coll = namespace.split('.', 1)
except AttributeError:
namespace = None
db = coll = "OP_KILL_CURSORS"
spec = SON([('killCursors', coll), ('cursors', cursor_ids)])
with server.get_socket(self.__all_credentials) as sock_info:
if sock_info.max_wire_version >= 4 and namespace is not None:
sock_info.command(db, spec, session=session, client=self)
else:
if publish:
start = datetime.datetime.now()
request_id, msg = message.kill_cursors(cursor_ids)
if publish:
duration = datetime.datetime.now() - start
# Here and below, address could be a tuple or
# _CursorAddress. We always want to publish a
# tuple to match the rest of the monitoring
# API.
listeners.publish_command_start(
spec, db, request_id, tuple(address))
start = datetime.datetime.now()
try:
sock_info.send_message(msg, 0)
except Exception as exc:
if publish:
dur = ((datetime.datetime.now() - start) + duration)
listeners.publish_command_failure(
dur, message._convert_exception(exc),
'killCursors', request_id,
tuple(address))
raise
if sock_info.max_wire_version >= 4 and namespace is not None:
sock_info.command(db, spec, session=session, client=self)
else:
if publish:
start = datetime.datetime.now()
request_id, msg = message.kill_cursors(cursor_ids)
if publish:
duration = datetime.datetime.now() - start
# Here and below, address could be a tuple or
# _CursorAddress. We always want to publish a
# tuple to match the rest of the monitoring
# API.
listeners.publish_command_start(
spec, db, request_id, tuple(address),
service_id=sock_info.service_id)
start = datetime.datetime.now()
try:
sock_info.send_message(msg, 0)
except Exception as exc:
if publish:
duration = ((datetime.datetime.now() - start) + duration)
# OP_KILL_CURSORS returns no reply, fake one.
reply = {'cursorsUnknown': cursor_ids, 'ok': 1}
listeners.publish_command_success(
duration, reply, 'killCursors', request_id,
tuple(address))
dur = ((datetime.datetime.now() - start) + duration)
listeners.publish_command_failure(
dur, message._convert_exception(exc),
'killCursors', request_id,
tuple(address), service_id=sock_info.service_id)
raise
if publish:
duration = ((datetime.datetime.now() - start) + duration)
# OP_KILL_CURSORS returns no reply, fake one.
reply = {'cursorsUnknown': cursor_ids, 'ok': 1}
listeners.publish_command_success(
duration, reply, 'killCursors', request_id,
tuple(address), service_id=sock_info.service_id)
def _process_kill_cursors(self):
"""Process any pending kill cursors requests."""
@ -1966,7 +1977,8 @@ def _add_retryable_write_error(exc, max_wire_version):
class _MongoClientErrorHandler(object):
"""Handle errors raised when executing an operation."""
__slots__ = ('client', 'server_address', 'session', 'max_wire_version',
'sock_generation', 'completed_handshake', 'service_id')
'sock_generation', 'completed_handshake', 'service_id',
'handled')
def __init__(self, client, server, session):
self.client = client
@ -1980,6 +1992,7 @@ class _MongoClientErrorHandler(object):
self.sock_generation = server.pool.gen.get_overall()
self.completed_handshake = False
self.service_id = None
self.handled = False
def contribute_socket(self, sock_info):
"""Provide socket information to the error handler."""
@ -1988,13 +2001,10 @@ class _MongoClientErrorHandler(object):
self.service_id = sock_info.service_id
self.completed_handshake = True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
def handle(self, exc_type, exc_val):
if self.handled or exc_type is None:
return
self.handled = True
if self.session:
if issubclass(exc_type, ConnectionFailure):
if self.session.in_transaction:
@ -2010,3 +2020,9 @@ class _MongoClientErrorHandler(object):
exc_val, self.max_wire_version, self.sock_generation,
self.completed_handshake, self.service_id)
self.client._topology.handle_error(self.server_address, err_ctx)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return self.handle(exc_type, exc_val)

View File

@ -1310,7 +1310,8 @@ class _EventListeners(object):
self.__topology_listeners[:])
def publish_command_start(self, command, database_name,
request_id, connection_id, op_id=None):
request_id, connection_id, op_id=None,
service_id=None):
"""Publish a CommandStartedEvent to all command listeners.
:Parameters:
@ -1321,11 +1322,13 @@ class _EventListeners(object):
- `connection_id`: The address (host, port) of the server this
command was sent to.
- `op_id`: The (optional) operation id for this operation.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
if op_id is None:
op_id = request_id
event = CommandStartedEvent(
command, database_name, request_id, connection_id, op_id)
command, database_name, request_id, connection_id, op_id,
service_id=service_id)
for subscriber in self.__command_listeners:
try:
subscriber.started(event)
@ -1333,7 +1336,8 @@ class _EventListeners(object):
_handle_exception()
def publish_command_success(self, duration, reply, command_name,
request_id, connection_id, op_id=None):
request_id, connection_id, op_id=None,
service_id=None):
"""Publish a CommandSucceededEvent to all command listeners.
:Parameters:
@ -1344,11 +1348,13 @@ class _EventListeners(object):
- `connection_id`: The address (host, port) of the server this
command was sent to.
- `op_id`: The (optional) operation id for this operation.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
if op_id is None:
op_id = request_id
event = CommandSucceededEvent(
duration, reply, command_name, request_id, connection_id, op_id)
duration, reply, command_name, request_id, connection_id, op_id,
service_id)
for subscriber in self.__command_listeners:
try:
subscriber.succeeded(event)
@ -1356,7 +1362,8 @@ class _EventListeners(object):
_handle_exception()
def publish_command_failure(self, duration, failure, command_name,
request_id, connection_id, op_id=None):
request_id, connection_id, op_id=None,
service_id=None):
"""Publish a CommandFailedEvent to all command listeners.
:Parameters:
@ -1368,11 +1375,13 @@ class _EventListeners(object):
- `connection_id`: The address (host, port) of the server this
command was sent to.
- `op_id`: The (optional) operation id for this operation.
- `service_id`: The service_id this command was sent to, or ``None``.
"""
if op_id is None:
op_id = request_id
event = CommandFailedEvent(
duration, failure, command_name, request_id, connection_id, op_id)
duration, failure, command_name, request_id, connection_id, op_id,
service_id=service_id)
for subscriber in self.__command_listeners:
try:
subscriber.failed(event)

View File

@ -134,7 +134,8 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos,
if publish:
encoding_duration = datetime.datetime.now() - start
listeners.publish_command_start(orig, dbname, request_id, address)
listeners.publish_command_start(orig, dbname, request_id, address,
service_id=sock_info.service_id)
start = datetime.datetime.now()
try:
@ -164,12 +165,14 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos,
else:
failure = message._convert_exception(exc)
listeners.publish_command_failure(
duration, failure, name, request_id, address)
duration, failure, name, request_id, address,
service_id=sock_info.service_id)
raise
if publish:
duration = (datetime.datetime.now() - start) + encoding_duration
listeners.publish_command_success(
duration, response_doc, name, request_id, address)
duration, response_doc, name, request_id, address,
service_id=sock_info.service_id)
if client and client._encrypter and reply:
decrypted = client._encrypter.decrypt(reply.raw_command_response())

View File

@ -23,6 +23,7 @@ import socket
import sys
import threading
import time
import weakref
from bson import DEFAULT_CODEC_OPTIONS
from bson.son import SON
@ -503,6 +504,7 @@ class SocketInfo(object):
- `id`: the id of this socket in it's pool
"""
def __init__(self, sock, pool, address, id):
self.pool_ref = weakref.ref(pool)
self.sock = sock
self.address = address
self.id = id
@ -541,6 +543,17 @@ class SocketInfo(object):
self.more_to_come = False
# For load balancer support.
self.service_id = None
# When executing a transaction in load balancing mode, this flag is
# set to true to indicate that the session now owns the connection.
self.pinned = False
def unpin(self):
self.pinned = False
pool = self.pool_ref()
if pool:
pool.return_socket(self)
else:
self.close_socket(ConnectionClosedReason.STALE)
def hello_cmd(self):
if self.opts.server_api:
@ -712,7 +725,7 @@ class SocketInfo(object):
unacknowledged=unacknowledged,
user_fields=user_fields,
exhaust_allowed=exhaust_allowed)
except OperationFailure:
except (OperationFailure, NotMasterError):
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
except BaseException as error:
@ -898,7 +911,13 @@ class SocketInfo(object):
# ...) is called in Python code, which experiences the signal as a
# KeyboardInterrupt from the start, rather than as an initial
# socket.error, so we catch that, close the socket, and reraise it.
self.close_socket(ConnectionClosedReason.ERROR)
#
# The connection closed event will be emitted later in return_socket.
if self.ready:
reason = None
else:
reason = ConnectionClosedReason.ERROR
self.close_socket(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, _SSLError)):
_raise_connection_failure(self.address, error)
@ -1152,6 +1171,10 @@ class Pool:
self.address, self.opts.non_default_options)
# Similar to active_sockets but includes threads in the wait queue.
self.operation_count = 0
# Retain references to pinned connections to prevent the CPython GC
# from thinking that a cursor's pinned connection can be GC'd when the
# cursor is GC'd (see PYTHON-2751).
self.__pinned_sockets = set()
def ready(self):
old_state, self.state = self.state, PoolState.READY
@ -1328,7 +1351,7 @@ class Pool:
return sock_info
@contextlib.contextmanager
def get_socket(self, all_credentials, checkout=False):
def get_socket(self, all_credentials, checkout=False, handler=None):
"""Get a socket from the pool. Use with a "with" statement.
Returns a :class:`SocketInfo` object wrapping a connected
@ -1349,12 +1372,15 @@ class Pool:
:Parameters:
- `all_credentials`: dict, maps auth source to MongoCredential.
- `checkout` (optional): keep socket checked out.
- `handler` (optional): A _MongoClientErrorHandler.
"""
listeners = self.opts.event_listeners
if self.enabled_for_cmap:
listeners.publish_connection_check_out_started(self.address)
sock_info = self._get_socket(all_credentials)
if checkout:
self.__pinned_sockets.add(sock_info)
if self.enabled_for_cmap:
listeners.publish_connection_checked_out(
@ -1362,11 +1388,23 @@ class Pool:
try:
yield sock_info
except:
# Exception in caller. Decrement semaphore.
self.return_socket(sock_info)
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
# connection and it is responsible for checking the connection
# back into the pool.
pinned = sock_info.pinned
if handler:
# Perform SDAM error handling rules while the connection is
# still checked out.
exc_type, exc_val, _ = sys.exc_info()
handler.handle(exc_type, exc_val)
if not pinned:
self.return_socket(sock_info)
raise
else:
if not checkout:
if sock_info.pinned:
self.__pinned_sockets.add(sock_info)
elif not checkout:
self.return_socket(sock_info)
def _raise_if_not_ready(self, emit_event):
@ -1487,6 +1525,8 @@ class Pool:
:Parameters:
- `sock_info`: The socket to check into the pool.
"""
self.__pinned_sockets.discard(sock_info)
sock_info.pinned = False
listeners = self.opts.event_listeners
if self.enabled_for_cmap:
listeners.publish_connection_checked_in(self.address, sock_info.id)
@ -1495,7 +1535,13 @@ class Pool:
else:
if self.closed:
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
elif not sock_info.closed:
elif sock_info.closed:
# CMAP requires the closed event be emitted after the check in.
if self.enabled_for_cmap:
listeners.publish_connection_closed(
self.address, sock_info.id,
ConnectionClosedReason.ERROR)
else:
with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().

View File

@ -68,10 +68,10 @@ class Response(object):
return self._docs
class ExhaustResponse(Response):
__slots__ = ('_socket_info', '_pool', '_more_to_come')
class PinnedResponse(Response):
__slots__ = ('_socket_info', '_more_to_come')
def __init__(self, data, address, socket_info, pool, request_id, duration,
def __init__(self, data, address, socket_info, request_id, duration,
from_command, docs, more_to_come):
"""Represent a response to an exhaust cursor's initial query.
@ -87,13 +87,12 @@ class ExhaustResponse(Response):
- `more_to_come`: Bool indicating whether cursor is ready to be
exhausted.
"""
super(ExhaustResponse, self).__init__(data,
address,
request_id,
duration,
from_command, docs)
super(PinnedResponse, self).__init__(data,
address,
request_id,
duration,
from_command, docs)
self._socket_info = socket_info
self._pool = pool
self._more_to_come = more_to_come
@property
@ -106,11 +105,6 @@ class ExhaustResponse(Response):
"""
return self._socket_info
@property
def pool(self):
"""The Pool from which the SocketInfo came."""
return self._pool
@property
def more_to_come(self):
"""If true, server is ready to send batches on the socket until the

View File

@ -21,7 +21,7 @@ from bson import _decode_all_selective
from pymongo.errors import NotMasterError, OperationFailure
from pymongo.helpers import _check_command_response
from pymongo.message import _convert_exception, _OpMsg
from pymongo.response import Response, ExhaustResponse
from pymongo.response import Response, PinnedResponse
from pymongo.server_type import SERVER_TYPE
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
@ -68,14 +68,8 @@ class Server(object):
"""Check the server's state soon."""
self._monitor.request_check()
def run_operation_with_response(
self,
sock_info,
operation,
set_slave_okay,
listeners,
exhaust,
unpack_res):
def run_operation(self, sock_info, operation, set_slave_okay, listeners,
pin, unpack_res):
"""Run a _Query or _GetMore operation and return a Response object.
This method is used only to run _Query/_GetMore operations from
@ -87,7 +81,7 @@ class Server(object):
- `set_slave_okay`: Pass to operation.get_message.
- `all_credentials`: dict, maps auth source to MongoCredential.
- `listeners`: Instance of _EventListeners or None.
- `exhaust`: If True, then this is an exhaust cursor operation.
- `pin`: If True, then this is a pinned cursor operation.
- `unpack_res`: A callable that decodes the wire protocol response.
"""
duration = None
@ -95,9 +89,9 @@ class Server(object):
if publish:
start = datetime.now()
use_cmd = operation.use_command(sock_info, exhaust)
more_to_come = (operation.exhaust_mgr
and operation.exhaust_mgr.more_to_come)
use_cmd = operation.use_command(sock_info)
more_to_come = (operation.sock_mgr
and operation.sock_mgr.more_to_come)
if more_to_come:
request_id = 0
else:
@ -108,7 +102,8 @@ class Server(object):
if publish:
cmd, dbn = operation.as_command(sock_info)
listeners.publish_command_start(
cmd, dbn, request_id, sock_info.address)
cmd, dbn, request_id, sock_info.address,
service_id=sock_info.service_id)
start = datetime.now()
try:
@ -142,7 +137,8 @@ class Server(object):
failure = _convert_exception(exc)
listeners.publish_command_failure(
duration, failure, operation.name,
request_id, sock_info.address)
request_id, sock_info.address,
service_id=sock_info.service_id)
raise
if publish:
@ -163,7 +159,7 @@ class Server(object):
res["cursor"]["nextBatch"] = docs
listeners.publish_command_success(
duration, res, operation.name, request_id,
sock_info.address)
sock_info.address, service_id=sock_info.service_id)
# Decrypt response.
client = operation.client
@ -174,19 +170,20 @@ class Server(object):
docs = _decode_all_selective(
decrypted, operation.codec_options, user_fields)
if exhaust:
if pin:
if isinstance(reply, _OpMsg):
# In OP_MSG, the server keeps sending only if the
# more_to_come flag is set.
more_to_come = reply.more_to_come
else:
# In OP_REPLY, the server keeps sending until cursor_id is 0.
more_to_come = bool(reply.cursor_id)
response = ExhaustResponse(
more_to_come = bool(operation.exhaust and reply.cursor_id)
if operation.sock_mgr:
operation.sock_mgr.update_exhaust(more_to_come)
response = PinnedResponse(
data=reply,
address=self._description.address,
socket_info=sock_info,
pool=self._pool,
duration=duration,
request_id=request_id,
from_command=use_cmd,
@ -203,8 +200,8 @@ class Server(object):
return response
def get_socket(self, all_credentials, checkout=False):
return self.pool.get_socket(all_credentials, checkout)
def get_socket(self, all_credentials, checkout=False, handler=None):
return self.pool.get_socket(all_credentials, checkout, handler)
@property
def description(self):

View File

@ -19,8 +19,8 @@ import sys
sys.path[0:0] = [""]
from test import unittest
from test import unittest, IntegrationTest, client_context
from test.utils import get_pool
from test.unified_format import generate_test_classes
# Location of JSON test specifications.
@ -30,5 +30,23 @@ TEST_PATH = os.path.join(
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__))
class TestLB(IntegrationTest):
@client_context.require_load_balancer
def test_unpin_committed_transaction(self):
pool = get_pool(self.client)
with self.client.start_session() as session:
with session.start_transaction():
self.assertEqual(pool.active_sockets, 0)
self.db.test.insert_one({}, session=session)
self.assertEqual(pool.active_sockets, 1) # Pinned.
self.assertEqual(pool.active_sockets, 1) # Still pinned.
self.assertEqual(pool.active_sockets, 0) # Unpinned.
def test_client_can_be_reopened(self):
self.client.close()
self.db.test.find_one({})
if __name__ == "__main__":
unittest.main()

View File

@ -376,7 +376,7 @@
]
},
{
"description": "pinned connections are not returned after an network error during getMore",
"description": "pinned connections are returned after an network error during getMore",
"operations": [
{
"name": "failPoint",
@ -440,7 +440,7 @@
"object": "testRunner",
"arguments": {
"client": "client0",
"connections": 1
"connections": 0
}
},
{
@ -659,7 +659,7 @@
]
},
{
"description": "pinned connections are not returned to the pool after a non-network error on getMore",
"description": "pinned connections are returned to the pool after a non-network error on getMore",
"operations": [
{
"name": "failPoint",
@ -715,7 +715,7 @@
"object": "testRunner",
"arguments": {
"client": "client0",
"connections": 1
"connections": 0
}
},
{

View File

@ -366,9 +366,6 @@
{
"connectionCreatedEvent": {}
},
{
"poolClearedEvent": {}
},
{
"connectionClosedEvent": {
"reason": "error"
@ -378,6 +375,9 @@
"connectionCheckOutFailedEvent": {
"reason": "connectionError"
}
},
{
"poolClearedEvent": {}
}
]
}

View File

@ -40,7 +40,7 @@ class MockPool(Pool):
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
@contextlib.contextmanager
def get_socket(self, all_credentials, checkout=False):
def get_socket(self, all_credentials, checkout=False, handler=None):
client = self.client
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
if host_and_port in client.mock_down_hosts:
@ -51,7 +51,8 @@ class MockPool(Pool):
+ client.mock_members
+ client.mock_mongoses), "bad host: %s" % host_and_port
with Pool.get_socket(self, all_credentials, checkout) as sock_info:
with Pool.get_socket(
self, all_credentials, checkout, handler) as sock_info:
sock_info.mock_host = self.mock_host
sock_info.mock_port = self.mock_port
yield sock_info

View File

@ -1323,11 +1323,11 @@ class TestClient(IntegrationTest):
with self.assertRaises(AutoReconnect):
client = rs_client(connect=False,
serverSelectionTimeoutMS=100)
client._run_operation_with_response(
client._run_operation(
operation=message._GetMore('pymongo_test', 'collection',
101, 1234, client.codec_options,
ReadPreference.PRIMARY,
None, client, None, None),
None, client, None, None, False),
unpack_res=Cursor(
client.pymongo_test.collection)._unpack_response,
address=('not-a-member', 27017))
@ -1708,7 +1708,7 @@ class TestExhaustCursor(IntegrationTest):
cursor.next()
# Cause a network error.
sock_info = cursor._Cursor__exhaust_mgr.sock
sock_info = cursor._Cursor__sock_mgr.sock
sock_info.sock.close()
# A getmore fails.

View File

@ -20,6 +20,7 @@ import warnings
sys.path[0:0] = [""]
from bson.int64 import Int64
from bson.objectid import ObjectId
from bson.son import SON
from pymongo import CursorType, monitoring, InsertOne, UpdateOne, DeleteOne
@ -397,7 +398,8 @@ class TestCommandMonitoring(PyMongoTestCase):
def test_get_more_failure(self):
address = self.client.address
coll = self.client.pymongo_test.test
cursor_doc = {"id": 12345, "firstBatch": [], "ns": coll.full_name}
cursor_id = Int64(12345)
cursor_doc = {"id": cursor_id, "firstBatch": [], "ns": coll.full_name}
cursor = CommandCursor(coll, cursor_doc, address)
try:
next(cursor)
@ -410,7 +412,7 @@ class TestCommandMonitoring(PyMongoTestCase):
self.assertTrue(
isinstance(started, monitoring.CommandStartedEvent))
self.assertEqualCommand(
SON([('getMore', 12345),
SON([('getMore', cursor_id),
('collection', 'test')]),
started.command)
self.assertEqual('getMore', started.command_name)

View File

@ -322,10 +322,9 @@ class ReadPrefTester(MongoClient):
yield sock_info, slave_ok
@contextlib.contextmanager
def _slaveok_for_server(self, read_preference, server, session,
exhaust=False):
def _slaveok_for_server(self, read_preference, server, session, pin=False):
context = super(ReadPrefTester, self)._slaveok_for_server(
read_preference, server, session, exhaust=exhaust)
read_preference, server, session, pin=pin)
with context as (sock_info, slave_ok):
self.record_a_read(sock_info.address)
yield sock_info, slave_ok

View File

@ -519,6 +519,14 @@ class MatchEvaluatorUtil(object):
self.test.assertIsInstance(actual, type(expectation))
self.test.assertEqual(expectation, actual)
def assertHasServiceId(self, spec, actual):
if 'hasServiceId' in spec:
if spec.get('hasServiceId'):
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
def match_event(self, event_type, expectation, actual):
name, spec = next(iter(expectation.items()))
@ -543,24 +551,23 @@ class MatchEvaluatorUtil(object):
if database_name:
self.test.assertEqual(
database_name, actual.database_name)
self.assertHasServiceId(spec, actual)
elif name == 'commandSucceededEvent':
self.test.assertIsInstance(actual, CommandSucceededEvent)
reply = spec.get('reply')
if reply:
self.match_result(reply, actual.reply)
self.assertHasServiceId(spec, actual)
elif name == 'commandFailedEvent':
self.test.assertIsInstance(actual, CommandFailedEvent)
self.assertHasServiceId(spec, actual)
elif name == 'poolCreatedEvent':
self.test.assertIsInstance(actual, PoolCreatedEvent)
elif name == 'poolReadyEvent':
self.test.assertIsInstance(actual, PoolReadyEvent)
elif name == 'poolClearedEvent':
self.test.assertIsInstance(actual, PoolClearedEvent)
if spec.get('hasServiceId'):
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
self.assertHasServiceId(spec, actual)
elif name == 'poolClosedEvent':
self.test.assertIsInstance(actual, PoolClosedEvent)
elif name == 'connectionCreatedEvent':
@ -569,12 +576,14 @@ class MatchEvaluatorUtil(object):
self.test.assertIsInstance(actual, ConnectionReadyEvent)
elif name == 'connectionClosedEvent':
self.test.assertIsInstance(actual, ConnectionClosedEvent)
self.test.assertEqual(actual.reason, spec['reason'])
if 'reason' in spec:
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckOutStartedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
elif name == 'connectionCheckOutFailedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
self.test.assertEqual(actual.reason, spec['reason'])
if 'reason' in spec:
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckedOutEvent':
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
elif name == 'connectionCheckedInEvent':

View File

@ -267,7 +267,7 @@ class MockPool(object):
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def get_socket(self, all_credentials, checkout=False):
def get_socket(self, all_credentials, checkout=False, handler=None):
return MockSocketInfo()
def return_socket(self, *args, **kwargs):