diff --git a/test/test_transactions.py b/test/test_transactions.py index 1b15536eb..892889f8c 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -498,28 +498,18 @@ class TestTransactions(TransactionsBase): self.assertGreater(len(addresses), 1) -class PatchMockTimer(object): - """Patches the client_session module's monotonic time to test timeouts.""" - def __init__(self): - self.monotonic_time = None - self.counter = 0 - - def time(self): - time = self.monotonic_time() - if self.counter > 0: - time += client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT - self.counter += 1 - return time +class PatchSessionTimeout(object): + """Patches the client_session's with_transaction timeout for testing.""" + def __init__(self, mock_timeout): + self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT + self.mock_timeout = mock_timeout def __enter__(self): - self.monotonic_time = client_session.monotonic.time - self.counter = 0 - client_session.monotonic.time = self.time + client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.monotonic_time: - client_session.monotonic.time = self.monotonic_time + client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout class TestTransactionsConvenientAPI(TransactionsBase): @@ -572,7 +562,7 @@ class TestTransactionsConvenientAPI(TransactionsBase): coll.insert_one({}) listener.results.clear() with client.start_session() as s: - with PatchMockTimer(): + with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) @@ -601,7 +591,7 @@ class TestTransactionsConvenientAPI(TransactionsBase): listener.results.clear() with client.start_session() as s: - with PatchMockTimer(): + with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) @@ -629,7 +619,7 @@ class TestTransactionsConvenientAPI(TransactionsBase): listener.results.clear() with client.start_session() as s: - with PatchMockTimer(): + with PatchSessionTimeout(0): with self.assertRaises(ConnectionFailure): s.with_transaction(callback)