PYTHON-2673 Connection pinning behavior for load balanced clusters (#630)
Tweak spec test because pymongo unpins cursors eagerly after errors. Tweak spec test for PoolClearedEvent ordering when MongoDB handshake fails (see DRIVERS-1785). Only skip killCursors for some error codes. Rely on SDAM error handling to close the connection after a state change error. Add service_id to various events. Retain reference to pinned sockets to prevent premptive closure by CPython's cyclic GC.
This commit is contained in:
parent
7a48831124
commit
c8f32a7a37
@ -161,11 +161,13 @@ class _AggregationCommand(object):
|
||||
}
|
||||
|
||||
# Create and return cursor instance.
|
||||
return self._cursor_class(
|
||||
cmd_cursor = self._cursor_class(
|
||||
self._cursor_collection(cursor), cursor, sock_info.address,
|
||||
batch_size=self._batch_size or 0,
|
||||
max_await_time_ms=self._max_await_time_ms,
|
||||
session=session, explicit_session=self._explicit_session)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
return cmd_cursor
|
||||
|
||||
|
||||
class _CollectionAggregationCommand(_AggregationCommand):
|
||||
|
||||
@ -181,7 +181,7 @@ class ChangeStream(object):
|
||||
|
||||
return self._client._retryable_read(
|
||||
cmd.get_cursor, self._target._read_preference_for(session),
|
||||
session)
|
||||
session, pin=self._client._should_pin_cursor(session))
|
||||
|
||||
def _create_cursor(self):
|
||||
with self._client._tmp_session(self._session, close=False) as s:
|
||||
|
||||
@ -108,6 +108,7 @@ from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from bson.timestamp import Timestamp
|
||||
|
||||
from pymongo.cursor import _SocketManager
|
||||
from pymongo.errors import (ConfigurationError,
|
||||
ConnectionFailure,
|
||||
InvalidOperation,
|
||||
@ -117,6 +118,7 @@ from pymongo.errors import (ConfigurationError,
|
||||
from pymongo.helpers import _RETRYABLE_ERROR_CODES
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference, _ServerMode
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
@ -292,6 +294,7 @@ class _Transaction(object):
|
||||
self.state = _TxnState.NONE
|
||||
self.sharded = False
|
||||
self.pinned_address = None
|
||||
self.sock_mgr = None
|
||||
self.recovery_token = None
|
||||
self.attempt = 0
|
||||
|
||||
@ -301,10 +304,29 @@ class _Transaction(object):
|
||||
def starting(self):
|
||||
return self.state == _TxnState.STARTING
|
||||
|
||||
@property
|
||||
def pinned_conn(self):
|
||||
if self.active() and self.sock_mgr:
|
||||
return self.sock_mgr.sock
|
||||
return None
|
||||
|
||||
def pin(self, server, sock_info):
|
||||
self.sharded = True
|
||||
self.pinned_address = server.description.address
|
||||
if server.description.server_type == SERVER_TYPE.LoadBalancer:
|
||||
sock_info.pinned = True
|
||||
self.sock_mgr = _SocketManager(sock_info, False)
|
||||
|
||||
def unpin(self):
|
||||
self.pinned_address = None
|
||||
if self.sock_mgr:
|
||||
self.sock_mgr.close()
|
||||
self.sock_mgr = None
|
||||
|
||||
def reset(self):
|
||||
self.unpin()
|
||||
self.state = _TxnState.NONE
|
||||
self.sharded = False
|
||||
self.pinned_address = None
|
||||
self.recovery_token = None
|
||||
self.attempt = 0
|
||||
|
||||
@ -374,6 +396,9 @@ class ClientSession(object):
|
||||
try:
|
||||
if self.in_transaction:
|
||||
self.abort_transaction()
|
||||
# It's possible we're still pinned here when the transaction
|
||||
# is in the committed state when the session is discarded.
|
||||
self._unpin()
|
||||
finally:
|
||||
self._client._return_server_session(self._server_session, lock)
|
||||
self._server_session = None
|
||||
@ -779,14 +804,18 @@ class ClientSession(object):
|
||||
return self._transaction.pinned_address
|
||||
return None
|
||||
|
||||
def _pin(self, server):
|
||||
"""Pin this session to the given Server."""
|
||||
self._transaction.sharded = True
|
||||
self._transaction.pinned_address = server.description.address
|
||||
@property
|
||||
def _pinned_connection(self):
|
||||
"""The connection this transaction was started on."""
|
||||
return self._transaction.pinned_conn
|
||||
|
||||
def _pin(self, server, sock_info):
|
||||
"""Pin this session to the given Server or to the given connection."""
|
||||
self._transaction.pin(server, sock_info)
|
||||
|
||||
def _unpin(self):
|
||||
"""Unpin this session from any pinned Server."""
|
||||
self._transaction.pinned_address = None
|
||||
self._transaction.unpin()
|
||||
|
||||
def _txn_read_preference(self):
|
||||
"""Return read preference of this transaction or None."""
|
||||
@ -800,9 +829,6 @@ class ClientSession(object):
|
||||
self._server_session.last_use = time.monotonic()
|
||||
command['lsid'] = self._server_session.session_id
|
||||
|
||||
if not self.in_transaction:
|
||||
self._transaction.reset()
|
||||
|
||||
if is_retryable:
|
||||
command['txnNumber'] = self._server_session.transaction_id
|
||||
return
|
||||
|
||||
@ -520,7 +520,8 @@ class Collection(common.BaseObject):
|
||||
if publish:
|
||||
duration = datetime.datetime.now() - start
|
||||
listeners.publish_command_start(
|
||||
cmd, self.__database.name, rqst_id, sock_info.address, op_id)
|
||||
cmd, self.__database.name, rqst_id, sock_info.address, op_id,
|
||||
sock_info.service_id)
|
||||
start = datetime.datetime.now()
|
||||
try:
|
||||
result = sock_info.legacy_write(rqst_id, msg, max_size, False)
|
||||
@ -534,12 +535,14 @@ class Collection(common.BaseObject):
|
||||
reply = message._convert_write_result(
|
||||
name, cmd, details)
|
||||
listeners.publish_command_success(
|
||||
dur, reply, name, rqst_id, sock_info.address, op_id)
|
||||
dur, reply, name, rqst_id, sock_info.address,
|
||||
op_id, sock_info.service_id)
|
||||
raise
|
||||
else:
|
||||
details = message._convert_exception(exc)
|
||||
listeners.publish_command_failure(
|
||||
dur, details, name, rqst_id, sock_info.address, op_id)
|
||||
dur, details, name, rqst_id, sock_info.address, op_id,
|
||||
sock_info.service_id)
|
||||
raise
|
||||
if publish:
|
||||
if result is not None:
|
||||
@ -549,7 +552,8 @@ class Collection(common.BaseObject):
|
||||
reply = {'ok': 1}
|
||||
duration = (datetime.datetime.now() - start) + duration
|
||||
listeners.publish_command_success(
|
||||
duration, reply, name, rqst_id, sock_info.address, op_id)
|
||||
duration, reply, name, rqst_id, sock_info.address, op_id,
|
||||
sock_info.service_id)
|
||||
return result
|
||||
|
||||
def _insert_one(
|
||||
@ -2072,9 +2076,9 @@ class Collection(common.BaseObject):
|
||||
if exc.code != 26:
|
||||
raise
|
||||
cursor = {'id': 0, 'firstBatch': []}
|
||||
return CommandCursor(coll, cursor, sock_info.address,
|
||||
session=s,
|
||||
explicit_session=session is not None)
|
||||
cmd_cursor = CommandCursor(
|
||||
coll, cursor, sock_info.address, session=s,
|
||||
explicit_session=session is not None)
|
||||
else:
|
||||
res = message._first_batch(
|
||||
sock_info, self.__database.name, "system.indexes",
|
||||
@ -2084,10 +2088,13 @@ class Collection(common.BaseObject):
|
||||
cursor = res["cursor"]
|
||||
# Note that a collection can only have 64 indexes, so there
|
||||
# will never be a getMore call.
|
||||
return CommandCursor(coll, cursor, sock_info.address)
|
||||
cmd_cursor = CommandCursor(coll, cursor, sock_info.address)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
return cmd_cursor
|
||||
|
||||
return self.__database.client._retryable_read(
|
||||
_cmd, read_pref, session)
|
||||
_cmd, read_pref, session,
|
||||
pin=self.__database.client._should_pin_cursor(session))
|
||||
|
||||
def index_information(self, session=None):
|
||||
"""Get information on this collection's indexes.
|
||||
@ -2168,7 +2175,8 @@ class Collection(common.BaseObject):
|
||||
user_fields={'cursor': {'firstBatch': 1}})
|
||||
return self.__database.client._retryable_read(
|
||||
cmd.get_cursor, cmd.get_read_preference(session), session,
|
||||
retryable=not cmd._performs_write)
|
||||
retryable=not cmd._performs_write,
|
||||
pin=self.database.client._should_pin_cursor(session))
|
||||
|
||||
def aggregate(self, pipeline, session=None, **kwargs):
|
||||
"""Perform an aggregation using the aggregation framework on this
|
||||
|
||||
@ -17,13 +17,14 @@
|
||||
from collections import deque
|
||||
|
||||
from bson import _convert_raw_document_lists_to_streams
|
||||
from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
InvalidOperation,
|
||||
NotMasterError,
|
||||
OperationFailure)
|
||||
from pymongo.message import (_CursorAddress,
|
||||
_GetMore,
|
||||
_RawBatchGetMore)
|
||||
from pymongo.response import PinnedResponse
|
||||
|
||||
|
||||
class CommandCursor(object):
|
||||
@ -37,6 +38,7 @@ class CommandCursor(object):
|
||||
|
||||
The parameter 'retrieved' is unused.
|
||||
"""
|
||||
self.__sock_mgr = None
|
||||
self.__collection = collection
|
||||
self.__id = cursor_info['id']
|
||||
self.__data = deque(cursor_info['firstBatch'])
|
||||
@ -75,11 +77,15 @@ class CommandCursor(object):
|
||||
self.__address, self.__collection.full_name)
|
||||
if synchronous:
|
||||
self.__collection.database.client._close_cursor_now(
|
||||
self.__id, address, session=self.__session)
|
||||
self.__id, address, session=self.__session,
|
||||
sock_mgr=self.__sock_mgr)
|
||||
else:
|
||||
# The cursor will be closed later in a different session.
|
||||
self.__collection.database.client._close_cursor(
|
||||
self.__id, address)
|
||||
if self.__sock_mgr:
|
||||
self.__sock_mgr.close()
|
||||
self.__sock_mgr = None
|
||||
self.__end_session(synchronous)
|
||||
|
||||
def __end_session(self, synchronous):
|
||||
@ -127,52 +133,58 @@ class CommandCursor(object):
|
||||
changeStream aggregate or getMore."""
|
||||
return self.__postbatchresumetoken
|
||||
|
||||
def _maybe_pin_connection(self, sock_info):
|
||||
client = self.__collection.database.client
|
||||
if not client._should_pin_cursor(self.__session):
|
||||
return
|
||||
if not self.__sock_mgr:
|
||||
sock_mgr = _SocketManager(sock_info, False)
|
||||
# Ensure the connection gets returned when the entire result is
|
||||
# returned in the first batch.
|
||||
if self.__id == 0:
|
||||
sock_mgr.close()
|
||||
else:
|
||||
self.__sock_mgr = sock_mgr
|
||||
|
||||
def __send_message(self, operation):
|
||||
"""Send a getmore message and handle the response.
|
||||
"""
|
||||
def kill():
|
||||
self.__killed = True
|
||||
self.__end_session(True)
|
||||
|
||||
client = self.__collection.database.client
|
||||
try:
|
||||
response = client._run_operation_with_response(
|
||||
response = client._run_operation(
|
||||
operation, self._unpack_response, address=self.__address)
|
||||
except OperationFailure:
|
||||
kill()
|
||||
raise
|
||||
except NotMasterError:
|
||||
# Don't send kill cursors to another server after a "not master"
|
||||
# error. It's completely pointless.
|
||||
kill()
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't try to send kill cursors on another socket
|
||||
# or to another server. It can cause a _pinValue
|
||||
# assertion on some server releases if we get here
|
||||
# due to a socket timeout.
|
||||
kill()
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
# Return the session and pinned connection, if necessary.
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Close the cursor
|
||||
self.__die()
|
||||
self.close()
|
||||
raise
|
||||
|
||||
from_command = response.from_command
|
||||
reply = response.data
|
||||
docs = response.docs
|
||||
|
||||
if from_command:
|
||||
cursor = docs[0]['cursor']
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self.__sock_mgr:
|
||||
self.__sock_mgr = _SocketManager(response.socket_info,
|
||||
response.more_to_come)
|
||||
if response.from_command:
|
||||
cursor = response.docs[0]['cursor']
|
||||
documents = cursor['nextBatch']
|
||||
self.__postbatchresumetoken = cursor.get('postBatchResumeToken')
|
||||
self.__id = cursor['id']
|
||||
else:
|
||||
documents = docs
|
||||
self.__id = reply.cursor_id
|
||||
documents = response.docs
|
||||
self.__id = response.data.cursor_id
|
||||
|
||||
if self.__id == 0:
|
||||
kill()
|
||||
self.__die(True)
|
||||
self.__data = deque(documents)
|
||||
|
||||
def _unpack_response(self, response, cursor_id, codec_options,
|
||||
@ -203,10 +215,9 @@ class CommandCursor(object):
|
||||
self.__session,
|
||||
self.__collection.database.client,
|
||||
self.__max_await_time_ms,
|
||||
False))
|
||||
self.__sock_mgr, False))
|
||||
else: # Cursor id is zero nothing else to return
|
||||
self.__killed = True
|
||||
self.__end_session(True)
|
||||
self.__die(True)
|
||||
|
||||
return len(self.__data)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Cursor class to iterate over Mongo query results."""
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from collections import deque
|
||||
@ -27,7 +28,6 @@ from pymongo.common import validate_boolean, validate_is_mapping
|
||||
from pymongo.collation import validate_collation_or_none
|
||||
from pymongo.errors import (ConnectionFailure,
|
||||
InvalidOperation,
|
||||
NotMasterError,
|
||||
OperationFailure)
|
||||
from pymongo.message import (_CursorAddress,
|
||||
_GetMore,
|
||||
@ -35,7 +35,37 @@ from pymongo.message import (_CursorAddress,
|
||||
_Query,
|
||||
_RawBatchQuery)
|
||||
from pymongo.monitoring import ConnectionClosedReason
|
||||
from pymongo.response import PinnedResponse
|
||||
|
||||
# These errors mean that the server has already killed the cursor so there is
|
||||
# no need to send killCursors.
|
||||
_CURSOR_CLOSED_ERRORS = frozenset([
|
||||
43, # CursorNotFound
|
||||
50, # MaxTimeMSExpired
|
||||
175, # QueryPlanKilled
|
||||
237, # CursorKilled
|
||||
|
||||
# On a tailable cursor, the following errors mean the capped collection
|
||||
# rolled over.
|
||||
# MongoDB 2.6:
|
||||
# {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0}
|
||||
28617,
|
||||
# MongoDB 3.0:
|
||||
# {'$err': 'getMore executor error: UnknownError no details available',
|
||||
# 'code': 17406, 'ok': 0}
|
||||
17406,
|
||||
# MongoDB 3.2 + 3.4:
|
||||
# {'ok': 0.0, 'errmsg': 'GetMore command executor error:
|
||||
# CappedPositionLost: CollectionScan died due to failure to restore
|
||||
# tailable cursor position. Last seen record id: RecordId(3)',
|
||||
# 'code': 96}
|
||||
96,
|
||||
# MongoDB 3.6+:
|
||||
# {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to
|
||||
# restore tailable cursor position. Last seen record id: RecordId(3)"',
|
||||
# 'code': 136, 'codeName': 'CappedPositionLost'}
|
||||
136,
|
||||
])
|
||||
|
||||
_QUERY_OPTIONS = {
|
||||
"tailable_cursor": 2,
|
||||
@ -78,14 +108,14 @@ class CursorType(object):
|
||||
|
||||
# This has to be an old style class due to
|
||||
# http://bugs.jython.org/issue1057
|
||||
class _ExhaustManager:
|
||||
class _SocketManager:
|
||||
"""Used with exhaust cursors to ensure the socket is returned.
|
||||
"""
|
||||
def __init__(self, sock, pool, more_to_come):
|
||||
def __init__(self, sock, more_to_come):
|
||||
self.sock = sock
|
||||
self.pool = pool
|
||||
self.more_to_come = more_to_come
|
||||
self.__closed = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@ -98,8 +128,8 @@ class _ExhaustManager:
|
||||
"""
|
||||
if not self.__closed:
|
||||
self.__closed = True
|
||||
self.pool.return_socket(self.sock)
|
||||
self.sock, self.pool = None, None
|
||||
self.sock.unpin()
|
||||
self.sock = None
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
@ -128,7 +158,7 @@ class Cursor(object):
|
||||
# an error to avoid attribute errors during garbage collection.
|
||||
self.__id = None
|
||||
self.__exhaust = False
|
||||
self.__exhaust_mgr = None
|
||||
self.__sock_mgr = None
|
||||
self.__killed = False
|
||||
|
||||
if session:
|
||||
@ -319,24 +349,26 @@ class Cursor(object):
|
||||
|
||||
self.__killed = True
|
||||
if self.__id and not already_killed:
|
||||
if self.__exhaust and self.__exhaust_mgr:
|
||||
if self.__exhaust and self.__sock_mgr:
|
||||
# If this is an exhaust cursor and we haven't completely
|
||||
# exhausted the result set we *must* close the socket
|
||||
# to stop the server from sending more data.
|
||||
self.__exhaust_mgr.sock.close_socket(
|
||||
self.__sock_mgr.sock.close_socket(
|
||||
ConnectionClosedReason.ERROR)
|
||||
else:
|
||||
address = _CursorAddress(
|
||||
self.__address, self.__collection.full_name)
|
||||
if synchronous:
|
||||
self.__collection.database.client._close_cursor_now(
|
||||
self.__id, address, session=self.__session)
|
||||
self.__id, address, session=self.__session,
|
||||
sock_mgr=self.__sock_mgr)
|
||||
else:
|
||||
# The cursor will be closed later in a different session.
|
||||
self.__collection.database.client._close_cursor(
|
||||
self.__id, address)
|
||||
if self.__exhaust and self.__exhaust_mgr:
|
||||
self.__exhaust_mgr.close()
|
||||
if self.__sock_mgr:
|
||||
self.__sock_mgr.close()
|
||||
self.__sock_mgr = None
|
||||
if self.__session and not self.__explicit_session:
|
||||
self.__session._end_session(lock=synchronous)
|
||||
self.__session = None
|
||||
@ -1004,53 +1036,35 @@ class Cursor(object):
|
||||
"exhaust cursors do not support auto encryption")
|
||||
|
||||
try:
|
||||
response = client._run_operation_with_response(
|
||||
operation, self._unpack_response, exhaust=self.__exhaust,
|
||||
address=self.__address)
|
||||
except OperationFailure:
|
||||
self.__killed = True
|
||||
|
||||
# Make sure exhaust socket is returned immediately, if necessary.
|
||||
self.__die()
|
||||
|
||||
response = client._run_operation(
|
||||
operation, self._unpack_response, address=self.__address)
|
||||
except OperationFailure as exc:
|
||||
if exc.code in _CURSOR_CLOSED_ERRORS or self.__exhaust:
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
self.close()
|
||||
# If this is a tailable cursor the error is likely
|
||||
# due to capped collection roll over. Setting
|
||||
# self.__killed to True ensures Cursor.alive will be
|
||||
# False. No need to re-raise.
|
||||
if self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]:
|
||||
if (exc.code in _CURSOR_CLOSED_ERRORS and
|
||||
self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]):
|
||||
return
|
||||
raise
|
||||
except NotMasterError:
|
||||
# Don't send kill cursors to another server after a "not master"
|
||||
# error. It's completely pointless.
|
||||
self.__killed = True
|
||||
|
||||
# Make sure exhaust socket is returned immediately, if necessary.
|
||||
self.__die()
|
||||
|
||||
raise
|
||||
except ConnectionFailure:
|
||||
# Don't try to send kill cursors on another socket
|
||||
# or to another server. It can cause a _pinValue
|
||||
# assertion on some server releases if we get here
|
||||
# due to a socket timeout.
|
||||
# Don't send killCursors because the cursor is already closed.
|
||||
self.__killed = True
|
||||
self.__die()
|
||||
self.close()
|
||||
raise
|
||||
except Exception:
|
||||
# Close the cursor
|
||||
self.__die()
|
||||
self.close()
|
||||
raise
|
||||
|
||||
self.__address = response.address
|
||||
if self.__exhaust:
|
||||
# 'response' is an ExhaustResponse.
|
||||
if not self.__exhaust_mgr:
|
||||
self.__exhaust_mgr = _ExhaustManager(response.socket_info,
|
||||
response.pool,
|
||||
response.more_to_come)
|
||||
else:
|
||||
self.__exhaust_mgr.update_exhaust(response.more_to_come)
|
||||
if isinstance(response, PinnedResponse):
|
||||
if not self.__sock_mgr:
|
||||
self.__sock_mgr = _SocketManager(response.socket_info,
|
||||
response.more_to_come)
|
||||
|
||||
cmd_name = operation.name
|
||||
docs = response.docs
|
||||
@ -1078,7 +1092,6 @@ class Cursor(object):
|
||||
self.__retrieved += response.data.number_returned
|
||||
|
||||
if self.__id == 0:
|
||||
self.__killed = True
|
||||
# Don't wait for garbage collection to call __del__, return the
|
||||
# socket and the session to the pool now.
|
||||
self.__die()
|
||||
@ -1132,7 +1145,8 @@ class Cursor(object):
|
||||
self.__collation,
|
||||
self.__session,
|
||||
self.__collection.database.client,
|
||||
self.__allow_disk_use)
|
||||
self.__allow_disk_use,
|
||||
self.__exhaust)
|
||||
self.__send_message(q)
|
||||
elif self.__id: # Get More
|
||||
if self.__limit:
|
||||
@ -1151,7 +1165,8 @@ class Cursor(object):
|
||||
self.__session,
|
||||
self.__collection.database.client,
|
||||
self.__max_await_time_ms,
|
||||
self.__exhaust_mgr)
|
||||
self.__sock_mgr,
|
||||
self.__exhaust)
|
||||
self.__send_message(g)
|
||||
|
||||
return len(self.__data)
|
||||
|
||||
@ -377,7 +377,8 @@ class Database(common.BaseObject):
|
||||
user_fields={'cursor': {'firstBatch': 1}})
|
||||
return self.client._retryable_read(
|
||||
cmd.get_cursor, cmd.get_read_preference(s), s,
|
||||
retryable=not cmd._performs_write)
|
||||
retryable=not cmd._performs_write,
|
||||
pin=self.client._should_pin_cursor(s))
|
||||
|
||||
def watch(self, pipeline=None, full_document=None, resume_after=None,
|
||||
max_await_time_ms=None, batch_size=None, collation=None,
|
||||
@ -636,7 +637,7 @@ class Database(common.BaseObject):
|
||||
sock_info, cmd, slave_okay,
|
||||
read_preference=read_preference,
|
||||
session=tmp_session)["cursor"]
|
||||
return CommandCursor(
|
||||
cmd_cursor = CommandCursor(
|
||||
coll,
|
||||
cursor,
|
||||
sock_info.address,
|
||||
@ -656,7 +657,9 @@ class Database(common.BaseObject):
|
||||
("pipeline", pipeline),
|
||||
("cursor", kwargs.get("cursor", {}))])
|
||||
cursor = self._command(sock_info, cmd, slave_okay)["cursor"]
|
||||
return CommandCursor(coll, cursor, sock_info.address)
|
||||
cmd_cursor = CommandCursor(coll, cursor, sock_info.address)
|
||||
cmd_cursor._maybe_pin_connection(sock_info)
|
||||
return cmd_cursor
|
||||
|
||||
def list_collections(self, session=None, filter=None, **kwargs):
|
||||
"""Get a cursor over the collectons of this database.
|
||||
@ -688,7 +691,8 @@ class Database(common.BaseObject):
|
||||
**kwargs)
|
||||
|
||||
return self.__client._retryable_read(
|
||||
_cmd, read_pref, session)
|
||||
_cmd, read_pref, session,
|
||||
pin=self.client._should_pin_cursor(session))
|
||||
|
||||
def list_collection_names(self, session=None, filter=None, **kwargs):
|
||||
"""Get a list of all the collection names in this database.
|
||||
|
||||
@ -234,16 +234,17 @@ class _Query(object):
|
||||
__slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec',
|
||||
'fields', 'codec_options', 'read_preference', 'limit',
|
||||
'batch_size', 'name', 'read_concern', 'collation',
|
||||
'session', 'client', 'allow_disk_use', '_as_command')
|
||||
'session', 'client', 'allow_disk_use', '_as_command',
|
||||
'exhaust')
|
||||
|
||||
# For compatibility with the _GetMore class.
|
||||
exhaust_mgr = None
|
||||
sock_mgr = None
|
||||
cursor_id = None
|
||||
|
||||
def __init__(self, flags, db, coll, ntoskip, spec, fields,
|
||||
codec_options, read_preference, limit,
|
||||
batch_size, read_concern, collation, session, client,
|
||||
allow_disk_use):
|
||||
allow_disk_use, exhaust):
|
||||
self.flags = flags
|
||||
self.db = db
|
||||
self.coll = coll
|
||||
@ -261,13 +262,14 @@ class _Query(object):
|
||||
self.allow_disk_use = allow_disk_use
|
||||
self.name = 'find'
|
||||
self._as_command = None
|
||||
self.exhaust = exhaust
|
||||
|
||||
def namespace(self):
|
||||
return "%s.%s" % (self.db, self.coll)
|
||||
|
||||
def use_command(self, sock_info, exhaust):
|
||||
def use_command(self, sock_info):
|
||||
use_find_cmd = False
|
||||
if sock_info.max_wire_version >= 4 and not exhaust:
|
||||
if sock_info.max_wire_version >= 4 and not self.exhaust:
|
||||
use_find_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
@ -375,13 +377,13 @@ class _GetMore(object):
|
||||
|
||||
__slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms',
|
||||
'codec_options', 'read_preference', 'session', 'client',
|
||||
'exhaust_mgr', '_as_command')
|
||||
'sock_mgr', '_as_command', 'exhaust')
|
||||
|
||||
name = 'getMore'
|
||||
|
||||
def __init__(self, db, coll, ntoreturn, cursor_id, codec_options,
|
||||
read_preference, session, client, max_await_time_ms,
|
||||
exhaust_mgr):
|
||||
sock_mgr, exhaust):
|
||||
self.db = db
|
||||
self.coll = coll
|
||||
self.ntoreturn = ntoreturn
|
||||
@ -391,15 +393,16 @@ class _GetMore(object):
|
||||
self.session = session
|
||||
self.client = client
|
||||
self.max_await_time_ms = max_await_time_ms
|
||||
self.exhaust_mgr = exhaust_mgr
|
||||
self.sock_mgr = sock_mgr
|
||||
self._as_command = None
|
||||
self.exhaust = exhaust
|
||||
|
||||
def namespace(self):
|
||||
return "%s.%s" % (self.db, self.coll)
|
||||
|
||||
def use_command(self, sock_info, exhaust):
|
||||
def use_command(self, sock_info):
|
||||
use_cmd = False
|
||||
if sock_info.max_wire_version >= 4 and not exhaust:
|
||||
if sock_info.max_wire_version >= 4 and not self.exhaust:
|
||||
use_cmd = True
|
||||
elif sock_info.max_wire_version >= 8:
|
||||
# OP_MSG supports exhaust on MongoDB 4.2+
|
||||
@ -440,7 +443,7 @@ class _GetMore(object):
|
||||
if use_cmd:
|
||||
spec = self.as_command(sock_info)[0]
|
||||
if sock_info.op_msg_enabled:
|
||||
if self.exhaust_mgr:
|
||||
if self.sock_mgr:
|
||||
flags = _OpMsg.EXHAUST_ALLOWED
|
||||
else:
|
||||
flags = 0
|
||||
@ -456,23 +459,23 @@ class _GetMore(object):
|
||||
|
||||
|
||||
class _RawBatchQuery(_Query):
|
||||
def use_command(self, socket_info, exhaust):
|
||||
def use_command(self, sock_info):
|
||||
# Compatibility checks.
|
||||
super(_RawBatchQuery, self).use_command(socket_info, exhaust)
|
||||
if socket_info.max_wire_version >= 8:
|
||||
super(_RawBatchQuery, self).use_command(sock_info)
|
||||
if sock_info.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif socket_info.op_msg_enabled and not exhaust:
|
||||
elif sock_info.op_msg_enabled and not self.exhaust:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _RawBatchGetMore(_GetMore):
|
||||
def use_command(self, socket_info, exhaust):
|
||||
if socket_info.max_wire_version >= 8:
|
||||
def use_command(self, sock_info):
|
||||
if sock_info.max_wire_version >= 8:
|
||||
# MongoDB 4.2+ supports exhaust over OP_MSG
|
||||
return True
|
||||
elif socket_info.op_msg_enabled and not exhaust:
|
||||
elif sock_info.op_msg_enabled and not self.exhaust:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1033,20 +1036,23 @@ class _BulkWriteContext(object):
|
||||
cmd[self.field] = docs
|
||||
self.listeners.publish_command_start(
|
||||
cmd, self.db_name,
|
||||
request_id, self.sock_info.address, self.op_id)
|
||||
request_id, self.sock_info.address, self.op_id,
|
||||
self.sock_info.service_id)
|
||||
return cmd
|
||||
|
||||
def _succeed(self, request_id, reply, duration):
|
||||
"""Publish a CommandSucceededEvent."""
|
||||
self.listeners.publish_command_success(
|
||||
duration, reply, self.name,
|
||||
request_id, self.sock_info.address, self.op_id)
|
||||
request_id, self.sock_info.address, self.op_id,
|
||||
self.sock_info.service_id)
|
||||
|
||||
def _fail(self, request_id, failure, duration):
|
||||
"""Publish a CommandFailedEvent."""
|
||||
self.listeners.publish_command_failure(
|
||||
duration, failure, self.name,
|
||||
request_id, self.sock_info.address, self.op_id)
|
||||
request_id, self.sock_info.address, self.op_id,
|
||||
self.sock_info.service_id)
|
||||
|
||||
|
||||
# From the Client Side Encryption spec:
|
||||
@ -1686,7 +1692,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
|
||||
query = _Query(
|
||||
0, db, coll, 0, query, None, codec_options,
|
||||
read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None,
|
||||
None, None)
|
||||
None, None, False)
|
||||
|
||||
name = next(iter(cmd))
|
||||
publish = listeners.enabled_for_commands
|
||||
@ -1698,7 +1704,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
|
||||
if publish:
|
||||
encoding_duration = datetime.datetime.now() - start
|
||||
listeners.publish_command_start(
|
||||
cmd, db, request_id, sock_info.address)
|
||||
cmd, db, request_id, sock_info.address,
|
||||
service_id=sock_info.service_id)
|
||||
start = datetime.datetime.now()
|
||||
|
||||
sock_info.send_message(msg, max_doc_size)
|
||||
@ -1713,7 +1720,8 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
|
||||
else:
|
||||
failure = _convert_exception(exc)
|
||||
listeners.publish_command_failure(
|
||||
duration, failure, name, request_id, sock_info.address)
|
||||
duration, failure, name, request_id, sock_info.address,
|
||||
service_id=sock_info.service_id)
|
||||
raise
|
||||
# listIndexes
|
||||
if 'cursor' in cmd:
|
||||
@ -1732,6 +1740,7 @@ def _first_batch(sock_info, db, coll, query, ntoreturn,
|
||||
if publish:
|
||||
duration = (datetime.datetime.now() - start) + encoding_duration
|
||||
listeners.publish_command_success(
|
||||
duration, result, name, request_id, sock_info.address)
|
||||
duration, result, name, request_id, sock_info.address,
|
||||
service_id=sock_info.service_id)
|
||||
|
||||
return result
|
||||
|
||||
@ -60,6 +60,7 @@ from pymongo.errors import (AutoReconnect,
|
||||
OperationFailure,
|
||||
PyMongoError,
|
||||
ServerSelectionTimeoutError)
|
||||
from pymongo.pool import ConnectionClosedReason
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_selectors import (writable_preferred_server_selector,
|
||||
writable_server_selector)
|
||||
@ -1160,10 +1161,20 @@ class MongoClient(common.BaseObject):
|
||||
return self._topology
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _get_socket(self, server, session, exhaust=False):
|
||||
def _get_socket(self, server, session, pin=False):
|
||||
in_txn = session and session.in_transaction
|
||||
with _MongoClientErrorHandler(self, server, session) as err_handler:
|
||||
# Reuse the pinned connection, if it exists.
|
||||
if in_txn and session._pinned_connection:
|
||||
yield session._pinned_connection
|
||||
return
|
||||
with server.get_socket(
|
||||
self.__all_credentials, checkout=exhaust) as sock_info:
|
||||
self.__all_credentials, checkout=pin,
|
||||
handler=err_handler) as sock_info:
|
||||
# Pin this session to the selected server or connection.
|
||||
if (in_txn and server.description.server_type in (
|
||||
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)):
|
||||
session._pin(server, sock_info)
|
||||
err_handler.contribute_socket(sock_info)
|
||||
if (self._encrypter and
|
||||
not self._encrypter._bypass_auto_encryption and
|
||||
@ -1186,6 +1197,8 @@ class MongoClient(common.BaseObject):
|
||||
"""
|
||||
try:
|
||||
topology = self._get_topology()
|
||||
if session and not session.in_transaction:
|
||||
session._transaction.reset()
|
||||
address = address or (session and session._pinned_address)
|
||||
if address:
|
||||
# We're running a getMore or this session is pinned to a mongos.
|
||||
@ -1195,12 +1208,6 @@ class MongoClient(common.BaseObject):
|
||||
% address)
|
||||
else:
|
||||
server = topology.select_server(server_selector)
|
||||
# Pin this session to the selected server if it's performing a
|
||||
# sharded transaction.
|
||||
if (server.description.server_type in (
|
||||
SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)
|
||||
and session and session.in_transaction):
|
||||
session._pin(server)
|
||||
return server
|
||||
except PyMongoError as exc:
|
||||
# Server selection errors in a transaction are transient.
|
||||
@ -1214,8 +1221,7 @@ class MongoClient(common.BaseObject):
|
||||
return self._get_socket(server, session)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _slaveok_for_server(self, read_preference, server, session,
|
||||
exhaust=False):
|
||||
def _slaveok_for_server(self, read_preference, server, session, pin=False):
|
||||
assert read_preference is not None, "read_preference must not be None"
|
||||
# Get a socket for a server matching the read preference, and yield
|
||||
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
|
||||
@ -1226,7 +1232,7 @@ class MongoClient(common.BaseObject):
|
||||
topology = self._get_topology()
|
||||
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
|
||||
|
||||
with self._get_socket(server, session, exhaust=exhaust) as sock_info:
|
||||
with self._get_socket(server, session, pin=pin) as sock_info:
|
||||
slave_ok = (single and not sock_info.is_mongos) or (
|
||||
read_preference != ReadPreference.PRIMARY)
|
||||
yield sock_info, slave_ok
|
||||
@ -1249,47 +1255,41 @@ class MongoClient(common.BaseObject):
|
||||
read_preference != ReadPreference.PRIMARY)
|
||||
yield sock_info, slave_ok
|
||||
|
||||
def _run_operation_with_response(self, operation, unpack_res,
|
||||
exhaust=False, address=None):
|
||||
def _should_pin_cursor(self, session):
|
||||
return (self.__options.load_balanced and
|
||||
not (session and session.in_transaction))
|
||||
|
||||
def _run_operation(self, operation, unpack_res, pin=False, address=None):
|
||||
"""Run a _Query/_GetMore operation and return a Response.
|
||||
|
||||
:Parameters:
|
||||
- `operation`: a _Query or _GetMore object.
|
||||
- `unpack_res`: A callable that decodes the wire protocol response.
|
||||
- `exhaust` (optional): If True, the socket used stays checked out.
|
||||
It is returned along with its Pool in the Response.
|
||||
- `address` (optional): Optional address when sending a message
|
||||
to a specific server, used for getMore.
|
||||
"""
|
||||
if operation.exhaust_mgr:
|
||||
pin = self._should_pin_cursor(operation.session) or operation.exhaust
|
||||
if operation.sock_mgr:
|
||||
server = self._select_server(
|
||||
operation.read_preference, operation.session, address=address)
|
||||
|
||||
with _MongoClientErrorHandler(
|
||||
self, server, operation.session) as err_handler:
|
||||
err_handler.contribute_socket(operation.exhaust_mgr.sock)
|
||||
return server.run_operation_with_response(
|
||||
operation.exhaust_mgr.sock,
|
||||
operation,
|
||||
True,
|
||||
self._event_listeners,
|
||||
exhaust,
|
||||
unpack_res)
|
||||
with operation.sock_mgr.lock:
|
||||
with _MongoClientErrorHandler(
|
||||
self, server, operation.session) as err_handler:
|
||||
err_handler.contribute_socket(operation.sock_mgr.sock)
|
||||
return server.run_operation(
|
||||
operation.sock_mgr.sock, operation, True,
|
||||
self._event_listeners, pin, unpack_res)
|
||||
|
||||
def _cmd(session, server, sock_info, slave_ok):
|
||||
return server.run_operation_with_response(
|
||||
sock_info,
|
||||
operation,
|
||||
slave_ok,
|
||||
self._event_listeners,
|
||||
exhaust,
|
||||
return server.run_operation(
|
||||
sock_info, operation, slave_ok, self._event_listeners, pin,
|
||||
unpack_res)
|
||||
|
||||
return self._retryable_read(
|
||||
_cmd, operation.read_preference, operation.session,
|
||||
address=address,
|
||||
retryable=isinstance(operation, message._Query),
|
||||
exhaust=exhaust)
|
||||
address=address, retryable=isinstance(operation, message._Query),
|
||||
pin=pin)
|
||||
|
||||
def _retry_with_session(self, retryable, func, session, bulk):
|
||||
"""Execute an operation with at most one consecutive retries
|
||||
@ -1361,7 +1361,7 @@ class MongoClient(common.BaseObject):
|
||||
last_error = exc
|
||||
|
||||
def _retryable_read(self, func, read_pref, session, address=None,
|
||||
retryable=True, exhaust=False):
|
||||
retryable=True, pin=False):
|
||||
"""Execute an operation with at most one consecutive retries
|
||||
|
||||
Returns func()'s return value on success. On error retries the same
|
||||
@ -1381,9 +1381,9 @@ class MongoClient(common.BaseObject):
|
||||
read_pref, session, address=address)
|
||||
if not server.description.retryable_reads_supported:
|
||||
retryable = False
|
||||
with self._slaveok_for_server(read_pref, server, session,
|
||||
exhaust=exhaust) as (sock_info,
|
||||
slave_ok):
|
||||
with self._slaveok_for_server(
|
||||
read_pref, server, session, pin=pin) as (
|
||||
sock_info, slave_ok):
|
||||
if retrying and not retryable:
|
||||
# A retry is not possible because this server does
|
||||
# not support retryable reads, raise the last error.
|
||||
@ -1496,7 +1496,8 @@ class MongoClient(common.BaseObject):
|
||||
"""
|
||||
self.__kill_cursors_queue.append((address, [cursor_id]))
|
||||
|
||||
def _close_cursor_now(self, cursor_id, address=None, session=None):
|
||||
def _close_cursor_now(self, cursor_id, address=None, session=None,
|
||||
sock_mgr=None):
|
||||
"""Send a kill cursors message with the given id.
|
||||
|
||||
The cursor is closed synchronously on the current thread.
|
||||
@ -1505,16 +1506,20 @@ class MongoClient(common.BaseObject):
|
||||
raise TypeError("cursor_id must be an instance of int")
|
||||
|
||||
try:
|
||||
self._kill_cursors(
|
||||
[cursor_id], address, self._get_topology(), session)
|
||||
if sock_mgr:
|
||||
with sock_mgr.lock:
|
||||
# Cursor is pinned to LB outside of a transaction.
|
||||
self._kill_cursor_impl(
|
||||
[cursor_id], address, session, sock_mgr.sock)
|
||||
else:
|
||||
self._kill_cursors(
|
||||
[cursor_id], address, self._get_topology(), session)
|
||||
except PyMongoError:
|
||||
# Make another attempt to kill the cursor later.
|
||||
self.__kill_cursors_queue.append((address, [cursor_id]))
|
||||
|
||||
def _kill_cursors(self, cursor_ids, address, topology, session):
|
||||
"""Send a kill cursors message with the given ids."""
|
||||
listeners = self._event_listeners
|
||||
publish = listeners.enabled_for_commands
|
||||
if address:
|
||||
# address could be a tuple or _CursorAddress, but
|
||||
# select_server_by_address needs (host, port).
|
||||
@ -1523,49 +1528,55 @@ class MongoClient(common.BaseObject):
|
||||
# Application called close_cursor() with no address.
|
||||
server = topology.select_server(writable_server_selector)
|
||||
|
||||
with self._get_socket(server, session) as sock_info:
|
||||
self._kill_cursor_impl(cursor_ids, address, session, sock_info)
|
||||
|
||||
def _kill_cursor_impl(self, cursor_ids, address, session, sock_info):
|
||||
listeners = self._event_listeners
|
||||
publish = listeners.enabled_for_commands
|
||||
|
||||
try:
|
||||
namespace = address.namespace
|
||||
db, coll = namespace.split('.', 1)
|
||||
except AttributeError:
|
||||
namespace = None
|
||||
db = coll = "OP_KILL_CURSORS"
|
||||
|
||||
spec = SON([('killCursors', coll), ('cursors', cursor_ids)])
|
||||
with server.get_socket(self.__all_credentials) as sock_info:
|
||||
if sock_info.max_wire_version >= 4 and namespace is not None:
|
||||
sock_info.command(db, spec, session=session, client=self)
|
||||
else:
|
||||
if publish:
|
||||
start = datetime.datetime.now()
|
||||
request_id, msg = message.kill_cursors(cursor_ids)
|
||||
if publish:
|
||||
duration = datetime.datetime.now() - start
|
||||
# Here and below, address could be a tuple or
|
||||
# _CursorAddress. We always want to publish a
|
||||
# tuple to match the rest of the monitoring
|
||||
# API.
|
||||
listeners.publish_command_start(
|
||||
spec, db, request_id, tuple(address))
|
||||
start = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
sock_info.send_message(msg, 0)
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
dur = ((datetime.datetime.now() - start) + duration)
|
||||
listeners.publish_command_failure(
|
||||
dur, message._convert_exception(exc),
|
||||
'killCursors', request_id,
|
||||
tuple(address))
|
||||
raise
|
||||
if sock_info.max_wire_version >= 4 and namespace is not None:
|
||||
sock_info.command(db, spec, session=session, client=self)
|
||||
else:
|
||||
if publish:
|
||||
start = datetime.datetime.now()
|
||||
request_id, msg = message.kill_cursors(cursor_ids)
|
||||
if publish:
|
||||
duration = datetime.datetime.now() - start
|
||||
# Here and below, address could be a tuple or
|
||||
# _CursorAddress. We always want to publish a
|
||||
# tuple to match the rest of the monitoring
|
||||
# API.
|
||||
listeners.publish_command_start(
|
||||
spec, db, request_id, tuple(address),
|
||||
service_id=sock_info.service_id)
|
||||
start = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
sock_info.send_message(msg, 0)
|
||||
except Exception as exc:
|
||||
if publish:
|
||||
duration = ((datetime.datetime.now() - start) + duration)
|
||||
# OP_KILL_CURSORS returns no reply, fake one.
|
||||
reply = {'cursorsUnknown': cursor_ids, 'ok': 1}
|
||||
listeners.publish_command_success(
|
||||
duration, reply, 'killCursors', request_id,
|
||||
tuple(address))
|
||||
dur = ((datetime.datetime.now() - start) + duration)
|
||||
listeners.publish_command_failure(
|
||||
dur, message._convert_exception(exc),
|
||||
'killCursors', request_id,
|
||||
tuple(address), service_id=sock_info.service_id)
|
||||
raise
|
||||
|
||||
if publish:
|
||||
duration = ((datetime.datetime.now() - start) + duration)
|
||||
# OP_KILL_CURSORS returns no reply, fake one.
|
||||
reply = {'cursorsUnknown': cursor_ids, 'ok': 1}
|
||||
listeners.publish_command_success(
|
||||
duration, reply, 'killCursors', request_id,
|
||||
tuple(address), service_id=sock_info.service_id)
|
||||
|
||||
def _process_kill_cursors(self):
|
||||
"""Process any pending kill cursors requests."""
|
||||
@ -1966,7 +1977,8 @@ def _add_retryable_write_error(exc, max_wire_version):
|
||||
class _MongoClientErrorHandler(object):
|
||||
"""Handle errors raised when executing an operation."""
|
||||
__slots__ = ('client', 'server_address', 'session', 'max_wire_version',
|
||||
'sock_generation', 'completed_handshake', 'service_id')
|
||||
'sock_generation', 'completed_handshake', 'service_id',
|
||||
'handled')
|
||||
|
||||
def __init__(self, client, server, session):
|
||||
self.client = client
|
||||
@ -1980,6 +1992,7 @@ class _MongoClientErrorHandler(object):
|
||||
self.sock_generation = server.pool.gen.get_overall()
|
||||
self.completed_handshake = False
|
||||
self.service_id = None
|
||||
self.handled = False
|
||||
|
||||
def contribute_socket(self, sock_info):
|
||||
"""Provide socket information to the error handler."""
|
||||
@ -1988,13 +2001,10 @@ class _MongoClientErrorHandler(object):
|
||||
self.service_id = sock_info.service_id
|
||||
self.completed_handshake = True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
def handle(self, exc_type, exc_val):
|
||||
if self.handled or exc_type is None:
|
||||
return
|
||||
|
||||
self.handled = True
|
||||
if self.session:
|
||||
if issubclass(exc_type, ConnectionFailure):
|
||||
if self.session.in_transaction:
|
||||
@ -2010,3 +2020,9 @@ class _MongoClientErrorHandler(object):
|
||||
exc_val, self.max_wire_version, self.sock_generation,
|
||||
self.completed_handshake, self.service_id)
|
||||
self.client._topology.handle_error(self.server_address, err_ctx)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return self.handle(exc_type, exc_val)
|
||||
|
||||
@ -1310,7 +1310,8 @@ class _EventListeners(object):
|
||||
self.__topology_listeners[:])
|
||||
|
||||
def publish_command_start(self, command, database_name,
|
||||
request_id, connection_id, op_id=None):
|
||||
request_id, connection_id, op_id=None,
|
||||
service_id=None):
|
||||
"""Publish a CommandStartedEvent to all command listeners.
|
||||
|
||||
:Parameters:
|
||||
@ -1321,11 +1322,13 @@ class _EventListeners(object):
|
||||
- `connection_id`: The address (host, port) of the server this
|
||||
command was sent to.
|
||||
- `op_id`: The (optional) operation id for this operation.
|
||||
- `service_id`: The service_id this command was sent to, or ``None``.
|
||||
"""
|
||||
if op_id is None:
|
||||
op_id = request_id
|
||||
event = CommandStartedEvent(
|
||||
command, database_name, request_id, connection_id, op_id)
|
||||
command, database_name, request_id, connection_id, op_id,
|
||||
service_id=service_id)
|
||||
for subscriber in self.__command_listeners:
|
||||
try:
|
||||
subscriber.started(event)
|
||||
@ -1333,7 +1336,8 @@ class _EventListeners(object):
|
||||
_handle_exception()
|
||||
|
||||
def publish_command_success(self, duration, reply, command_name,
|
||||
request_id, connection_id, op_id=None):
|
||||
request_id, connection_id, op_id=None,
|
||||
service_id=None):
|
||||
"""Publish a CommandSucceededEvent to all command listeners.
|
||||
|
||||
:Parameters:
|
||||
@ -1344,11 +1348,13 @@ class _EventListeners(object):
|
||||
- `connection_id`: The address (host, port) of the server this
|
||||
command was sent to.
|
||||
- `op_id`: The (optional) operation id for this operation.
|
||||
- `service_id`: The service_id this command was sent to, or ``None``.
|
||||
"""
|
||||
if op_id is None:
|
||||
op_id = request_id
|
||||
event = CommandSucceededEvent(
|
||||
duration, reply, command_name, request_id, connection_id, op_id)
|
||||
duration, reply, command_name, request_id, connection_id, op_id,
|
||||
service_id)
|
||||
for subscriber in self.__command_listeners:
|
||||
try:
|
||||
subscriber.succeeded(event)
|
||||
@ -1356,7 +1362,8 @@ class _EventListeners(object):
|
||||
_handle_exception()
|
||||
|
||||
def publish_command_failure(self, duration, failure, command_name,
|
||||
request_id, connection_id, op_id=None):
|
||||
request_id, connection_id, op_id=None,
|
||||
service_id=None):
|
||||
"""Publish a CommandFailedEvent to all command listeners.
|
||||
|
||||
:Parameters:
|
||||
@ -1368,11 +1375,13 @@ class _EventListeners(object):
|
||||
- `connection_id`: The address (host, port) of the server this
|
||||
command was sent to.
|
||||
- `op_id`: The (optional) operation id for this operation.
|
||||
- `service_id`: The service_id this command was sent to, or ``None``.
|
||||
"""
|
||||
if op_id is None:
|
||||
op_id = request_id
|
||||
event = CommandFailedEvent(
|
||||
duration, failure, command_name, request_id, connection_id, op_id)
|
||||
duration, failure, command_name, request_id, connection_id, op_id,
|
||||
service_id=service_id)
|
||||
for subscriber in self.__command_listeners:
|
||||
try:
|
||||
subscriber.failed(event)
|
||||
|
||||
@ -134,7 +134,8 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos,
|
||||
|
||||
if publish:
|
||||
encoding_duration = datetime.datetime.now() - start
|
||||
listeners.publish_command_start(orig, dbname, request_id, address)
|
||||
listeners.publish_command_start(orig, dbname, request_id, address,
|
||||
service_id=sock_info.service_id)
|
||||
start = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
@ -164,12 +165,14 @@ def command(sock_info, dbname, spec, slave_ok, is_mongos,
|
||||
else:
|
||||
failure = message._convert_exception(exc)
|
||||
listeners.publish_command_failure(
|
||||
duration, failure, name, request_id, address)
|
||||
duration, failure, name, request_id, address,
|
||||
service_id=sock_info.service_id)
|
||||
raise
|
||||
if publish:
|
||||
duration = (datetime.datetime.now() - start) + encoding_duration
|
||||
listeners.publish_command_success(
|
||||
duration, response_doc, name, request_id, address)
|
||||
duration, response_doc, name, request_id, address,
|
||||
service_id=sock_info.service_id)
|
||||
|
||||
if client and client._encrypter and reply:
|
||||
decrypted = client._encrypter.decrypt(reply.raw_command_response())
|
||||
|
||||
@ -23,6 +23,7 @@ import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
from bson.son import SON
|
||||
@ -503,6 +504,7 @@ class SocketInfo(object):
|
||||
- `id`: the id of this socket in it's pool
|
||||
"""
|
||||
def __init__(self, sock, pool, address, id):
|
||||
self.pool_ref = weakref.ref(pool)
|
||||
self.sock = sock
|
||||
self.address = address
|
||||
self.id = id
|
||||
@ -541,6 +543,17 @@ class SocketInfo(object):
|
||||
self.more_to_come = False
|
||||
# For load balancer support.
|
||||
self.service_id = None
|
||||
# When executing a transaction in load balancing mode, this flag is
|
||||
# set to true to indicate that the session now owns the connection.
|
||||
self.pinned = False
|
||||
|
||||
def unpin(self):
|
||||
self.pinned = False
|
||||
pool = self.pool_ref()
|
||||
if pool:
|
||||
pool.return_socket(self)
|
||||
else:
|
||||
self.close_socket(ConnectionClosedReason.STALE)
|
||||
|
||||
def hello_cmd(self):
|
||||
if self.opts.server_api:
|
||||
@ -712,7 +725,7 @@ class SocketInfo(object):
|
||||
unacknowledged=unacknowledged,
|
||||
user_fields=user_fields,
|
||||
exhaust_allowed=exhaust_allowed)
|
||||
except OperationFailure:
|
||||
except (OperationFailure, NotMasterError):
|
||||
raise
|
||||
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
|
||||
except BaseException as error:
|
||||
@ -898,7 +911,13 @@ class SocketInfo(object):
|
||||
# ...) is called in Python code, which experiences the signal as a
|
||||
# KeyboardInterrupt from the start, rather than as an initial
|
||||
# socket.error, so we catch that, close the socket, and reraise it.
|
||||
self.close_socket(ConnectionClosedReason.ERROR)
|
||||
#
|
||||
# The connection closed event will be emitted later in return_socket.
|
||||
if self.ready:
|
||||
reason = None
|
||||
else:
|
||||
reason = ConnectionClosedReason.ERROR
|
||||
self.close_socket(reason)
|
||||
# SSLError from PyOpenSSL inherits directly from Exception.
|
||||
if isinstance(error, (IOError, OSError, _SSLError)):
|
||||
_raise_connection_failure(self.address, error)
|
||||
@ -1152,6 +1171,10 @@ class Pool:
|
||||
self.address, self.opts.non_default_options)
|
||||
# Similar to active_sockets but includes threads in the wait queue.
|
||||
self.operation_count = 0
|
||||
# Retain references to pinned connections to prevent the CPython GC
|
||||
# from thinking that a cursor's pinned connection can be GC'd when the
|
||||
# cursor is GC'd (see PYTHON-2751).
|
||||
self.__pinned_sockets = set()
|
||||
|
||||
def ready(self):
|
||||
old_state, self.state = self.state, PoolState.READY
|
||||
@ -1328,7 +1351,7 @@ class Pool:
|
||||
return sock_info
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, all_credentials, checkout=False):
|
||||
def get_socket(self, all_credentials, checkout=False, handler=None):
|
||||
"""Get a socket from the pool. Use with a "with" statement.
|
||||
|
||||
Returns a :class:`SocketInfo` object wrapping a connected
|
||||
@ -1349,12 +1372,15 @@ class Pool:
|
||||
:Parameters:
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
- `checkout` (optional): keep socket checked out.
|
||||
- `handler` (optional): A _MongoClientErrorHandler.
|
||||
"""
|
||||
listeners = self.opts.event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_check_out_started(self.address)
|
||||
|
||||
sock_info = self._get_socket(all_credentials)
|
||||
if checkout:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_out(
|
||||
@ -1362,11 +1388,23 @@ class Pool:
|
||||
try:
|
||||
yield sock_info
|
||||
except:
|
||||
# Exception in caller. Decrement semaphore.
|
||||
self.return_socket(sock_info)
|
||||
# Exception in caller. Ensure the connection gets returned.
|
||||
# Note that when pinned is True, the session owns the
|
||||
# connection and it is responsible for checking the connection
|
||||
# back into the pool.
|
||||
pinned = sock_info.pinned
|
||||
if handler:
|
||||
# Perform SDAM error handling rules while the connection is
|
||||
# still checked out.
|
||||
exc_type, exc_val, _ = sys.exc_info()
|
||||
handler.handle(exc_type, exc_val)
|
||||
if not pinned:
|
||||
self.return_socket(sock_info)
|
||||
raise
|
||||
else:
|
||||
if not checkout:
|
||||
if sock_info.pinned:
|
||||
self.__pinned_sockets.add(sock_info)
|
||||
elif not checkout:
|
||||
self.return_socket(sock_info)
|
||||
|
||||
def _raise_if_not_ready(self, emit_event):
|
||||
@ -1487,6 +1525,8 @@ class Pool:
|
||||
:Parameters:
|
||||
- `sock_info`: The socket to check into the pool.
|
||||
"""
|
||||
self.__pinned_sockets.discard(sock_info)
|
||||
sock_info.pinned = False
|
||||
listeners = self.opts.event_listeners
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_checked_in(self.address, sock_info.id)
|
||||
@ -1495,7 +1535,13 @@ class Pool:
|
||||
else:
|
||||
if self.closed:
|
||||
sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED)
|
||||
elif not sock_info.closed:
|
||||
elif sock_info.closed:
|
||||
# CMAP requires the closed event be emitted after the check in.
|
||||
if self.enabled_for_cmap:
|
||||
listeners.publish_connection_closed(
|
||||
self.address, sock_info.id,
|
||||
ConnectionClosedReason.ERROR)
|
||||
else:
|
||||
with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
|
||||
@ -68,10 +68,10 @@ class Response(object):
|
||||
return self._docs
|
||||
|
||||
|
||||
class ExhaustResponse(Response):
|
||||
__slots__ = ('_socket_info', '_pool', '_more_to_come')
|
||||
class PinnedResponse(Response):
|
||||
__slots__ = ('_socket_info', '_more_to_come')
|
||||
|
||||
def __init__(self, data, address, socket_info, pool, request_id, duration,
|
||||
def __init__(self, data, address, socket_info, request_id, duration,
|
||||
from_command, docs, more_to_come):
|
||||
"""Represent a response to an exhaust cursor's initial query.
|
||||
|
||||
@ -87,13 +87,12 @@ class ExhaustResponse(Response):
|
||||
- `more_to_come`: Bool indicating whether cursor is ready to be
|
||||
exhausted.
|
||||
"""
|
||||
super(ExhaustResponse, self).__init__(data,
|
||||
address,
|
||||
request_id,
|
||||
duration,
|
||||
from_command, docs)
|
||||
super(PinnedResponse, self).__init__(data,
|
||||
address,
|
||||
request_id,
|
||||
duration,
|
||||
from_command, docs)
|
||||
self._socket_info = socket_info
|
||||
self._pool = pool
|
||||
self._more_to_come = more_to_come
|
||||
|
||||
@property
|
||||
@ -106,11 +105,6 @@ class ExhaustResponse(Response):
|
||||
"""
|
||||
return self._socket_info
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
"""The Pool from which the SocketInfo came."""
|
||||
return self._pool
|
||||
|
||||
@property
|
||||
def more_to_come(self):
|
||||
"""If true, server is ready to send batches on the socket until the
|
||||
|
||||
@ -21,7 +21,7 @@ from bson import _decode_all_selective
|
||||
from pymongo.errors import NotMasterError, OperationFailure
|
||||
from pymongo.helpers import _check_command_response
|
||||
from pymongo.message import _convert_exception, _OpMsg
|
||||
from pymongo.response import Response, ExhaustResponse
|
||||
from pymongo.response import Response, PinnedResponse
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
|
||||
_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}}
|
||||
@ -68,14 +68,8 @@ class Server(object):
|
||||
"""Check the server's state soon."""
|
||||
self._monitor.request_check()
|
||||
|
||||
def run_operation_with_response(
|
||||
self,
|
||||
sock_info,
|
||||
operation,
|
||||
set_slave_okay,
|
||||
listeners,
|
||||
exhaust,
|
||||
unpack_res):
|
||||
def run_operation(self, sock_info, operation, set_slave_okay, listeners,
|
||||
pin, unpack_res):
|
||||
"""Run a _Query or _GetMore operation and return a Response object.
|
||||
|
||||
This method is used only to run _Query/_GetMore operations from
|
||||
@ -87,7 +81,7 @@ class Server(object):
|
||||
- `set_slave_okay`: Pass to operation.get_message.
|
||||
- `all_credentials`: dict, maps auth source to MongoCredential.
|
||||
- `listeners`: Instance of _EventListeners or None.
|
||||
- `exhaust`: If True, then this is an exhaust cursor operation.
|
||||
- `pin`: If True, then this is a pinned cursor operation.
|
||||
- `unpack_res`: A callable that decodes the wire protocol response.
|
||||
"""
|
||||
duration = None
|
||||
@ -95,9 +89,9 @@ class Server(object):
|
||||
if publish:
|
||||
start = datetime.now()
|
||||
|
||||
use_cmd = operation.use_command(sock_info, exhaust)
|
||||
more_to_come = (operation.exhaust_mgr
|
||||
and operation.exhaust_mgr.more_to_come)
|
||||
use_cmd = operation.use_command(sock_info)
|
||||
more_to_come = (operation.sock_mgr
|
||||
and operation.sock_mgr.more_to_come)
|
||||
if more_to_come:
|
||||
request_id = 0
|
||||
else:
|
||||
@ -108,7 +102,8 @@ class Server(object):
|
||||
if publish:
|
||||
cmd, dbn = operation.as_command(sock_info)
|
||||
listeners.publish_command_start(
|
||||
cmd, dbn, request_id, sock_info.address)
|
||||
cmd, dbn, request_id, sock_info.address,
|
||||
service_id=sock_info.service_id)
|
||||
start = datetime.now()
|
||||
|
||||
try:
|
||||
@ -142,7 +137,8 @@ class Server(object):
|
||||
failure = _convert_exception(exc)
|
||||
listeners.publish_command_failure(
|
||||
duration, failure, operation.name,
|
||||
request_id, sock_info.address)
|
||||
request_id, sock_info.address,
|
||||
service_id=sock_info.service_id)
|
||||
raise
|
||||
|
||||
if publish:
|
||||
@ -163,7 +159,7 @@ class Server(object):
|
||||
res["cursor"]["nextBatch"] = docs
|
||||
listeners.publish_command_success(
|
||||
duration, res, operation.name, request_id,
|
||||
sock_info.address)
|
||||
sock_info.address, service_id=sock_info.service_id)
|
||||
|
||||
# Decrypt response.
|
||||
client = operation.client
|
||||
@ -174,19 +170,20 @@ class Server(object):
|
||||
docs = _decode_all_selective(
|
||||
decrypted, operation.codec_options, user_fields)
|
||||
|
||||
if exhaust:
|
||||
if pin:
|
||||
if isinstance(reply, _OpMsg):
|
||||
# In OP_MSG, the server keeps sending only if the
|
||||
# more_to_come flag is set.
|
||||
more_to_come = reply.more_to_come
|
||||
else:
|
||||
# In OP_REPLY, the server keeps sending until cursor_id is 0.
|
||||
more_to_come = bool(reply.cursor_id)
|
||||
response = ExhaustResponse(
|
||||
more_to_come = bool(operation.exhaust and reply.cursor_id)
|
||||
if operation.sock_mgr:
|
||||
operation.sock_mgr.update_exhaust(more_to_come)
|
||||
response = PinnedResponse(
|
||||
data=reply,
|
||||
address=self._description.address,
|
||||
socket_info=sock_info,
|
||||
pool=self._pool,
|
||||
duration=duration,
|
||||
request_id=request_id,
|
||||
from_command=use_cmd,
|
||||
@ -203,8 +200,8 @@ class Server(object):
|
||||
|
||||
return response
|
||||
|
||||
def get_socket(self, all_credentials, checkout=False):
|
||||
return self.pool.get_socket(all_credentials, checkout)
|
||||
def get_socket(self, all_credentials, checkout=False, handler=None):
|
||||
return self.pool.get_socket(all_credentials, checkout, handler)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
|
||||
@ -19,8 +19,8 @@ import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
|
||||
from test import unittest, IntegrationTest, client_context
|
||||
from test.utils import get_pool
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
# Location of JSON test specifications.
|
||||
@ -30,5 +30,23 @@ TEST_PATH = os.path.join(
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
|
||||
class TestLB(IntegrationTest):
|
||||
@client_context.require_load_balancer
|
||||
def test_unpin_committed_transaction(self):
|
||||
pool = get_pool(self.client)
|
||||
with self.client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
self.assertEqual(pool.active_sockets, 0)
|
||||
self.db.test.insert_one({}, session=session)
|
||||
self.assertEqual(pool.active_sockets, 1) # Pinned.
|
||||
self.assertEqual(pool.active_sockets, 1) # Still pinned.
|
||||
self.assertEqual(pool.active_sockets, 0) # Unpinned.
|
||||
|
||||
def test_client_can_be_reopened(self):
|
||||
self.client.close()
|
||||
self.db.test.find_one({})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -376,7 +376,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "pinned connections are not returned after an network error during getMore",
|
||||
"description": "pinned connections are returned after an network error during getMore",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
@ -440,7 +440,7 @@
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"connections": 1
|
||||
"connections": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -659,7 +659,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"description": "pinned connections are not returned to the pool after a non-network error on getMore",
|
||||
"description": "pinned connections are returned to the pool after a non-network error on getMore",
|
||||
"operations": [
|
||||
{
|
||||
"name": "failPoint",
|
||||
@ -715,7 +715,7 @@
|
||||
"object": "testRunner",
|
||||
"arguments": {
|
||||
"client": "client0",
|
||||
"connections": 1
|
||||
"connections": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@ -366,9 +366,6 @@
|
||||
{
|
||||
"connectionCreatedEvent": {}
|
||||
},
|
||||
{
|
||||
"poolClearedEvent": {}
|
||||
},
|
||||
{
|
||||
"connectionClosedEvent": {
|
||||
"reason": "error"
|
||||
@ -378,6 +375,9 @@
|
||||
"connectionCheckOutFailedEvent": {
|
||||
"reason": "connectionError"
|
||||
}
|
||||
},
|
||||
{
|
||||
"poolClearedEvent": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ class MockPool(Pool):
|
||||
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_socket(self, all_credentials, checkout=False):
|
||||
def get_socket(self, all_credentials, checkout=False, handler=None):
|
||||
client = self.client
|
||||
host_and_port = '%s:%s' % (self.mock_host, self.mock_port)
|
||||
if host_and_port in client.mock_down_hosts:
|
||||
@ -51,7 +51,8 @@ class MockPool(Pool):
|
||||
+ client.mock_members
|
||||
+ client.mock_mongoses), "bad host: %s" % host_and_port
|
||||
|
||||
with Pool.get_socket(self, all_credentials, checkout) as sock_info:
|
||||
with Pool.get_socket(
|
||||
self, all_credentials, checkout, handler) as sock_info:
|
||||
sock_info.mock_host = self.mock_host
|
||||
sock_info.mock_port = self.mock_port
|
||||
yield sock_info
|
||||
|
||||
@ -1323,11 +1323,11 @@ class TestClient(IntegrationTest):
|
||||
with self.assertRaises(AutoReconnect):
|
||||
client = rs_client(connect=False,
|
||||
serverSelectionTimeoutMS=100)
|
||||
client._run_operation_with_response(
|
||||
client._run_operation(
|
||||
operation=message._GetMore('pymongo_test', 'collection',
|
||||
101, 1234, client.codec_options,
|
||||
ReadPreference.PRIMARY,
|
||||
None, client, None, None),
|
||||
None, client, None, None, False),
|
||||
unpack_res=Cursor(
|
||||
client.pymongo_test.collection)._unpack_response,
|
||||
address=('not-a-member', 27017))
|
||||
@ -1708,7 +1708,7 @@ class TestExhaustCursor(IntegrationTest):
|
||||
cursor.next()
|
||||
|
||||
# Cause a network error.
|
||||
sock_info = cursor._Cursor__exhaust_mgr.sock
|
||||
sock_info = cursor._Cursor__sock_mgr.sock
|
||||
sock_info.sock.close()
|
||||
|
||||
# A getmore fails.
|
||||
|
||||
@ -20,6 +20,7 @@ import warnings
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from bson.int64 import Int64
|
||||
from bson.objectid import ObjectId
|
||||
from bson.son import SON
|
||||
from pymongo import CursorType, monitoring, InsertOne, UpdateOne, DeleteOne
|
||||
@ -397,7 +398,8 @@ class TestCommandMonitoring(PyMongoTestCase):
|
||||
def test_get_more_failure(self):
|
||||
address = self.client.address
|
||||
coll = self.client.pymongo_test.test
|
||||
cursor_doc = {"id": 12345, "firstBatch": [], "ns": coll.full_name}
|
||||
cursor_id = Int64(12345)
|
||||
cursor_doc = {"id": cursor_id, "firstBatch": [], "ns": coll.full_name}
|
||||
cursor = CommandCursor(coll, cursor_doc, address)
|
||||
try:
|
||||
next(cursor)
|
||||
@ -410,7 +412,7 @@ class TestCommandMonitoring(PyMongoTestCase):
|
||||
self.assertTrue(
|
||||
isinstance(started, monitoring.CommandStartedEvent))
|
||||
self.assertEqualCommand(
|
||||
SON([('getMore', 12345),
|
||||
SON([('getMore', cursor_id),
|
||||
('collection', 'test')]),
|
||||
started.command)
|
||||
self.assertEqual('getMore', started.command_name)
|
||||
|
||||
@ -322,10 +322,9 @@ class ReadPrefTester(MongoClient):
|
||||
yield sock_info, slave_ok
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _slaveok_for_server(self, read_preference, server, session,
|
||||
exhaust=False):
|
||||
def _slaveok_for_server(self, read_preference, server, session, pin=False):
|
||||
context = super(ReadPrefTester, self)._slaveok_for_server(
|
||||
read_preference, server, session, exhaust=exhaust)
|
||||
read_preference, server, session, pin=pin)
|
||||
with context as (sock_info, slave_ok):
|
||||
self.record_a_read(sock_info.address)
|
||||
yield sock_info, slave_ok
|
||||
|
||||
@ -519,6 +519,14 @@ class MatchEvaluatorUtil(object):
|
||||
self.test.assertIsInstance(actual, type(expectation))
|
||||
self.test.assertEqual(expectation, actual)
|
||||
|
||||
def assertHasServiceId(self, spec, actual):
|
||||
if 'hasServiceId' in spec:
|
||||
if spec.get('hasServiceId'):
|
||||
self.test.assertIsNotNone(actual.service_id)
|
||||
self.test.assertIsInstance(actual.service_id, ObjectId)
|
||||
else:
|
||||
self.test.assertIsNone(actual.service_id)
|
||||
|
||||
def match_event(self, event_type, expectation, actual):
|
||||
name, spec = next(iter(expectation.items()))
|
||||
|
||||
@ -543,24 +551,23 @@ class MatchEvaluatorUtil(object):
|
||||
if database_name:
|
||||
self.test.assertEqual(
|
||||
database_name, actual.database_name)
|
||||
self.assertHasServiceId(spec, actual)
|
||||
elif name == 'commandSucceededEvent':
|
||||
self.test.assertIsInstance(actual, CommandSucceededEvent)
|
||||
reply = spec.get('reply')
|
||||
if reply:
|
||||
self.match_result(reply, actual.reply)
|
||||
self.assertHasServiceId(spec, actual)
|
||||
elif name == 'commandFailedEvent':
|
||||
self.test.assertIsInstance(actual, CommandFailedEvent)
|
||||
self.assertHasServiceId(spec, actual)
|
||||
elif name == 'poolCreatedEvent':
|
||||
self.test.assertIsInstance(actual, PoolCreatedEvent)
|
||||
elif name == 'poolReadyEvent':
|
||||
self.test.assertIsInstance(actual, PoolReadyEvent)
|
||||
elif name == 'poolClearedEvent':
|
||||
self.test.assertIsInstance(actual, PoolClearedEvent)
|
||||
if spec.get('hasServiceId'):
|
||||
self.test.assertIsNotNone(actual.service_id)
|
||||
self.test.assertIsInstance(actual.service_id, ObjectId)
|
||||
else:
|
||||
self.test.assertIsNone(actual.service_id)
|
||||
self.assertHasServiceId(spec, actual)
|
||||
elif name == 'poolClosedEvent':
|
||||
self.test.assertIsInstance(actual, PoolClosedEvent)
|
||||
elif name == 'connectionCreatedEvent':
|
||||
@ -569,12 +576,14 @@ class MatchEvaluatorUtil(object):
|
||||
self.test.assertIsInstance(actual, ConnectionReadyEvent)
|
||||
elif name == 'connectionClosedEvent':
|
||||
self.test.assertIsInstance(actual, ConnectionClosedEvent)
|
||||
self.test.assertEqual(actual.reason, spec['reason'])
|
||||
if 'reason' in spec:
|
||||
self.test.assertEqual(actual.reason, spec['reason'])
|
||||
elif name == 'connectionCheckOutStartedEvent':
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
|
||||
elif name == 'connectionCheckOutFailedEvent':
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
|
||||
self.test.assertEqual(actual.reason, spec['reason'])
|
||||
if 'reason' in spec:
|
||||
self.test.assertEqual(actual.reason, spec['reason'])
|
||||
elif name == 'connectionCheckedOutEvent':
|
||||
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
|
||||
elif name == 'connectionCheckedInEvent':
|
||||
|
||||
@ -267,7 +267,7 @@ class MockPool(object):
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def get_socket(self, all_credentials, checkout=False):
|
||||
def get_socket(self, all_credentials, checkout=False, handler=None):
|
||||
return MockSocketInfo()
|
||||
|
||||
def return_socket(self, *args, **kwargs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user