From bc26c0db697d6f856d19b18eab95c4e3c6c26105 Mon Sep 17 00:00:00 2001 From: Prashant Mital <5883388+prashantmital@users.noreply.github.com> Date: Thu, 30 Aug 2018 17:33:03 -0500 Subject: [PATCH] PYTHON-1577 Allow applications to register a custom server selector (#371) PYTHON-1577 Allow applications to register a custom server selector --- doc/examples/index.rst | 1 + doc/examples/server_selection.rst | 108 +++++++++++++++++ pymongo/client_options.py | 7 ++ pymongo/common.py | 12 +- pymongo/mongo_client.py | 18 ++- pymongo/settings.py | 8 +- pymongo/topology.py | 5 +- pymongo/topology_description.py | 19 +-- test/test_server_selection.py | 189 +++++++++++++++++++++++++++++- test/utils_selection_tests.py | 22 +++- 10 files changed, 367 insertions(+), 22 deletions(-) create mode 100644 doc/examples/server_selection.rst diff --git a/doc/examples/index.rst b/doc/examples/index.rst index c33024caf..ab3d3086a 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -27,5 +27,6 @@ MongoDB, you can start it like so: gridfs high_availability mod_wsgi + server_selection tailable tls diff --git a/doc/examples/server_selection.rst b/doc/examples/server_selection.rst new file mode 100644 index 000000000..28659c133 --- /dev/null +++ b/doc/examples/server_selection.rst @@ -0,0 +1,108 @@ +Server Selector Example +======================= + +Users can exert fine-grained control over the `server selection algorithm`_ +by setting the `server_selector` option on the :class:`~pymongo.MongoClient` +to an appropriate callable. This example shows how to use this functionality +to prefer servers running on ``localhost``. + + +.. warning:: + + Use of custom server selector functions is a power user feature. Misusing + custom server selectors can have unintended consequences such as degraded + read/write performance. + + +.. testsetup:: + + from pymongo import MongoClient + + +.. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/ + + +Example: Selecting Servers Running on ``localhost`` +--------------------------------------------------- + +To start, we need to write the server selector function that will be used. +The server selector function should accept a list of +:class:`~pymongo.server_description.ServerDescription` objects and return a +list of server descriptions that are suitable for the read or write operation. +A server selector must not create or modify +:class:`~pymongo.server_description.ServerDescription` objects, and must return +the selected instances unchanged. + +In this example, we write a server selector that prioritizes servers running on +``localhost``. This can be desirable when using a sharded cluster with multiple +``mongos``, as locally run queries are likely to see lower latency and higher +throughput. Please note, however, that it is highly dependent on the +application if preferring ``localhost`` is beneficial or not. + +In addition to comparing the hostname with ``localhost``, our server selector +function accounts for the edge case when no servers are running on +``localhost``. In this case, we allow the default server selection logic to +prevail by passing through the received server description list unchanged. +Failure to do this would render the client unable to communicate with MongoDB +in the event that no servers were running on ``localhost``. + + +The described server selection logic is implemented in the following server +selector function: + + +.. doctest:: + + >>> def server_selector(server_descriptions): + ... servers = [ + ... server for server in server_descriptions + ... if server.address[0] == 'localhost' + ... ] + ... if not servers: + ... return server_descriptions + ... return servers + + + +Finally, we can create a :class:`~pymongo.MongoClient` instance with this +server selector. + + +.. doctest:: + + >>> client = MongoClient(server_selector=server_selector) + + + +Server Selection Process +------------------------ + +This section dives deeper into the server selection process for reads and +writes. In the case of a write, the driver performs the following operations +(in order) during the selection process: + + +#. Select all writeable servers from the list of known hosts. For a replica set + this is the primary, while for a sharded cluster this is all the known mongoses. + +#. Apply the user-defined server selector function. Note that the custom server + selector is **not** called if there are no servers left from the previous + filtering stage. + +#. Apply the ``localThresholdMS`` setting to the list of remaining hosts. This + whittles the host list down to only contain servers whose latency is at most + ``localThresholdMS`` milliseconds higher than the lowest observed latency. + +#. Select a server at random from the remaining host list. The desired + operation is then performed against the selected server. + + +In the case of **reads** the process is identical except for the first step. +Here, instead of selecting all writeable servers, we select all servers +matching the user's :class:`~pymongo.read_preferences.ReadPreference` from the +list of known hosts. As an example, for a 3-member replica set with a +:class:`~pymongo.read_preferences.Secondary` read preference, we would select +all available secondaries. + + +.. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/ \ No newline at end of file diff --git a/pymongo/client_options.py b/pymongo/client_options.py index b91e1d88c..3c3865b32 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -25,6 +25,7 @@ from pymongo.pool import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.read_preferences import (make_read_preference, read_pref_mode_from_name) +from pymongo.server_selectors import any_server_selector from pymongo.ssl_support import get_ssl_context from pymongo.write_concern import WriteConcern @@ -163,6 +164,8 @@ class ClientOptions(object): self.__heartbeat_frequency = options.get( 'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY) self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES) + self.__server_selector = options.get( + 'server_selector', any_server_selector) @property def _options(self): @@ -194,6 +197,10 @@ class ClientOptions(object): """The server selection timeout for this instance in seconds.""" return self.__server_selection_timeout + @property + def server_selector(self): + return self.__server_selector + @property def heartbeat_frequency(self): """The monitoring frequency in seconds.""" diff --git a/pymongo/common.py b/pymongo/common.py index 258bd05f5..91a8c2d92 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -473,6 +473,15 @@ def validate_driver_or_none(option, value): return value +def validate_is_callable_or_none(option, value): + """Validates that 'value' is a callable.""" + if value is None: + return value + if not callable(value): + raise ValueError("%s must be a callable" % (option,)) + return value + + def validate_ok_for_replace(replacement): """Validate a replacement document.""" validate_is_mapping("replacement", replacement) @@ -552,7 +561,7 @@ URI_VALIDATORS = { 'unicode_decode_error_handler': validate_unicode_decode_error_handler, 'retrywrites': validate_boolean_or_string, 'compressors': validate_compressors, - 'zlibcompressionlevel': validate_zlib_compression_level + 'zlibcompressionlevel': validate_zlib_compression_level, } TIMEOUT_VALIDATORS = { @@ -572,6 +581,7 @@ KW_VALIDATORS = { 'tzinfo': validate_tzinfo, 'username': validate_string_or_none, 'password': validate_string_or_none, + 'server_selector': validate_is_callable_or_none, } URI_VALIDATORS.update(TIMEOUT_VALIDATORS) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index d1941a60a..9409ed4e4 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -215,6 +215,12 @@ class MongoClient(common.BaseObject): milliseconds) the driver will wait during server monitoring when connecting a new socket to a server before concluding the server is unavailable. Defaults to ``20000`` (20 seconds). + - `server_selector`: (callable or None) Optional, user-provided + function that augments server selection rules. The function should + accept as an argument a list of + :class:`~pymongo.server_description.ServerDescription` objects and + return a list of server descriptions that should be considered + suitable for the desired operation. - `serverSelectionTimeoutMS`: (integer) Controls how long (in milliseconds) the driver will wait to find an available, appropriate server to carry out a database operation; while it is @@ -331,6 +337,8 @@ class MongoClient(common.BaseObject): is set, it must be a positive integer greater than or equal to 90 seconds. + .. seealso:: :doc:`/examples/server_selection` + | **Authentication:** - `username`: A string. @@ -411,13 +419,16 @@ class MongoClient(common.BaseObject): .. mongodoc:: connections - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - Added the ``retryWrites`` keyword argument and URI option. + .. versionchanged:: 3.8 + Added the ``server_selector`` keyword argument. .. versionchanged:: 3.7 Added the ``driver`` keyword argument. + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + Added the ``retryWrites`` keyword argument and URI option. + .. versionchanged:: 3.5 Add ``username`` and ``password`` options. Document the ``authSource``, ``authMechanism``, and ``authMechanismProperties `` @@ -572,6 +583,7 @@ class MongoClient(common.BaseObject): condition_class=condition_class, local_threshold_ms=options.local_threshold_ms, server_selection_timeout=options.server_selection_timeout, + server_selector=options.server_selector, heartbeat_frequency=options.heartbeat_frequency) self._topology = Topology(self._topology_settings) diff --git a/pymongo/settings.py b/pymongo/settings.py index 88bac1a65..a03bbc98a 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -35,7 +35,8 @@ class TopologySettings(object): condition_class=None, local_threshold_ms=LOCAL_THRESHOLD_MS, server_selection_timeout=SERVER_SELECTION_TIMEOUT, - heartbeat_frequency=common.HEARTBEAT_FREQUENCY): + heartbeat_frequency=common.HEARTBEAT_FREQUENCY, + server_selector=None): """Represent MongoClient's configuration. Take a list of (host, port) pairs and optional replica set name. @@ -53,6 +54,7 @@ class TopologySettings(object): self._condition_class = condition_class or threading.Condition self._local_threshold_ms = local_threshold_ms self._server_selection_timeout = server_selection_timeout + self._server_selector = server_selector self._heartbeat_frequency = heartbeat_frequency self._direct = (len(self._seeds) == 1 and not replica_set_name) self._topology_id = ObjectId() @@ -90,6 +92,10 @@ class TopologySettings(object): def server_selection_timeout(self): return self._server_selection_timeout + @property + def server_selector(self): + return self._server_selector + @property def heartbeat_frequency(self): return self._heartbeat_frequency diff --git a/pymongo/topology.py b/pymongo/topology.py index aca1d91f1..c69c127ff 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -190,7 +190,7 @@ class Topology(object): now = _time() end_time = now + timeout server_descriptions = self._description.apply_selector( - selector, address) + selector, address, custom_selector=self._settings.server_selector) while not server_descriptions: # No suitable servers. @@ -209,7 +209,8 @@ class Topology(object): self._description.check_compatible() now = _time() server_descriptions = self._description.apply_selector( - selector, address) + selector, address, + custom_selector=self._settings.server_selector) self._description.check_compatible() return server_descriptions diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index ac282bb14..eca761520 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -214,7 +214,7 @@ class TopologyDescription(object): def heartbeat_frequency(self): return self._topology_settings.heartbeat_frequency - def apply_selector(self, selector, address): + def apply_selector(self, selector, address, custom_selector=None): def apply_local_threshold(selection): if not selection: @@ -239,18 +239,23 @@ class TopologyDescription(object): common_wv)) if self.topology_type == TOPOLOGY_TYPE.Single: - # Ignore the selector. + # Ignore selectors for standalone. return self.known_servers elif address: + # Ignore selectors when explicit address is requested. description = self.server_descriptions().get(address) return [description] if description else [] elif self.topology_type == TOPOLOGY_TYPE.Sharded: - # Ignore the read preference, but apply localThresholdMS. - return apply_local_threshold( - Selection.from_topology_description(self)) + # Ignore read preference. + selection = Selection.from_topology_description(self) else: - return apply_local_threshold( - selector(Selection.from_topology_description(self))) + selection = selector(Selection.from_topology_description(self)) + + # Apply custom selector followed by localThresholdMS. + if custom_selector is not None and selection: + selection = selection.with_server_descriptions( + custom_selector(selection.server_descriptions)) + return apply_local_threshold(selection) def has_readable_server(self, read_preference=ReadPreference.PRIMARY): """Does this topology have any readable servers available matching the diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 57b3ed302..ff2729dc7 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -17,10 +17,20 @@ import os import sys +from pymongo import MongoClient +from pymongo import ReadPreference +from pymongo.errors import ServerSelectionTimeoutError +from pymongo.server_selectors import writable_server_selector +from pymongo.settings import TopologySettings +from pymongo.topology import Topology + sys.path[0:0] = [""] -from test import unittest -from test.utils_selection_tests import create_selection_tests +from test import client_context, unittest, IntegrationTest +from test.utils import rs_or_single_client, wait_until, EventListener +from test.utils_selection_tests import ( + create_selection_tests, get_addresses, get_topology_settings_dict, + make_server_description) # Location of JSON test specifications. @@ -29,8 +39,183 @@ _TEST_PATH = os.path.join( os.path.join('server_selection', 'server_selection')) +class CallCountSelector(object): + """No-op selector that keeps track of how many times it is called.""" + def __init__(self): + self.call_count = 0 + + def __call__(self, servers): + self.call_count += 1 + return servers + + +class SelectionStoreSelector(object): + """No-op selector that keeps track of what was passed to it.""" + def __init__(self): + self.selection = None + + def __call__(self, selection): + self.selection = selection + return selection + + + class TestAllScenarios(create_selection_tests(_TEST_PATH)): pass + +class TestCustomServerSelectorFunction(IntegrationTest): + @client_context.require_replica_set + def test_functional_select_max_port_number_host(self): + # Selector that returns server with highest port number. + def custom_selector(servers): + ports = [s.address[1] for s in servers] + idx = ports.index(max(ports)) + return [servers[idx]] + + # Initialize client with appropriate listeners. + listener = EventListener() + client = rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener]) + self.addCleanup(client.close) + coll = client.get_database( + 'testdb', read_preference=ReadPreference.NEAREST).coll + self.addCleanup(client.drop_database, 'testdb') + + # Wait the node list to be fully populated. + def all_hosts_started(): + return (len(client.admin.command('isMaster')['hosts']) == + len(client._topology._description.readable_servers)) + + wait_until(all_hosts_started, 'receive heartbeat from all hosts') + expected_port = max([ + n.address[1] + for n in client._topology._description.readable_servers]) + + # Insert 1 record and access it 10 times. + coll.insert_one({'name': 'John Doe'}) + for _ in range(10): + coll.find_one({'name': 'John Doe'}) + + # Confirm all find commands are run against appropriate host. + for command in listener.results['started']: + if command.command_name == 'find': + self.assertEqual( + command.connection_id[1], expected_port) + + def test_invalid_server_selector(self): + # Client initialization must fail if server_selector is not callable. + for selector_candidate in [list(), 10, 'string', {}]: + with self.assertRaisesRegex(ValueError, "must be a callable"): + MongoClient(connect=False, server_selector=selector_candidate) + + # None value for server_selector is OK. + MongoClient(connect=False, server_selector=None) + + @client_context.require_replica_set + def test_selector_called(self): + selector = CallCountSelector() + + # Client setup. + mongo_client = rs_or_single_client(server_selector=selector) + test_collection = mongo_client.testdb.test_collection + self.addCleanup(mongo_client.drop_database, 'testdb') + self.addCleanup(mongo_client.close) + + # Do N operations and test selector is called at least N times. + test_collection.insert_one({'age': 20, 'name': 'John'}) + test_collection.insert_one({'age': 31, 'name': 'Jane'}) + test_collection.update_one({'name': 'Jane'}, {'$set': {'age': 21}}) + test_collection.find_one({'name': 'Roe'}) + self.assertGreaterEqual(selector.call_count, 4) + + @client_context.require_replica_set + def test_latency_threshold_application(self): + selector = SelectionStoreSelector() + + scenario_def = { + 'topology_description': { + 'type': 'ReplicaSetWithPrimary', 'servers': [ + {'address': 'b:27017', + 'avg_rtt_ms': 10000, + 'type': 'RSSecondary', + 'tag': {}}, + {'address': 'c:27017', + 'avg_rtt_ms': 20000, + 'type': 'RSSecondary', + 'tag': {}}, + {'address': 'a:27017', + 'avg_rtt_ms': 30000, + 'type': 'RSPrimary', + 'tag': {}}, + ]}} + + # Create & populate Topology such that all but one server is too slow. + rtt_times = [srv['avg_rtt_ms'] for srv in + scenario_def['topology_description']['servers']] + min_rtt_idx = rtt_times.index(min(rtt_times)) + seeds, hosts = get_addresses( + scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, + server_selector=selector) + topology = Topology(TopologySettings(**settings)) + topology.open() + for server in scenario_def['topology_description']['servers']: + server_description = make_server_description(server, hosts) + topology.on_change(server_description) + + # Invoke server selection and assert no filtering based on latency + # prior to custom server selection logic kicking in. + server = topology.select_server(ReadPreference.NEAREST) + self.assertEqual( + len(selector.selection), + len(topology.description.server_descriptions())) + + # Ensure proper filtering based on latency after custom selection. + self.assertEqual( + server.description.address, seeds[min_rtt_idx]) + + @client_context.require_replica_set + def test_server_selector_bypassed(self): + selector = CallCountSelector() + + scenario_def = { + 'topology_description': { + 'type': 'ReplicaSetNoPrimary', 'servers': [ + {'address': 'b:27017', + 'avg_rtt_ms': 10000, + 'type': 'RSSecondary', + 'tag': {}}, + {'address': 'c:27017', + 'avg_rtt_ms': 20000, + 'type': 'RSSecondary', + 'tag': {}}, + {'address': 'a:27017', + 'avg_rtt_ms': 30000, + 'type': 'RSSecondary', + 'tag': {}}, + ]}} + + # Create & populate Topology such that no server is writeable. + seeds, hosts = get_addresses( + scenario_def["topology_description"]["servers"]) + settings = get_topology_settings_dict( + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, + server_selector=selector) + topology = Topology(TopologySettings(**settings)) + topology.open() + for server in scenario_def['topology_description']['servers']: + server_description = make_server_description(server, hosts) + topology.on_change(server_description) + + # Invoke server selection and assert no calls to our custom selector. + with self.assertRaisesRegex( + ServerSelectionTimeoutError, 'No primary available for writes'): + topology.select_server( + writable_server_selector, server_selection_timeout=0.1) + self.assertEqual(selector.call_count, 0) + + if __name__ == "__main__": unittest.main() diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 66d31b823..d9c1a5a03 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -139,6 +139,16 @@ def get_topology_type_name(scenario_def): return name +def get_topology_settings_dict(**kwargs): + settings = dict( + monitor_class=MockMonitor, + heartbeat_frequency=HEARTBEAT_FREQUENCY, + pool_class=MockPool + ) + settings.update(kwargs) + return settings + + def create_test(scenario_def): def run_scenario(self): # Initialize topologies. @@ -147,14 +157,14 @@ def create_test(scenario_def): else: frequency = HEARTBEAT_FREQUENCY - settings = dict( - monitor_class=MockMonitor, - heartbeat_frequency=frequency, - pool_class=MockPool) - - settings['seeds'], hosts = get_addresses( + seeds, hosts = get_addresses( scenario_def['topology_description']['servers']) + settings = get_topology_settings_dict( + heartbeat_frequency=frequency, + seeds=seeds + ) + # "Eligible servers" is defined in the server selection spec as # the set of servers matching both the ReadPreference's mode # and tag sets.