diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index c7369db90..361db4ca9 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -19,7 +19,7 @@ import os import pprint import sys import threading -from test.asynchronous.utils import async_set_fail_point +from test.asynchronous.utils import async_ensure_all_connected, async_set_fail_point from unittest import mock from pymongo import MongoClient @@ -295,6 +295,9 @@ class TestRetryableReads(AsyncIntegrationTest): enableOverloadRetargeting=True, ) + # Ensure the client has discovered all nodes. + await async_ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. command_args = { "configureFailPoint": "failCommand", @@ -334,6 +337,9 @@ class TestRetryableReads(AsyncIntegrationTest): event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" ) + # Ensure the client has discovered all nodes. + await async_ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError error label. command_args = { "configureFailPoint": "failCommand", @@ -375,6 +381,9 @@ class TestRetryableReads(AsyncIntegrationTest): readPreference="primaryPreferred", ) + # Ensure the client has discovered all nodes. + await async_ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. command_args = { "configureFailPoint": "failCommand", diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 751319479..2da87f8b2 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -19,7 +19,7 @@ import os import pprint import sys import threading -from test.utils import set_fail_point +from test.utils import ensure_all_connected, set_fail_point from unittest import mock from pymongo import MongoClient @@ -293,6 +293,9 @@ class TestRetryableReads(IntegrationTest): enableOverloadRetargeting=True, ) + # Ensure the client has discovered all nodes. + ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. command_args = { "configureFailPoint": "failCommand", @@ -332,6 +335,9 @@ class TestRetryableReads(IntegrationTest): event_listeners=[listener], retryReads=True, readPreference="primaryPreferred" ) + # Ensure the client has discovered all nodes. + ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError error label. command_args = { "configureFailPoint": "failCommand", @@ -373,6 +379,9 @@ class TestRetryableReads(IntegrationTest): readPreference="primaryPreferred", ) + # Ensure the client has discovered all nodes. + ensure_all_connected(client) + # 2. Configure a fail point with the RetryableError and SystemOverloadedError error labels. command_args = { "configureFailPoint": "failCommand",