diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 4a796154d..166b39a23 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -844,7 +844,7 @@ class ClientSession: ) -> Dict[str, Any]: return self._finish_transaction(conn, command_name) - return self._client._retry_internal(True, func, self, None) + return self._client._retry_internal(func, self, None, retryable=True) def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 63fdab04a..1bfd1d4b6 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1405,92 +1405,31 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): Re-raises any exception thrown by func(). """ - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) - return self._retry_internal(retryable, func, session, bulk) + return self._retry_internal(func=func, session=session, bulk=bulk, retryable=retryable) @_csot.apply def _retry_internal( self, - retryable: bool, func: Callable[[Optional[ClientSession], Connection, bool], T], session: Optional[ClientSession], bulk: Optional[_Bulk], + is_read: bool = False, + address: Optional[_Address] = None, + read_pref: Optional[_ServerMode] = None, + retryable: Optional[bool] = None, ) -> T: - """Internal retryable write helper.""" - max_wire_version = 0 - last_error: Optional[Exception] = None - retrying = False - multiple_retries = _csot.get_timeout() is not None + """Internal retryable helper for reads and writes.""" + return _ClientConnectionRetryable( + mongo_client=self, + func=func, + bulk=bulk, + is_read=is_read, + session=session, + read_pref=read_pref, + address=address, + retryable=retryable, + ).run() - def is_retrying() -> bool: - return bulk.retrying if bulk else retrying - - # Increment the transaction id up front to ensure any retry attempt - # will use the proper txnNumber, even if server or socket selection - # fails before the command can be sent. - if retryable and session and not session.in_transaction: - session._start_retryable_write() - if bulk: - bulk.started_retryable_write = True - - while True: - if is_retrying(): - remaining = _csot.remaining() - if remaining is not None and remaining <= 0: - assert last_error is not None - raise last_error - try: - server = self._select_server(writable_server_selector, session) - supports_session = ( - session is not None and server.description.retryable_writes_supported - ) - with self._checkout(server, session) as conn: - max_wire_version = conn.max_wire_version - if retryable and not supports_session: - if is_retrying(): - # A retry is not possible because this server does - # not support sessions raise the last error. - assert last_error is not None - raise last_error - retryable = False - return func(session, conn, retryable) - except ServerSelectionTimeoutError: - if is_retrying(): - # The application may think the write was never attempted - # if we raise ServerSelectionTimeoutError on the retry - # attempt. Raise the original exception instead. - assert last_error is not None - raise last_error - # A ServerSelectionTimeoutError error indicates that there may - # be a persistent outage. Attempting to retry in this case will - # most likely be a waste of time. - raise - except PyMongoError as exc: - if not retryable: - raise - assert session - # Add the RetryableWriteError label, if applicable. - _add_retryable_write_error(exc, max_wire_version) - retryable_error = exc.has_error_label("RetryableWriteError") - if retryable_error: - session._unpin() - if not retryable_error or (is_retrying() and not multiple_retries): - if exc.has_error_label("NoWritesPerformed") and last_error: - raise last_error from exc - else: - raise - if bulk: - bulk.retrying = True - else: - retrying = True - if not exc.has_error_label("NoWritesPerformed"): - last_error = exc - if last_error is None: - last_error = exc - - @_csot.apply def _retryable_read( self, func: Callable[[Optional[ClientSession], Server, Connection, _ServerMode], T], @@ -1506,54 +1445,15 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): Re-raises any exception thrown by func(). """ - retryable = ( - retryable and self.options.retry_reads and not (session and session.in_transaction) + return self._retry_internal( + func, + session, + None, + is_read=True, + address=address, + read_pref=read_pref, + retryable=retryable, ) - last_error: Optional[Exception] = None - retrying = False - multiple_retries = _csot.get_timeout() is not None - - while True: - if retrying: - remaining = _csot.remaining() - if remaining is not None and remaining <= 0: - assert last_error is not None - raise last_error - try: - server = self._select_server(read_pref, session, address=address) - with self._conn_from_server(read_pref, server, session) as ( - conn, - read_pref, - ): - if retrying and not retryable: - # A retry is not possible because this server does - # not support retryable reads, raise the last error. - assert last_error is not None - raise last_error - return func(session, server, conn, read_pref) - except ServerSelectionTimeoutError: - if retrying: - # The application may think the write was never attempted - # if we raise ServerSelectionTimeoutError on the retry - # attempt. Raise the original exception instead. - assert last_error is not None - raise last_error - # A ServerSelectionTimeoutError error indicates that there may - # be a persistent outage. Attempting to retry in this case will - # most likely be a waste of time. - raise - except ConnectionFailure as exc: - if not retryable or (retrying and not multiple_retries): - raise - retrying = True - last_error = exc - except OperationFailure as exc: - if not retryable or (retrying and not multiple_retries): - raise - if exc.code not in helpers._RETRYABLE_ERROR_CODES: - raise - retrying = True - last_error = exc def _retryable_write( self, @@ -2304,6 +2204,166 @@ class _MongoClientErrorHandler: return self.handle(exc_type, exc_val) +class _ClientConnectionRetryable: + """Responsible for executing retryable connections on read or write operations""" + + def __init__( + self, + mongo_client: MongoClient, + func: Callable[[Optional[ClientSession], Connection, bool], T], + bulk: Optional[_Bulk], + is_read: bool = False, + session: Optional[ClientSession] = None, + read_pref: Optional[_ServerMode] = None, + address: Optional[_Address] = None, + retryable: Optional[bool] = None, + ): + self._last_error: Optional[Exception] = None + self._retrying = False + self._multiple_retries = _csot.get_timeout() is not None + self._client = mongo_client + self._retry_operation = ( + mongo_client.options.retry_reads if is_read else mongo_client.options.retry_writes + ) + + self._func = func + self._bulk = bulk + self._session = session + self._is_read = is_read + self._retryable = retryable and self._retry_operation and not self._in_transaction() + self._read_pref = read_pref + self._server_selector = read_pref if is_read else writable_server_selector + self._address = address + self._server: Server = None + + def run(self) -> T: + """Runs the supplied func() and attempts a retry + + :raises self._last_error: Exception raised + :return: Result of the func() call + :rtype: T + """ + # Increment the transaction id up front to ensure any retry attempt + # will use the proper txnNumber, even if server or socket selection + # fails before the command can be sent. + if not self._in_transaction() and self._retryable and not self._is_read: + self._session._start_retryable_write() + if self._bulk: + self._bulk.started_retryable_write = True + + while True: + self._check_last_error(check_csot=True) + try: + self._server = self._client._select_server( + self._server_selector, self._session, address=self._address + ) + return self._read() if self._is_read else self._write() + except ServerSelectionTimeoutError: + # The application may think the write was never attempted + # if we raise ServerSelectionTimeoutError on the retry + # attempt. Raise the original exception instead. + self._check_last_error() + # A ServerSelectionTimeoutError error indicates that there may + # be a persistent outage. Attempting to retry in this case will + # most likely be a waste of time. + raise + except PyMongoError as exc: + # Execute specialized catch on read + if self._is_read: + if isinstance(exc, (ConnectionFailure, OperationFailure)): + exc_code = getattr(exc, "code", None) + if self._is_not_retry_eligible() or ( + exc_code and exc_code not in helpers._RETRYABLE_ERROR_CODES + ): + raise + self._retrying = True + self._last_error = exc + else: + raise + + # Assumed write operation henceforth + if not self._retryable: + raise + assert self._session + if exc.has_error_label("RetryableWriteError"): + self._session._unpin() + elif self._is_retrying() and not self._multiple_retries: + if exc.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if self._bulk: + self._bulk.retrying = True + else: + self._retrying = True + if not exc.has_error_label("NoWritesPerformed"): + self._last_error = exc + if self._last_error is None: + self._last_error = exc + + def _is_not_retry_eligible(self) -> bool: + """Checks if the exchange is not eligible for retry""" + return not self._retryable or (self._retrying and not self._multiple_retries) + + def _is_retrying(self) -> bool: + """Checks if the exchange is currently undergoing a retry""" + return self._bulk.retrying if self._bulk else self._retrying + + def _in_transaction(self): + """Checks if the ongoing session is in a transaction""" + return self._session and self._session.in_transaction + + def _check_last_error(self, check_csot: bool = False): + """Checks if the ongoing client exchange experienced a exception during retry + + :param check_csot: Check csot timeout, defaults to False + """ + if self._is_retrying(): + remaining = _csot.remaining() + if not check_csot or (remaining is not None and remaining <= 0): + assert self._last_error is not None + raise self._last_error + + def _write(self) -> T: + """Wrapper method for write-type retryable client executions + + :return: output for func()'s call + """ + try: + max_wire_version = 0 + supports_session = ( + self._session is not None and self._server.description.retryable_writes_supported + ) + with self._client._checkout(self._server, self._session) as conn: + max_wire_version = conn.max_wire_version + if self._retryable and not supports_session: + # A retry is not possible because this server does + # not support sessions raise the last error. + self._check_last_error() + self._retryable = False + return self._func(self._session, conn, self._retryable) + except PyMongoError as exc: + if not self._retryable: + raise + assert self._session + # Add the RetryableWriteError label, if applicable. + _add_retryable_write_error(exc, max_wire_version) + raise + + def _read(self) -> T: + """Wrapper method for read-type retryable client executions + + :return: output for func()'s call + """ + with self._client._conn_from_server(self._read_pref, self._server, self._session) as ( + conn, + read_pref, + ): + if self._retrying and not self._retryable: + self._check_last_error() + return self._func(self._session, self._server, conn, read_pref) + + def _after_fork_child() -> None: """Releases the locks in child process and resets the topologies in all MongoClients.