From 2cc37059b836c5bf66adf99747992bdcd1b832b2 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Wed, 15 Apr 2015 17:14:25 -0400 Subject: [PATCH] PYTHON-898 - Send getMore to same mongos as initial query. --- pymongo/server_selectors.py | 4 ---- pymongo/topology.py | 30 ++++++++++++++++++++---------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index 621324232..c159c8ee0 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -21,10 +21,6 @@ def any_server_selector(server_descriptions): return server_descriptions -def address_server_selector(address, server_descriptions): - return [s for s in server_descriptions if s.address == address] - - def writable_server_selector(server_descriptions): return [s for s in server_descriptions if s.is_writable] diff --git a/pymongo/topology.py b/pymongo/topology.py index 8ee1090d2..40c2c2a0a 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -16,7 +16,6 @@ import random import threading -from functools import partial from bson.py3compat import itervalues from pymongo import common @@ -27,7 +26,7 @@ from pymongo.topology_description import (updated_topology_description, from pymongo.errors import ServerSelectionTimeoutError, InvalidOperation from pymongo.monotonic import time as _time from pymongo.server import Server -from pymongo.server_selectors import (address_server_selector, +from pymongo.server_selectors import (any_server_selector, apply_local_threshold, arbiter_server_selector, secondary_server_selector, @@ -59,7 +58,10 @@ class Topology(object): with self._lock: self._ensure_opened() - def select_servers(self, selector, server_selection_timeout=None): + def select_servers(self, + selector, + server_selection_timeout=None, + address=None): """Return a list of Servers matching selector, or time out. :Parameters: @@ -68,6 +70,7 @@ class Topology(object): - `server_selection_timeout` (optional): maximum seconds to wait. If not provided, the default value common.SERVER_SELECTION_TIMEOUT is used. + - `address`: optional server address to select. Calls self.open() if needed. @@ -84,7 +87,7 @@ class Topology(object): now = _time() end_time = now + server_timeout - server_descriptions = self._apply_selector(selector) + server_descriptions = self._apply_selector(selector, address) while not server_descriptions: # No suitable servers. @@ -102,15 +105,19 @@ class Topology(object): self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = _time() - server_descriptions = self._apply_selector(selector) + server_descriptions = self._apply_selector(selector, address) return [self.get_server_by_address(sd.address) for sd in server_descriptions] - def select_server(self, selector, server_selection_timeout=None): + def select_server(self, + selector, + server_selection_timeout=None, + address=None): """Like select_servers, but choose a random server if several match.""" return random.choice(self.select_servers(selector, - server_selection_timeout)) + server_selection_timeout, + address)) def select_server_by_address(self, address, server_selection_timeout=None): @@ -131,8 +138,9 @@ class Topology(object): Raises exc:`ServerSelectionTimeoutError` after `server_selection_timeout` if no matching servers are found. """ - selector = partial(address_server_selector, address) - return self.select_server(selector, server_selection_timeout) + return self.select_server(any_server_selector, + server_selection_timeout, + address) def on_change(self, server_description): """Process a new ServerDescription after an ismaster call completes.""" @@ -295,10 +303,12 @@ class Topology(object): for server in self._servers.values(): server.request_check() - def _apply_selector(self, selector): + def _apply_selector(self, selector, address): if self._description.topology_type == TOPOLOGY_TYPE.Single: # Ignore the selector. return self._description.known_servers + elif address: + return [self._description.server_descriptions().get(address)] elif self._description.topology_type == TOPOLOGY_TYPE.Sharded: return apply_local_threshold(self._settings.local_threshold_ms, self._description.known_servers)