PYTHON-1674 Support retryable reads

Add retryReads URI option that defaults to True.
Supported read operations will be retried once after transient
network, election, and shutdown errors on MongoDB 3.6+.
Supported operations are:
listCollections, listIndexes, and listDatabases
distinct
count, estimated_document_count, count_documents
aggregate (not including $out)
find (only for the initial find command, getMore commands are not
retried).
ChangeStreams: watch (initial aggregate command).
GridFS read APIs.

Test changes:
Add retryable reads spec test runner.
Disable retryable reads in network error tests.
This commit is contained in:
Shane Harvey 2019-04-12 15:15:00 -07:00
parent 0ef728acd1
commit a15266083b
16 changed files with 431 additions and 139 deletions

View File

@ -33,6 +33,13 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include:
- The ``retryWrites`` URI option now defaults to ``True``. Supported write
operations that fail with a retryable error will automatically be retried one
time, with at-most-once semantics.
- Support for retryable reads and the ``retryReads`` URI option which is
enabled by default. See the :class:`~pymongo.mongo_client.MongoClient`
documentation for details.
Now that supported operations are retried automatically and transparently,
users should consider adjusting any custom retry logic to prevent
an application from inadvertently retrying for too long.
.. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst

View File

@ -118,8 +118,8 @@ class ChangeStream(object):
"""
read_preference = self._target._read_preference_for(session)
client = self._database.client
with client._socket_for_reads(
read_preference, session) as (sock_info, slave_ok):
def _cmd(session, server, sock_info, slave_ok):
pipeline = self._full_pipeline()
cmd = SON([("aggregate", self._aggregation_target),
("pipeline", pipeline),
@ -160,6 +160,8 @@ class ChangeStream(object):
max_await_time_ms=self._max_await_time_ms,
session=session, explicit_session=explicit_session)
return client._retryable_read(_cmd, read_preference, session)
def _create_cursor(self):
with self._database.client._tmp_session(self._session, close=False) as s:
return self._run_aggregation_cmd(

View File

@ -164,6 +164,7 @@ class ClientOptions(object):
self.__heartbeat_frequency = options.get(
'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES)
self.__retry_reads = options.get('retryreads', common.RETRY_READS)
self.__server_selector = options.get(
'server_selector', any_server_selector)
@ -235,3 +236,8 @@ class ClientOptions(object):
def retry_writes(self):
"""If this instance should retry supported write operations."""
return self.__retry_writes
@property
def retry_reads(self):
"""If this instance should retry supported read operations."""
return self.__retry_reads

View File

@ -188,12 +188,6 @@ class Collection(common.BaseObject):
return self.__database.client._socket_for_reads(
self._read_preference_for(session), session)
def _socket_for_primary_reads(self, session):
read_pref = ((session and session._txn_read_preference())
or ReadPreference.PRIMARY)
return self.__database.client._socket_for_reads(
read_pref, session), read_pref
def _socket_for_writes(self, session):
return self.__database.client._socket_for_writes(session)
@ -1572,7 +1566,7 @@ class Collection(common.BaseObject):
def _count(self, cmd, collation=None, session=None):
"""Internal count helper."""
with self._socket_for_reads(session) as (sock_info, slave_ok):
def _cmd(session, server, sock_info, slave_ok):
res = self._command(
sock_info,
cmd,
@ -1582,9 +1576,12 @@ class Collection(common.BaseObject):
read_concern=self.read_concern,
collation=collation,
session=session)
if res.get("errmsg", "") == "ns missing":
return 0
return int(res["n"])
if res.get("errmsg", "") == "ns missing":
return 0
return int(res["n"])
return self.__database.client._retryable_read(
_cmd, self._read_preference_for(session), session)
def _aggregate_one_result(
self, sock_info, slave_ok, cmd, collation=None, session=None):
@ -1693,12 +1690,16 @@ class Collection(common.BaseObject):
kwargs["hint"] = helpers._index_document(kwargs["hint"])
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok):
def _cmd(session, server, sock_info, slave_ok):
result = self._aggregate_one_result(
sock_info, slave_ok, cmd, collation, session)
if not result:
return 0
return result['n']
if not result:
return 0
return result['n']
return self.__database.client._retryable_read(
_cmd, self._read_preference_for(session), session)
def count(self, filter=None, session=None, **kwargs):
"""**DEPRECATED** - Get the number of documents in this collection.
@ -2149,8 +2150,10 @@ class Collection(common.BaseObject):
codec_options = CodecOptions(SON)
coll = self.with_options(codec_options=codec_options,
read_preference=ReadPreference.PRIMARY)
sock_ctx, read_pref = self._socket_for_primary_reads(session)
with sock_ctx as (sock_info, slave_ok):
read_pref = ((session and session._txn_read_preference())
or ReadPreference.PRIMARY)
def _cmd(session, server, sock_info, slave_ok):
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
if sock_info.max_wire_version > 2:
with self.__database.client._tmp_session(session, False) as s:
@ -2179,6 +2182,9 @@ class Collection(common.BaseObject):
# will never be a getMore call.
return CommandCursor(coll, cursor, sock_info.address)
return self.__database.client._retryable_read(
_cmd, read_pref, session)
def index_information(self, session=None):
"""Get information on this collection's indexes.
@ -2275,10 +2281,11 @@ class Collection(common.BaseObject):
"useCursor", kwargs.pop("useCursor"))
batch_size = common.validate_non_negative_integer_or_none(
"batchSize", kwargs.pop("batchSize", None))
dollar_out = pipeline and '$out' in pipeline[-1]
# If the server does not support the "cursor" option we
# ignore useCursor and batchSize.
with self._socket_for_reads(session) as (sock_info, slave_ok):
dollar_out = pipeline and '$out' in pipeline[-1]
def _cmd(session, server, sock_info, slave_ok):
if use_cursor:
if "cursor" not in kwargs:
kwargs["cursor"] = {}
@ -2336,6 +2343,10 @@ class Collection(common.BaseObject):
max_await_time_ms=max_await_time_ms,
session=session, explicit_session=explicit_session)
return self.__database.client._retryable_read(
_cmd, self._read_preference_for(session), session,
retryable=not dollar_out)
def aggregate(self, pipeline, session=None, **kwargs):
"""Perform an aggregation using the aggregation framework on this
collection.
@ -2681,12 +2692,53 @@ class Collection(common.BaseObject):
kwargs["query"] = filter
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok):
return self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation,
session=session,
user_fields={"values": 1})["values"]
def _cmd(session, server, sock_info, slave_ok):
return self._command(
sock_info, cmd, slave_ok, read_concern=self.read_concern,
collation=collation, session=session,
user_fields={"values": 1})["values"]
return self.__database.client._retryable_read(
_cmd, self._read_preference_for(session), session)
def _map_reduce(self, map, reduce, out, session, read_pref, **kwargs):
"""Internal mapReduce helper."""
cmd = SON([("mapReduce", self.__name),
("map", map),
("reduce", reduce),
("out", out)])
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
inline = 'inline' in out
if inline:
user_fields = {'results': 1}
else:
user_fields = None
read_pref = ((session and session._txn_read_preference())
or read_pref)
with self.__database.client._socket_for_reads(read_pref, session) as (
sock_info, slave_ok):
if (sock_info.max_wire_version >= 4 and
('readConcern' not in cmd) and
inline):
read_concern = self.read_concern
else:
read_concern = None
if 'writeConcern' not in cmd and not inline:
write_concern = self._write_concern_for(session)
else:
write_concern = None
return self._command(
sock_info, cmd, slave_ok, read_pref,
read_concern=read_concern,
write_concern=write_concern,
collation=collation, session=session,
user_fields=user_fields)
def map_reduce(self, map, reduce, out, full_response=False, session=None,
**kwargs):
@ -2747,36 +2799,8 @@ class Collection(common.BaseObject):
raise TypeError("'out' must be an instance of "
"%s or a mapping" % (string_type.__name__,))
cmd = SON([("mapreduce", self.__name),
("map", map),
("reduce", reduce),
("out", out)])
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
inline = 'inline' in cmd['out']
sock_ctx, read_pref = self._socket_for_primary_reads(session)
with sock_ctx as (sock_info, slave_ok):
if (sock_info.max_wire_version >= 4 and 'readConcern' not in cmd and
inline):
read_concern = self.read_concern
else:
read_concern = None
if 'writeConcern' not in cmd and not inline:
write_concern = self._write_concern_for(session)
else:
write_concern = None
if inline:
user_fields = {'results': 1}
else:
user_fields = None
response = self._command(
sock_info, cmd, slave_ok, read_pref,
read_concern=read_concern,
write_concern=write_concern,
collation=collation, session=session,
user_fields=user_fields)
response = self._map_reduce(map, reduce, out, session,
ReadPreference.PRIMARY, **kwargs)
if full_response or not response.get('result'):
return response
@ -2822,23 +2846,8 @@ class Collection(common.BaseObject):
Added the `collation` option.
"""
cmd = SON([("mapreduce", self.__name),
("map", map),
("reduce", reduce),
("out", {"inline": 1})])
user_fields = {'results': 1}
collation = validate_collation_or_none(kwargs.pop('collation', None))
cmd.update(kwargs)
with self._socket_for_reads(session) as (sock_info, slave_ok):
if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd:
res = self._command(sock_info, cmd, slave_ok,
read_concern=self.read_concern,
collation=collation, session=session,
user_fields=user_fields)
else:
res = self._command(sock_info, cmd, slave_ok,
collation=collation, session=session,
user_fields=user_fields)
res = self._map_reduce(map, reduce, {"inline": 1}, session,
self.read_preference, **kwargs)
if full_response:
return res

View File

@ -128,9 +128,8 @@ class CommandCursor(object):
client = self.__collection.database.client
try:
response = client._send_message_with_response(
operation, address=self.__address,
unpack_res=self._unpack_response)
response = client._run_operation_with_response(
operation, self._unpack_response, address=self.__address)
except OperationFailure:
kill()
raise

View File

@ -91,6 +91,9 @@ LOCAL_THRESHOLD_MS = 15
# Default value for retryWrites.
RETRY_WRITES = True
# Default value for retryReads.
RETRY_READS = True
# mongod/s 2.6 and above return code 59 when a command doesn't exist.
COMMAND_NOT_FOUND_CODES = (59,)
@ -569,6 +572,7 @@ URI_OPTIONS_VALIDATOR_MAP = {
'readpreference': validate_read_preference_mode,
'readpreferencetags': validate_read_preference_tags,
'replicaset': validate_string_or_none,
'retryreads': validate_boolean_or_string,
'retrywrites': validate_boolean_or_string,
'serverselectiontimeoutms': validate_timeout_or_zero,
'sockettimeoutms': validate_timeout_or_none,

View File

@ -937,10 +937,11 @@ class Cursor(object):
Can raise ConnectionFailure.
"""
client = self.__collection.database.client
try:
response = client._send_message_with_response(
operation, exhaust=self.__exhaust, address=self.__address,
unpack_res=self._unpack_response)
response = client._run_operation_with_response(
operation, self._unpack_response, exhaust=self.__exhaust,
address=self.__address)
except OperationFailure:
self.__killed = True

View File

@ -657,6 +657,22 @@ class Database(common.BaseObject):
check, allowable_errors, read_preference,
codec_options, session=session, **kwargs)
def _retryable_read_command(self, command, value=1, check=True,
allowable_errors=None, read_preference=None,
codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs):
"""Same as command but used for retryable read commands."""
if read_preference is None:
read_preference = ((session and session._txn_read_preference())
or ReadPreference.PRIMARY)
def _cmd(session, server, sock_info, slave_ok):
return self._command(sock_info, command, slave_ok, value,
check, allowable_errors, read_preference,
codec_options, session=session, **kwargs)
return self.__client._retryable_read(
_cmd, read_preference, session)
def _list_collections(self, sock_info, slave_okay, session,
read_preference, **kwargs):
"""Internal listCollections helper."""
@ -718,12 +734,15 @@ class Database(common.BaseObject):
kwargs['filter'] = filter
read_pref = ((session and session._txn_read_preference())
or ReadPreference.PRIMARY)
with self.__client._socket_for_reads(
read_pref, session) as (sock_info, slave_okay):
def _cmd(session, server, sock_info, slave_okay):
return self._list_collections(
sock_info, slave_okay, session, read_preference=read_pref,
**kwargs)
return self.__client._retryable_read(
_cmd, read_pref, session)
def list_collection_names(self, session=None, filter=None, **kwargs):
"""Get a list of all the collection names in this database.

View File

@ -276,6 +276,36 @@ class MongoClient(common.BaseObject):
pipeline operator and any operation with an unacknowledged write
concern (e.g. {w: 0})). See
https://github.com/mongodb/specifications/blob/master/source/retryable-writes/retryable-writes.rst
- `retryReads`: (boolean) Whether supported read operations
executed within this MongoClient will be retried once after a
network error on MongoDB 3.6+. Defaults to ``True``.
The supported read operations are:
:meth:`~pymongo.collection.Collection.find`,
:meth:`~pymongo.collection.Collection.find_one`,
:meth:`~pymongo.collection.Collection.aggregate` without ``$out``,
:meth:`~pymongo.collection.Collection.distinct`,
:meth:`~pymongo.collection.Collection.count`,
:meth:`~pymongo.collection.Collection.estimated_document_count`,
:meth:`~pymongo.collection.Collection.count_documents`,
:meth:`pymongo.collection.Collection.watch`,
:meth:`~pymongo.collection.Collection.list_indexes`,
:meth:`pymongo.database.Database.watch`,
:meth:`~pymongo.database.Database.list_collections`,
:meth:`pymongo.mongo_client.MongoClient.watch`,
and :meth:`~pymongo.mongo_client.MongoClient.list_databases`.
Unsupported read operations include, but are not limited to:
:meth:`~pymongo.collection.Collection.map_reduce`,
:meth:`~pymongo.collection.Collection.inline_map_reduce`,
:meth:`~pymongo.database.Database.command`,
and any getMore operation on a cursor.
Enabling retryable reads makes applications more resilient to
transient errors such as network failures, database upgrades, and
replica set failovers. For an exact definition of which errors
trigger a retry, see the `retryable reads specification
<https://github.com/mongodb/specifications/blob/master/source/retryable-reads/retryable-reads.rst>`_.
- `socketKeepAlive`: (boolean) **DEPRECATED** Whether to send
periodic keep-alive packets on connected sockets. Defaults to
``True``. Disabling it is not recommended, see
@ -441,7 +471,8 @@ class MongoClient(common.BaseObject):
.. mongodoc:: connections
.. versionchanged:: 4.0
.. versionchanged:: 3.9
Added the ``retryReads`` keyword argument and URI option.
Added the ``tlsInsecure`` keyword argument and URI option.
The following keyword arguments and URI options were deprecated:
@ -1032,6 +1063,11 @@ class MongoClient(common.BaseObject):
"""If this instance should retry supported write operations."""
return self.__options.retry_writes
@property
def retry_reads(self):
"""If this instance should retry supported write operations."""
return self.__options.retry_reads
def _is_writable(self):
"""Attempt to connect to a writable server, or return False.
"""
@ -1173,6 +1209,24 @@ class MongoClient(common.BaseObject):
server = self._select_server(writable_server_selector, session)
return self._get_socket(server, session)
@contextlib.contextmanager
def _slaveok_for_server(self, read_preference, server, session,
exhaust=False):
assert read_preference is not None, "read_preference must not be None"
# Get a socket for a server matching the read preference, and yield
# sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to
# mongods with topology type Single. If the server type is Mongos,
# follow the rules for passing read preference to mongos, even for
# topology type Single."
# Thread safe: if the type is single it cannot change.
topology = self._get_topology()
single = topology.description.topology_type == TOPOLOGY_TYPE.Single
with self._get_socket(server, session, exhaust=exhaust) as sock_info:
slave_ok = (single and not sock_info.is_mongos) or (
read_preference != ReadPreference.PRIMARY)
yield sock_info, slave_ok
@contextlib.contextmanager
def _socket_for_reads(self, read_preference, session):
assert read_preference is not None, "read_preference must not be None"
@ -1191,25 +1245,25 @@ class MongoClient(common.BaseObject):
read_preference != ReadPreference.PRIMARY)
yield sock_info, slave_ok
def _send_message_with_response(self, operation, exhaust=False,
address=None, unpack_res=None):
"""Send a message to MongoDB and return a Response.
def _run_operation_with_response(self, operation, unpack_res,
exhaust=False, address=None):
"""Run a _Query/_GetMore operation and return a Response.
:Parameters:
- `operation`: a _Query or _GetMore object.
- `read_preference` (optional): A ReadPreference.
- `unpack_res`: A callable that decodes the wire protocol response.
- `exhaust` (optional): If True, the socket used stays checked out.
It is returned along with its Pool in the Response.
- `address` (optional): Optional address when sending a message
to a specific server, used for getMore.
"""
server = self._select_server(
operation.read_preference, operation.session, address=address)
if operation.exhaust_mgr:
server = self._select_server(
operation.read_preference, operation.session, address=address)
with self._reset_on_error(server.description.address,
operation.session):
return server.send_message_with_response(
return server.run_operation_with_response(
operation.exhaust_mgr.sock,
operation,
True,
@ -1217,24 +1271,21 @@ class MongoClient(common.BaseObject):
exhaust,
unpack_res)
# If this is a direct connection to a mongod, *always* set the slaveOk
# bit. See bullet point 2 in server-selection.rst#topology-type-single.
topology = self._get_topology()
set_slave_ok = (
topology.description.topology_type == TOPOLOGY_TYPE.Single
and server.description.server_type != SERVER_TYPE.Mongos) or (
operation.read_preference != ReadPreference.PRIMARY)
with self._get_socket(server, operation.session,
exhaust=exhaust) as sock_info:
return server.send_message_with_response(
def _cmd(session, server, sock_info, slave_ok):
return server.run_operation_with_response(
sock_info,
operation,
set_slave_ok,
slave_ok,
self._event_listeners,
exhaust,
unpack_res)
return self._retryable_read(
_cmd, operation.read_preference, operation.session,
address=address,
retryable=isinstance(operation, message._Query),
exhaust=exhaust)
@contextlib.contextmanager
def _reset_on_error(self, server_address, session):
"""On "not master" or "node is recovering" errors reset the server
@ -1354,6 +1405,58 @@ class MongoClient(common.BaseObject):
retrying = True
last_error = exc
def _retryable_read(self, func, read_pref, session, address=None,
retryable=True, exhaust=False):
"""Execute an operation with at most one consecutive retries
Returns func()'s return value on success. On error retries the same
command once.
Re-raises any exception thrown by func().
"""
retryable = (retryable and
self.retry_reads
and not (session and session._in_transaction))
last_error = None
retrying = False
while True:
try:
server = self._select_server(
read_pref, session, address=address)
if not server.description.retryable_reads_supported:
retryable = False
with self._slaveok_for_server(read_pref, server, session,
exhaust=exhaust) as (sock_info,
slave_ok):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
raise last_error
return func(session, server, sock_info, slave_ok)
except ServerSelectionTimeoutError:
if retrying:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except ConnectionFailure as exc:
if not retryable or retrying:
raise
retrying = True
last_error = exc
except OperationFailure as exc:
if not retryable or retrying:
raise
if exc.code not in helpers._RETRYABLE_ERROR_CODES:
raise
retrying = True
last_error = exc
def _retryable_write(self, retryable, func, session):
"""Internal retryable write helper."""
with self._tmp_session(session) as s:
@ -1752,7 +1855,7 @@ class MongoClient(common.BaseObject):
cmd = SON([("listDatabases", 1)])
cmd.update(kwargs)
admin = self._database_default_options("admin")
res = admin.command(cmd, session=session)
res = admin._retryable_read_command(cmd, session=session)
# listDatabases doesn't return a cursor (yet). Fake one.
cursor = {
"id": 0,

View File

@ -14,8 +14,6 @@
"""Communicate with one MongoDB server in a topology."""
import contextlib
from datetime import datetime
from pymongo.errors import NotMasterError, OperationFailure
@ -67,7 +65,7 @@ class Server(object):
"""Check the server's state soon."""
self._monitor.request_check()
def send_message_with_response(
def run_operation_with_response(
self,
sock_info,
operation,
@ -75,17 +73,19 @@ class Server(object):
listeners,
exhaust,
unpack_res):
"""Send a message to MongoDB and return a Response object.
"""Run a _Query or _GetMore operation and return a Response object.
Can raise ConnectionFailure.
This method is used only to run _Query/_GetMore operations from
cursors.
Can raise ConnectionFailure, OperationFailure, etc.
:Parameters:
- `operation`: A _Query or _GetMore object.
- `set_slave_okay`: Pass to operation.get_message.
- `all_credentials`: dict, maps auth source to MongoCredential.
- `listeners`: Instance of _EventListeners or None.
- `exhaust` (optional): If True, the socket used stays checked out.
It is returned along with its Pool in the Response.
- `exhaust`: If True, then this is an exhaust cursor operation.
- `unpack_res`: A callable that decodes the wire protocol response.
"""
duration = None
publish = listeners.enabled_for_commands

View File

@ -202,5 +202,10 @@ class ServerDescription(object):
self._ls_timeout_minutes is not None and
self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary))
@property
def retryable_reads_supported(self):
"""Checks if this server supports retryable writes."""
return self._max_wire_version >= 6
# For unittesting only. Use under no circumstances!
_host_to_round_trip_time = {}

View File

@ -37,7 +37,7 @@ from pymongo import auth, message
from pymongo.common import _UUID_REPRESENTATIONS
from pymongo.command_cursor import CommandCursor
from pymongo.compression_support import _HAVE_SNAPPY
from pymongo.cursor import CursorType
from pymongo.cursor import Cursor, CursorType
from pymongo.database import Database
from pymongo.errors import (AutoReconnect,
ConfigurationError,
@ -1161,7 +1161,7 @@ class TestClient(IntegrationTest):
def test_exhaust_network_error(self):
# When doing an exhaust query, the socket stays checked out on success
# but must be checked in on error to avoid semaphore leaks.
client = rs_or_single_client(maxPoolSize=1)
client = rs_or_single_client(maxPoolSize=1, retryReads=False)
collection = client.pymongo_test.test
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
@ -1188,7 +1188,8 @@ class TestClient(IntegrationTest):
# Get a client with one socket so we detect if it's leaked.
c = connected(rs_or_single_client(maxPoolSize=1,
waitQueueTimeoutMS=1))
waitQueueTimeoutMS=1,
retryReads=False))
# Simulate an authenticate() call on a different socket.
credentials = auth._build_credentials_tuple(
@ -1220,15 +1221,17 @@ class TestClient(IntegrationTest):
def test_stale_getmore(self):
# A cursor is created, but its member goes down and is removed from
# the topology before the getMore message is sent. Test that
# MongoClient._send_message_with_response handles the error.
# MongoClient._run_operation_with_response handles the error.
with self.assertRaises(AutoReconnect):
client = rs_client(connect=False,
serverSelectionTimeoutMS=100)
client._send_message_with_response(
client._run_operation_with_response(
operation=message._GetMore('pymongo_test', 'collection',
101, 1234, client.codec_options,
ReadPreference.PRIMARY,
None, client, None, None),
unpack_res=Cursor(
client.pymongo_test.collection)._unpack_response,
address=('not-a-member', 27017))
def test_heartbeat_frequency_ms(self):
@ -1419,7 +1422,8 @@ class TestExhaustCursor(IntegrationTest):
def test_exhaust_query_network_error(self):
# When doing an exhaust query, the socket stays checked out on success
# but must be checked in on error to avoid semaphore leaks.
client = connected(rs_or_single_client(maxPoolSize=1))
client = connected(rs_or_single_client(maxPoolSize=1,
retryReads=False))
collection = client.pymongo_test.test
pool = get_pool(client)
pool._check_interval_seconds = None # Never check.
@ -1576,7 +1580,8 @@ class TestMongoClientFailover(MockClientTest):
members=['a:1', 'b:2', 'c:3'],
mongoses=[],
host='b:2', # Pass a secondary.
replicaSet='rs')
replicaSet='rs',
retryReads=False)
wait_until(lambda: len(c.nodes) == 3, 'connect')
@ -1604,7 +1609,8 @@ class TestMongoClientFailover(MockClientTest):
mongoses=[],
host='a:1',
replicaSet='rs',
connect=False)
connect=False,
retryReads=False)
# Set host-specific information so we can test whether it is reset.
c.set_wire_version_range('a:1', 2, 6)

View File

@ -325,6 +325,15 @@ class ReadPrefTester(MongoClient):
self.record_a_read(sock_info.address)
yield sock_info, slave_ok
@contextlib.contextmanager
def _slaveok_for_server(self, read_preference, server, session,
exhaust=False):
context = super(ReadPrefTester, self)._slaveok_for_server(
read_preference, server, session, exhaust=exhaust)
with context as (sock_info, slave_ok):
self.record_a_read(sock_info.address)
yield sock_info, slave_ok
def record_a_read(self, address):
server = self._get_topology().select_server_by_address(address, 0)
self.has_read_from.add(server)

View File

@ -0,0 +1,83 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test retryable reads spec."""
import os
import sys
sys.path[0:0] = [""]
from pymongo.mongo_client import MongoClient
from test import unittest, client_context, PyMongoTestCase
from test.utils import TestCreator
from test.utils_spec_runner import SpecRunner
# Location of JSON test specifications.
_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'retryable_reads')
class TestClientOptions(PyMongoTestCase):
def test_default(self):
client = MongoClient(connect=False)
self.assertEqual(client.retry_reads, True)
def test_kwargs(self):
client = MongoClient(retryReads=True, connect=False)
self.assertEqual(client.retry_reads, True)
client = MongoClient(retryReads=False, connect=False)
self.assertEqual(client.retry_reads, False)
def test_uri(self):
client = MongoClient('mongodb://h/?retryReads=true', connect=False)
self.assertEqual(client.retry_reads, True)
client = MongoClient('mongodb://h/?retryReads=false', connect=False)
self.assertEqual(client.retry_reads, False)
class TestSpec(SpecRunner):
@classmethod
@client_context.require_version_min(4, 0)
def setUpClass(cls):
super(TestSpec, cls).setUpClass()
if client_context.is_mongos and client_context.version[:2] <= (4, 0):
raise unittest.SkipTest("4.0 mongos does not support failCommand")
def maybe_skip_scenario(self, test):
super(TestSpec, self).maybe_skip_scenario(test)
skip_names = [
'listCollectionObjects', 'listIndexNames', 'listDatabaseObjects']
for name in skip_names:
if name.lower() in test['description'].lower():
raise unittest.SkipTest(
'PyMongo does not support %s' % (name,))
def create_test(scenario_def, test, name):
@client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
return run_scenario
test_creator = TestCreator(create_test, TestSpec, _TEST_PATH)
test_creator.create_tests()
if __name__ == "__main__":
unittest.main()

View File

@ -141,7 +141,7 @@ class ScenarioDict(dict):
def convert(v):
if isinstance(v, collections.Mapping):
return ScenarioDict(v)
if isinstance(v, py3compat.string_type):
if isinstance(v, (py3compat.string_type, bytes)):
return v
if isinstance(v, collections.Sequence):
return [convert(item) for item in v]
@ -264,8 +264,10 @@ class TestCreator(object):
# Construct test from scenario.
for test_def in scenario_def['tests']:
test_name = 'test_%s_%s_%s' % (
dirname, test_type,
str(test_def['description'].replace(" ", "_")))
dirname,
test_type.replace("-", "_").replace('.', '_'),
str(test_def['description'].replace(" ", "_").replace(
'.', '_')))
new_test = self._create_test(
scenario_def, test_def, test_name)

View File

@ -175,6 +175,10 @@ class SpecRunner(IntegrationTest):
name = camel_to_snake(operation['name'])
if name == 'run_command':
name = 'command'
elif name == 'download_by_name':
name = 'open_download_stream_by_name'
elif name == 'download':
name = 'open_download_stream'
def parse_options(opts):
if 'readPreference' in opts:
@ -197,14 +201,21 @@ class SpecRunner(IntegrationTest):
**dict(parse_options(operation['collectionOptions'])))
object_name = operation['object']
objects = {
'client': database.client,
'database': database,
'collection': collection,
'testRunner': self
}
objects.update(sessions)
obj = objects[object_name]
if object_name == 'gridfsbucket':
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = GridFSBucket(
database, bucket_name=collection.name,
disable_md5=True)
else:
objects = {
'client': database.client,
'database': database,
'collection': collection,
'testRunner': self
}
objects.update(sessions)
obj = objects[object_name]
# Combine arguments with options and handle special cases.
arguments = operation.get('arguments', {})
@ -244,6 +255,8 @@ class SpecRunner(IntegrationTest):
ordered_command = SON([(operation['command_name'], 1)])
ordered_command.update(arguments['command'])
arguments['command'] = ordered_command
elif name == 'open_download_stream' and arg_name == 'id':
arguments['file_id'] = arguments.pop(arg_name)
elif name == 'with_transaction' and arg_name == 'callback':
callback_ops = arguments[arg_name]['operations']
arguments['callback'] = lambda _: self.run_operations(
@ -261,6 +274,11 @@ class SpecRunner(IntegrationTest):
arguments["pipeline"][-1]["$out"],
read_preference=ReadPreference.PRIMARY)
return out.find()
if name == "map_reduce":
if isinstance(result, dict) and 'results' in result:
return result['results']
if 'download' in name:
result = Binary(result.read())
if isinstance(result, Cursor) or isinstance(result, CommandCursor):
return list(result)
@ -271,7 +289,7 @@ class SpecRunner(IntegrationTest):
in_with_transaction=False):
for op in ops:
expected_result = op.get('result')
if expect_error(expected_result):
if expect_error(op):
with self.assertRaises(PyMongoError,
msg=op['name']) as context:
self.run_operation(sessions, collection, op.copy())
@ -391,13 +409,23 @@ class SpecRunner(IntegrationTest):
database_name = scenario_def['database_name']
write_concern_db = client_context.client.get_database(
database_name, write_concern=WriteConcern(w='majority'))
collection_name = scenario_def['collection_name']
write_concern_coll = write_concern_db[collection_name]
write_concern_coll.drop()
write_concern_db.create_collection(collection_name)
if scenario_def['data']:
# Load data.
write_concern_coll.insert_many(scenario_def['data'])
if 'bucket_name' in scenario_def:
# Create a bucket for the retryable reads GridFS tests.
collection_name = scenario_def['bucket_name']
client_context.client.drop_database(database_name)
if scenario_def['data']:
data = scenario_def['data']
# Load data.
write_concern_db['fs.chunks'].insert_many(data['fs.chunks'])
write_concern_db['fs.files'].insert_many(data['fs.files'])
else:
collection_name = scenario_def['collection_name']
write_concern_coll = write_concern_db[collection_name]
write_concern_coll.drop()
write_concern_db.create_collection(collection_name)
if scenario_def['data']:
# Load data.
write_concern_coll.insert_many(scenario_def['data'])
# SPEC-1245 workaround StaleDbVersion on distinct
for c in self.mongos_clients:
@ -473,6 +501,13 @@ class SpecRunner(IntegrationTest):
self.assertEqual(list(primary_coll.find()), expected_c['data'])
def expect_any_error(op):
if isinstance(op, dict):
return op.get('error')
return False
def expect_error_message(expected_result):
if isinstance(expected_result, dict):
return expected_result['errorContains']
@ -501,8 +536,10 @@ def expect_error_labels_omit(expected_result):
return False
def expect_error(expected_result):
return (expect_error_message(expected_result)
def expect_error(op):
expected_result = op.get('result')
return (expect_any_error(op) or
expect_error_message(expected_result)
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result))