PYTHON-3739 Refactor retryable reads and writes logic to avoid duplication (#1344)
* first draft commit; consolidated _retryable_(read|write) to call _retry_internal * removed extra self usage * formatting * swapped last_error usage * switched to using more objective syntax * black formatter * don't use conn_from_server * changed variable naming is_write -> is_read; consolidated errorhandling; revisited is_retrying * added an explicit if not self._is_read catch * switched self._in_transaction to be self._not_in_transaction * fixed logic on checking if a read/write was in transaction and added commentary * fixed encryption-based error getting retried * separated server selection as the exception raised gets handled differently in each caller * do not mutate 'retryable' within the class instantiation * centralized usage of _retryable_write to avoid _retry_with_session used outwardly * added docstrings to our _retryable_(read|write) operations * refactored docstrings to align with rest of the file * clearer docstrings and function calls
This commit is contained in:
parent
52112a2220
commit
3e1a4ab56e
@ -443,8 +443,7 @@ class _Bulk:
|
||||
)
|
||||
|
||||
client = self.collection.database.client
|
||||
with client._tmp_session(session) as s:
|
||||
client._retry_with_session(self.is_retryable, retryable_bulk, s, self)
|
||||
client._retryable_write(self.is_retryable, retryable_bulk, session, bulk=self)
|
||||
|
||||
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
|
||||
_raise_bulk_write_error(full_result)
|
||||
|
||||
@ -844,7 +844,7 @@ class ClientSession:
|
||||
) -> Dict[str, Any]:
|
||||
return self._finish_transaction(conn, command_name)
|
||||
|
||||
return self._client._retry_internal(True, func, self, None)
|
||||
return self._client._retry_internal(func, self, None, retryable=True)
|
||||
|
||||
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
|
||||
self._transaction.attempt += 1
|
||||
|
||||
@ -137,6 +137,9 @@ if TYPE_CHECKING:
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T]
|
||||
_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T]
|
||||
|
||||
|
||||
class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
@ -1396,176 +1399,125 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
def _retry_with_session(
|
||||
self,
|
||||
retryable: bool,
|
||||
func: Callable[[Optional[ClientSession], Connection, bool], T],
|
||||
func: _WriteCall[T],
|
||||
session: Optional[ClientSession],
|
||||
bulk: Optional[_Bulk],
|
||||
) -> T:
|
||||
"""Execute an operation with at most one consecutive retries
|
||||
|
||||
Returns func()'s return value on success. On error retries the same
|
||||
command once.
|
||||
command.
|
||||
|
||||
Re-raises any exception thrown by func().
|
||||
"""
|
||||
# Ensure that the options supports retry_writes and there is a valid session not in
|
||||
# transaction, otherwise, we will not support retry behavior for this txn.
|
||||
retryable = bool(
|
||||
retryable and self.options.retry_writes and session and not session.in_transaction
|
||||
)
|
||||
return self._retry_internal(retryable, func, session, bulk)
|
||||
return self._retry_internal(
|
||||
func=func,
|
||||
session=session,
|
||||
bulk=bulk,
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
@_csot.apply
|
||||
def _retry_internal(
|
||||
self,
|
||||
retryable: bool,
|
||||
func: Callable[[Optional[ClientSession], Connection, bool], T],
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
session: Optional[ClientSession],
|
||||
bulk: Optional[_Bulk],
|
||||
is_read: bool = False,
|
||||
address: Optional[_Address] = None,
|
||||
read_pref: Optional[_ServerMode] = None,
|
||||
retryable: bool = False,
|
||||
) -> T:
|
||||
"""Internal retryable write helper."""
|
||||
max_wire_version = 0
|
||||
last_error: Optional[Exception] = None
|
||||
retrying = False
|
||||
multiple_retries = _csot.get_timeout() is not None
|
||||
"""Internal retryable helper for all client transactions.
|
||||
|
||||
def is_retrying() -> bool:
|
||||
return bulk.retrying if bulk else retrying
|
||||
:Parameters:
|
||||
- `func`: Callback function we want to retry
|
||||
- `session`: Client Session on which the transaction should occur
|
||||
- `bulk`: Abstraction to handle bulk write operations
|
||||
- `is_read`: If this is an exclusive read transaction, defaults to False
|
||||
- `address`: Server Address, defaults to None
|
||||
- `read_pref`: Topology of read operation, defaults to None
|
||||
- `retryable`: If the operation should be retried once, defaults to None
|
||||
|
||||
# Increment the transaction id up front to ensure any retry attempt
|
||||
# will use the proper txnNumber, even if server or socket selection
|
||||
# fails before the command can be sent.
|
||||
if retryable and session and not session.in_transaction:
|
||||
session._start_retryable_write()
|
||||
if bulk:
|
||||
bulk.started_retryable_write = True
|
||||
:Returns:
|
||||
Output of the calling func()
|
||||
"""
|
||||
return _ClientConnectionRetryable(
|
||||
mongo_client=self,
|
||||
func=func,
|
||||
bulk=bulk,
|
||||
is_read=is_read,
|
||||
session=session,
|
||||
read_pref=read_pref,
|
||||
address=address,
|
||||
retryable=retryable,
|
||||
).run()
|
||||
|
||||
while True:
|
||||
if is_retrying():
|
||||
remaining = _csot.remaining()
|
||||
if remaining is not None and remaining <= 0:
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
try:
|
||||
server = self._select_server(writable_server_selector, session)
|
||||
supports_session = (
|
||||
session is not None and server.description.retryable_writes_supported
|
||||
)
|
||||
with self._checkout(server, session) as conn:
|
||||
max_wire_version = conn.max_wire_version
|
||||
if retryable and not supports_session:
|
||||
if is_retrying():
|
||||
# A retry is not possible because this server does
|
||||
# not support sessions raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
retryable = False
|
||||
return func(session, conn, retryable)
|
||||
except ServerSelectionTimeoutError:
|
||||
if is_retrying():
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
# attempt. Raise the original exception instead.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
# A ServerSelectionTimeoutError error indicates that there may
|
||||
# be a persistent outage. Attempting to retry in this case will
|
||||
# most likely be a waste of time.
|
||||
raise
|
||||
except PyMongoError as exc:
|
||||
if not retryable:
|
||||
raise
|
||||
assert session
|
||||
# Add the RetryableWriteError label, if applicable.
|
||||
_add_retryable_write_error(exc, max_wire_version)
|
||||
retryable_error = exc.has_error_label("RetryableWriteError")
|
||||
if retryable_error:
|
||||
session._unpin()
|
||||
if not retryable_error or (is_retrying() and not multiple_retries):
|
||||
if exc.has_error_label("NoWritesPerformed") and last_error:
|
||||
raise last_error from exc
|
||||
else:
|
||||
raise
|
||||
if bulk:
|
||||
bulk.retrying = True
|
||||
else:
|
||||
retrying = True
|
||||
if not exc.has_error_label("NoWritesPerformed"):
|
||||
last_error = exc
|
||||
if last_error is None:
|
||||
last_error = exc
|
||||
|
||||
@_csot.apply
|
||||
def _retryable_read(
|
||||
self,
|
||||
func: Callable[[Optional[ClientSession], Server, Connection, _ServerMode], T],
|
||||
func: _ReadCall[T],
|
||||
read_pref: _ServerMode,
|
||||
session: Optional[ClientSession],
|
||||
address: Optional[_Address] = None,
|
||||
retryable: bool = True,
|
||||
) -> T:
|
||||
"""Execute an operation with at most one consecutive retries
|
||||
"""Execute an operation with consecutive retries if possible
|
||||
|
||||
Returns func()'s return value on success. On error retries the same
|
||||
command once.
|
||||
command.
|
||||
|
||||
Re-raises any exception thrown by func().
|
||||
|
||||
- `func`: Read call we want to execute
|
||||
- `read_pref`: Desired topology of read operation
|
||||
- `session`: Client session we should use to execute operation
|
||||
- `address`: Optional address when sending a message, defaults to None
|
||||
- `retryable`: if we should attempt retries
|
||||
(may not always be supported even if supplied), defaults to False
|
||||
"""
|
||||
retryable = (
|
||||
|
||||
# Ensure that the client supports retrying on reads and there is no session in
|
||||
# transaction, otherwise, we will not support retry behavior for this call.
|
||||
retryable = bool(
|
||||
retryable and self.options.retry_reads and not (session and session.in_transaction)
|
||||
)
|
||||
last_error: Optional[Exception] = None
|
||||
retrying = False
|
||||
multiple_retries = _csot.get_timeout() is not None
|
||||
|
||||
while True:
|
||||
if retrying:
|
||||
remaining = _csot.remaining()
|
||||
if remaining is not None and remaining <= 0:
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
try:
|
||||
server = self._select_server(read_pref, session, address=address)
|
||||
with self._conn_from_server(read_pref, server, session) as (
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if retrying and not retryable:
|
||||
# A retry is not possible because this server does
|
||||
# not support retryable reads, raise the last error.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
return func(session, server, conn, read_pref)
|
||||
except ServerSelectionTimeoutError:
|
||||
if retrying:
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
# attempt. Raise the original exception instead.
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
# A ServerSelectionTimeoutError error indicates that there may
|
||||
# be a persistent outage. Attempting to retry in this case will
|
||||
# most likely be a waste of time.
|
||||
raise
|
||||
except ConnectionFailure as exc:
|
||||
if not retryable or (retrying and not multiple_retries):
|
||||
raise
|
||||
retrying = True
|
||||
last_error = exc
|
||||
except OperationFailure as exc:
|
||||
if not retryable or (retrying and not multiple_retries):
|
||||
raise
|
||||
if exc.code not in helpers._RETRYABLE_ERROR_CODES:
|
||||
raise
|
||||
retrying = True
|
||||
last_error = exc
|
||||
return self._retry_internal(
|
||||
func,
|
||||
session,
|
||||
None,
|
||||
is_read=True,
|
||||
address=address,
|
||||
read_pref=read_pref,
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
def _retryable_write(
|
||||
self,
|
||||
retryable: bool,
|
||||
func: Callable[[Optional[ClientSession], Connection, bool], T],
|
||||
func: _WriteCall[T],
|
||||
session: Optional[ClientSession],
|
||||
bulk: Optional[_Bulk] = None,
|
||||
) -> T:
|
||||
"""Internal retryable write helper."""
|
||||
"""Execute an operation with consecutive retries if possible
|
||||
|
||||
Returns func()'s return value on success. On error retries the same
|
||||
command.
|
||||
|
||||
Re-raises any exception thrown by func().
|
||||
|
||||
:Parameters:
|
||||
- `retryable`: if we should attempt retries (may not always be supported)
|
||||
- `func`: write call we want to execute during a session
|
||||
- `session`: Client session we will use to execute write operation
|
||||
- `bulk`: bulk abstraction to execute operations in bulk, defaults to None
|
||||
"""
|
||||
with self._tmp_session(session) as s:
|
||||
return self._retry_with_session(retryable, func, s, None)
|
||||
return self._retry_with_session(retryable, func, s, bulk)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
@ -2307,6 +2259,189 @@ class _MongoClientErrorHandler:
|
||||
return self.handle(exc_type, exc_val)
|
||||
|
||||
|
||||
class _ClientConnectionRetryable(Generic[T]):
|
||||
"""Responsible for executing retryable connections on read or write operations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mongo_client: MongoClient,
|
||||
func: _WriteCall[T] | _ReadCall[T],
|
||||
bulk: Optional[_Bulk],
|
||||
is_read: bool = False,
|
||||
session: Optional[ClientSession] = None,
|
||||
read_pref: Optional[_ServerMode] = None,
|
||||
address: Optional[_Address] = None,
|
||||
retryable: bool = False,
|
||||
):
|
||||
self._last_error: Optional[Exception] = None
|
||||
self._retrying = False
|
||||
self._multiple_retries = _csot.get_timeout() is not None
|
||||
self._client = mongo_client
|
||||
|
||||
self._func = func
|
||||
self._bulk = bulk
|
||||
self._session = session
|
||||
self._is_read = is_read
|
||||
self._retryable = retryable
|
||||
self._read_pref = read_pref
|
||||
self._server_selector: Callable[[Selection], Selection] = (
|
||||
read_pref if is_read else writable_server_selector # type: ignore
|
||||
)
|
||||
self._address = address
|
||||
self._server: Server = None # type: ignore
|
||||
|
||||
def run(self) -> T:
|
||||
"""Runs the supplied func() and attempts a retry
|
||||
|
||||
:Raises:
|
||||
self._last_error: Last exception raised
|
||||
|
||||
:Returns:
|
||||
Result of the func() call
|
||||
"""
|
||||
# Increment the transaction id up front to ensure any retry attempt
|
||||
# will use the proper txnNumber, even if server or socket selection
|
||||
# fails before the command can be sent.
|
||||
if self._is_session_state_retryable() and self._retryable and not self._is_read:
|
||||
self._session._start_retryable_write() # type: ignore
|
||||
if self._bulk:
|
||||
self._bulk.started_retryable_write = True
|
||||
|
||||
while True:
|
||||
self._check_last_error(check_csot=True)
|
||||
try:
|
||||
return self._read() if self._is_read else self._write()
|
||||
except ServerSelectionTimeoutError:
|
||||
# The application may think the write was never attempted
|
||||
# if we raise ServerSelectionTimeoutError on the retry
|
||||
# attempt. Raise the original exception instead.
|
||||
self._check_last_error()
|
||||
# A ServerSelectionTimeoutError error indicates that there may
|
||||
# be a persistent outage. Attempting to retry in this case will
|
||||
# most likely be a waste of time.
|
||||
raise
|
||||
except PyMongoError as exc:
|
||||
# Execute specialized catch on read
|
||||
if self._is_read:
|
||||
if isinstance(exc, (ConnectionFailure, OperationFailure)):
|
||||
# ConnectionFailures do not supply a code property
|
||||
exc_code = getattr(exc, "code", None)
|
||||
if self._is_not_eligible_for_retry() or (
|
||||
exc_code and exc_code not in helpers._RETRYABLE_ERROR_CODES
|
||||
):
|
||||
raise
|
||||
self._retrying = True
|
||||
self._last_error = exc
|
||||
else:
|
||||
raise
|
||||
|
||||
# Specialized catch on write operation
|
||||
if not self._is_read:
|
||||
if not self._retryable:
|
||||
raise
|
||||
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
|
||||
if retryable_write_error_exc:
|
||||
assert self._session
|
||||
self._session._unpin()
|
||||
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
|
||||
if exc.has_error_label("NoWritesPerformed") and self._last_error:
|
||||
raise self._last_error from exc
|
||||
else:
|
||||
raise
|
||||
if self._bulk:
|
||||
self._bulk.retrying = True
|
||||
else:
|
||||
self._retrying = True
|
||||
if not exc.has_error_label("NoWritesPerformed"):
|
||||
self._last_error = exc
|
||||
if self._last_error is None:
|
||||
self._last_error = exc
|
||||
|
||||
def _is_not_eligible_for_retry(self) -> bool:
|
||||
"""Checks if the exchange is not eligible for retry"""
|
||||
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
|
||||
|
||||
def _is_retrying(self) -> bool:
|
||||
"""Checks if the exchange is currently undergoing a retry"""
|
||||
return self._bulk.retrying if self._bulk else self._retrying
|
||||
|
||||
def _is_session_state_retryable(self) -> bool:
|
||||
"""Checks if provided session is eligible for retry
|
||||
|
||||
reads: Make sure there is no ongoing transaction (if provided a session)
|
||||
writes: Make sure there is a session without an active transaction
|
||||
"""
|
||||
if self._is_read:
|
||||
return not (self._session and self._session.in_transaction)
|
||||
return bool(self._session and not self._session.in_transaction)
|
||||
|
||||
def _check_last_error(self, check_csot: bool = False) -> None:
|
||||
"""Checks if the ongoing client exchange experienced a exception previously.
|
||||
If so, raise last error
|
||||
|
||||
:Parameters:
|
||||
- `check_csot`: Checks CSOT to ensure we are retrying with time remaining defaults to False
|
||||
"""
|
||||
if self._is_retrying():
|
||||
remaining = _csot.remaining()
|
||||
if not check_csot or (remaining is not None and remaining <= 0):
|
||||
assert self._last_error is not None
|
||||
raise self._last_error
|
||||
|
||||
def _get_server(self) -> Server:
|
||||
"""Retrieves a server object based on provided object context
|
||||
|
||||
:Returns:
|
||||
Abstraction to connect to server
|
||||
"""
|
||||
return self._client._select_server(
|
||||
self._server_selector, self._session, address=self._address
|
||||
)
|
||||
|
||||
def _write(self) -> T:
|
||||
"""Wrapper method for write-type retryable client executions
|
||||
|
||||
:Returns:
|
||||
Output for func()'s call
|
||||
"""
|
||||
try:
|
||||
max_wire_version = 0
|
||||
self._server = self._get_server()
|
||||
supports_session = (
|
||||
self._session is not None and self._server.description.retryable_writes_supported
|
||||
)
|
||||
with self._client._checkout(self._server, self._session) as conn:
|
||||
max_wire_version = conn.max_wire_version
|
||||
if self._retryable and not supports_session:
|
||||
# A retry is not possible because this server does
|
||||
# not support sessions raise the last error.
|
||||
self._check_last_error()
|
||||
self._retryable = False
|
||||
return self._func(self._session, conn, self._retryable) # type: ignore
|
||||
except PyMongoError as exc:
|
||||
if not self._retryable:
|
||||
raise
|
||||
# Add the RetryableWriteError label, if applicable.
|
||||
_add_retryable_write_error(exc, max_wire_version)
|
||||
raise
|
||||
|
||||
def _read(self) -> T:
|
||||
"""Wrapper method for read-type retryable client executions
|
||||
|
||||
:Returns:
|
||||
Output for func()'s call
|
||||
"""
|
||||
self._server = self._get_server()
|
||||
assert self._read_pref is not None, "Read Preference required on read calls"
|
||||
with self._client._conn_from_server(self._read_pref, self._server, self._session) as (
|
||||
conn,
|
||||
read_pref,
|
||||
):
|
||||
if self._retrying and not self._retryable:
|
||||
self._check_last_error()
|
||||
return self._func(self._session, self._server, conn, read_pref) # type: ignore
|
||||
|
||||
|
||||
def _after_fork_child() -> None:
|
||||
"""Releases the locks in child process and resets the
|
||||
topologies in all MongoClients.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user