PYTHON-3739 Refactor retryable reads and writes logic to avoid duplication (#1344)

* first draft commit; consolidated _retryable_(read|write) to call _retry_internal

* removed extra self usage

* formatting

* swapped last_error usage

* switched to using more objective syntax

* black formatter

* don't use conn_from_server

* changed variable naming is_write -> is_read; consolidated errorhandling; revisited is_retrying

* added an explicit if not self._is_read catch

* switched self._in_transaction to be self._not_in_transaction

* fixed logic on checking if a read/write was in transaction and added commentary

* fixed encryption-based error getting retried

* separated server selection as the exception raised gets handled differently in each caller

* do not mutate 'retryable' within the class instantiation

* centralized usage of _retryable_write to avoid _retry_with_session used outwardly

* added docstrings to our _retryable_(read|write) operations

* refactored docstrings to align with rest of the file

* clearer docstrings and function calls
This commit is contained in:
Jib 2023-08-29 10:13:38 -04:00 committed by GitHub
parent 52112a2220
commit 3e1a4ab56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 264 additions and 130 deletions

View File

@ -443,8 +443,7 @@ class _Bulk:
)
client = self.collection.database.client
with client._tmp_session(session) as s:
client._retry_with_session(self.is_retryable, retryable_bulk, s, self)
client._retryable_write(self.is_retryable, retryable_bulk, session, bulk=self)
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
_raise_bulk_write_error(full_result)

View File

@ -844,7 +844,7 @@ class ClientSession:
) -> Dict[str, Any]:
return self._finish_transaction(conn, command_name)
return self._client._retry_internal(True, func, self, None)
return self._client._retry_internal(func, self, None, retryable=True)
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
self._transaction.attempt += 1

View File

@ -137,6 +137,9 @@ if TYPE_CHECKING:
T = TypeVar("T")
_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T]
_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T]
class MongoClient(common.BaseObject, Generic[_DocumentType]):
"""
@ -1396,176 +1399,125 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
def _retry_with_session(
self,
retryable: bool,
func: Callable[[Optional[ClientSession], Connection, bool], T],
func: _WriteCall[T],
session: Optional[ClientSession],
bulk: Optional[_Bulk],
) -> T:
"""Execute an operation with at most one consecutive retries
Returns func()'s return value on success. On error retries the same
command once.
command.
Re-raises any exception thrown by func().
"""
# Ensure that the options supports retry_writes and there is a valid session not in
# transaction, otherwise, we will not support retry behavior for this txn.
retryable = bool(
retryable and self.options.retry_writes and session and not session.in_transaction
)
return self._retry_internal(retryable, func, session, bulk)
return self._retry_internal(
func=func,
session=session,
bulk=bulk,
retryable=retryable,
)
@_csot.apply
def _retry_internal(
self,
retryable: bool,
func: Callable[[Optional[ClientSession], Connection, bool], T],
func: _WriteCall[T] | _ReadCall[T],
session: Optional[ClientSession],
bulk: Optional[_Bulk],
is_read: bool = False,
address: Optional[_Address] = None,
read_pref: Optional[_ServerMode] = None,
retryable: bool = False,
) -> T:
"""Internal retryable write helper."""
max_wire_version = 0
last_error: Optional[Exception] = None
retrying = False
multiple_retries = _csot.get_timeout() is not None
"""Internal retryable helper for all client transactions.
def is_retrying() -> bool:
return bulk.retrying if bulk else retrying
:Parameters:
- `func`: Callback function we want to retry
- `session`: Client Session on which the transaction should occur
- `bulk`: Abstraction to handle bulk write operations
- `is_read`: If this is an exclusive read transaction, defaults to False
- `address`: Server Address, defaults to None
- `read_pref`: Topology of read operation, defaults to None
- `retryable`: If the operation should be retried once, defaults to None
# Increment the transaction id up front to ensure any retry attempt
# will use the proper txnNumber, even if server or socket selection
# fails before the command can be sent.
if retryable and session and not session.in_transaction:
session._start_retryable_write()
if bulk:
bulk.started_retryable_write = True
:Returns:
Output of the calling func()
"""
return _ClientConnectionRetryable(
mongo_client=self,
func=func,
bulk=bulk,
is_read=is_read,
session=session,
read_pref=read_pref,
address=address,
retryable=retryable,
).run()
while True:
if is_retrying():
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
assert last_error is not None
raise last_error
try:
server = self._select_server(writable_server_selector, session)
supports_session = (
session is not None and server.description.retryable_writes_supported
)
with self._checkout(server, session) as conn:
max_wire_version = conn.max_wire_version
if retryable and not supports_session:
if is_retrying():
# A retry is not possible because this server does
# not support sessions raise the last error.
assert last_error is not None
raise last_error
retryable = False
return func(session, conn, retryable)
except ServerSelectionTimeoutError:
if is_retrying():
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except PyMongoError as exc:
if not retryable:
raise
assert session
# Add the RetryableWriteError label, if applicable.
_add_retryable_write_error(exc, max_wire_version)
retryable_error = exc.has_error_label("RetryableWriteError")
if retryable_error:
session._unpin()
if not retryable_error or (is_retrying() and not multiple_retries):
if exc.has_error_label("NoWritesPerformed") and last_error:
raise last_error from exc
else:
raise
if bulk:
bulk.retrying = True
else:
retrying = True
if not exc.has_error_label("NoWritesPerformed"):
last_error = exc
if last_error is None:
last_error = exc
@_csot.apply
def _retryable_read(
self,
func: Callable[[Optional[ClientSession], Server, Connection, _ServerMode], T],
func: _ReadCall[T],
read_pref: _ServerMode,
session: Optional[ClientSession],
address: Optional[_Address] = None,
retryable: bool = True,
) -> T:
"""Execute an operation with at most one consecutive retries
"""Execute an operation with consecutive retries if possible
Returns func()'s return value on success. On error retries the same
command once.
command.
Re-raises any exception thrown by func().
- `func`: Read call we want to execute
- `read_pref`: Desired topology of read operation
- `session`: Client session we should use to execute operation
- `address`: Optional address when sending a message, defaults to None
- `retryable`: if we should attempt retries
(may not always be supported even if supplied), defaults to False
"""
retryable = (
# Ensure that the client supports retrying on reads and there is no session in
# transaction, otherwise, we will not support retry behavior for this call.
retryable = bool(
retryable and self.options.retry_reads and not (session and session.in_transaction)
)
last_error: Optional[Exception] = None
retrying = False
multiple_retries = _csot.get_timeout() is not None
while True:
if retrying:
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
assert last_error is not None
raise last_error
try:
server = self._select_server(read_pref, session, address=address)
with self._conn_from_server(read_pref, server, session) as (
conn,
read_pref,
):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
assert last_error is not None
raise last_error
return func(session, server, conn, read_pref)
except ServerSelectionTimeoutError:
if retrying:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except ConnectionFailure as exc:
if not retryable or (retrying and not multiple_retries):
raise
retrying = True
last_error = exc
except OperationFailure as exc:
if not retryable or (retrying and not multiple_retries):
raise
if exc.code not in helpers._RETRYABLE_ERROR_CODES:
raise
retrying = True
last_error = exc
return self._retry_internal(
func,
session,
None,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
)
def _retryable_write(
self,
retryable: bool,
func: Callable[[Optional[ClientSession], Connection, bool], T],
func: _WriteCall[T],
session: Optional[ClientSession],
bulk: Optional[_Bulk] = None,
) -> T:
"""Internal retryable write helper."""
"""Execute an operation with consecutive retries if possible
Returns func()'s return value on success. On error retries the same
command.
Re-raises any exception thrown by func().
:Parameters:
- `retryable`: if we should attempt retries (may not always be supported)
- `func`: write call we want to execute during a session
- `session`: Client session we will use to execute write operation
- `bulk`: bulk abstraction to execute operations in bulk, defaults to None
"""
with self._tmp_session(session) as s:
return self._retry_with_session(retryable, func, s, None)
return self._retry_with_session(retryable, func, s, bulk)
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
@ -2307,6 +2259,189 @@ class _MongoClientErrorHandler:
return self.handle(exc_type, exc_val)
class _ClientConnectionRetryable(Generic[T]):
"""Responsible for executing retryable connections on read or write operations"""
def __init__(
self,
mongo_client: MongoClient,
func: _WriteCall[T] | _ReadCall[T],
bulk: Optional[_Bulk],
is_read: bool = False,
session: Optional[ClientSession] = None,
read_pref: Optional[_ServerMode] = None,
address: Optional[_Address] = None,
retryable: bool = False,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client
self._func = func
self._bulk = bulk
self._session = session
self._is_read = is_read
self._retryable = retryable
self._read_pref = read_pref
self._server_selector: Callable[[Selection], Selection] = (
read_pref if is_read else writable_server_selector # type: ignore
)
self._address = address
self._server: Server = None # type: ignore
def run(self) -> T:
"""Runs the supplied func() and attempts a retry
:Raises:
self._last_error: Last exception raised
:Returns:
Result of the func() call
"""
# Increment the transaction id up front to ensure any retry attempt
# will use the proper txnNumber, even if server or socket selection
# fails before the command can be sent.
if self._is_session_state_retryable() and self._retryable and not self._is_read:
self._session._start_retryable_write() # type: ignore
if self._bulk:
self._bulk.started_retryable_write = True
while True:
self._check_last_error(check_csot=True)
try:
return self._read() if self._is_read else self._write()
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
self._check_last_error()
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except PyMongoError as exc:
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
exc_code and exc_code not in helpers._RETRYABLE_ERROR_CODES
):
raise
self._retrying = True
self._last_error = exc
else:
raise
# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
raise
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
assert self._session
self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if self._bulk:
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
def _is_retrying(self) -> bool:
"""Checks if the exchange is currently undergoing a retry"""
return self._bulk.retrying if self._bulk else self._retrying
def _is_session_state_retryable(self) -> bool:
"""Checks if provided session is eligible for retry
reads: Make sure there is no ongoing transaction (if provided a session)
writes: Make sure there is a session without an active transaction
"""
if self._is_read:
return not (self._session and self._session.in_transaction)
return bool(self._session and not self._session.in_transaction)
def _check_last_error(self, check_csot: bool = False) -> None:
"""Checks if the ongoing client exchange experienced a exception previously.
If so, raise last error
:Parameters:
- `check_csot`: Checks CSOT to ensure we are retrying with time remaining defaults to False
"""
if self._is_retrying():
remaining = _csot.remaining()
if not check_csot or (remaining is not None and remaining <= 0):
assert self._last_error is not None
raise self._last_error
def _get_server(self) -> Server:
"""Retrieves a server object based on provided object context
:Returns:
Abstraction to connect to server
"""
return self._client._select_server(
self._server_selector, self._session, address=self._address
)
def _write(self) -> T:
"""Wrapper method for write-type retryable client executions
:Returns:
Output for func()'s call
"""
try:
max_wire_version = 0
self._server = self._get_server()
supports_session = (
self._session is not None and self._server.description.retryable_writes_supported
)
with self._client._checkout(self._server, self._session) as conn:
max_wire_version = conn.max_wire_version
if self._retryable and not supports_session:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
self._retryable = False
return self._func(self._session, conn, self._retryable) # type: ignore
except PyMongoError as exc:
if not self._retryable:
raise
# Add the RetryableWriteError label, if applicable.
_add_retryable_write_error(exc, max_wire_version)
raise
def _read(self) -> T:
"""Wrapper method for read-type retryable client executions
:Returns:
Output for func()'s call
"""
self._server = self._get_server()
assert self._read_pref is not None, "Read Preference required on read calls"
with self._client._conn_from_server(self._read_pref, self._server, self._session) as (
conn,
read_pref,
):
if self._retrying and not self._retryable:
self._check_last_error()
return self._func(self._session, self._server, conn, read_pref) # type: ignore
def _after_fork_child() -> None:
"""Releases the locks in child process and resets the
topologies in all MongoClients.