PYTHON-2834 Direct read/write retries to another mongos if possible (#1421)
This commit is contained in:
parent
b0cd7d2361
commit
5dc60342ed
@ -1277,6 +1277,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
server_selector: Callable[[Selection], Selection],
|
||||
session: Optional[ClientSession],
|
||||
address: Optional[_Address] = None,
|
||||
deprioritized_servers: Optional[list[Server]] = None,
|
||||
) -> Server:
|
||||
"""Select a server to run an operation on this client.
|
||||
|
||||
@ -1300,7 +1301,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
if not server:
|
||||
raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031
|
||||
else:
|
||||
server = topology.select_server(server_selector)
|
||||
server = topology.select_server(
|
||||
server_selector, deprioritized_servers=deprioritized_servers
|
||||
)
|
||||
return server
|
||||
except PyMongoError as exc:
|
||||
# Server selection errors in a transaction are transient.
|
||||
@ -2291,6 +2294,7 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
)
|
||||
self._address = address
|
||||
self._server: Server = None # type: ignore
|
||||
self._deprioritized_servers: list[Server] = []
|
||||
|
||||
def run(self) -> T:
|
||||
"""Runs the supplied func() and attempts a retry
|
||||
@ -2359,6 +2363,9 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
if self._last_error is None:
|
||||
self._last_error = exc
|
||||
|
||||
if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
|
||||
self._deprioritized_servers.append(self._server)
|
||||
|
||||
def _is_not_eligible_for_retry(self) -> bool:
|
||||
"""Checks if the exchange is not eligible for retry"""
|
||||
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
|
||||
@ -2397,7 +2404,10 @@ class _ClientConnectionRetryable(Generic[T]):
|
||||
Abstraction to connect to server
|
||||
"""
|
||||
return self._client._select_server(
|
||||
self._server_selector, self._session, address=self._address
|
||||
self._server_selector,
|
||||
self._session,
|
||||
address=self._address,
|
||||
deprioritized_servers=self._deprioritized_servers,
|
||||
)
|
||||
|
||||
def _write(self) -> T:
|
||||
|
||||
@ -282,8 +282,10 @@ class Topology:
|
||||
selector: Callable[[Selection], Selection],
|
||||
server_selection_timeout: Optional[float] = None,
|
||||
address: Optional[_Address] = None,
|
||||
deprioritized_servers: Optional[list[Server]] = None,
|
||||
) -> Server:
|
||||
servers = self.select_servers(selector, server_selection_timeout, address)
|
||||
servers = _filter_servers(servers, deprioritized_servers)
|
||||
if len(servers) == 1:
|
||||
return servers[0]
|
||||
server1, server2 = random.sample(servers, 2)
|
||||
@ -297,9 +299,12 @@ class Topology:
|
||||
selector: Callable[[Selection], Selection],
|
||||
server_selection_timeout: Optional[float] = None,
|
||||
address: Optional[_Address] = None,
|
||||
deprioritized_servers: Optional[list[Server]] = None,
|
||||
) -> Server:
|
||||
"""Like select_servers, but choose a random server if several match."""
|
||||
server = self._select_server(selector, server_selection_timeout, address)
|
||||
server = self._select_server(
|
||||
selector, server_selection_timeout, address, deprioritized_servers
|
||||
)
|
||||
if _csot.get_timeout():
|
||||
_csot.set_rtt(server.description.min_round_trip_time)
|
||||
return server
|
||||
@ -931,3 +936,16 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
|
||||
if current_tv["processId"] != new_tv["processId"]:
|
||||
return False
|
||||
return current_tv["counter"] > new_tv["counter"]
|
||||
|
||||
|
||||
def _filter_servers(
|
||||
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
|
||||
) -> list[Server]:
|
||||
"""Filter out deprioritized servers from a list of server candidates."""
|
||||
if not deprioritized_servers:
|
||||
return candidates
|
||||
|
||||
filtered = [server for server in candidates if server not in deprioritized_servers]
|
||||
|
||||
# If not possible to pick a prioritized server, return the original list
|
||||
return filtered or candidates
|
||||
|
||||
@ -20,6 +20,9 @@ import pprint
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from bson import SON
|
||||
from pymongo.errors import AutoReconnect
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import (
|
||||
@ -31,9 +34,12 @@ from test import (
|
||||
)
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
rs_client,
|
||||
rs_or_single_client,
|
||||
set_fail_point,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
|
||||
@ -221,5 +227,48 @@ class TestPoolPausedError(IntegrationTest):
|
||||
self.assertEqual(1, len(failed), msg)
|
||||
|
||||
|
||||
class TestRetryableReads(IntegrationTest):
|
||||
@client_context.require_multiple_mongoses
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_retryable_reads_in_sharded_cluster_multiple_available(self):
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["find"],
|
||||
"closeConnection": True,
|
||||
"appName": "retryableReadTest",
|
||||
},
|
||||
}
|
||||
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="retryableReadTest",
|
||||
event_listeners=[listener],
|
||||
retryReads=True,
|
||||
)
|
||||
|
||||
with self.fail_point(fail_command):
|
||||
with self.assertRaises(AutoReconnect):
|
||||
client.t.t.find_one({})
|
||||
|
||||
# Disable failpoints on each mongos
|
||||
for client in mongos_clients:
|
||||
fail_command["mode"] = "off"
|
||||
set_fail_point(client, fail_command)
|
||||
|
||||
self.assertEqual(len(listener.failed_events), 2)
|
||||
self.assertEqual(len(listener.succeeded_events), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -31,6 +31,7 @@ from test.utils import (
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
rs_or_single_client,
|
||||
set_fail_point,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunner
|
||||
from test.version import Version
|
||||
@ -40,6 +41,7 @@ from bson.int64 import Int64
|
||||
from bson.raw_bson import RawBSONDocument
|
||||
from bson.son import SON
|
||||
from pymongo.errors import (
|
||||
AutoReconnect,
|
||||
ConnectionFailure,
|
||||
OperationFailure,
|
||||
ServerSelectionTimeoutError,
|
||||
@ -469,6 +471,46 @@ class TestRetryableWrites(IgnoreDeprecationsTest):
|
||||
self.assertEqual(final_txn, expected_txn)
|
||||
self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1})
|
||||
|
||||
@client_context.require_multiple_mongoses
|
||||
@client_context.require_failCommand_fail_point
|
||||
def test_retryable_writes_in_sharded_cluster_multiple_available(self):
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
"mode": {"times": 1},
|
||||
"data": {
|
||||
"failCommands": ["insert"],
|
||||
"closeConnection": True,
|
||||
"appName": "retryableWriteTest",
|
||||
},
|
||||
}
|
||||
|
||||
mongos_clients = []
|
||||
|
||||
for mongos in client_context.mongos_seeds().split(","):
|
||||
client = rs_or_single_client(mongos)
|
||||
set_fail_point(client, fail_command)
|
||||
self.addCleanup(client.close)
|
||||
mongos_clients.append(client)
|
||||
|
||||
listener = OvertCommandListener()
|
||||
client = rs_or_single_client(
|
||||
client_context.mongos_seeds(),
|
||||
appName="retryableWriteTest",
|
||||
event_listeners=[listener],
|
||||
retryWrites=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(AutoReconnect):
|
||||
client.t.t.insert_one({"x": 1})
|
||||
|
||||
# Disable failpoints on each mongos
|
||||
for client in mongos_clients:
|
||||
fail_command["mode"] = "off"
|
||||
set_fail_point(client, fail_command)
|
||||
|
||||
self.assertEqual(len(listener.failed_events), 2)
|
||||
self.assertEqual(len(listener.succeeded_events), 0)
|
||||
|
||||
|
||||
class TestWriteConcernError(IntegrationTest):
|
||||
RUN_ON_LOAD_BALANCER = True
|
||||
|
||||
@ -30,11 +30,12 @@ from pymongo.hello import Hello, HelloCompat
|
||||
from pymongo.monitor import Monitor
|
||||
from pymongo.pool import PoolOptions
|
||||
from pymongo.read_preferences import ReadPreference, Secondary
|
||||
from pymongo.server import Server
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import any_server_selector, writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.settings import TopologySettings
|
||||
from pymongo.topology import Topology, _ErrorContext
|
||||
from pymongo.topology import Topology, _ErrorContext, _filter_servers
|
||||
from pymongo.topology_description import TOPOLOGY_TYPE
|
||||
|
||||
|
||||
@ -685,6 +686,23 @@ class TestMultiServerTopology(TopologyTest):
|
||||
self.assertNotIn(("a", 27017), t.description.server_descriptions())
|
||||
self.assertEqual(t.description.topology_type_name, "Unknown")
|
||||
|
||||
def test_filtered_server_selection(self):
|
||||
s1 = Server(ServerDescription(("localhost", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
|
||||
s2 = Server(ServerDescription(("localhost2", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type]
|
||||
servers = [s1, s2]
|
||||
|
||||
result = _filter_servers(servers, deprioritized_servers=[s2])
|
||||
self.assertEqual(result, [s1])
|
||||
|
||||
result = _filter_servers(servers, deprioritized_servers=[s1, s2])
|
||||
self.assertEqual(result, servers)
|
||||
|
||||
result = _filter_servers(servers, deprioritized_servers=[])
|
||||
self.assertEqual(result, servers)
|
||||
|
||||
result = _filter_servers(servers)
|
||||
self.assertEqual(result, servers)
|
||||
|
||||
|
||||
def wait_for_primary(topology):
|
||||
"""Wait for a Topology to discover a writable server.
|
||||
|
||||
@ -1153,3 +1153,9 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac
|
||||
raise AssertionError(f"Unsupported cursorType: {cursor_type}")
|
||||
else:
|
||||
arguments[c2s] = arguments.pop(arg_name)
|
||||
|
||||
|
||||
def set_fail_point(client, command_args):
|
||||
cmd = SON([("configureFailPoint", "failCommand")])
|
||||
cmd.update(command_args)
|
||||
client.admin.command(cmd)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user