stepped through each logical call to ensure functional parity; added refactoring suggestions

This commit is contained in:
Jib 2023-08-15 17:29:49 -04:00
parent a69c8fdc56
commit 77e99c97d6
2 changed files with 185 additions and 125 deletions

View File

@ -844,7 +844,7 @@ class ClientSession:
) -> Dict[str, Any]:
return self._finish_transaction(conn, command_name)
return self._client._retry_internal(True, func, self, None)
return self._client._retry_internal(func, self, None, retryable=True)
def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]:
self._transaction.attempt += 1

View File

@ -1405,92 +1405,31 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
Re-raises any exception thrown by func().
"""
retryable = bool(
retryable and self.options.retry_writes and session and not session.in_transaction
)
return self._retry_internal(retryable, func, session, bulk)
return self._retry_internal(func=func, session=session, bulk=bulk, retryable=retryable)
@_csot.apply
def _retry_internal(
self,
retryable: bool,
func: Callable[[Optional[ClientSession], Connection, bool], T],
session: Optional[ClientSession],
bulk: Optional[_Bulk],
is_read: bool = False,
address: Optional[_Address] = None,
read_pref: Optional[_ServerMode] = None,
retryable: Optional[bool] = None,
) -> T:
"""Internal retryable write helper."""
max_wire_version = 0
last_error: Optional[Exception] = None
retrying = False
multiple_retries = _csot.get_timeout() is not None
"""Internal retryable helper for reads and writes."""
return _ClientConnectionRetryable(
mongo_client=self,
func=func,
bulk=bulk,
is_read=is_read,
session=session,
read_pref=read_pref,
address=address,
retryable=retryable,
).run()
def is_retrying() -> bool:
return bulk.retrying if bulk else retrying
# Increment the transaction id up front to ensure any retry attempt
# will use the proper txnNumber, even if server or socket selection
# fails before the command can be sent.
if retryable and session and not session.in_transaction:
session._start_retryable_write()
if bulk:
bulk.started_retryable_write = True
while True:
if is_retrying():
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
assert last_error is not None
raise last_error
try:
server = self._select_server(writable_server_selector, session)
supports_session = (
session is not None and server.description.retryable_writes_supported
)
with self._checkout(server, session) as conn:
max_wire_version = conn.max_wire_version
if retryable and not supports_session:
if is_retrying():
# A retry is not possible because this server does
# not support sessions raise the last error.
assert last_error is not None
raise last_error
retryable = False
return func(session, conn, retryable)
except ServerSelectionTimeoutError:
if is_retrying():
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except PyMongoError as exc:
if not retryable:
raise
assert session
# Add the RetryableWriteError label, if applicable.
_add_retryable_write_error(exc, max_wire_version)
retryable_error = exc.has_error_label("RetryableWriteError")
if retryable_error:
session._unpin()
if not retryable_error or (is_retrying() and not multiple_retries):
if exc.has_error_label("NoWritesPerformed") and last_error:
raise last_error from exc
else:
raise
if bulk:
bulk.retrying = True
else:
retrying = True
if not exc.has_error_label("NoWritesPerformed"):
last_error = exc
if last_error is None:
last_error = exc
@_csot.apply
def _retryable_read(
self,
func: Callable[[Optional[ClientSession], Server, Connection, _ServerMode], T],
@ -1506,54 +1445,15 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
Re-raises any exception thrown by func().
"""
retryable = (
retryable and self.options.retry_reads and not (session and session.in_transaction)
return self._retry_internal(
func,
session,
None,
is_read=True,
address=address,
read_pref=read_pref,
retryable=retryable,
)
last_error: Optional[Exception] = None
retrying = False
multiple_retries = _csot.get_timeout() is not None
while True:
if retrying:
remaining = _csot.remaining()
if remaining is not None and remaining <= 0:
assert last_error is not None
raise last_error
try:
server = self._select_server(read_pref, session, address=address)
with self._conn_from_server(read_pref, server, session) as (
conn,
read_pref,
):
if retrying and not retryable:
# A retry is not possible because this server does
# not support retryable reads, raise the last error.
assert last_error is not None
raise last_error
return func(session, server, conn, read_pref)
except ServerSelectionTimeoutError:
if retrying:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
assert last_error is not None
raise last_error
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except ConnectionFailure as exc:
if not retryable or (retrying and not multiple_retries):
raise
retrying = True
last_error = exc
except OperationFailure as exc:
if not retryable or (retrying and not multiple_retries):
raise
if exc.code not in helpers._RETRYABLE_ERROR_CODES:
raise
retrying = True
last_error = exc
def _retryable_write(
self,
@ -2304,6 +2204,166 @@ class _MongoClientErrorHandler:
return self.handle(exc_type, exc_val)
class _ClientConnectionRetryable:
"""Responsible for executing retryable connections on read or write operations"""
def __init__(
self,
mongo_client: MongoClient,
func: Callable[[Optional[ClientSession], Connection, bool], T],
bulk: Optional[_Bulk],
is_read: bool = False,
session: Optional[ClientSession] = None,
read_pref: Optional[_ServerMode] = None,
address: Optional[_Address] = None,
retryable: Optional[bool] = None,
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client
self._retry_operation = (
mongo_client.options.retry_reads if is_read else mongo_client.options.retry_writes
)
self._func = func
self._bulk = bulk
self._session = session
self._is_read = is_read
self._retryable = retryable and self._retry_operation and not self._in_transaction()
self._read_pref = read_pref
self._server_selector = read_pref if is_read else writable_server_selector
self._address = address
self._server: Server = None
def run(self) -> T:
"""Runs the supplied func() and attempts a retry
:raises self._last_error: Exception raised
:return: Result of the func() call
:rtype: T
"""
# Increment the transaction id up front to ensure any retry attempt
# will use the proper txnNumber, even if server or socket selection
# fails before the command can be sent.
if not self._in_transaction() and self._retryable and not self._is_read:
self._session._start_retryable_write()
if self._bulk:
self._bulk.started_retryable_write = True
while True:
self._check_last_error(check_csot=True)
try:
self._server = self._client._select_server(
self._server_selector, self._session, address=self._address
)
return self._read() if self._is_read else self._write()
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
# attempt. Raise the original exception instead.
self._check_last_error()
# A ServerSelectionTimeoutError error indicates that there may
# be a persistent outage. Attempting to retry in this case will
# most likely be a waste of time.
raise
except PyMongoError as exc:
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
exc_code = getattr(exc, "code", None)
if self._is_not_retry_eligible() or (
exc_code and exc_code not in helpers._RETRYABLE_ERROR_CODES
):
raise
self._retrying = True
self._last_error = exc
else:
raise
# Assumed write operation henceforth
if not self._retryable:
raise
assert self._session
if exc.has_error_label("RetryableWriteError"):
self._session._unpin()
elif self._is_retrying() and not self._multiple_retries:
if exc.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if self._bulk:
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc
def _is_not_retry_eligible(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._retrying and not self._multiple_retries)
def _is_retrying(self) -> bool:
"""Checks if the exchange is currently undergoing a retry"""
return self._bulk.retrying if self._bulk else self._retrying
def _in_transaction(self):
"""Checks if the ongoing session is in a transaction"""
return self._session and self._session.in_transaction
def _check_last_error(self, check_csot: bool = False):
"""Checks if the ongoing client exchange experienced a exception during retry
:param check_csot: Check csot timeout, defaults to False
"""
if self._is_retrying():
remaining = _csot.remaining()
if not check_csot or (remaining is not None and remaining <= 0):
assert self._last_error is not None
raise self._last_error
def _write(self) -> T:
"""Wrapper method for write-type retryable client executions
:return: output for func()'s call
"""
try:
max_wire_version = 0
supports_session = (
self._session is not None and self._server.description.retryable_writes_supported
)
with self._client._checkout(self._server, self._session) as conn:
max_wire_version = conn.max_wire_version
if self._retryable and not supports_session:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
self._retryable = False
return self._func(self._session, conn, self._retryable)
except PyMongoError as exc:
if not self._retryable:
raise
assert self._session
# Add the RetryableWriteError label, if applicable.
_add_retryable_write_error(exc, max_wire_version)
raise
def _read(self) -> T:
"""Wrapper method for read-type retryable client executions
:return: output for func()'s call
"""
with self._client._conn_from_server(self._read_pref, self._server, self._session) as (
conn,
read_pref,
):
if self._retrying and not self._retryable:
self._check_last_error()
return self._func(self._session, self._server, conn, read_pref)
def _after_fork_child() -> None:
"""Releases the locks in child process and resets the
topologies in all MongoClients.