diff --git a/doc/changelog.rst b/doc/changelog.rst index e85d7d3c3..12a863113 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -33,6 +33,13 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include: - The ``retryWrites`` URI option now defaults to ``True``. Supported write operations that fail with a retryable error will automatically be retried one time, with at-most-once semantics. +- Support for retryable reads and the ``retryReads`` URI option which is + enabled by default. See the :class:`~pymongo.mongo_client.MongoClient` + documentation for details. + + Now that supported operations are retried automatically and transparently, + users should consider adjusting any custom retry logic to prevent + an application from inadvertently retrying for too long. .. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 8b2b57289..6e77bc311 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -118,8 +118,8 @@ class ChangeStream(object): """ read_preference = self._target._read_preference_for(session) client = self._database.client - with client._socket_for_reads( - read_preference, session) as (sock_info, slave_ok): + + def _cmd(session, server, sock_info, slave_ok): pipeline = self._full_pipeline() cmd = SON([("aggregate", self._aggregation_target), ("pipeline", pipeline), @@ -160,6 +160,8 @@ class ChangeStream(object): max_await_time_ms=self._max_await_time_ms, session=session, explicit_session=explicit_session) + return client._retryable_read(_cmd, read_preference, session) + def _create_cursor(self): with self._database.client._tmp_session(self._session, close=False) as s: return self._run_aggregation_cmd( diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 14541ae7d..040284895 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -164,6 +164,7 @@ class ClientOptions(object): self.__heartbeat_frequency = options.get( 'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY) self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES) + self.__retry_reads = options.get('retryreads', common.RETRY_READS) self.__server_selector = options.get( 'server_selector', any_server_selector) @@ -235,3 +236,8 @@ class ClientOptions(object): def retry_writes(self): """If this instance should retry supported write operations.""" return self.__retry_writes + + @property + def retry_reads(self): + """If this instance should retry supported read operations.""" + return self.__retry_reads diff --git a/pymongo/collection.py b/pymongo/collection.py index f10e66033..2eef6395d 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -188,12 +188,6 @@ class Collection(common.BaseObject): return self.__database.client._socket_for_reads( self._read_preference_for(session), session) - def _socket_for_primary_reads(self, session): - read_pref = ((session and session._txn_read_preference()) - or ReadPreference.PRIMARY) - return self.__database.client._socket_for_reads( - read_pref, session), read_pref - def _socket_for_writes(self, session): return self.__database.client._socket_for_writes(session) @@ -1572,7 +1566,7 @@ class Collection(common.BaseObject): def _count(self, cmd, collation=None, session=None): """Internal count helper.""" - with self._socket_for_reads(session) as (sock_info, slave_ok): + def _cmd(session, server, sock_info, slave_ok): res = self._command( sock_info, cmd, @@ -1582,9 +1576,12 @@ class Collection(common.BaseObject): read_concern=self.read_concern, collation=collation, session=session) - if res.get("errmsg", "") == "ns missing": - return 0 - return int(res["n"]) + if res.get("errmsg", "") == "ns missing": + return 0 + return int(res["n"]) + + return self.__database.client._retryable_read( + _cmd, self._read_preference_for(session), session) def _aggregate_one_result( self, sock_info, slave_ok, cmd, collation=None, session=None): @@ -1693,12 +1690,16 @@ class Collection(common.BaseObject): kwargs["hint"] = helpers._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) - with self._socket_for_reads(session) as (sock_info, slave_ok): + + def _cmd(session, server, sock_info, slave_ok): result = self._aggregate_one_result( sock_info, slave_ok, cmd, collation, session) - if not result: - return 0 - return result['n'] + if not result: + return 0 + return result['n'] + + return self.__database.client._retryable_read( + _cmd, self._read_preference_for(session), session) def count(self, filter=None, session=None, **kwargs): """**DEPRECATED** - Get the number of documents in this collection. @@ -2149,8 +2150,10 @@ class Collection(common.BaseObject): codec_options = CodecOptions(SON) coll = self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY) - sock_ctx, read_pref = self._socket_for_primary_reads(session) - with sock_ctx as (sock_info, slave_ok): + read_pref = ((session and session._txn_read_preference()) + or ReadPreference.PRIMARY) + + def _cmd(session, server, sock_info, slave_ok): cmd = SON([("listIndexes", self.__name), ("cursor", {})]) if sock_info.max_wire_version > 2: with self.__database.client._tmp_session(session, False) as s: @@ -2179,6 +2182,9 @@ class Collection(common.BaseObject): # will never be a getMore call. return CommandCursor(coll, cursor, sock_info.address) + return self.__database.client._retryable_read( + _cmd, read_pref, session) + def index_information(self, session=None): """Get information on this collection's indexes. @@ -2275,10 +2281,11 @@ class Collection(common.BaseObject): "useCursor", kwargs.pop("useCursor")) batch_size = common.validate_non_negative_integer_or_none( "batchSize", kwargs.pop("batchSize", None)) + + dollar_out = pipeline and '$out' in pipeline[-1] # If the server does not support the "cursor" option we # ignore useCursor and batchSize. - with self._socket_for_reads(session) as (sock_info, slave_ok): - dollar_out = pipeline and '$out' in pipeline[-1] + def _cmd(session, server, sock_info, slave_ok): if use_cursor: if "cursor" not in kwargs: kwargs["cursor"] = {} @@ -2336,6 +2343,10 @@ class Collection(common.BaseObject): max_await_time_ms=max_await_time_ms, session=session, explicit_session=explicit_session) + return self.__database.client._retryable_read( + _cmd, self._read_preference_for(session), session, + retryable=not dollar_out) + def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this collection. @@ -2681,12 +2692,53 @@ class Collection(common.BaseObject): kwargs["query"] = filter collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) - with self._socket_for_reads(session) as (sock_info, slave_ok): - return self._command(sock_info, cmd, slave_ok, - read_concern=self.read_concern, - collation=collation, - session=session, - user_fields={"values": 1})["values"] + def _cmd(session, server, sock_info, slave_ok): + return self._command( + sock_info, cmd, slave_ok, read_concern=self.read_concern, + collation=collation, session=session, + user_fields={"values": 1})["values"] + + return self.__database.client._retryable_read( + _cmd, self._read_preference_for(session), session) + + def _map_reduce(self, map, reduce, out, session, read_pref, **kwargs): + """Internal mapReduce helper.""" + cmd = SON([("mapReduce", self.__name), + ("map", map), + ("reduce", reduce), + ("out", out)]) + collation = validate_collation_or_none(kwargs.pop('collation', None)) + cmd.update(kwargs) + + inline = 'inline' in out + + if inline: + user_fields = {'results': 1} + else: + user_fields = None + + read_pref = ((session and session._txn_read_preference()) + or read_pref) + + with self.__database.client._socket_for_reads(read_pref, session) as ( + sock_info, slave_ok): + if (sock_info.max_wire_version >= 4 and + ('readConcern' not in cmd) and + inline): + read_concern = self.read_concern + else: + read_concern = None + if 'writeConcern' not in cmd and not inline: + write_concern = self._write_concern_for(session) + else: + write_concern = None + + return self._command( + sock_info, cmd, slave_ok, read_pref, + read_concern=read_concern, + write_concern=write_concern, + collation=collation, session=session, + user_fields=user_fields) def map_reduce(self, map, reduce, out, full_response=False, session=None, **kwargs): @@ -2747,36 +2799,8 @@ class Collection(common.BaseObject): raise TypeError("'out' must be an instance of " "%s or a mapping" % (string_type.__name__,)) - cmd = SON([("mapreduce", self.__name), - ("map", map), - ("reduce", reduce), - ("out", out)]) - collation = validate_collation_or_none(kwargs.pop('collation', None)) - cmd.update(kwargs) - - inline = 'inline' in cmd['out'] - sock_ctx, read_pref = self._socket_for_primary_reads(session) - with sock_ctx as (sock_info, slave_ok): - if (sock_info.max_wire_version >= 4 and 'readConcern' not in cmd and - inline): - read_concern = self.read_concern - else: - read_concern = None - if 'writeConcern' not in cmd and not inline: - write_concern = self._write_concern_for(session) - else: - write_concern = None - if inline: - user_fields = {'results': 1} - else: - user_fields = None - - response = self._command( - sock_info, cmd, slave_ok, read_pref, - read_concern=read_concern, - write_concern=write_concern, - collation=collation, session=session, - user_fields=user_fields) + response = self._map_reduce(map, reduce, out, session, + ReadPreference.PRIMARY, **kwargs) if full_response or not response.get('result'): return response @@ -2822,23 +2846,8 @@ class Collection(common.BaseObject): Added the `collation` option. """ - cmd = SON([("mapreduce", self.__name), - ("map", map), - ("reduce", reduce), - ("out", {"inline": 1})]) - user_fields = {'results': 1} - collation = validate_collation_or_none(kwargs.pop('collation', None)) - cmd.update(kwargs) - with self._socket_for_reads(session) as (sock_info, slave_ok): - if sock_info.max_wire_version >= 4 and 'readConcern' not in cmd: - res = self._command(sock_info, cmd, slave_ok, - read_concern=self.read_concern, - collation=collation, session=session, - user_fields=user_fields) - else: - res = self._command(sock_info, cmd, slave_ok, - collation=collation, session=session, - user_fields=user_fields) + res = self._map_reduce(map, reduce, {"inline": 1}, session, + self.read_preference, **kwargs) if full_response: return res diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index be24782f0..40b1ad8ae 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -128,9 +128,8 @@ class CommandCursor(object): client = self.__collection.database.client try: - response = client._send_message_with_response( - operation, address=self.__address, - unpack_res=self._unpack_response) + response = client._run_operation_with_response( + operation, self._unpack_response, address=self.__address) except OperationFailure: kill() raise diff --git a/pymongo/common.py b/pymongo/common.py index d40658ae0..3216b99ca 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -91,6 +91,9 @@ LOCAL_THRESHOLD_MS = 15 # Default value for retryWrites. RETRY_WRITES = True +# Default value for retryReads. +RETRY_READS = True + # mongod/s 2.6 and above return code 59 when a command doesn't exist. COMMAND_NOT_FOUND_CODES = (59,) @@ -569,6 +572,7 @@ URI_OPTIONS_VALIDATOR_MAP = { 'readpreference': validate_read_preference_mode, 'readpreferencetags': validate_read_preference_tags, 'replicaset': validate_string_or_none, + 'retryreads': validate_boolean_or_string, 'retrywrites': validate_boolean_or_string, 'serverselectiontimeoutms': validate_timeout_or_zero, 'sockettimeoutms': validate_timeout_or_none, diff --git a/pymongo/cursor.py b/pymongo/cursor.py index adcfe099a..cc3df06bc 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -937,10 +937,11 @@ class Cursor(object): Can raise ConnectionFailure. """ client = self.__collection.database.client + try: - response = client._send_message_with_response( - operation, exhaust=self.__exhaust, address=self.__address, - unpack_res=self._unpack_response) + response = client._run_operation_with_response( + operation, self._unpack_response, exhaust=self.__exhaust, + address=self.__address) except OperationFailure: self.__killed = True diff --git a/pymongo/database.py b/pymongo/database.py index 1cf0794e2..c7f6308f8 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -657,6 +657,22 @@ class Database(common.BaseObject): check, allowable_errors, read_preference, codec_options, session=session, **kwargs) + def _retryable_read_command(self, command, value=1, check=True, + allowable_errors=None, read_preference=None, + codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): + """Same as command but used for retryable read commands.""" + if read_preference is None: + read_preference = ((session and session._txn_read_preference()) + or ReadPreference.PRIMARY) + + def _cmd(session, server, sock_info, slave_ok): + return self._command(sock_info, command, slave_ok, value, + check, allowable_errors, read_preference, + codec_options, session=session, **kwargs) + + return self.__client._retryable_read( + _cmd, read_preference, session) + def _list_collections(self, sock_info, slave_okay, session, read_preference, **kwargs): """Internal listCollections helper.""" @@ -718,12 +734,15 @@ class Database(common.BaseObject): kwargs['filter'] = filter read_pref = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) - with self.__client._socket_for_reads( - read_pref, session) as (sock_info, slave_okay): + + def _cmd(session, server, sock_info, slave_okay): return self._list_collections( sock_info, slave_okay, session, read_preference=read_pref, **kwargs) + return self.__client._retryable_read( + _cmd, read_pref, session) + def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index e98f62a84..6d94817cb 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -276,6 +276,36 @@ class MongoClient(common.BaseObject): pipeline operator and any operation with an unacknowledged write concern (e.g. {w: 0})). See https://github.com/mongodb/specifications/blob/master/source/retryable-writes/retryable-writes.rst + - `retryReads`: (boolean) Whether supported read operations + executed within this MongoClient will be retried once after a + network error on MongoDB 3.6+. Defaults to ``True``. + The supported read operations are: + :meth:`~pymongo.collection.Collection.find`, + :meth:`~pymongo.collection.Collection.find_one`, + :meth:`~pymongo.collection.Collection.aggregate` without ``$out``, + :meth:`~pymongo.collection.Collection.distinct`, + :meth:`~pymongo.collection.Collection.count`, + :meth:`~pymongo.collection.Collection.estimated_document_count`, + :meth:`~pymongo.collection.Collection.count_documents`, + :meth:`pymongo.collection.Collection.watch`, + :meth:`~pymongo.collection.Collection.list_indexes`, + :meth:`pymongo.database.Database.watch`, + :meth:`~pymongo.database.Database.list_collections`, + :meth:`pymongo.mongo_client.MongoClient.watch`, + and :meth:`~pymongo.mongo_client.MongoClient.list_databases`. + + Unsupported read operations include, but are not limited to: + :meth:`~pymongo.collection.Collection.map_reduce`, + :meth:`~pymongo.collection.Collection.inline_map_reduce`, + :meth:`~pymongo.database.Database.command`, + and any getMore operation on a cursor. + + Enabling retryable reads makes applications more resilient to + transient errors such as network failures, database upgrades, and + replica set failovers. For an exact definition of which errors + trigger a retry, see the `retryable reads specification + `_. + - `socketKeepAlive`: (boolean) **DEPRECATED** Whether to send periodic keep-alive packets on connected sockets. Defaults to ``True``. Disabling it is not recommended, see @@ -441,7 +471,8 @@ class MongoClient(common.BaseObject): .. mongodoc:: connections - .. versionchanged:: 4.0 + .. versionchanged:: 3.9 + Added the ``retryReads`` keyword argument and URI option. Added the ``tlsInsecure`` keyword argument and URI option. The following keyword arguments and URI options were deprecated: @@ -1032,6 +1063,11 @@ class MongoClient(common.BaseObject): """If this instance should retry supported write operations.""" return self.__options.retry_writes + @property + def retry_reads(self): + """If this instance should retry supported write operations.""" + return self.__options.retry_reads + def _is_writable(self): """Attempt to connect to a writable server, or return False. """ @@ -1173,6 +1209,24 @@ class MongoClient(common.BaseObject): server = self._select_server(writable_server_selector, session) return self._get_socket(server, session) + @contextlib.contextmanager + def _slaveok_for_server(self, read_preference, server, session, + exhaust=False): + assert read_preference is not None, "read_preference must not be None" + # Get a socket for a server matching the read preference, and yield + # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to + # mongods with topology type Single. If the server type is Mongos, + # follow the rules for passing read preference to mongos, even for + # topology type Single." + # Thread safe: if the type is single it cannot change. + topology = self._get_topology() + single = topology.description.topology_type == TOPOLOGY_TYPE.Single + + with self._get_socket(server, session, exhaust=exhaust) as sock_info: + slave_ok = (single and not sock_info.is_mongos) or ( + read_preference != ReadPreference.PRIMARY) + yield sock_info, slave_ok + @contextlib.contextmanager def _socket_for_reads(self, read_preference, session): assert read_preference is not None, "read_preference must not be None" @@ -1191,25 +1245,25 @@ class MongoClient(common.BaseObject): read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok - def _send_message_with_response(self, operation, exhaust=False, - address=None, unpack_res=None): - """Send a message to MongoDB and return a Response. + def _run_operation_with_response(self, operation, unpack_res, + exhaust=False, address=None): + """Run a _Query/_GetMore operation and return a Response. :Parameters: - `operation`: a _Query or _GetMore object. - - `read_preference` (optional): A ReadPreference. + - `unpack_res`: A callable that decodes the wire protocol response. - `exhaust` (optional): If True, the socket used stays checked out. It is returned along with its Pool in the Response. - `address` (optional): Optional address when sending a message to a specific server, used for getMore. """ - server = self._select_server( - operation.read_preference, operation.session, address=address) - if operation.exhaust_mgr: + server = self._select_server( + operation.read_preference, operation.session, address=address) + with self._reset_on_error(server.description.address, operation.session): - return server.send_message_with_response( + return server.run_operation_with_response( operation.exhaust_mgr.sock, operation, True, @@ -1217,24 +1271,21 @@ class MongoClient(common.BaseObject): exhaust, unpack_res) - # If this is a direct connection to a mongod, *always* set the slaveOk - # bit. See bullet point 2 in server-selection.rst#topology-type-single. - topology = self._get_topology() - set_slave_ok = ( - topology.description.topology_type == TOPOLOGY_TYPE.Single - and server.description.server_type != SERVER_TYPE.Mongos) or ( - operation.read_preference != ReadPreference.PRIMARY) - - with self._get_socket(server, operation.session, - exhaust=exhaust) as sock_info: - return server.send_message_with_response( + def _cmd(session, server, sock_info, slave_ok): + return server.run_operation_with_response( sock_info, operation, - set_slave_ok, + slave_ok, self._event_listeners, exhaust, unpack_res) + return self._retryable_read( + _cmd, operation.read_preference, operation.session, + address=address, + retryable=isinstance(operation, message._Query), + exhaust=exhaust) + @contextlib.contextmanager def _reset_on_error(self, server_address, session): """On "not master" or "node is recovering" errors reset the server @@ -1354,6 +1405,58 @@ class MongoClient(common.BaseObject): retrying = True last_error = exc + def _retryable_read(self, func, read_pref, session, address=None, + retryable=True, exhaust=False): + """Execute an operation with at most one consecutive retries + + Returns func()'s return value on success. On error retries the same + command once. + + Re-raises any exception thrown by func(). + """ + retryable = (retryable and + self.retry_reads + and not (session and session._in_transaction)) + last_error = None + retrying = False + + while True: + try: + server = self._select_server( + read_pref, session, address=address) + if not server.description.retryable_reads_supported: + retryable = False + with self._slaveok_for_server(read_pref, server, session, + exhaust=exhaust) as (sock_info, + slave_ok): + if retrying and not retryable: + # A retry is not possible because this server does + # not support retryable reads, raise the last error. + raise last_error + return func(session, server, sock_info, slave_ok) + except ServerSelectionTimeoutError: + if retrying: + # The application may think the write was never attempted + # if we raise ServerSelectionTimeoutError on the retry + # attempt. Raise the original exception instead. + raise last_error + # A ServerSelectionTimeoutError error indicates that there may + # be a persistent outage. Attempting to retry in this case will + # most likely be a waste of time. + raise + except ConnectionFailure as exc: + if not retryable or retrying: + raise + retrying = True + last_error = exc + except OperationFailure as exc: + if not retryable or retrying: + raise + if exc.code not in helpers._RETRYABLE_ERROR_CODES: + raise + retrying = True + last_error = exc + def _retryable_write(self, retryable, func, session): """Internal retryable write helper.""" with self._tmp_session(session) as s: @@ -1752,7 +1855,7 @@ class MongoClient(common.BaseObject): cmd = SON([("listDatabases", 1)]) cmd.update(kwargs) admin = self._database_default_options("admin") - res = admin.command(cmd, session=session) + res = admin._retryable_read_command(cmd, session=session) # listDatabases doesn't return a cursor (yet). Fake one. cursor = { "id": 0, diff --git a/pymongo/server.py b/pymongo/server.py index 24e49593c..65d8f57c4 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -14,8 +14,6 @@ """Communicate with one MongoDB server in a topology.""" -import contextlib - from datetime import datetime from pymongo.errors import NotMasterError, OperationFailure @@ -67,7 +65,7 @@ class Server(object): """Check the server's state soon.""" self._monitor.request_check() - def send_message_with_response( + def run_operation_with_response( self, sock_info, operation, @@ -75,17 +73,19 @@ class Server(object): listeners, exhaust, unpack_res): - """Send a message to MongoDB and return a Response object. + """Run a _Query or _GetMore operation and return a Response object. - Can raise ConnectionFailure. + This method is used only to run _Query/_GetMore operations from + cursors. + Can raise ConnectionFailure, OperationFailure, etc. :Parameters: - `operation`: A _Query or _GetMore object. - `set_slave_okay`: Pass to operation.get_message. - `all_credentials`: dict, maps auth source to MongoCredential. - `listeners`: Instance of _EventListeners or None. - - `exhaust` (optional): If True, the socket used stays checked out. - It is returned along with its Pool in the Response. + - `exhaust`: If True, then this is an exhaust cursor operation. + - `unpack_res`: A callable that decodes the wire protocol response. """ duration = None publish = listeners.enabled_for_commands diff --git a/pymongo/server_description.py b/pymongo/server_description.py index d978b4799..04e9dbfe7 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -202,5 +202,10 @@ class ServerDescription(object): self._ls_timeout_minutes is not None and self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)) + @property + def retryable_reads_supported(self): + """Checks if this server supports retryable writes.""" + return self._max_wire_version >= 6 + # For unittesting only. Use under no circumstances! _host_to_round_trip_time = {} diff --git a/test/test_client.py b/test/test_client.py index 2e033081f..aeea75942 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -37,7 +37,7 @@ from pymongo import auth, message from pymongo.common import _UUID_REPRESENTATIONS from pymongo.command_cursor import CommandCursor from pymongo.compression_support import _HAVE_SNAPPY -from pymongo.cursor import CursorType +from pymongo.cursor import Cursor, CursorType from pymongo.database import Database from pymongo.errors import (AutoReconnect, ConfigurationError, @@ -1161,7 +1161,7 @@ class TestClient(IntegrationTest): def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1188,7 +1188,8 @@ class TestClient(IntegrationTest): # Get a client with one socket so we detect if it's leaked. c = connected(rs_or_single_client(maxPoolSize=1, - waitQueueTimeoutMS=1)) + waitQueueTimeoutMS=1, + retryReads=False)) # Simulate an authenticate() call on a different socket. credentials = auth._build_credentials_tuple( @@ -1220,15 +1221,17 @@ class TestClient(IntegrationTest): def test_stale_getmore(self): # A cursor is created, but its member goes down and is removed from # the topology before the getMore message is sent. Test that - # MongoClient._send_message_with_response handles the error. + # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): client = rs_client(connect=False, serverSelectionTimeoutMS=100) - client._send_message_with_response( + client._run_operation_with_response( operation=message._GetMore('pymongo_test', 'collection', 101, 1234, client.codec_options, ReadPreference.PRIMARY, None, client, None, None), + unpack_res=Cursor( + client.pymongo_test.collection)._unpack_response, address=('not-a-member', 27017)) def test_heartbeat_frequency_ms(self): @@ -1419,7 +1422,8 @@ class TestExhaustCursor(IntegrationTest): def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1)) + client = connected(rs_or_single_client(maxPoolSize=1, + retryReads=False)) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1576,7 +1580,8 @@ class TestMongoClientFailover(MockClientTest): members=['a:1', 'b:2', 'c:3'], mongoses=[], host='b:2', # Pass a secondary. - replicaSet='rs') + replicaSet='rs', + retryReads=False) wait_until(lambda: len(c.nodes) == 3, 'connect') @@ -1604,7 +1609,8 @@ class TestMongoClientFailover(MockClientTest): mongoses=[], host='a:1', replicaSet='rs', - connect=False) + connect=False, + retryReads=False) # Set host-specific information so we can test whether it is reset. c.set_wire_version_range('a:1', 2, 6) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index a0e30e9f0..c188992ce 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -325,6 +325,15 @@ class ReadPrefTester(MongoClient): self.record_a_read(sock_info.address) yield sock_info, slave_ok + @contextlib.contextmanager + def _slaveok_for_server(self, read_preference, server, session, + exhaust=False): + context = super(ReadPrefTester, self)._slaveok_for_server( + read_preference, server, session, exhaust=exhaust) + with context as (sock_info, slave_ok): + self.record_a_read(sock_info.address) + yield sock_info, slave_ok + def record_a_read(self, address): server = self._get_topology().select_server_by_address(address, 0) self.has_read_from.add(server) diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py new file mode 100644 index 000000000..90346f763 --- /dev/null +++ b/test/test_retryable_reads.py @@ -0,0 +1,83 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test retryable reads spec.""" + +import os +import sys + +sys.path[0:0] = [""] + +from pymongo.mongo_client import MongoClient + +from test import unittest, client_context, PyMongoTestCase +from test.utils import TestCreator +from test.utils_spec_runner import SpecRunner + + +# Location of JSON test specifications. +_TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'retryable_reads') + + +class TestClientOptions(PyMongoTestCase): + def test_default(self): + client = MongoClient(connect=False) + self.assertEqual(client.retry_reads, True) + + def test_kwargs(self): + client = MongoClient(retryReads=True, connect=False) + self.assertEqual(client.retry_reads, True) + client = MongoClient(retryReads=False, connect=False) + self.assertEqual(client.retry_reads, False) + + def test_uri(self): + client = MongoClient('mongodb://h/?retryReads=true', connect=False) + self.assertEqual(client.retry_reads, True) + client = MongoClient('mongodb://h/?retryReads=false', connect=False) + self.assertEqual(client.retry_reads, False) + + +class TestSpec(SpecRunner): + + @classmethod + @client_context.require_version_min(4, 0) + def setUpClass(cls): + super(TestSpec, cls).setUpClass() + if client_context.is_mongos and client_context.version[:2] <= (4, 0): + raise unittest.SkipTest("4.0 mongos does not support failCommand") + + def maybe_skip_scenario(self, test): + super(TestSpec, self).maybe_skip_scenario(test) + skip_names = [ + 'listCollectionObjects', 'listIndexNames', 'listDatabaseObjects'] + for name in skip_names: + if name.lower() in test['description'].lower(): + raise unittest.SkipTest( + 'PyMongo does not support %s' % (name,)) + + +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = TestCreator(create_test, TestSpec, _TEST_PATH) +test_creator.create_tests() + +if __name__ == "__main__": + unittest.main() diff --git a/test/utils.py b/test/utils.py index 364cab7ad..a197364fc 100644 --- a/test/utils.py +++ b/test/utils.py @@ -141,7 +141,7 @@ class ScenarioDict(dict): def convert(v): if isinstance(v, collections.Mapping): return ScenarioDict(v) - if isinstance(v, py3compat.string_type): + if isinstance(v, (py3compat.string_type, bytes)): return v if isinstance(v, collections.Sequence): return [convert(item) for item in v] @@ -264,8 +264,10 @@ class TestCreator(object): # Construct test from scenario. for test_def in scenario_def['tests']: test_name = 'test_%s_%s_%s' % ( - dirname, test_type, - str(test_def['description'].replace(" ", "_"))) + dirname, + test_type.replace("-", "_").replace('.', '_'), + str(test_def['description'].replace(" ", "_").replace( + '.', '_'))) new_test = self._create_test( scenario_def, test_def, test_name) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 78cea5451..368b87e5f 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -175,6 +175,10 @@ class SpecRunner(IntegrationTest): name = camel_to_snake(operation['name']) if name == 'run_command': name = 'command' + elif name == 'download_by_name': + name = 'open_download_stream_by_name' + elif name == 'download': + name = 'open_download_stream' def parse_options(opts): if 'readPreference' in opts: @@ -197,14 +201,21 @@ class SpecRunner(IntegrationTest): **dict(parse_options(operation['collectionOptions']))) object_name = operation['object'] - objects = { - 'client': database.client, - 'database': database, - 'collection': collection, - 'testRunner': self - } - objects.update(sessions) - obj = objects[object_name] + if object_name == 'gridfsbucket': + # Only create the GridFSBucket when we need it (for the gridfs + # retryable reads tests). + obj = GridFSBucket( + database, bucket_name=collection.name, + disable_md5=True) + else: + objects = { + 'client': database.client, + 'database': database, + 'collection': collection, + 'testRunner': self + } + objects.update(sessions) + obj = objects[object_name] # Combine arguments with options and handle special cases. arguments = operation.get('arguments', {}) @@ -244,6 +255,8 @@ class SpecRunner(IntegrationTest): ordered_command = SON([(operation['command_name'], 1)]) ordered_command.update(arguments['command']) arguments['command'] = ordered_command + elif name == 'open_download_stream' and arg_name == 'id': + arguments['file_id'] = arguments.pop(arg_name) elif name == 'with_transaction' and arg_name == 'callback': callback_ops = arguments[arg_name]['operations'] arguments['callback'] = lambda _: self.run_operations( @@ -261,6 +274,11 @@ class SpecRunner(IntegrationTest): arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY) return out.find() + if name == "map_reduce": + if isinstance(result, dict) and 'results' in result: + return result['results'] + if 'download' in name: + result = Binary(result.read()) if isinstance(result, Cursor) or isinstance(result, CommandCursor): return list(result) @@ -271,7 +289,7 @@ class SpecRunner(IntegrationTest): in_with_transaction=False): for op in ops: expected_result = op.get('result') - if expect_error(expected_result): + if expect_error(op): with self.assertRaises(PyMongoError, msg=op['name']) as context: self.run_operation(sessions, collection, op.copy()) @@ -391,13 +409,23 @@ class SpecRunner(IntegrationTest): database_name = scenario_def['database_name'] write_concern_db = client_context.client.get_database( database_name, write_concern=WriteConcern(w='majority')) - collection_name = scenario_def['collection_name'] - write_concern_coll = write_concern_db[collection_name] - write_concern_coll.drop() - write_concern_db.create_collection(collection_name) - if scenario_def['data']: - # Load data. - write_concern_coll.insert_many(scenario_def['data']) + if 'bucket_name' in scenario_def: + # Create a bucket for the retryable reads GridFS tests. + collection_name = scenario_def['bucket_name'] + client_context.client.drop_database(database_name) + if scenario_def['data']: + data = scenario_def['data'] + # Load data. + write_concern_db['fs.chunks'].insert_many(data['fs.chunks']) + write_concern_db['fs.files'].insert_many(data['fs.files']) + else: + collection_name = scenario_def['collection_name'] + write_concern_coll = write_concern_db[collection_name] + write_concern_coll.drop() + write_concern_db.create_collection(collection_name) + if scenario_def['data']: + # Load data. + write_concern_coll.insert_many(scenario_def['data']) # SPEC-1245 workaround StaleDbVersion on distinct for c in self.mongos_clients: @@ -473,6 +501,13 @@ class SpecRunner(IntegrationTest): self.assertEqual(list(primary_coll.find()), expected_c['data']) +def expect_any_error(op): + if isinstance(op, dict): + return op.get('error') + + return False + + def expect_error_message(expected_result): if isinstance(expected_result, dict): return expected_result['errorContains'] @@ -501,8 +536,10 @@ def expect_error_labels_omit(expected_result): return False -def expect_error(expected_result): - return (expect_error_message(expected_result) +def expect_error(op): + expected_result = op.get('result') + return (expect_any_error(op) or + expect_error_message(expected_result) or expect_error_code(expected_result) or expect_error_labels_contain(expected_result) or expect_error_labels_omit(expected_result))