PYTHON-2834 Direct read/write retries to another mongos if possible (#1421)

This commit is contained in:
Noah Stapp 2023-11-14 12:49:42 -08:00 committed by GitHub
parent b0cd7d2361
commit 5dc60342ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 4 deletions

View File

@ -1277,6 +1277,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
server_selector: Callable[[Selection], Selection],
session: Optional[ClientSession],
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
"""Select a server to run an operation on this client.
@ -1300,7 +1301,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
if not server:
raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031
else:
server = topology.select_server(server_selector)
server = topology.select_server(
server_selector, deprioritized_servers=deprioritized_servers
)
return server
except PyMongoError as exc:
# Server selection errors in a transaction are transient.
@ -2291,6 +2294,7 @@ class _ClientConnectionRetryable(Generic[T]):
)
self._address = address
self._server: Server = None # type: ignore
self._deprioritized_servers: list[Server] = []
def run(self) -> T:
"""Runs the supplied func() and attempts a retry
@ -2359,6 +2363,9 @@ class _ClientConnectionRetryable(Generic[T]):
if self._last_error is None:
self._last_error = exc
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
self._deprioritized_servers.append(self._server)
def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
@ -2397,7 +2404,10 @@ class _ClientConnectionRetryable(Generic[T]):
Abstraction to connect to server
"""
return self._client._select_server(
self._server_selector, self._session, address=self._address
self._server_selector,
self._session,
address=self._address,
deprioritized_servers=self._deprioritized_servers,
)
def _write(self) -> T:

View File

@ -282,8 +282,10 @@ class Topology:
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
servers = self.select_servers(selector, server_selection_timeout, address)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
@ -297,9 +299,12 @@ class Topology:
selector: Callable[[Selection], Selection],
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> Server:
"""Like select_servers, but choose a random server if several match."""
server = self._select_server(selector, server_selection_timeout, address)
server = self._select_server(
selector, server_selection_timeout, address, deprioritized_servers
)
if _csot.get_timeout():
_csot.set_rtt(server.description.min_round_trip_time)
return server
@ -931,3 +936,16 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]
def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates
filtered = [server for server in candidates if server not in deprioritized_servers]
# If not possible to pick a prioritized server, return the original list
return filtered or candidates

View File

@ -20,6 +20,9 @@ import pprint
import sys
import threading
from bson import SON
from pymongo.errors import AutoReconnect
sys.path[0:0] = [""]
from test import (
@ -31,9 +34,12 @@ from test import (
)
from test.utils import (
CMAPListener,
EventListener,
OvertCommandListener,
SpecTestCreator,
rs_client,
rs_or_single_client,
set_fail_point,
)
from test.utils_spec_runner import SpecRunner
@ -221,5 +227,48 @@ class TestPoolPausedError(IntegrationTest):
self.assertEqual(1, len(failed), msg)
class TestRetryableReads(IntegrationTest):
@client_context.require_multiple_mongoses
@client_context.require_failCommand_fail_point
def test_retryable_reads_in_sharded_cluster_multiple_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["find"],
"closeConnection": True,
"appName": "retryableReadTest",
},
}
mongos_clients = []
for mongos in client_context.mongos_seeds().split(","):
client = rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)
listener = OvertCommandListener()
client = rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableReadTest",
event_listeners=[listener],
retryReads=True,
)
with self.fail_point(fail_command):
with self.assertRaises(AutoReconnect):
client.t.t.find_one({})
# Disable failpoints on each mongos
for client in mongos_clients:
fail_command["mode"] = "off"
set_fail_point(client, fail_command)
self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -31,6 +31,7 @@ from test.utils import (
OvertCommandListener,
SpecTestCreator,
rs_or_single_client,
set_fail_point,
)
from test.utils_spec_runner import SpecRunner
from test.version import Version
@ -40,6 +41,7 @@ from bson.int64 import Int64
from bson.raw_bson import RawBSONDocument
from bson.son import SON
from pymongo.errors import (
AutoReconnect,
ConnectionFailure,
OperationFailure,
ServerSelectionTimeoutError,
@ -469,6 +471,46 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
self.assertEqual(final_txn, expected_txn)
self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1})
@client_context.require_multiple_mongoses
@client_context.require_failCommand_fail_point
def test_retryable_writes_in_sharded_cluster_multiple_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {
"failCommands": ["insert"],
"closeConnection": True,
"appName": "retryableWriteTest",
},
}
mongos_clients = []
for mongos in client_context.mongos_seeds().split(","):
client = rs_or_single_client(mongos)
set_fail_point(client, fail_command)
self.addCleanup(client.close)
mongos_clients.append(client)
listener = OvertCommandListener()
client = rs_or_single_client(
client_context.mongos_seeds(),
appName="retryableWriteTest",
event_listeners=[listener],
retryWrites=True,
)
with self.assertRaises(AutoReconnect):
client.t.t.insert_one({"x": 1})
# Disable failpoints on each mongos
for client in mongos_clients:
fail_command["mode"] = "off"
set_fail_point(client, fail_command)
self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)
class TestWriteConcernError(IntegrationTest):
RUN_ON_LOAD_BALANCER = True

View File

@ -30,11 +30,12 @@ from pymongo.hello import Hello, HelloCompat
from pymongo.monitor import Monitor
from pymongo.pool import PoolOptions
from pymongo.read_preferences import ReadPreference, Secondary
from pymongo.server import Server
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import any_server_selector, writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.settings import TopologySettings
from pymongo.topology import Topology, _ErrorContext
from pymongo.topology import Topology, _ErrorContext, _filter_servers
from pymongo.topology_description import TOPOLOGY_TYPE
@ -685,6 +686,23 @@ class TestMultiServerTopology(TopologyTest):
self.assertNotIn(("a", 27017), t.description.server_descriptions())
self.assertEqual(t.description.topology_type_name, "Unknown")
def test_filtered_server_selection(self):
s1 = Server(ServerDescription(("localhost", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
s2 = Server(ServerDescription(("localhost2", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
servers = [s1, s2]
result = _filter_servers(servers, deprioritized_servers=[s2])
self.assertEqual(result, [s1])
result = _filter_servers(servers, deprioritized_servers=[s1, s2])
self.assertEqual(result, servers)
result = _filter_servers(servers, deprioritized_servers=[])
self.assertEqual(result, servers)
result = _filter_servers(servers)
self.assertEqual(result, servers)
def wait_for_primary(topology):
"""Wait for a Topology to discover a writable server.

View File

@ -1153,3 +1153,9 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
raise AssertionError(f"Unsupported cursorType: {cursor_type}")
else:
arguments[c2s] = arguments.pop(arg_name)
def set_fail_point(client, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
client.admin.command(cmd)