PYTHON-1577 Allow applications to register a custom server selector (#371)

PYTHON-1577 Allow applications to register a custom server selector
This commit is contained in:
Prashant Mital 2018-08-30 17:33:03 -05:00 committed by GitHub
parent 58851e1221
commit bc26c0db69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 367 additions and 22 deletions

View File

@ -27,5 +27,6 @@ MongoDB, you can start it like so:
gridfs
high_availability
mod_wsgi
server_selection
tailable
tls

View File

@ -0,0 +1,108 @@
Server Selector Example
=======================
Users can exert fine-grained control over the `server selection algorithm`_
by setting the `server_selector` option on the :class:`~pymongo.MongoClient`
to an appropriate callable. This example shows how to use this functionality
to prefer servers running on ``localhost``.
.. warning::
Use of custom server selector functions is a power user feature. Misusing
custom server selectors can have unintended consequences such as degraded
read/write performance.
.. testsetup::
from pymongo import MongoClient
.. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/
Example: Selecting Servers Running on ``localhost``
---------------------------------------------------
To start, we need to write the server selector function that will be used.
The server selector function should accept a list of
:class:`~pymongo.server_description.ServerDescription` objects and return a
list of server descriptions that are suitable for the read or write operation.
A server selector must not create or modify
:class:`~pymongo.server_description.ServerDescription` objects, and must return
the selected instances unchanged.
In this example, we write a server selector that prioritizes servers running on
``localhost``. This can be desirable when using a sharded cluster with multiple
``mongos``, as locally run queries are likely to see lower latency and higher
throughput. Please note, however, that it is highly dependent on the
application if preferring ``localhost`` is beneficial or not.
In addition to comparing the hostname with ``localhost``, our server selector
function accounts for the edge case when no servers are running on
``localhost``. In this case, we allow the default server selection logic to
prevail by passing through the received server description list unchanged.
Failure to do this would render the client unable to communicate with MongoDB
in the event that no servers were running on ``localhost``.
The described server selection logic is implemented in the following server
selector function:
.. doctest::
>>> def server_selector(server_descriptions):
... servers = [
... server for server in server_descriptions
... if server.address[0] == 'localhost'
... ]
... if not servers:
... return server_descriptions
... return servers
Finally, we can create a :class:`~pymongo.MongoClient` instance with this
server selector.
.. doctest::
>>> client = MongoClient(server_selector=server_selector)
Server Selection Process
------------------------
This section dives deeper into the server selection process for reads and
writes. In the case of a write, the driver performs the following operations
(in order) during the selection process:
#. Select all writeable servers from the list of known hosts. For a replica set
this is the primary, while for a sharded cluster this is all the known mongoses.
#. Apply the user-defined server selector function. Note that the custom server
selector is **not** called if there are no servers left from the previous
filtering stage.
#. Apply the ``localThresholdMS`` setting to the list of remaining hosts. This
whittles the host list down to only contain servers whose latency is at most
``localThresholdMS`` milliseconds higher than the lowest observed latency.
#. Select a server at random from the remaining host list. The desired
operation is then performed against the selected server.
In the case of **reads** the process is identical except for the first step.
Here, instead of selecting all writeable servers, we select all servers
matching the user's :class:`~pymongo.read_preferences.ReadPreference` from the
list of known hosts. As an example, for a 3-member replica set with a
:class:`~pymongo.read_preferences.Secondary` read preference, we would select
all available secondaries.
.. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/

View File

@ -25,6 +25,7 @@ from pymongo.pool import PoolOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (make_read_preference,
read_pref_mode_from_name)
from pymongo.server_selectors import any_server_selector
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern
@ -163,6 +164,8 @@ class ClientOptions(object):
self.__heartbeat_frequency = options.get(
'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY)
self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES)
self.__server_selector = options.get(
'server_selector', any_server_selector)
@property
def _options(self):
@ -194,6 +197,10 @@ class ClientOptions(object):
"""The server selection timeout for this instance in seconds."""
return self.__server_selection_timeout
@property
def server_selector(self):
return self.__server_selector
@property
def heartbeat_frequency(self):
"""The monitoring frequency in seconds."""

View File

@ -473,6 +473,15 @@ def validate_driver_or_none(option, value):
return value
def validate_is_callable_or_none(option, value):
"""Validates that 'value' is a callable."""
if value is None:
return value
if not callable(value):
raise ValueError("%s must be a callable" % (option,))
return value
def validate_ok_for_replace(replacement):
"""Validate a replacement document."""
validate_is_mapping("replacement", replacement)
@ -552,7 +561,7 @@ URI_VALIDATORS = {
'unicode_decode_error_handler': validate_unicode_decode_error_handler,
'retrywrites': validate_boolean_or_string,
'compressors': validate_compressors,
'zlibcompressionlevel': validate_zlib_compression_level
'zlibcompressionlevel': validate_zlib_compression_level,
}
TIMEOUT_VALIDATORS = {
@ -572,6 +581,7 @@ KW_VALIDATORS = {
'tzinfo': validate_tzinfo,
'username': validate_string_or_none,
'password': validate_string_or_none,
'server_selector': validate_is_callable_or_none,
}
URI_VALIDATORS.update(TIMEOUT_VALIDATORS)

View File

@ -215,6 +215,12 @@ class MongoClient(common.BaseObject):
milliseconds) the driver will wait during server monitoring when
connecting a new socket to a server before concluding the server
is unavailable. Defaults to ``20000`` (20 seconds).
- `server_selector`: (callable or None) Optional, user-provided
function that augments server selection rules. The function should
accept as an argument a list of
:class:`~pymongo.server_description.ServerDescription` objects and
return a list of server descriptions that should be considered
suitable for the desired operation.
- `serverSelectionTimeoutMS`: (integer) Controls how long (in
milliseconds) the driver will wait to find an available,
appropriate server to carry out a database operation; while it is
@ -331,6 +337,8 @@ class MongoClient(common.BaseObject):
is set, it must be a positive integer greater than or equal to
90 seconds.
.. seealso:: :doc:`/examples/server_selection`
| **Authentication:**
- `username`: A string.
@ -411,13 +419,16 @@ class MongoClient(common.BaseObject):
.. mongodoc:: connections
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
Added the ``retryWrites`` keyword argument and URI option.
.. versionchanged:: 3.8
Added the ``server_selector`` keyword argument.
.. versionchanged:: 3.7
Added the ``driver`` keyword argument.
.. versionchanged:: 3.6
Added support for mongodb+srv:// URIs.
Added the ``retryWrites`` keyword argument and URI option.
.. versionchanged:: 3.5
Add ``username`` and ``password`` options. Document the
``authSource``, ``authMechanism``, and ``authMechanismProperties ``
@ -572,6 +583,7 @@ class MongoClient(common.BaseObject):
condition_class=condition_class,
local_threshold_ms=options.local_threshold_ms,
server_selection_timeout=options.server_selection_timeout,
server_selector=options.server_selector,
heartbeat_frequency=options.heartbeat_frequency)
self._topology = Topology(self._topology_settings)

View File

@ -35,7 +35,8 @@ class TopologySettings(object):
condition_class=None,
local_threshold_ms=LOCAL_THRESHOLD_MS,
server_selection_timeout=SERVER_SELECTION_TIMEOUT,
heartbeat_frequency=common.HEARTBEAT_FREQUENCY):
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
server_selector=None):
"""Represent MongoClient's configuration.
Take a list of (host, port) pairs and optional replica set name.
@ -53,6 +54,7 @@ class TopologySettings(object):
self._condition_class = condition_class or threading.Condition
self._local_threshold_ms = local_threshold_ms
self._server_selection_timeout = server_selection_timeout
self._server_selector = server_selector
self._heartbeat_frequency = heartbeat_frequency
self._direct = (len(self._seeds) == 1 and not replica_set_name)
self._topology_id = ObjectId()
@ -90,6 +92,10 @@ class TopologySettings(object):
def server_selection_timeout(self):
return self._server_selection_timeout
@property
def server_selector(self):
return self._server_selector
@property
def heartbeat_frequency(self):
return self._heartbeat_frequency

View File

@ -190,7 +190,7 @@ class Topology(object):
now = _time()
end_time = now + timeout
server_descriptions = self._description.apply_selector(
selector, address)
selector, address, custom_selector=self._settings.server_selector)
while not server_descriptions:
# No suitable servers.
@ -209,7 +209,8 @@ class Topology(object):
self._description.check_compatible()
now = _time()
server_descriptions = self._description.apply_selector(
selector, address)
selector, address,
custom_selector=self._settings.server_selector)
self._description.check_compatible()
return server_descriptions

View File

@ -214,7 +214,7 @@ class TopologyDescription(object):
def heartbeat_frequency(self):
return self._topology_settings.heartbeat_frequency
def apply_selector(self, selector, address):
def apply_selector(self, selector, address, custom_selector=None):
def apply_local_threshold(selection):
if not selection:
@ -239,18 +239,23 @@ class TopologyDescription(object):
common_wv))
if self.topology_type == TOPOLOGY_TYPE.Single:
# Ignore the selector.
# Ignore selectors for standalone.
return self.known_servers
elif address:
# Ignore selectors when explicit address is requested.
description = self.server_descriptions().get(address)
return [description] if description else []
elif self.topology_type == TOPOLOGY_TYPE.Sharded:
# Ignore the read preference, but apply localThresholdMS.
return apply_local_threshold(
Selection.from_topology_description(self))
# Ignore read preference.
selection = Selection.from_topology_description(self)
else:
return apply_local_threshold(
selector(Selection.from_topology_description(self)))
selection = selector(Selection.from_topology_description(self))
# Apply custom selector followed by localThresholdMS.
if custom_selector is not None and selection:
selection = selection.with_server_descriptions(
custom_selector(selection.server_descriptions))
return apply_local_threshold(selection)
def has_readable_server(self, read_preference=ReadPreference.PRIMARY):
"""Does this topology have any readable servers available matching the

View File

@ -17,10 +17,20 @@
import os
import sys
from pymongo import MongoClient
from pymongo import ReadPreference
from pymongo.errors import ServerSelectionTimeoutError
from pymongo.server_selectors import writable_server_selector
from pymongo.settings import TopologySettings
from pymongo.topology import Topology
sys.path[0:0] = [""]
from test import unittest
from test.utils_selection_tests import create_selection_tests
from test import client_context, unittest, IntegrationTest
from test.utils import rs_or_single_client, wait_until, EventListener
from test.utils_selection_tests import (
create_selection_tests, get_addresses, get_topology_settings_dict,
make_server_description)
# Location of JSON test specifications.
@ -29,8 +39,183 @@ _TEST_PATH = os.path.join(
os.path.join('server_selection', 'server_selection'))
class CallCountSelector(object):
"""No-op selector that keeps track of how many times it is called."""
def __init__(self):
self.call_count = 0
def __call__(self, servers):
self.call_count += 1
return servers
class SelectionStoreSelector(object):
"""No-op selector that keeps track of what was passed to it."""
def __init__(self):
self.selection = None
def __call__(self, selection):
self.selection = selection
return selection
class TestAllScenarios(create_selection_tests(_TEST_PATH)):
pass
class TestCustomServerSelectorFunction(IntegrationTest):
@client_context.require_replica_set
def test_functional_select_max_port_number_host(self):
# Selector that returns server with highest port number.
def custom_selector(servers):
ports = [s.address[1] for s in servers]
idx = ports.index(max(ports))
return [servers[idx]]
# Initialize client with appropriate listeners.
listener = EventListener()
client = rs_or_single_client(
server_selector=custom_selector, event_listeners=[listener])
self.addCleanup(client.close)
coll = client.get_database(
'testdb', read_preference=ReadPreference.NEAREST).coll
self.addCleanup(client.drop_database, 'testdb')
# Wait the node list to be fully populated.
def all_hosts_started():
return (len(client.admin.command('isMaster')['hosts']) ==
len(client._topology._description.readable_servers))
wait_until(all_hosts_started, 'receive heartbeat from all hosts')
expected_port = max([
n.address[1]
for n in client._topology._description.readable_servers])
# Insert 1 record and access it 10 times.
coll.insert_one({'name': 'John Doe'})
for _ in range(10):
coll.find_one({'name': 'John Doe'})
# Confirm all find commands are run against appropriate host.
for command in listener.results['started']:
if command.command_name == 'find':
self.assertEqual(
command.connection_id[1], expected_port)
def test_invalid_server_selector(self):
# Client initialization must fail if server_selector is not callable.
for selector_candidate in [list(), 10, 'string', {}]:
with self.assertRaisesRegex(ValueError, "must be a callable"):
MongoClient(connect=False, server_selector=selector_candidate)
# None value for server_selector is OK.
MongoClient(connect=False, server_selector=None)
@client_context.require_replica_set
def test_selector_called(self):
selector = CallCountSelector()
# Client setup.
mongo_client = rs_or_single_client(server_selector=selector)
test_collection = mongo_client.testdb.test_collection
self.addCleanup(mongo_client.drop_database, 'testdb')
self.addCleanup(mongo_client.close)
# Do N operations and test selector is called at least N times.
test_collection.insert_one({'age': 20, 'name': 'John'})
test_collection.insert_one({'age': 31, 'name': 'Jane'})
test_collection.update_one({'name': 'Jane'}, {'$set': {'age': 21}})
test_collection.find_one({'name': 'Roe'})
self.assertGreaterEqual(selector.call_count, 4)
@client_context.require_replica_set
def test_latency_threshold_application(self):
selector = SelectionStoreSelector()
scenario_def = {
'topology_description': {
'type': 'ReplicaSetWithPrimary', 'servers': [
{'address': 'b:27017',
'avg_rtt_ms': 10000,
'type': 'RSSecondary',
'tag': {}},
{'address': 'c:27017',
'avg_rtt_ms': 20000,
'type': 'RSSecondary',
'tag': {}},
{'address': 'a:27017',
'avg_rtt_ms': 30000,
'type': 'RSPrimary',
'tag': {}},
]}}
# Create & populate Topology such that all but one server is too slow.
rtt_times = [srv['avg_rtt_ms'] for srv in
scenario_def['topology_description']['servers']]
min_rtt_idx = rtt_times.index(min(rtt_times))
seeds, hosts = get_addresses(
scenario_def["topology_description"]["servers"])
settings = get_topology_settings_dict(
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds,
server_selector=selector)
topology = Topology(TopologySettings(**settings))
topology.open()
for server in scenario_def['topology_description']['servers']:
server_description = make_server_description(server, hosts)
topology.on_change(server_description)
# Invoke server selection and assert no filtering based on latency
# prior to custom server selection logic kicking in.
server = topology.select_server(ReadPreference.NEAREST)
self.assertEqual(
len(selector.selection),
len(topology.description.server_descriptions()))
# Ensure proper filtering based on latency after custom selection.
self.assertEqual(
server.description.address, seeds[min_rtt_idx])
@client_context.require_replica_set
def test_server_selector_bypassed(self):
selector = CallCountSelector()
scenario_def = {
'topology_description': {
'type': 'ReplicaSetNoPrimary', 'servers': [
{'address': 'b:27017',
'avg_rtt_ms': 10000,
'type': 'RSSecondary',
'tag': {}},
{'address': 'c:27017',
'avg_rtt_ms': 20000,
'type': 'RSSecondary',
'tag': {}},
{'address': 'a:27017',
'avg_rtt_ms': 30000,
'type': 'RSSecondary',
'tag': {}},
]}}
# Create & populate Topology such that no server is writeable.
seeds, hosts = get_addresses(
scenario_def["topology_description"]["servers"])
settings = get_topology_settings_dict(
heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds,
server_selector=selector)
topology = Topology(TopologySettings(**settings))
topology.open()
for server in scenario_def['topology_description']['servers']:
server_description = make_server_description(server, hosts)
topology.on_change(server_description)
# Invoke server selection and assert no calls to our custom selector.
with self.assertRaisesRegex(
ServerSelectionTimeoutError, 'No primary available for writes'):
topology.select_server(
writable_server_selector, server_selection_timeout=0.1)
self.assertEqual(selector.call_count, 0)
if __name__ == "__main__":
unittest.main()

View File

@ -139,6 +139,16 @@ def get_topology_type_name(scenario_def):
return name
def get_topology_settings_dict(**kwargs):
settings = dict(
monitor_class=MockMonitor,
heartbeat_frequency=HEARTBEAT_FREQUENCY,
pool_class=MockPool
)
settings.update(kwargs)
return settings
def create_test(scenario_def):
def run_scenario(self):
# Initialize topologies.
@ -147,14 +157,14 @@ def create_test(scenario_def):
else:
frequency = HEARTBEAT_FREQUENCY
settings = dict(
monitor_class=MockMonitor,
heartbeat_frequency=frequency,
pool_class=MockPool)
settings['seeds'], hosts = get_addresses(
seeds, hosts = get_addresses(
scenario_def['topology_description']['servers'])
settings = get_topology_settings_dict(
heartbeat_frequency=frequency,
seeds=seeds
)
# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.