PYTHON-2123 Streaming heartbeat protocol

MongoClient now requires 2 connections and 2 threads to each MongoDB 4.4+ server.
With one connection, the server streams (or pushes) updated heartbeat info.
With the other connection, the client periodically pings the server to
establish an accurate round-trip time (RTT). This change optimizes the
discovery of server state changes such as replica set elections.

Additional changes:
- Mark server Unknown before retrying isMaster check.
- Always reset the pool _after_ marking the server unknown.
- Configure fail point before creating the client in test SpecRunner.
- Unfreeze with replSetFreeze:0 to ensure a speedy elections in test suite.
This commit is contained in:
Shane Harvey 2020-06-02 12:21:39 -07:00
parent 0b375a2604
commit 1f4123e4bf
32 changed files with 2702 additions and 232 deletions

View File

@ -282,3 +282,9 @@ class EncryptionError(PyMongoError):
def cause(self):
"""The exception that caused this encryption or decryption error."""
return self.__cause
class _OperationCancelled(AutoReconnect):
"""Internal error raised when a socket operation is cancelled.
"""
pass

View File

@ -46,9 +46,10 @@ def _get_server_type(doc):
class IsMaster(object):
__slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable')
__slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable',
'_awaitable')
def __init__(self, doc):
def __init__(self, doc, awaitable=False):
"""Parse an ismaster response from the server."""
self._server_type = _get_server_type(doc)
self._doc = doc
@ -60,6 +61,7 @@ class IsMaster(object):
self._is_readable = (
self.server_type == SERVER_TYPE.RSSecondary
or self._is_writable)
self._awaitable = awaitable
@property
def document(self):
@ -177,3 +179,7 @@ class IsMaster(object):
@property
def topology_version(self):
return self._doc.get('topologyVersion')
@property
def awaitable(self):
return self._awaitable

View File

@ -1563,6 +1563,11 @@ class _OpReply(object):
# This should never be called on _OpReply.
raise NotImplementedError
@property
def more_to_come(self):
"""Is the moreToCome bit set on this response?"""
return False
@classmethod
def unpack(cls, msg):
"""Construct an _OpReply from raw bytes."""
@ -1583,6 +1588,11 @@ class _OpMsg(object):
UNPACK_FROM = struct.Struct("<IBi").unpack_from
OP_CODE = 2013
# Flag bits.
CHECKSUM_PRESENT = 1
MORE_TO_COME = 1 << 1
EXHAUST_ALLOWED = 1 << 16 # Only present on requests.
def __init__(self, flags, payload_document):
self.flags = flags
self.payload_document = payload_document
@ -1613,15 +1623,28 @@ class _OpMsg(object):
"""Return the bytes of the command response."""
return self.payload_document
@property
def more_to_come(self):
"""Is the moreToCome bit set on this response?"""
return self.flags & self.MORE_TO_COME
@classmethod
def unpack(cls, msg):
"""Construct an _OpMsg from raw bytes."""
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
raise ProtocolError("Unsupported OP_MSG flags (%r)" % (flags,))
if flags & cls.CHECKSUM_PRESENT:
raise ProtocolError(
"Unsupported OP_MSG flag checksumPresent: "
"0x%x" % (flags,))
if flags ^ cls.MORE_TO_COME:
raise ProtocolError(
"Unsupported OP_MSG flags: 0x%x" % (flags,))
if first_payload_type != 0:
raise ProtocolError(
"Unsupported OP_MSG payload type (%r)" % (first_payload_type,))
"Unsupported OP_MSG payload type: "
"0x%x" % (first_payload_type,))
if len(msg) != first_payload_size + 5:
raise ProtocolError("Unsupported OP_MSG reply: >1 section")

View File

@ -14,14 +14,19 @@
"""Class to monitor a MongoDB server on a background thread."""
import atexit
import threading
import weakref
from pymongo import common, periodic_executor
from pymongo.errors import OperationFailure
from pymongo.errors import (NotMasterError,
OperationFailure,
_OperationCancelled)
from pymongo.ismaster import IsMaster
from pymongo.monotonic import time as _time
from pymongo.periodic_executor import _shutdown_executors
from pymongo.read_preferences import MovingAverage
from pymongo.server_description import ServerDescription
from pymongo.server_type import SERVER_TYPE
from pymongo.srv_resolver import _SrvResolver
@ -49,9 +54,17 @@ class MonitorBase(object):
self._executor = executor
def _on_topology_gc(dummy=None):
# This prevents GC from waiting 10 seconds for isMaster to complete
# See test_cleanup_executors_on_client_del.
monitor = self_ref()
if monitor:
monitor.gc_safe_close()
# Avoid cycles. When self or topology is freed, stop executor soon.
self_ref = weakref.ref(self, executor.close)
self._topology = weakref.proxy(topology, executor.close)
self._topology = weakref.proxy(topology, _on_topology_gc)
_register(self)
def open(self):
"""Start monitoring, or restart after a fork.
@ -60,12 +73,16 @@ class MonitorBase(object):
"""
self._executor.open()
def gc_safe_close(self):
"""GC safe close."""
self._executor.close()
def close(self):
"""Close and stop monitoring.
open() restarts the monitor after closing.
"""
self._executor.close()
self.gc_safe_close()
def join(self, timeout=None):
"""Wait for the monitor to stop."""
@ -99,72 +116,113 @@ class Monitor(MonitorBase):
self._server_description = server_description
self._pool = pool
self._settings = topology_settings
self._avg_round_trip_time = MovingAverage()
self._listeners = self._settings._pool_options.event_listeners
pub = self._listeners is not None
self._publish = pub and self._listeners.enabled_for_server_heartbeat
self._cancel_context = None
self._rtt_monitor = _RttMonitor(
topology, topology_settings, topology._create_pool_for_monitor(
server_description.address))
self.heartbeater = None
def cancel_check(self):
"""Cancel any concurrent isMaster check.
Note: this is called from a weakref.proxy callback and MUST NOT take
any locks.
"""
context = self._cancel_context
if context:
# Note: we cannot close the socket because doing so may cause
# concurrent reads/writes to hang until a timeout occurs
# (depending on the platform).
context.cancel()
def _start_rtt_monitor(self):
"""Start an _RttMonitor that periodically runs ping."""
# If this monitor is closed directly before (or during) this open()
# call, the _RttMonitor will not be closed. Checking if this monitor
# was closed directly after resolves the race.
self._rtt_monitor.open()
if self._executor._stopped:
self._rtt_monitor.close()
def gc_safe_close(self):
self._executor.close()
self._rtt_monitor.gc_safe_close()
self.cancel_check()
def close(self):
super(Monitor, self).close()
self.gc_safe_close()
self._rtt_monitor.close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._reset_connection()
def _reset_connection(self):
# Clear our pooled connection.
self._pool.reset()
def _run(self):
try:
self._server_description = self._check_with_retry()
prev_sd = self._server_description
try:
self._server_description = self._check_server()
except _OperationCancelled as exc:
# Already closed the connection, wait for the next check.
self._server_description = ServerDescription(
self._server_description.address, error=exc)
if prev_sd.is_server_type_known:
# Immediately retry since we've already waited 500ms to
# discover that we've been cancelled.
self._executor.skip_sleep()
return
self._topology.on_change(self._server_description)
if (self._server_description.is_server_type_known and
self._server_description.topology_version):
self._start_rtt_monitor()
# Immediately check for the next streaming response.
self._executor.skip_sleep()
if self._server_description.error:
# Reset the server pool only after marking the server Unknown.
self._topology.reset_pool(self._server_description.address)
if prev_sd.is_server_type_known:
# Immediately retry on network errors.
self._executor.skip_sleep()
except ReferenceError:
# Topology was garbage-collected.
self.close()
def _check_with_retry(self):
"""Call ismaster once or twice. Reset server's pool on error.
def _check_server(self):
"""Call isMaster or read the next streaming response.
Returns a ServerDescription.
"""
# According to the spec, if an ismaster call fails we reset the
# server's pool. If a server was once connected, change its type
# to Unknown only after retrying once.
address = self._server_description.address
retry = True
if self._server_description.server_type == SERVER_TYPE.Unknown:
retry = False
start = _time()
try:
return self._check_once()
try:
return self._check_once()
except (OperationFailure, NotMasterError) as exc:
# Update max cluster time even when isMaster fails.
self._topology.receive_cluster_time(
exc.details.get('$clusterTime'))
raise
except ReferenceError:
raise
except Exception as error:
error_time = _time() - start
address = self._server_description.address
duration = _time() - start
if self._publish:
self._listeners.publish_server_heartbeat_failed(
address, error_time, error)
default = ServerDescription(address, error=error)
# Reset the server pool only after marking the server Unknown.
self._topology.on_change(default)
self._topology.reset_pool(address)
self._avg_round_trip_time.reset()
if not retry:
# Server type defaults to Unknown.
return default
# Try a second and final time. If it fails return original error.
# Always send metadata: this is a new connection.
start = _time()
try:
return self._check_once()
except ReferenceError:
address, duration, error)
self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
except Exception as error:
error_time = _time() - start
if self._publish:
self._listeners.publish_server_heartbeat_failed(
address, error_time, error)
self._avg_round_trip_time.reset()
return default
self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)
def _check_once(self):
"""A single attempt to call ismaster.
@ -173,35 +231,46 @@ class Monitor(MonitorBase):
"""
address = self._server_description.address
if self._publish:
# PYTHON-2299: Add the "awaited" field to heartbeat events.
self._listeners.publish_server_heartbeat_started(address)
if self._cancel_context and self._cancel_context.cancelled:
self._reset_connection()
with self._pool.get_socket({}) as sock_info:
self._cancel_context = sock_info.cancel_context
response, round_trip_time = self._check_with_socket(sock_info)
self._avg_round_trip_time.add_sample(round_trip_time)
sd = ServerDescription(
address=address,
ismaster=response,
round_trip_time=self._avg_round_trip_time.get())
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
sd = ServerDescription(address, response,
self._rtt_monitor.average())
if self._publish:
self._listeners.publish_server_heartbeat_succeeded(
address, round_trip_time, response)
return sd
def _check_with_socket(self, sock_info):
def _check_with_socket(self, conn):
"""Return (IsMaster, round_trip_time).
Can raise ConnectionFailure or OperationFailure.
"""
cluster_time = self._topology.max_cluster_time()
start = _time()
try:
return (sock_info.ismaster(self._pool.opts.metadata,
self._topology.max_cluster_time()),
_time() - start)
except OperationFailure as exc:
# Update max cluster time even when isMaster fails.
self._topology.receive_cluster_time(
exc.details.get('$clusterTime'))
raise
if conn.more_to_come:
# Read the next streaming isMaster (MongoDB 4.4+).
response = IsMaster(conn._next_reply(), awaitable=True)
elif (conn.performed_handshake and
self._server_description.topology_version):
# Initiate streaming isMaster (MongoDB 4.4+).
response = conn._ismaster(
cluster_time,
self._server_description.topology_version,
self._settings.heartbeat_frequency,
None)
else:
# New connection handshake or polling isMaster (MongoDB <4.4).
response = conn._ismaster(cluster_time, None, None, None)
return response, _time() - start
class SrvMonitor(MonitorBase):
@ -252,3 +321,105 @@ class SrvMonitor(MonitorBase):
self._executor.update_interval(
max(ttl, common.MIN_SRV_RESCAN_INTERVAL))
return seedlist
class _RttMonitor(MonitorBase):
def __init__(self, topology, topology_settings, pool):
"""Maintain round trip times for a server.
The Topology is weakly referenced.
"""
super(_RttMonitor, self).__init__(
topology,
"pymongo_server_rtt_thread",
topology_settings.heartbeat_frequency,
common.MIN_HEARTBEAT_INTERVAL)
self._pool = pool
self._moving_average = MovingAverage()
self._lock = threading.Lock()
def close(self):
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
self._pool.reset()
def add_sample(self, sample):
"""Add a RTT sample."""
with self._lock:
self._moving_average.add_sample(sample)
def average(self):
"""Get the calculated average, or None if no samples yet."""
with self._lock:
return self._moving_average.get()
def reset(self):
"""Reset the average RTT."""
with self._lock:
return self._moving_average.reset()
def _run(self):
try:
# NOTE: This thread is only run when when using the streaming
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = self._ping()
self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
self.close()
except Exception:
self._pool.reset()
def _ping(self):
"""Run an "isMaster" command and return the RTT."""
with self._pool.get_socket({}) as sock_info:
start = _time()
sock_info.ismaster()
return _time() - start
# Close monitors to cancel any in progress streaming checks before joining
# executor threads. For an explanation of how this works see the comment
# about _EXECUTORS in periodic_executor.py.
_MONITORS = set()
def _register(monitor):
ref = weakref.ref(monitor, _unregister)
_MONITORS.add(ref)
def _unregister(monitor_ref):
_MONITORS.remove(monitor_ref)
def _shutdown_monitors():
if _MONITORS is None:
return
# Copy the set. Closing monitors removes them.
monitors = list(_MONITORS)
# Close all monitors.
for ref in monitors:
monitor = ref()
if monitor:
monitor.gc_safe_close()
monitor = None
def _shutdown_resources():
# _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown.
shutdown = _shutdown_monitors
if shutdown:
shutdown()
shutdown = _shutdown_executors
if shutdown:
shutdown()
atexit.register(_shutdown_resources)

View File

@ -16,8 +16,10 @@
import datetime
import errno
import socket
import struct
from bson import _decode_all_selective
from bson.py3compat import PY3
@ -27,15 +29,18 @@ from pymongo.compression_support import decompress, _NO_COMPRESSION
from pymongo.errors import (AutoReconnect,
NotMasterError,
OperationFailure,
ProtocolError)
from pymongo.message import _UNPACK_REPLY
ProtocolError,
NetworkTimeout,
_OperationCancelled)
from pymongo.message import _UNPACK_REPLY, _OpMsg
from pymongo.monotonic import time
from pymongo.socket_checker import _errno_from_exception
_UNPACK_HEADER = struct.Struct("<iiii").unpack
def command(sock, dbname, spec, slave_ok, is_mongos,
def command(sock_info, dbname, spec, slave_ok, is_mongos,
read_preference, codec_options, session, client, check=True,
allowable_errors=None, address=None,
check_keys=False, listeners=None, max_bson_size=None,
@ -45,7 +50,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
compression_ctx=None,
use_op_msg=False,
unacknowledged=False,
user_fields=None):
user_fields=None,
exhaust_allowed=False):
"""Execute a command over the socket, or raise socket.error.
:Parameters:
@ -74,6 +80,7 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
- `user_fields` (optional): Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
- `exhaust_allowed`: True if we should enable OP_MSG exhaustAllowed.
"""
name = next(iter(spec))
ns = dbname + '.$cmd'
@ -108,7 +115,8 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
check_keys = False
if use_op_msg:
flags = 2 if unacknowledged else 0
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, slave_ok, check_keys,
codec_options, ctx=compression_ctx)
@ -133,13 +141,14 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
start = datetime.datetime.now()
try:
sock.sendall(msg)
sock_info.sock.sendall(msg)
if use_op_msg and unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc = {"ok": 1}
else:
reply = receive_message(sock, request_id)
reply = receive_message(sock_info, request_id)
sock_info.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields)
@ -174,11 +183,16 @@ def command(sock, dbname, spec, slave_ok, is_mongos,
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE):
"""Receive a raw BSON message or raise socket.error."""
timeout = sock_info.sock.gettimeout()
if timeout:
deadline = time() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
_receive_data_on_socket(sock, 16))
_receive_data_on_socket(sock_info, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@ -192,11 +206,12 @@ def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
"message size (%r)" % (length, max_message_size))
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(sock, 9))
_receive_data_on_socket(sock_info, 9, deadline))
data = decompress(
_receive_data_on_socket(sock, length - 25), compressor_id)
_receive_data_on_socket(sock_info, length - 25, deadline),
compressor_id)
else:
data = _receive_data_on_socket(sock, length - 16)
data = _receive_data_on_socket(sock_info, length - 16, deadline)
try:
unpack_reply = _UNPACK_REPLY[op_code]
@ -206,18 +221,48 @@ def receive_message(sock, request_id, max_message_size=MAX_MESSAGE_SIZE):
return unpack_reply(data)
_POLL_TIMEOUT = 0.5
def wait_for_read(sock_info, deadline):
"""Block until at least one byte is read, or a timeout, or a cancel."""
context = sock_info.cancel_context
# Only Monitor connections can be cancelled.
if context:
sock = sock_info.sock
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, 'pending') and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
timeout = max(min(deadline - time(), _POLL_TIMEOUT), 0.001)
else:
timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(
sock, read=True, timeout=timeout)
if context.cancelled:
raise _OperationCancelled('isMaster cancelled')
if readable:
return
if deadline and time() > deadline:
raise socket.timeout("timed out")
# memoryview was introduced in Python 2.7 but we only use it on Python 3
# because before 2.7.4 the struct module did not support memoryview:
# https://bugs.python.org/issue10212.
# In Jython, using slice assignment on a memoryview results in a
# NullPointerException.
if not PY3:
def _receive_data_on_socket(sock, length):
def _receive_data_on_socket(sock_info, length, deadline):
buf = bytearray(length)
i = 0
while length:
try:
chunk = sock.recv(length)
wait_for_read(sock_info, deadline)
chunk = sock_info.sock.recv(length)
except (IOError, OSError) as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
@ -231,13 +276,14 @@ if not PY3:
return bytes(buf)
else:
def _receive_data_on_socket(sock, length):
def _receive_data_on_socket(sock_info, length, deadline):
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
chunk_length = sock.recv_into(mv[bytes_read:])
wait_for_read(sock_info, deadline)
chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
except (IOError, OSError) as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue

View File

@ -14,7 +14,6 @@
"""Run a target function on a background thread."""
import atexit
import threading
import time
import weakref
@ -46,6 +45,7 @@ class PeriodicExecutor(object):
self._stopped = False
self._thread = None
self._name = name
self._skip_sleep = False
self._thread_will_exit = False
self._lock = threading.Lock()
@ -109,6 +109,9 @@ class PeriodicExecutor(object):
def update_interval(self, new_interval):
self._interval = new_interval
def skip_sleep(self):
self._skip_sleep = True
def __should_stop(self):
with self._lock:
if self._stopped:
@ -129,12 +132,14 @@ class PeriodicExecutor(object):
raise
deadline = _time() + self._interval
while not self._stopped and _time() < deadline:
time.sleep(self._min_interval)
if self._event:
break # Early wake.
if self._skip_sleep:
self._skip_sleep = False
else:
deadline = _time() + self._interval
while not self._stopped and _time() < deadline:
time.sleep(self._min_interval)
if self._event:
break # Early wake.
self._event = False
@ -177,5 +182,3 @@ def _shutdown_executors():
executor.join(1)
executor = None
atexit.register(_shutdown_executors)

View File

@ -482,6 +482,20 @@ def _speculative_context(all_credentials):
return None
class _CancellationContext(object):
def __init__(self):
self._cancelled = False
def cancel(self):
"""Cancel this context."""
self._cancelled = True
@property
def cancelled(self):
"""Was cancel called?"""
return self._cancelled
class SocketInfo(object):
"""Store a socket with some metadata.
@ -521,13 +535,34 @@ class SocketInfo(object):
# sockets created before the last reset.
self.generation = pool.generation
self.ready = False
self.cancel_context = None
if not pool.handshake:
# This is a Monitor connection.
self.cancel_context = _CancellationContext()
self.opts = pool.opts
self.more_to_come = False
def ismaster(self, metadata, cluster_time, all_credentials=None):
def ismaster(self, all_credentials=None):
return self._ismaster(None, None, None, all_credentials)
def _ismaster(self, cluster_time, topology_version,
heartbeat_frequency, all_credentials):
cmd = SON([('ismaster', 1)])
if not self.performed_handshake:
cmd['client'] = metadata
performing_handshake = not self.performed_handshake
awaitable = False
if performing_handshake:
self.performed_handshake = True
cmd['client'] = self.opts.metadata
if self.compression_settings:
cmd['compression'] = self.compression_settings.compressors
elif topology_version is not None:
cmd['topologyVersion'] = topology_version
cmd['maxAwaitTimeMS'] = int(heartbeat_frequency*1000)
awaitable = True
# If connect_timeout is None there is no timeout.
if self.opts.connect_timeout:
self.sock.settimeout(
self.opts.connect_timeout + heartbeat_frequency)
if self.max_wire_version >= 6 and cluster_time is not None:
cmd['$clusterTime'] = cluster_time
@ -541,7 +576,9 @@ class SocketInfo(object):
if auth_ctx:
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
ismaster = IsMaster(self.command('admin', cmd, publish_events=False))
doc = self.command('admin', cmd, publish_events=False,
exhaust_allowed=awaitable)
ismaster = IsMaster(doc, awaitable=awaitable)
self.is_writable = ismaster.is_writable
self.max_wire_version = ismaster.max_wire_version
self.max_bson_size = ismaster.max_bson_size
@ -550,12 +587,11 @@ class SocketInfo(object):
self.supports_sessions = (
ismaster.logical_session_timeout_minutes is not None)
self.is_mongos = ismaster.server_type == SERVER_TYPE.Mongos
if not self.performed_handshake and self.compression_settings:
if performing_handshake and self.compression_settings:
ctx = self.compression_settings.get_compression_context(
ismaster.compressors)
self.compression_context = ctx
self.performed_handshake = True
self.op_msg_enabled = ismaster.max_wire_version >= 6
if creds:
self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs
@ -565,6 +601,14 @@ class SocketInfo(object):
self.auth_ctx[auth_ctx.credentials] = auth_ctx
return ismaster
def _next_reply(self):
reply = self.receive_message(None)
self.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response()
response_doc = unpacked_docs[0]
helpers._check_command_response(response_doc)
return response_doc
def command(self, dbname, spec, slave_ok=False,
read_preference=ReadPreference.PRIMARY,
codec_options=DEFAULT_CODEC_OPTIONS, check=True,
@ -577,7 +621,8 @@ class SocketInfo(object):
client=None,
retryable_write=False,
publish_events=True,
user_fields=None):
user_fields=None,
exhaust_allowed=False):
"""Execute a command or raise an error.
:Parameters:
@ -635,7 +680,7 @@ class SocketInfo(object):
if self.op_msg_enabled:
self._raise_if_not_writable(unacknowledged)
try:
return command(self.sock, dbname, spec, slave_ok,
return command(self, dbname, spec, slave_ok,
self.is_mongos, read_preference, codec_options,
session, client, check, allowable_errors,
self.address, check_keys, listeners,
@ -645,7 +690,8 @@ class SocketInfo(object):
compression_ctx=self.compression_context,
use_op_msg=self.op_msg_enabled,
unacknowledged=unacknowledged,
user_fields=user_fields)
user_fields=user_fields,
exhaust_allowed=exhaust_allowed)
except OperationFailure:
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
@ -675,8 +721,7 @@ class SocketInfo(object):
If any exception is raised, the socket is closed.
"""
try:
return receive_message(self.sock, request_id,
self.max_message_size)
return receive_message(self, request_id, self.max_message_size)
except BaseException as error:
self._raise_connection_failure(error)
@ -785,19 +830,27 @@ class SocketInfo(object):
def close_socket(self, reason):
"""Close this connection with a reason."""
if self.closed:
return
self._close_socket()
if reason and self.enabled_for_cmap:
self.listeners.publish_connection_closed(
self.address, self.id, reason)
def _close_socket(self):
"""Close this connection."""
if self.closed:
return
self.closed = True
# Avoid exceptions on interpreter shutdown.
if self.cancel_context:
self.cancel_context.cancel()
# Note: We catch exceptions to avoid spurious errors on interpreter
# shutdown.
try:
self.sock.close()
except Exception:
pass
if reason and self.enabled_for_cmap:
self.listeners.publish_connection_closed(
self.address, self.id, reason)
def socket_closed(self):
"""Return True if we know socket has been closed, False otherwise."""
return self.socket_checker.socket_closed(self.sock)
@ -1134,7 +1187,7 @@ class Pool:
sock_info = SocketInfo(sock, self, self.address, conn_id)
if self.handshake:
sock_info.ismaster(self.opts.metadata, None, all_credentials)
sock_info.ismaster(all_credentials)
self.is_writable = sock_info.is_writable
return sock_info

View File

@ -235,7 +235,8 @@ class ServerDescription(object):
(self._election_id == other.election_id) and
(self._primary == other.primary) and
(self._ls_timeout_minutes ==
other.logical_session_timeout_minutes))
other.logical_session_timeout_minutes) and
(self._error == other.error))
return NotImplemented

View File

@ -38,7 +38,10 @@ class SocketChecker(object):
self._poller = None
def select(self, sock, read=False, write=False, timeout=0):
"""Select for reads or writes with a timeout in seconds."""
"""Select for reads or writes with a timeout in seconds.
Returns True if the socket is readable/writable, False on timeout.
"""
while True:
try:
if self._poller:
@ -52,23 +55,31 @@ class SocketChecker(object):
# poll() timeout is in milliseconds. select()
# timeout is in seconds.
res = self._poller.poll(timeout * 1000)
# poll returns a possibly-empty list containing
# (fd, event) 2-tuples for the descriptors that have
# events or errors to report. Return True if the list
# is not empty.
return bool(res)
finally:
self._poller.unregister(sock)
else:
rlist = [sock] if read else []
wlist = [sock] if write else []
res = select.select(rlist, wlist, [sock], timeout)
# select returns a 3-tuple of lists of objects that are
# ready: subsets of the first three arguments. Return
# True if any of the lists are not empty.
return any(res)
except (_SelectError, IOError) as exc:
if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
continue
raise
return res
def socket_closed(self, sock):
"""Return True if we know socket has been closed, False otherwise.
"""
try:
res = self.select(sock, read=True)
return self.select(sock, read=True)
except (RuntimeError, KeyError):
# RuntimeError is raised during a concurrent poll. KeyError
# is raised by unregister if the socket is not in the poller.
@ -84,4 +95,3 @@ class SocketChecker(object):
# Any other exceptions should be attributed to a closed
# or invalid socket.
return True
return any(res)

View File

@ -596,6 +596,10 @@ class Topology(object):
self._process_change(ServerDescription(address, error=error))
# Clear the pool.
server.reset()
# "When a client marks a server Unknown from `Network error when
# reading or writing`_, clients MUST cancel the isMaster check on
# that server and close the current monitoring connection."
server._monitor.cancel_check()
elif issubclass(exc_type, OperationFailure):
# Do not request an immediate check since the server is likely
# shutting down.

View File

@ -226,6 +226,8 @@ class ClientContext(object):
self.connection_attempts.append(
'failed to connect client %r: %s' % (client, exc))
return None
finally:
client.close()
def _init_client(self):
self.client = self._connect(host, port)
@ -602,6 +604,14 @@ class ClientContext(object):
"failCommand fail point must be supported",
func=func)
def require_failCommand_appName(self, func):
"""Run a test only if the server supports the failCommand appName."""
# SERVER-47195
return self._require(lambda: (self.test_commands_enabled and
self.version >= (4, 4, -1)),
"failCommand appName must be supported",
func=func)
def require_tls(self, func):
"""Run a test only if the client can connect over TLS."""
return self._require(lambda: self.tls,
@ -788,10 +798,14 @@ def _get_executors(topology):
executors = []
for server in topology._servers.values():
# Some MockMonitor do not have an _executor.
executors.append(getattr(server._monitor, '_executor', None))
if hasattr(server._monitor, '_executor'):
executors.append(server._monitor._executor)
if hasattr(server._monitor, '_rtt_monitor'):
executors.append(server._monitor._rtt_monitor._executor)
executors.append(topology._Topology__events_executor)
if topology._srv_monitor:
executors.append(topology._srv_monitor._executor)
return [e for e in executors if e is not None]

View File

@ -0,0 +1,130 @@
{
"runOn": [
{
"minServerVersion": "4.0",
"topology": [
"replicaset"
]
},
{
"minServerVersion": "4.2",
"topology": [
"sharded"
]
}
],
"database_name": "sdam-tests",
"collection_name": "cancel-server-check",
"data": [],
"tests": [
{
"description": "Cancel server check",
"clientOptions": {
"retryWrites": true,
"heartbeatFrequencyMS": 10000,
"serverSelectionTimeoutMS": 5000,
"appname": "cancelServerCheckTest"
},
"operations": [
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 1
}
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"closeConnection": true
}
}
}
},
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 2
}
},
"result": {
"insertedId": 2
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 3
}
},
"result": {
"insertedId": 3
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
}
]
}
}
}
]
}

View File

@ -0,0 +1,144 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "find-network-error",
"data": [
{
"_id": 1
},
{
"_id": 2
}
],
"tests": [
{
"description": "Reset server and pool after network error on find",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"find"
],
"closeConnection": true,
"appName": "findNetworkErrorTest"
}
},
"clientOptions": {
"retryWrites": false,
"retryReads": false,
"appname": "findNetworkErrorTest"
},
"operations": [
{
"name": "find",
"object": "collection",
"arguments": {
"filter": {
"_id": 1
}
},
"error": true
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"find": "find-network-error"
},
"command_name": "find",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "find-network-error",
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 5
},
{
"_id": 6
}
]
}
}
}
]
}

View File

@ -0,0 +1,168 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "find-shutdown-error",
"data": [],
"tests": [
{
"description": "Concurrent shutdown error on find",
"clientOptions": {
"retryWrites": false,
"retryReads": false,
"heartbeatFrequencyMS": 500,
"appname": "shutdownErrorFindTest"
},
"operations": [
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 1
}
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"find"
],
"appName": "shutdownErrorFindTest",
"errorCode": 91,
"blockConnection": true,
"blockTimeMS": 500
}
}
}
},
{
"name": "startThread",
"object": "testRunner",
"arguments": {
"name": "thread1"
}
},
{
"name": "startThread",
"object": "testRunner",
"arguments": {
"name": "thread2"
}
},
{
"name": "runOnThread",
"object": "testRunner",
"arguments": {
"name": "thread1",
"operation": {
"name": "find",
"object": "collection",
"arguments": {
"filter": {
"_id": 1
}
},
"error": true
}
}
},
{
"name": "runOnThread",
"object": "testRunner",
"arguments": {
"name": "thread2",
"operation": {
"name": "find",
"object": "collection",
"arguments": {
"filter": {
"_id": 1
}
},
"error": true
}
}
},
{
"name": "waitForThread",
"object": "testRunner",
"arguments": {
"name": "thread1"
}
},
{
"name": "waitForThread",
"object": "testRunner",
"arguments": {
"name": "thread2"
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 4
}
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -0,0 +1,156 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "insert-network-error",
"data": [
{
"_id": 1
},
{
"_id": 2
}
],
"tests": [
{
"description": "Reset server and pool after network error on insert",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 1
},
"data": {
"failCommands": [
"insert"
],
"closeConnection": true,
"appName": "insertNetworkErrorTest"
}
},
"clientOptions": {
"retryWrites": false,
"appname": "insertNetworkErrorTest"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"error": true
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "insert-network-error",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "insert-network-error",
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 5
},
{
"_id": 6
}
]
}
}
}
]
}

View File

@ -0,0 +1,167 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "insert-shutdown-error",
"data": [],
"tests": [
{
"description": "Concurrent shutdown error on insert",
"clientOptions": {
"retryWrites": false,
"heartbeatFrequencyMS": 500,
"appname": "shutdownErrorInsertTest"
},
"operations": [
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 1
}
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"insert"
],
"appName": "shutdownErrorInsertTest",
"errorCode": 91,
"blockConnection": true,
"blockTimeMS": 500
}
}
}
},
{
"name": "startThread",
"object": "testRunner",
"arguments": {
"name": "thread1"
}
},
{
"name": "startThread",
"object": "testRunner",
"arguments": {
"name": "thread2"
}
},
{
"name": "runOnThread",
"object": "testRunner",
"arguments": {
"name": "thread1",
"operation": {
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 2
}
},
"error": true
}
}
},
{
"name": "runOnThread",
"object": "testRunner",
"arguments": {
"name": "thread2",
"operation": {
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 3
}
},
"error": true
}
}
},
{
"name": "waitForThread",
"object": "testRunner",
"arguments": {
"name": "thread1"
}
},
{
"name": "waitForThread",
"object": "testRunner",
"arguments": {
"name": "thread2"
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertOne",
"object": "collection",
"arguments": {
"document": {
"_id": 4
}
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -0,0 +1,245 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "isMaster-command-error",
"data": [],
"tests": [
{
"description": "Command error on Monitor handshake",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "commandErrorHandshakeTest",
"closeConnection": false,
"errorCode": 91
}
},
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 250,
"heartbeatFrequencyMS": 500,
"appname": "commandErrorHandshakeTest"
},
"operations": [
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-command-error",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
},
{
"description": "Command error on Monitor check",
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 1000,
"heartbeatFrequencyMS": 500,
"appname": "commandErrorCheckTest"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "commandErrorCheckTest",
"closeConnection": false,
"blockConnection": true,
"blockTimeMS": 750,
"errorCode": 91
}
}
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-command-error",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "isMaster-command-error",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -0,0 +1,225 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "isMaster-network-error",
"data": [],
"tests": [
{
"description": "Network error on Monitor handshake",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "networkErrorHandshakeTest",
"closeConnection": true
}
},
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 250,
"heartbeatFrequencyMS": 500,
"appname": "networkErrorHandshakeTest"
},
"operations": [
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-network-error",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
},
{
"description": "Network error on Monitor check",
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 250,
"heartbeatFrequencyMS": 500,
"appname": "networkErrorCheckTest"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "networkErrorCheckTest",
"closeConnection": true
}
}
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-network-error",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "isMaster-network-error",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -0,0 +1,359 @@
{
"runOn": [
{
"minServerVersion": "4.4"
}
],
"database_name": "sdam-tests",
"collection_name": "isMaster-timeout",
"data": [],
"tests": [
{
"description": "Network timeout on Monitor handshake",
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "timeoutMonitorHandshakeTest",
"blockConnection": true,
"blockTimeMS": 1000
}
},
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 250,
"heartbeatFrequencyMS": 500,
"appname": "timeoutMonitorHandshakeTest"
},
"operations": [
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-timeout",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
}
]
}
}
},
{
"description": "Network timeout on Monitor check",
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 750,
"heartbeatFrequencyMS": 500,
"appname": "timeoutMonitorCheckTest"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
},
{
"name": "configureFailPoint",
"object": "testRunner",
"arguments": {
"failPoint": {
"configureFailPoint": "failCommand",
"mode": {
"times": 2
},
"data": {
"failCommands": [
"isMaster"
],
"appName": "timeoutMonitorCheckTest",
"blockConnection": true,
"blockTimeMS": 1000
}
}
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "waitForEvent",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 1
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 1
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-timeout",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "isMaster-timeout",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
},
{
"description": "Driver extends timeout while streaming",
"clientOptions": {
"retryWrites": false,
"connectTimeoutMS": 250,
"heartbeatFrequencyMS": 500,
"appname": "extendsTimeoutTest"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
}
},
{
"name": "wait",
"object": "testRunner",
"arguments": {
"ms": 2000
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "ServerMarkedUnknownEvent",
"count": 0
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 0
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "isMaster-timeout",
"documents": [
{
"_id": 1
},
{
"_id": 2
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "isMaster-timeout",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
}
]
}
}
}
]
}

View File

@ -0,0 +1,165 @@
{
"runOn": [
{
"minServerVersion": "4.4",
"topology": [
"replicaset"
]
}
],
"database_name": "sdam-tests",
"collection_name": "test-replSetStepDown",
"data": [
{
"_id": 1
},
{
"_id": 2
}
],
"tests": [
{
"description": "Rediscover quickly after replSetStepDown",
"clientOptions": {
"appname": "replSetStepDownTest",
"heartbeatFrequencyMS": 60000,
"serverSelectionTimeoutMS": 5000,
"w": "majority"
},
"operations": [
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
}
},
{
"name": "recordPrimary",
"object": "testRunner"
},
{
"name": "runAdminCommand",
"object": "testRunner",
"command_name": "replSetFreeze",
"arguments": {
"command": {
"replSetFreeze": 0
},
"readPreference": {
"mode": "Secondary"
}
}
},
{
"name": "runAdminCommand",
"object": "testRunner",
"command_name": "replSetStepDown",
"arguments": {
"command": {
"replSetStepDown": 20,
"secondaryCatchUpPeriodSecs": 20,
"force": false
}
}
},
{
"name": "waitForPrimaryChange",
"object": "testRunner",
"arguments": {
"timeoutMS": 5000
}
},
{
"name": "insertMany",
"object": "collection",
"arguments": {
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
}
},
{
"name": "assertEventCount",
"object": "testRunner",
"arguments": {
"event": "PoolClearedEvent",
"count": 0
}
}
],
"expectations": [
{
"command_started_event": {
"command": {
"insert": "test-replSetStepDown",
"documents": [
{
"_id": 3
},
{
"_id": 4
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
},
{
"command_started_event": {
"command": {
"insert": "test-replSetStepDown",
"documents": [
{
"_id": 5
},
{
"_id": 6
}
]
},
"command_name": "insert",
"database_name": "sdam-tests"
}
}
],
"outcome": {
"collection": {
"data": [
{
"_id": 1
},
{
"_id": 2
},
{
"_id": 3
},
{
"_id": 4
},
{
"_id": 5
},
{
"_id": 6
}
]
}
}
}
]
}

View File

@ -816,6 +816,19 @@ class TestClient(IntegrationTest):
client.close()
self.assertEqual(topology._servers, {})
def test_close_closes_sockets(self):
client = rs_client()
self.addCleanup(client.close)
client.test.test.find_one()
topology = client._topology
client.close()
for server in topology._servers.values():
self.assertFalse(server._pool.sockets)
self.assertTrue(server._monitor._executor._stopped)
self.assertTrue(server._monitor._rtt_monitor._executor._stopped)
self.assertFalse(server._monitor._pool.sockets)
self.assertFalse(server._monitor._rtt_monitor._pool.sockets)
def test_bad_uri(self):
with self.assertRaises(InvalidURI):
MongoClient("http://localhost")
@ -1636,12 +1649,12 @@ class TestExhaustCursor(IntegrationTest):
msg += encode({'$err': 'mock err', 'code': 0})
return message._OpReply.unpack(msg)
saved = sock_info.receive_message
sock_info.receive_message = receive_message
self.assertRaises(OperationFailure, list, cursor)
sock_info.receive_message = saved
# Unpatch the instance.
del sock_info.receive_message
# The socket is returned the pool and it still works.
# The socket is returned to the pool and it still works.
self.assertEqual(200, collection.count_documents({}))
self.assertIn(sock_info, pool.sockets)

View File

@ -50,6 +50,7 @@ from test.utils import (camel_to_snake,
single_client,
TestCreator,
wait_until)
from test.utils_spec_runner import SpecRunnerThread
OBJECT_TYPES = {
@ -70,40 +71,6 @@ OBJECT_TYPES = {
}
class CMAPThread(threading.Thread):
def __init__(self, name):
super(CMAPThread, self).__init__()
self.name = name
self.exc = None
self.setDaemon(True)
self.cond = threading.Condition()
self.ops = []
self.stopped = False
def schedule(self, work):
self.ops.append(work)
with self.cond:
self.cond.notify()
def stop(self):
self.stopped = True
with self.cond:
self.cond.notify()
def run(self):
while not self.stopped or self.ops:
if not self. ops:
with self.cond:
self.cond.wait(10)
if self.ops:
try:
work = self.ops.pop(0)
work()
except Exception as exc:
self.exc = exc
self.stop()
class TestCMAP(IntegrationTest):
# Location of JSON test specifications.
TEST_PATH = os.path.join(
@ -114,7 +81,7 @@ class TestCMAP(IntegrationTest):
def start(self, op):
"""Run the 'start' thread operation."""
target = op['target']
thread = CMAPThread(target)
thread = SpecRunnerThread(target)
thread.start()
self.targets[target] = thread
@ -344,6 +311,8 @@ class TestCMAP(IntegrationTest):
def mock_connect(*args, **kwargs):
raise ConnectionFailure('connect failed')
pool.connect = mock_connect
# Un-patch Pool.connect to break the cyclic reference.
self.addCleanup(delattr, pool, 'connect')
# Attempt to create a new connection.
with self.assertRaisesRegex(ConnectionFailure, 'connect failed'):
@ -374,6 +343,8 @@ class TestCMAP(IntegrationTest):
def mock_connect(*args, **kwargs):
sock_info = connect(*args, **kwargs)
sock_info.check_auth = functools.partial(mock_check_auth, sock_info)
# Un-patch to break the cyclic reference.
self.addCleanup(delattr, sock_info, 'check_auth')
return sock_info
pool.connect = mock_connect
# Un-patch Pool.connect to break the cyclic reference.

View File

@ -28,6 +28,7 @@ from test import (client_context,
IntegrationTest)
from test.utils import (CMAPListener,
ensure_all_connected,
repl_set_step_down,
rs_or_single_client)
@ -38,7 +39,8 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
super(TestConnectionsSurvivePrimaryStepDown, cls).setUpClass()
cls.listener = CMAPListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener],
retryWrites=False)
retryWrites=False,
heartbeatFrequencyMS=500)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for sockets that
@ -84,9 +86,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
for _ in range(batch_size):
cursor.next()
# Force step-down the primary.
res = self.client.admin.command(
SON([("replSetStepDown", 5), ("force", True)]))
self.assertEqual(res["ok"], 1.0)
repl_set_step_down(self.client, replSetStepDown=5, force=True)
# Get next batch of results.
for _ in range(batch_size):
cursor.next()

View File

@ -17,11 +17,13 @@
import os
import sys
import threading
import time
sys.path[0:0] = [""]
from bson import json_util, Timestamp
from pymongo import common
from pymongo import (common,
monitoring)
from pymongo.errors import (AutoReconnect,
ConfigurationError,
NetworkTimeout,
@ -36,11 +38,14 @@ from pymongo.topology_description import TOPOLOGY_TYPE
from pymongo.uri_parser import parse_uri
from test import unittest, IntegrationTest
from test.utils import (assertion_context,
client_context,
Barrier,
get_pool,
server_name_to_type,
rs_or_single_client,
TestCreator,
wait_until)
from test.utils_spec_runner import SpecRunner, SpecRunnerThread
# Location of JSON test specifications.
@ -51,7 +56,9 @@ _TEST_PATH = os.path.join(
class MockMonitor(object):
def __init__(self, server_description, topology, pool, topology_settings):
self._server_description = server_description
self._topology = topology
def cancel_check(self):
pass
def open(self):
pass
@ -305,5 +312,104 @@ class TestIgnoreStaleErrors(IntegrationTest):
client.admin.command('ping')
class TestIntegration(SpecRunner):
# Location of JSON test specifications.
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'discovery_and_monitoring_integration')
def _event_count(self, event):
if event == 'ServerMarkedUnknownEvent':
def marked_unknown(e):
return (isinstance(e, monitoring.ServerDescriptionChangedEvent)
and not e.new_description.is_server_type_known)
return len(self.server_listener.matching(marked_unknown))
# Only support CMAP events for now.
self.assertTrue(event.startswith('Pool') or event.startswith('Conn'))
event_type = getattr(monitoring, event)
return self.pool_listener.event_count(event_type)
def assert_event_count(self, event, count):
"""Run the assertEventCount test operation.
Assert the given event was published exactly `count` times.
"""
self.assertEqual(self._event_count(event), count)
def wait_for_event(self, event, count):
"""Run the waitForEvent test operation.
Wait for a number of events to be published, or fail.
"""
wait_until(lambda: self._event_count(event) >= count,
'find %s %s event(s)' % (count, event))
def configure_fail_point(self, fail_point):
"""Run the configureFailPoint test operation.
"""
self.set_fail_point(fail_point)
self.addCleanup(self.set_fail_point, {
'configureFailPoint': fail_point['configureFailPoint'],
'mode': 'off'})
def run_admin_command(self, command, **kwargs):
"""Run the runAdminCommand test operation.
"""
self.client.admin.command(command, **kwargs)
def record_primary(self):
"""Run the recordPrimary test operation.
"""
self._previous_primary = self.scenario_client.primary
def wait_for_primary_change(self, timeout_ms):
"""Run the waitForPrimaryChange test operation.
"""
def primary_changed():
primary = self.scenario_client.primary
if primary is None:
return False
return primary != self._previous_primary
timeout = timeout_ms/1000.0
wait_until(primary_changed, 'change primary', timeout=timeout)
def wait(self, ms):
"""Run the "wait" test operation.
"""
time.sleep(ms/1000.0)
def start_thread(self, name):
"""Run the 'startThread' thread operation."""
thread = SpecRunnerThread(name)
thread.start()
self.targets[name] = thread
def run_on_thread(self, sessions, collection, name, operation):
"""Run the 'runOnThread' operation."""
thread = self.targets[name]
thread.schedule(lambda: self._run_op(
sessions, collection, operation, False))
def wait_for_thread(self, name):
"""Run the 'waitForThread' operation."""
thread = self.targets[name]
thread.stop()
thread.join()
if thread.exc:
raise thread.exc
def create_spec_test(scenario_def, test, name):
@client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
return run_scenario
test_creator = TestCreator(create_spec_test, TestIntegration, TestIntegration.TEST_PATH)
test_creator.create_tests()
if __name__ == "__main__":
unittest.main()

View File

@ -38,6 +38,7 @@ def get_executors(client):
executors = []
for server in client._topology._servers.values():
executors.append(server._monitor._executor)
executors.append(server._monitor._rtt_monitor._executor)
executors.append(client._kill_cursors_executor)
executors.append(client._topology._Topology__events_executor)
return [e for e in executors if e is not None]
@ -54,7 +55,7 @@ class TestMonitor(IntegrationTest):
def test_cleanup_executors_on_client_del(self):
client = create_client()
executors = get_executors(client)
self.assertEqual(len(executors), 3)
self.assertEqual(len(executors), 4)
# Each executor stores a weakref to itself in _EXECUTORS.
executor_refs = [
@ -71,7 +72,7 @@ class TestMonitor(IntegrationTest):
def test_cleanup_executors_on_client_close(self):
client = create_client()
executors = get_executors(client)
self.assertEqual(len(executors), 3)
self.assertEqual(len(executors), 4)
client.close()

View File

@ -320,7 +320,7 @@ class TestSdamMonitoring(IntegrationTest):
# Expect a single ServerDescriptionChangedEvent for the network error.
marked_unknown_events = self.listener.matching(marked_unknown)
self.assertEqual(len(marked_unknown_events), 1)
self.assertEqual(len(marked_unknown_events), 1, marked_unknown_events)
self.assertIsInstance(
marked_unknown_events[0].new_description.error, expected_error)

View File

@ -754,6 +754,7 @@ class TestSession(IntegrationTest):
# Ensure the collection exists.
self.client.pymongo_test.test_unacked_writes.insert_one({})
client = rs_or_single_client(w=0, event_listeners=[self.listener])
self.addCleanup(client.close)
db = client.pymongo_test
coll = db.test_unacked_writes
ops = [

View File

@ -0,0 +1,183 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the database module."""
import sys
import time
sys.path[0:0] = [""]
from pymongo import monitoring
from test import (client_context,
IntegrationTest,
unittest)
from test.utils import (HeartbeatEventListener,
rs_or_single_client,
ServerEventListener,
wait_until)
class TestStreamingProtocol(IntegrationTest):
@client_context.require_failCommand_appName
def test_failCommand_streaming(self):
listener = ServerEventListener()
hb_listener = HeartbeatEventListener()
client = rs_or_single_client(
event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500,
appName='failingIsMasterTest')
self.addCleanup(client.close)
# Force a connection.
client.admin.command('ping')
address = client.address
listener.reset()
fail_ismaster = {
'configureFailPoint': 'failCommand',
'mode': {'times': 4},
'data': {
'failCommands': ['isMaster'],
'closeConnection': False,
'errorCode': 10107,
'appName': 'failingIsMasterTest',
},
}
with self.fail_point(fail_ismaster):
def _marked_unknown(event):
return (event.server_address == address
and not event.new_description.is_server_type_known)
def _discovered_node(event):
return (event.server_address == address
and not event.previous_description.is_server_type_known
and event.new_description.is_server_type_known)
def marked_unknown():
return len(listener.matching(_marked_unknown)) >= 1
def rediscovered():
return len(listener.matching(_discovered_node)) >= 1
# Topology events are published asynchronously
wait_until(marked_unknown, 'mark node unknown')
wait_until(rediscovered, 'rediscover node')
# Server should be selectable.
client.admin.command('ping')
@client_context.require_failCommand_appName
def test_streaming_rtt(self):
listener = ServerEventListener()
hb_listener = HeartbeatEventListener()
# On Windows, RTT can actually be 0.0 because time.time() only has
# 1-15 millisecond resolution. We need to delay the initial isMaster
# to ensure that RTT is never zero.
name = 'streamingRttTest'
delay_ismaster = {
'configureFailPoint': 'failCommand',
'mode': {'times': 1000},
'data': {
'failCommands': ['isMaster'],
'blockConnection': True,
'blockTimeMS': 20,
# This can be uncommented after SERVER-49220 is fixed.
# 'appName': name,
},
}
with self.fail_point(delay_ismaster):
client = rs_or_single_client(
event_listeners=[listener, hb_listener],
heartbeatFrequencyMS=500,
appName=name)
self.addCleanup(client.close)
# Force a connection.
client.admin.command('ping')
address = client.address
delay_ismaster['data']['blockTimeMS'] = 500
delay_ismaster['data']['appName'] = name
with self.fail_point(delay_ismaster):
def rtt_exceeds_250_ms():
# XXX: Add a public TopologyDescription getter to MongoClient?
topology = client._topology
sd = topology.description.server_descriptions()[address]
return sd.round_trip_time > 0.250
wait_until(rtt_exceeds_250_ms, 'exceed 250ms RTT')
# Server should be selectable.
client.admin.command('ping')
def changed_event(event):
return (event.server_address == address and isinstance(
event, monitoring.ServerDescriptionChangedEvent))
# There should only be one event published, for the initial discovery.
events = listener.matching(changed_event)
self.assertEqual(1, len(events))
self.assertGreater(events[0].new_description.round_trip_time, 0)
@client_context.require_failCommand_appName
def test_monitor_waits_after_server_check_error(self):
hb_listener = HeartbeatEventListener()
client = rs_or_single_client(
event_listeners=[hb_listener], heartbeatFrequencyMS=500,
appName='waitAfterErrorTest')
self.addCleanup(client.close)
# Force a connection.
client.admin.command('ping')
address = client.address
fail_ismaster = {
'mode': {'times': 50},
'data': {
'failCommands': ['isMaster'],
'closeConnection': False,
'errorCode': 91,
# This can be uncommented after SERVER-49220 is fixed.
# 'appName': 'waitAfterErrorTest',
},
}
with self.fail_point(fail_ismaster):
time.sleep(2)
# Server should be selectable.
client.admin.command('ping')
def hb_started(event):
return (isinstance(event, monitoring.ServerHeartbeatStartedEvent)
and event.connection_id == address)
hb_started_events = hb_listener.matching(hb_started)
# Explanation of the expected heartbeat events:
# Time: event
# 0ms: create MongoClient
# 1ms: run monitor handshake, 1
# 2ms: run awaitable isMaster, 2
# 3ms: run configureFailPoint
# 502ms: isMaster fails for the first time with command error
# 1002ms: run monitor handshake, 3
# 1502ms: run monitor handshake, 4
# 2002ms: run monitor handshake, 5
# 2003ms: disable configureFailPoint
# 2004ms: isMaster succeeds, 6
# 2004ms: awaitable isMaster, 7
self.assertGreater(len(hb_started_events), 7)
# This can be reduced to ~15 after SERVER-49220 is fixed.
self.assertLess(len(hb_started_events), 40)
if __name__ == "__main__":
unittest.main()

View File

@ -42,9 +42,11 @@ from test.utils import MockPool, wait_until
class MockMonitor(object):
def __init__(self, server_description, topology, pool, topology_settings):
self._server_description = server_description
self._topology = topology
self.opened = False
def cancel_check(self):
pass
def open(self):
self.opened = True
@ -232,6 +234,7 @@ class TestSingleServerTopology(TopologyTest):
raise AutoReconnect('mock monitor error')
t = create_mock_topology(monitor_class=TestMonitor)
self.addCleanup(t.close)
s = t.select_server(writable_server_selector)
self.assertEqual(125, s.description.round_trip_time)
@ -712,6 +715,7 @@ class TestTopologyErrors(TopologyTest):
raise AutoReconnect('mock monitor error')
t = create_mock_topology(monitor_class=TestMonitor)
self.addCleanup(t.close)
server = wait_for_master(t)
self.assertEqual(1, ismaster_count[0])
generation = server.pool.generation
@ -734,18 +738,21 @@ class TestTopologyErrors(TopologyTest):
'mock monitor error #%s' % (ismaster_count[0],))
t = create_mock_topology(monitor_class=TestMonitor)
self.addCleanup(t.close)
server = wait_for_master(t)
self.assertEqual(1, ismaster_count[0])
self.assertEqual(SERVER_TYPE.Standalone,
server.description.server_type)
# Second ismaster call.
# Second ismaster call, server is marked Unknown, then the monitor
# immediately runs a retry (third ismaster).
t.request_check_all()
# The third ismaster call (the immediate retry) happens sometime soon
# after the failed check triggered by request_check_all. Wait until
# the server becomes known again.
t.select_server(writable_server_selector, 0.250)
self.assertEqual(SERVER_TYPE.Standalone, get_type(t, 'a'))
server = t.select_server(writable_server_selector, 0.250)
self.assertEqual(SERVER_TYPE.Standalone,
server.description.server_type)
self.assertEqual(3, ismaster_count[0])
def test_internal_monitor_error(self):
@ -756,6 +763,7 @@ class TestTopologyErrors(TopologyTest):
raise exception
t = create_mock_topology(monitor_class=TestMonitor)
self.addCleanup(t.close)
with self.assertRaisesRegex(ConnectionFailure, 'internal error'):
t.select_server(any_server_selector,
server_selection_timeout=0.5)

View File

@ -31,12 +31,14 @@ from functools import partial
from bson import json_util, py3compat
from bson.objectid import ObjectId
from bson.son import SON
from pymongo import (MongoClient,
monitoring, read_preferences)
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.monitoring import _SENSITIVE_COMMANDS, ConnectionPoolListener
from pymongo.pool import PoolOptions
from pymongo.pool import (_CancellationContext,
PoolOptions)
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.server_selectors import (any_server_selector,
@ -160,8 +162,7 @@ class OvertCommandListener(EventListener):
super(OvertCommandListener, self).failed(event)
class ServerAndTopologyEventListener(monitoring.ServerListener,
monitoring.TopologyListener):
class _ServerEventListener(object):
"""Listens to all events."""
def __init__(self):
@ -185,6 +186,16 @@ class ServerAndTopologyEventListener(monitoring.ServerListener,
self.results = []
class ServerEventListener(_ServerEventListener,
monitoring.ServerListener):
"""Listens to Server events."""
class ServerAndTopologyEventListener(ServerEventListener,
monitoring.TopologyListener):
"""Listens to Server and Topology events."""
class HeartbeatEventListener(monitoring.ServerHeartbeatListener):
"""Listens to only server heartbeat events."""
@ -200,9 +211,18 @@ class HeartbeatEventListener(monitoring.ServerHeartbeatListener):
def failed(self, event):
self.results.append(event)
def matching(self, matcher):
"""Return the matching events."""
results = self.results[:]
return [event for event in results if matcher(event)]
class MockSocketInfo(object):
def close(self):
def __init__(self):
self.cancel_context = _CancellationContext()
self.more_to_come = False
def close_socket(self, reason):
pass
def __enter__(self):
@ -218,7 +238,7 @@ class MockPool(object):
self._lock = threading.Lock()
self.opts = PoolOptions()
def get_socket(self, all_credentials):
def get_socket(self, all_credentials, checkout=False):
return MockSocketInfo()
def return_socket(self, *args, **kwargs):
@ -677,6 +697,16 @@ def wait_until(predicate, success_description, timeout=10):
time.sleep(interval)
def repl_set_step_down(client, **kwargs):
"""Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
cmd = SON([('replSetStepDown', 1)])
cmd.update(kwargs)
# Unfreeze a secondary to ensure a speedy election.
client.admin.command(
'replSetFreeze', 0, read_preference=ReadPreference.SECONDARY)
client.admin.command(cmd)
def is_mongos(client):
res = client.admin.command('ismaster')
return res.get('msg', '') == 'isdbgrid'

View File

@ -36,6 +36,9 @@ class MockMonitor(object):
def __init__(self, server_description, topology, pool, topology_settings):
pass
def cancel_check(self):
pass
def open(self):
pass

View File

@ -15,7 +15,7 @@
"""Utilities for testing driver specs."""
import copy
import sys
import threading
from bson import decode, encode
@ -48,8 +48,46 @@ from test.utils import (camel_to_snake,
camel_to_snake_args,
camel_to_upper_camel,
CompareType,
CMAPListener,
OvertCommandListener,
rs_client, parse_read_preference)
parse_read_preference,
rs_client,
ServerAndTopologyEventListener,
HeartbeatEventListener)
class SpecRunnerThread(threading.Thread):
def __init__(self, name):
super(SpecRunnerThread, self).__init__()
self.name = name
self.exc = None
self.setDaemon(True)
self.cond = threading.Condition()
self.ops = []
self.stopped = False
def schedule(self, work):
self.ops.append(work)
with self.cond:
self.cond.notify()
def stop(self):
self.stopped = True
with self.cond:
self.cond.notify()
def run(self):
while not self.stopped or self.ops:
if not self. ops:
with self.cond:
self.cond.wait(10)
if self.ops:
try:
work = self.ops.pop(0)
work()
except Exception as exc:
self.exc = exc
self.stop()
class SpecRunner(IntegrationTest):
@ -60,7 +98,8 @@ class SpecRunner(IntegrationTest):
cls.mongos_clients = []
# Speed up the tests by decreasing the heartbeat frequency.
cls.knobs = client_knobs(min_heartbeat_interval=0.1)
cls.knobs = client_knobs(heartbeat_frequency=0.1,
min_heartbeat_interval=0.1)
cls.knobs.enable()
@classmethod
@ -70,7 +109,10 @@ class SpecRunner(IntegrationTest):
def setUp(self):
super(SpecRunner, self).setUp()
self.targets = {}
self.listener = None
self.pool_listener = None
self.server_listener = None
self.maxDiff = None
def _set_fail_point(self, client, command_args):
@ -315,7 +357,8 @@ class SpecRunner(IntegrationTest):
arguments["requests"] = requests
elif arg_name == "session":
arguments['session'] = sessions[arguments['session']]
elif name == 'command' and arg_name == 'command':
elif (name in ('command', 'run_admin_command') and
arg_name == 'command'):
# Ensure the first key is the command name.
ordered_command = SON([(operation['command_name'], 1)])
ordered_command.update(arguments['command'])
@ -343,6 +386,10 @@ class SpecRunner(IntegrationTest):
else:
arguments[c2s] = arguments.pop(arg_name)
if name == 'run_on_thread':
args = {'sessions': sessions, 'collection': collection}
args.update(arguments)
arguments = args
result = cmd(**dict(arguments))
if name == "aggregate":
@ -367,45 +414,48 @@ class SpecRunner(IntegrationTest):
"""Allow encryption spec to override expected error classes."""
return (PyMongoError,)
def _run_op(self, sessions, collection, op, in_with_transaction):
expected_result = op.get('result')
if expect_error(op):
with self.assertRaises(self.allowable_errors(op),
msg=op['name']) as context:
self.run_operation(sessions, collection, op.copy())
if expect_error_message(expected_result):
if isinstance(context.exception, BulkWriteError):
errmsg = str(context.exception.details).lower()
else:
errmsg = str(context.exception).lower()
self.assertIn(expected_result['errorContains'].lower(),
errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result['errorCodeName'],
context.exception.details.get('codeName'))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(
context.exception,
expected_result['errorLabelsContain'])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(
context.exception,
expected_result['errorLabelsOmit'])
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = self.run_operation(sessions, collection, op.copy())
if 'result' in op:
if op['name'] == 'runCommand':
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
def run_operations(self, sessions, collection, ops,
in_with_transaction=False):
for op in ops:
expected_result = op.get('result')
if expect_error(op):
with self.assertRaises(self.allowable_errors(op),
msg=op['name']) as context:
self.run_operation(sessions, collection, op.copy())
if expect_error_message(expected_result):
if isinstance(context.exception, BulkWriteError):
errmsg = str(context.exception.details).lower()
else:
errmsg = str(context.exception).lower()
self.assertIn(expected_result['errorContains'].lower(),
errmsg)
if expect_error_code(expected_result):
self.assertEqual(expected_result['errorCodeName'],
context.exception.details.get('codeName'))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(
context.exception,
expected_result['errorLabelsContain'])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(
context.exception,
expected_result['errorLabelsOmit'])
# Reraise the exception if we're in the with_transaction
# callback.
if in_with_transaction:
raise context.exception
else:
result = self.run_operation(sessions, collection, op.copy())
if 'result' in op:
if op['name'] == 'runCommand':
self.check_command_result(expected_result, result)
else:
self.check_result(expected_result, result)
self._run_op(sessions, collection, op, in_with_transaction)
# TODO: factor with test_command_monitoring.py
def check_events(self, test, listener, session_ids):
@ -517,7 +567,29 @@ class SpecRunner(IntegrationTest):
def run_scenario(self, scenario_def, test):
self.maybe_skip_scenario(test)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
self.addCleanup(self.kill_all_sessions)
self.setup_scenario(scenario_def)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
c[database_name][collection_name].distinct("x")
# Configure the fail point before creating the client.
if 'failPoint' in test:
fp = test['failPoint']
self.set_fail_point(fp)
self.addCleanup(self.set_fail_point, {
'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})
listener = OvertCommandListener()
pool_listener = CMAPListener()
server_listener = ServerAndTopologyEventListener()
# Create a new client, to avoid interference from pooled sessions.
client_options = self.parse_client_options(test['clientOptions'])
# MMAPv1 does not support retryable writes.
@ -526,28 +598,21 @@ class SpecRunner(IntegrationTest):
self.skipTest("MMAPv1 does not support retryWrites=True")
use_multi_mongos = test['useMultipleMongoses']
if client_context.is_mongos and use_multi_mongos:
client = rs_client(client_context.mongos_seeds(),
event_listeners=[listener], **client_options)
client = rs_client(
client_context.mongos_seeds(),
event_listeners=[listener, pool_listener, server_listener],
**client_options)
else:
client = rs_client(event_listeners=[listener], **client_options)
client = rs_client(
event_listeners=[listener, pool_listener, server_listener],
**client_options)
self.scenario_client = client
self.listener = listener
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addCleanup(client.close)
# Kill all sessions before and after each test to prevent an open
# transaction (from a test failure) from blocking collection/database
# operations during test set up and tear down.
self.kill_all_sessions()
self.addCleanup(self.kill_all_sessions)
database_name = self.get_scenario_db_name(scenario_def)
collection_name = self.get_scenario_coll_name(scenario_def)
self.setup_scenario(scenario_def)
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
c[database_name][collection_name].distinct("x")
# Create session0 and session1.
sessions = {}
session_ids = {}
@ -572,14 +637,6 @@ class SpecRunner(IntegrationTest):
self.addCleanup(end_sessions, sessions)
if 'failPoint' in test:
fp = test['failPoint']
self.set_fail_point(fp)
self.addCleanup(self.set_fail_point, {
'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})
listener.results.clear()
collection = client[database_name][collection_name]
self.run_test_ops(sessions, collection, test)
@ -613,6 +670,7 @@ class SpecRunner(IntegrationTest):
# CompareType(Binary) doesn't work.
self.assertEqual(wrap_types(expected_c['data']), actual_data)
def expect_any_error(op):
if isinstance(op, dict):
return op.get('error')