PYTHON-5788 - Refine withTransaction timeout error wrapping semantics… (#2745)

This commit is contained in:
Noah Stapp 2026-04-13 11:07:37 -04:00 committed by GitHub
parent 02320d68e7
commit d864822d72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 14 deletions

View File

@ -516,9 +516,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
return NetworkTimeout(str(error))
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -804,15 +809,17 @@ class AsyncClientSession:
await self.commit_transaction()
except PyMongoError as exc:
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise

View File

@ -514,9 +514,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
return NetworkTimeout(str(error))
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error
_T = TypeVar("_T")
@ -800,15 +805,17 @@ class ClientSession:
self.commit_transaction()
except PyMongoError as exc:
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue
if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise

View File

@ -500,10 +500,12 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
@ -534,10 +536,12 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@async_client_context.require_test_commands
@async_client_context.require_transactions
@ -565,7 +569,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -573,6 +577,8 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):

View File

@ -492,10 +492,12 @@ class TestTransactionsConvenientAPI(TransactionsBase):
listener.reset()
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
@ -524,10 +526,12 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
@client_context.require_test_commands
@client_context.require_transactions
@ -553,7 +557,7 @@ class TestTransactionsConvenientAPI(TransactionsBase):
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)
# One insert for the callback and two commits (includes the automatic
@ -561,6 +565,8 @@ class TestTransactionsConvenientAPI(TransactionsBase):
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))
@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):