diff --git a/pymongo/pool.py b/pymongo/pool.py index d6c5b773d..49652647c 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -217,14 +217,15 @@ class PoolOptions(object): '__connect_timeout', '__socket_timeout', '__wait_queue_timeout', '__wait_queue_multiple', '__ssl_context', '__ssl_match_hostname', '__socket_keepalive', - '__event_listeners', '__appname', '__metadata') + '__event_listeners', '__appname', '__metadata', + '__handshake_callback') def __init__(self, max_pool_size=100, min_pool_size=0, max_idle_time_ms=None, connect_timeout=None, socket_timeout=None, wait_queue_timeout=None, wait_queue_multiple=None, ssl_context=None, ssl_match_hostname=True, socket_keepalive=False, - event_listeners=None, appname=None): + event_listeners=None, appname=None, handshake_callback=None): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size @@ -242,6 +243,27 @@ class PoolOptions(object): if appname: self.__metadata['application'] = {'name': appname} + self.__handshake_callback = handshake_callback + + def with_options(self, **kwargs): + options = { + 'max_pool_size': self.max_pool_size, + 'min_pool_size': self.min_pool_size, + 'max_idle_time_ms': self.max_idle_time_ms, + 'connect_timeout': self.connect_timeout, + 'socket_timeout': self.socket_timeout, + 'wait_queue_timeout': self.wait_queue_timeout, + 'wait_queue_multiple': self.wait_queue_multiple, + 'ssl_context': self.ssl_context, + 'ssl_match_hostname': self.ssl_match_hostname, + 'socket_keepalive': self.socket_keepalive, + 'event_listeners': self.event_listeners, + 'appname': self.appname, + 'handshake_callback': self.handshake_callback} + + options.update(kwargs) + return PoolOptions(**options) + @property def max_pool_size(self): """The maximum allowable number of concurrent connections to each @@ -335,6 +357,11 @@ class PoolOptions(object): """ return self.__metadata.copy() + @property + def handshake_callback(self): + """Receives an ismaster reply and updates the topology.""" + return self.__handshake_callback + class SocketInfo(object): """Store a socket with some metadata. @@ -746,6 +773,8 @@ class Pool: ('ismaster', 1), ('client', self.opts.metadata) ]) + + start = _time() ismaster = IsMaster( command(sock, 'admin', @@ -754,6 +783,9 @@ class Pool: False, ReadPreference.PRIMARY, DEFAULT_CODEC_OPTIONS)) + + # Can raise ConnectionFailure. + self._handshake_callback(ismaster, _time() - start) else: ismaster = None return SocketInfo(sock, self, ismaster, self.address) @@ -761,6 +793,10 @@ class Pool: if sock is not None: sock.close() _raise_connection_failure(self.address, error) + except: + if sock is not None: + sock.close() + raise @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -889,6 +925,14 @@ class Pool: else: return self.connect() + def _handshake_callback(self, ismaster, round_trip_time): + callback = self.opts.handshake_callback + if callback: + kept = callback(self.address, ismaster, round_trip_time) + if not kept: + _raise_connection_failure( + self.address, "server removed from topology") + def _raise_wait_queue_timeout(self): raise ConnectionFailure( 'Timed out waiting for socket from pool with max_size %r and' diff --git a/pymongo/server_description.py b/pymongo/server_description.py index e2b29bcaf..d1b6e0b1f 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -38,6 +38,8 @@ class ServerDescription(object): - `ismaster`: Optional IsMaster instance - `round_trip_time`: Optional float - `error`: Optional, the last error attempting to connect to the server + - `from_handshake`: Optional, whether this is from expanding a connection + pool, rather than from background monitoring. """ __slots__ = ( @@ -45,14 +47,16 @@ class ServerDescription(object): '_primary', '_max_bson_size', '_max_message_size', '_max_write_batch_size', '_min_wire_version', '_max_wire_version', '_round_trip_time', '_me', '_is_writable', '_is_readable', '_error', - '_set_version', '_election_id', '_last_write_date', '_last_update_time') + '_set_version', '_election_id', '_last_write_date', '_last_update_time', + '_from_handshake') def __init__( self, address, ismaster=None, round_trip_time=None, - error=None): + error=None, + from_handshake=False): self._address = address if not ismaster: ismaster = IsMaster({}) @@ -75,6 +79,7 @@ class ServerDescription(object): self._me = ismaster.me self._last_update_time = _time() self._error = error + self._from_handshake = from_handshake # For tests. if ismaster.last_write_date: # Convert from datetime to seconds. diff --git a/pymongo/topology.py b/pymongo/topology.py index cfc6c9958..85a00f41d 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -28,13 +28,13 @@ else: from pymongo import common from pymongo import periodic_executor -from pymongo.pool import PoolOptions from pymongo.topology_description import (updated_topology_description, TOPOLOGY_TYPE, TopologyDescription) from pymongo.errors import ServerSelectionTimeoutError from pymongo.monotonic import time as _time from pymongo.server import Server +from pymongo.server_description import ServerDescription from pymongo.server_selectors import (any_server_selector, arbiter_server_selector, secondary_server_selector, @@ -237,7 +237,10 @@ class Topology(object): address) def on_change(self, server_description): - """Process a new ServerDescription after an ismaster call completes.""" + """Process a new ServerDescription after an ismaster call completes. + + Returns False if the server was removed from the topology. + """ # We do no I/O holding the lock. with self._lock: # Any monitored server was definitely in the topology description @@ -266,6 +269,9 @@ class Topology(object): # Wake waiters in select_servers(). self._condition.notify_all() + return self._description.has_server(server_description.address) + else: + return False def get_server_by_address(self, address): """Get a Server or None. @@ -337,9 +343,13 @@ class Topology(object): def update_pool(self): # Remove any stale sockets and add new sockets if pool is too small. + # Avoid locking around network I/O, or deadlocking when a new connection + # opens and calls Topology.on_change() with the ismaster reply. with self._lock: - for server in self._servers.values(): - server._pool.remove_stale_sockets() + pools = [server._pool for server in self._servers.values()] + + for pool in pools: + pool.remove_stale_sockets() def close(self): """Clear pools and terminate monitors. Topology reopens on demand.""" @@ -448,22 +458,36 @@ class Topology(object): self._servers.pop(address) def _create_pool_for_server(self, address): - return self._settings.pool_class(address, self._settings.pool_options) + # Server Discovery And Monitoring Spec: When a client calls ismaster + # to handshake a new connection for application operations, use the + # ismaster reply to update the topology. + ref = weakref.proxy(self) + + def handshake_callback(address, ismaster, round_trip_time): + sd = ServerDescription(address, ismaster, round_trip_time, + from_handshake=True) + + try: + # Return False if server was removed from topology. + return ref.on_change(sd) + except ReferenceError: + return True + + server_pool_options = self._settings.pool_options.with_options( + handshake_callback=handshake_callback) + + return self._settings.pool_class(address, server_pool_options) def _create_pool_for_monitor(self, address): - options = self._settings.pool_options - # According to the Server Discovery And Monitoring Spec, monitors use # connect_timeout for both connect_timeout and socket_timeout. The # pool only has one socket so maxPoolSize and so on aren't needed. - monitor_pool_options = PoolOptions( - connect_timeout=options.connect_timeout, - socket_timeout=options.connect_timeout, - ssl_context=options.ssl_context, - ssl_match_hostname=options.ssl_match_hostname, + opts = self._settings.pool_options + monitor_pool_options = opts.with_options( + socket_timeout=opts.connect_timeout, socket_keepalive=True, - event_listeners=options.event_listeners, - appname=options.appname) + max_pool_size=None, + min_pool_size=None) return self._settings.pool_class(address, monitor_pool_options, handshake=False) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index e5b4af1e4..c3765219e 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -58,6 +58,10 @@ class MockPool(Pool): sock_info.mock_port = self.mock_port yield sock_info + def _handshake_callback(self, ismaster, round_trip_time): + # Don't mock how PyMongo updates topology from ismaster reply. + return True + class MockMonitor(Monitor): def __init__( diff --git a/test/test_topology.py b/test/test_topology.py index 70bb0bf26..741118e31 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -22,6 +22,7 @@ import threading from bson.py3compat import imap from pymongo import common +from pymongo import monitoring from pymongo.read_preferences import ReadPreference, Secondary from pymongo.server_type import SERVER_TYPE from pymongo.topology import Topology @@ -37,7 +38,7 @@ from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.settings import TopologySettings from test import client_knobs, unittest -from test.utils import wait_until +from test.utils import rs_or_single_client, wait_until class MockSocketInfo(object): @@ -297,6 +298,23 @@ class TestSingleServerTopology(TopologyTest): if tries > 10: self.fail("Didn't ever calculate correct new average") + def test_update_from_handshake(self): + class ServerHandshakes(monitoring.ServerListener, list): + def opened(self, e): + pass + + def description_changed(self, e): + if e.new_description._from_handshake: + self.append(e) + + def closed(self, e): + pass + + handshakes = ServerHandshakes() + client = rs_or_single_client(event_listeners=[handshakes]) + client.admin.command('ping') + wait_until(lambda: handshakes, 'record handshakes') + class TestMultiServerTopology(TopologyTest): def test_readable_writable(self):