diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 728c9a167..dde4ef242 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1977,7 +1977,7 @@ class _MongoClientErrorHandler(object): # "Note that when a network error occurs before the handshake # completes then the error's generation number is the generation # of the pool at the time the connection attempt was started." - self.sock_generation = server.pool.generation + self.sock_generation = server.pool.gen.get_overall() self.completed_handshake = False self.service_id = None diff --git a/pymongo/pool.py b/pymongo/pool.py index df4a430cd..53766f65f 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -530,7 +530,8 @@ class SocketInfo(object): # The pool's generation changes with each reset() so we can close # sockets created before the last reset. - self.generation = pool.generation + self.pool_gen = pool.gen + self.generation = self.pool_gen.get_overall() self.ready = False self.cancel_context = None if not pool.handshake: @@ -616,6 +617,7 @@ class SocketInfo(object): 'Driver attempted to initialize in load balancing mode' ' but the server does not support this mode') self.service_id = ismaster.service_id + self.generation = self.pool_gen.get(self.service_id) return ismaster def _next_reply(self): @@ -1044,6 +1046,37 @@ class _PoolClosedError(PyMongoError): pass +class _PoolGeneration(object): + def __init__(self): + # Maps service_id to generation. + self._generations = collections.defaultdict(int) + # Overall pool generation. + self._generation = 0 + + def get(self, service_id): + """Get the generation for the given service_id.""" + if service_id is None: + return self._generation + return self._generations[service_id] + + def get_overall(self): + """Get the Pool's overall generation.""" + return self._generation + + def inc(self, service_id): + """Increment the generation for the given service_id.""" + self._generation += 1 + if service_id is None: + for service_id in self._generations: + self._generations[service_id] += 1 + else: + self._generations[service_id] += 1 + + def stale(self, gen, service_id): + """Return if the given generation for a given service_id is stale.""" + return gen != self.get(service_id) + + class PoolState(object): PAUSED = 1 READY = 2 @@ -1080,7 +1113,8 @@ class Pool: # Keep track of resets, so we notice sockets created before the most # recent reset and close them. - self.generation = 0 + # self.generation = 0 + self.gen = _PoolGeneration() self.pid = os.getpid() self.address = address self.opts = options @@ -1137,13 +1171,25 @@ class Pool: if (self.opts.pause_enabled and pause and not self.opts.load_balanced): old_state, self.state = self.state, PoolState.PAUSED - self.generation += 1 + self.gen.inc(service_id) newpid = os.getpid() if self.pid != newpid: self.pid = newpid self.active_sockets = 0 self.operation_count = 0 - sockets, self.sockets = self.sockets, collections.deque() + if service_id is None: + sockets, self.sockets = self.sockets, collections.deque() + else: + discard = collections.deque() + keep = collections.deque() + for sock_info in self.sockets: + if sock_info.service_id == service_id: + discard.append(sock_info) + else: + keep.append(sock_info) + sockets = discard + self.sockets = keep + if close: self.state = PoolState.CLOSED # Clear the wait queue @@ -1184,6 +1230,9 @@ class Pool: def close(self): self._reset(close=True) + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + def remove_stale_sockets(self, reference_generation, all_credentials): """Removes stale sockets then adds new ones if pool is too small and has not been reset. The `reference_generation` argument specifies the @@ -1222,7 +1271,7 @@ class Pool: with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. - if self.generation != reference_generation: + if self.gen.get_overall() != reference_generation: sock_info.close_socket(ConnectionClosedReason.STALE) return self.sockets.appendleft(sock_info) @@ -1450,7 +1499,8 @@ class Pool: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). - if sock_info.generation != self.generation: + if self.stale_generation(sock_info.generation, + sock_info.service_id): sock_info.close_socket(ConnectionClosedReason.STALE) else: sock_info.update_last_checkin_time() @@ -1493,7 +1543,7 @@ class Pool: sock_info.close_socket(ConnectionClosedReason.ERROR) return True - if sock_info.generation != self.generation: + if self.stale_generation(sock_info.generation, sock_info.service_id): sock_info.close_socket(ConnectionClosedReason.STALE) return True diff --git a/pymongo/topology.py b/pymongo/topology.py index 18d5c4c8f..baa4293dd 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -447,7 +447,8 @@ class Topology(object): # Only update pools for data-bearing servers. for sd in self.data_bearing_servers(): server = self._servers[sd.address] - servers.append((server, server.pool.generation)) + servers.append((server, + server.pool.gen.get_overall())) for server, generation in servers: try: @@ -577,7 +578,8 @@ class Topology(object): # Another thread removed this server from the topology. return True - if err_ctx.sock_generation != server._pool.generation: + if server._pool.stale_generation( + err_ctx.sock_generation, err_ctx.service_id): # This is an outdated error from a previous pool version. return True diff --git a/test/test_client.py b/test/test_client.py index a8246e338..88adaac6a 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1468,7 +1468,7 @@ class TestClient(IntegrationTest): self.addCleanup(client.close) client.admin.command('ping') pool = get_pool(client) - generation = pool.generation + generation = pool.gen.get_overall() # Continuously reset the pool. class ResetPoolThread(threading.Thread): @@ -1483,7 +1483,8 @@ class TestClient(IntegrationTest): def run(self): while self.running: exc = AutoReconnect('mock pool error') - ctx = _ErrorContext(exc, 0, pool.generation, False, None) + ctx = _ErrorContext( + exc, 0, pool.gen.get_overall(), False, None) client._topology.handle_error(pool.address, ctx) time.sleep(0.001) @@ -1497,7 +1498,7 @@ class TestClient(IntegrationTest): for _ in range(10): client._topology.update_pool( client._MongoClient__all_credentials) - if generation != pool.generation: + if generation != pool.gen.get_overall(): break finally: t.stop() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 7fa96b5e1..0cbc40f1f 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -87,7 +87,8 @@ def got_app_error(topology, app_error): server_address = common.partition_node(app_error['address']) server = topology.get_server_by_address(server_address) error_type = app_error['type'] - generation = app_error.get('generation', server.pool.generation) + generation = app_error.get( + 'generation', server.pool.gen.get_overall()) when = app_error['when'] max_wire_version = app_error['maxWireVersion'] # XXX: We could get better test coverage by mocking the errors on the @@ -181,7 +182,7 @@ def check_outcome(self, topology, outcome): if expected_pool: self.assertEqual( expected_pool.get('generation'), - actual_server.pool.generation) + actual_server.pool.gen.get_overall()) self.assertEqual(outcome['setName'], topology.description.replica_set_name) self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'), @@ -270,7 +271,7 @@ class TestIgnoreStaleErrors(IntegrationTest): # Wait for initial discovery. client.admin.command('ping') pool = get_pool(client) - starting_generation = pool.generation + starting_generation = pool.gen.get_overall() wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets') def mock_command(*args, **kwargs): @@ -296,7 +297,8 @@ class TestIgnoreStaleErrors(IntegrationTest): t.join() # Expect a single pool reset for the network error - self.assertEqual(starting_generation+1, pool.generation) + self.assertEqual( + starting_generation+1, pool.gen.get_overall()) # Server should be selectable. client.admin.command('ping') diff --git a/test/test_topology.py b/test/test_topology.py index 5e2f683f7..bb94207f4 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -659,11 +659,11 @@ class TestTopologyErrors(TopologyTest): self.addCleanup(t.close) server = wait_for_master(t) self.assertEqual(1, ismaster_count[0]) - generation = server.pool.generation + generation = server.pool.gen.get_overall() # Pool is reset by ismaster failure. t.request_check_all() - self.assertNotEqual(generation, server.pool.generation) + self.assertNotEqual(generation, server.pool.gen.get_overall()) def test_ismaster_retry(self): # ismaster succeeds at first, then raises socket error, then succeeds. diff --git a/test/unified_format.py b/test/unified_format.py index 7fa7f5513..91e02e9e2 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -124,8 +124,16 @@ def is_run_on_requirement_satisfied(requirement): elif client_context.server_parameters[param] != val: params_satisfied = False + auth_satisfied = True + req_auth = requirement.get('auth') + if req_auth is not None: + if req_auth: + auth_satisfied = client_context.auth_enabled + else: + auth_satisfied = not client_context.auth_enabled + return (topology_satisfied and min_version_satisfied and - max_version_satisfied and params_satisfied) + max_version_satisfied and params_satisfied and auth_satisfied) def parse_collection_or_database_options(options): diff --git a/test/utils.py b/test/utils.py index 682782a43..fa2865c83 100644 --- a/test/utils.py +++ b/test/utils.py @@ -39,7 +39,7 @@ from pymongo import (MongoClient, from pymongo.collection import ReturnDocument from pymongo.errors import ConfigurationError, OperationFailure from pymongo.monitoring import _SENSITIVE_COMMANDS -from pymongo.pool import _CancellationContext +from pymongo.pool import _CancellationContext, _PoolGeneration from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import (any_server_selector, @@ -259,20 +259,23 @@ class MockSocketInfo(object): class MockPool(object): def __init__(self, address, options, handshake=True): - self.generation = 0 + self.gen = _PoolGeneration() self._lock = threading.Lock() self.opts = options self.operation_count = 0 + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + def get_socket(self, all_credentials, checkout=False): return MockSocketInfo() def return_socket(self, *args, **kwargs): pass - def _reset(self): + def _reset(self, service_id=None): with self._lock: - self.generation += 1 + self.gen.inc(service_id) def ready(self): pass