PYTHON-2674 Pool.reset only clears connections to the given serviceId (#628)
This commit is contained in:
parent
9c1ff6ad9d
commit
112ee69de8
@ -1977,7 +1977,7 @@ class _MongoClientErrorHandler(object):
|
||||
# "Note that when a network error occurs before the handshake
|
||||
# completes then the error's generation number is the generation
|
||||
# of the pool at the time the connection attempt was started."
|
||||
self.sock_generation = server.pool.generation
|
||||
self.sock_generation = server.pool.gen.get_overall()
|
||||
self.completed_handshake = False
|
||||
self.service_id = None
|
||||
|
||||
|
||||
@ -530,7 +530,8 @@ class SocketInfo(object):
|
||||
|
||||
# The pool's generation changes with each reset() so we can close
|
||||
# sockets created before the last reset.
|
||||
self.generation = pool.generation
|
||||
self.pool_gen = pool.gen
|
||||
self.generation = self.pool_gen.get_overall()
|
||||
self.ready = False
|
||||
self.cancel_context = None
|
||||
if not pool.handshake:
|
||||
@ -616,6 +617,7 @@ class SocketInfo(object):
|
||||
'Driver attempted to initialize in load balancing mode'
|
||||
' but the server does not support this mode')
|
||||
self.service_id = ismaster.service_id
|
||||
self.generation = self.pool_gen.get(self.service_id)
|
||||
return ismaster
|
||||
|
||||
def _next_reply(self):
|
||||
@ -1044,6 +1046,37 @@ class _PoolClosedError(PyMongoError):
|
||||
pass
|
||||
|
||||
|
||||
class _PoolGeneration(object):
|
||||
def __init__(self):
|
||||
# Maps service_id to generation.
|
||||
self._generations = collections.defaultdict(int)
|
||||
# Overall pool generation.
|
||||
self._generation = 0
|
||||
|
||||
def get(self, service_id):
|
||||
"""Get the generation for the given service_id."""
|
||||
if service_id is None:
|
||||
return self._generation
|
||||
return self._generations[service_id]
|
||||
|
||||
def get_overall(self):
|
||||
"""Get the Pool's overall generation."""
|
||||
return self._generation
|
||||
|
||||
def inc(self, service_id):
|
||||
"""Increment the generation for the given service_id."""
|
||||
self._generation += 1
|
||||
if service_id is None:
|
||||
for service_id in self._generations:
|
||||
self._generations[service_id] += 1
|
||||
else:
|
||||
self._generations[service_id] += 1
|
||||
|
||||
def stale(self, gen, service_id):
|
||||
"""Return if the given generation for a given service_id is stale."""
|
||||
return gen != self.get(service_id)
|
||||
|
||||
|
||||
class PoolState(object):
|
||||
PAUSED = 1
|
||||
READY = 2
|
||||
@ -1080,7 +1113,8 @@ class Pool:
|
||||
|
||||
# Keep track of resets, so we notice sockets created before the most
|
||||
# recent reset and close them.
|
||||
self.generation = 0
|
||||
# self.generation = 0
|
||||
self.gen = _PoolGeneration()
|
||||
self.pid = os.getpid()
|
||||
self.address = address
|
||||
self.opts = options
|
||||
@ -1137,13 +1171,25 @@ class Pool:
|
||||
if (self.opts.pause_enabled and pause and
|
||||
not self.opts.load_balanced):
|
||||
old_state, self.state = self.state, PoolState.PAUSED
|
||||
self.generation += 1
|
||||
self.gen.inc(service_id)
|
||||
newpid = os.getpid()
|
||||
if self.pid != newpid:
|
||||
self.pid = newpid
|
||||
self.active_sockets = 0
|
||||
self.operation_count = 0
|
||||
sockets, self.sockets = self.sockets, collections.deque()
|
||||
if service_id is None:
|
||||
sockets, self.sockets = self.sockets, collections.deque()
|
||||
else:
|
||||
discard = collections.deque()
|
||||
keep = collections.deque()
|
||||
for sock_info in self.sockets:
|
||||
if sock_info.service_id == service_id:
|
||||
discard.append(sock_info)
|
||||
else:
|
||||
keep.append(sock_info)
|
||||
sockets = discard
|
||||
self.sockets = keep
|
||||
|
||||
if close:
|
||||
self.state = PoolState.CLOSED
|
||||
# Clear the wait queue
|
||||
@ -1184,6 +1230,9 @@ class Pool:
|
||||
def close(self):
|
||||
self._reset(close=True)
|
||||
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def remove_stale_sockets(self, reference_generation, all_credentials):
|
||||
"""Removes stale sockets then adds new ones if pool is too small and
|
||||
has not been reset. The `reference_generation` argument specifies the
|
||||
@ -1222,7 +1271,7 @@ class Pool:
|
||||
with self.lock:
|
||||
# Close connection and return if the pool was reset during
|
||||
# socket creation or while acquiring the pool lock.
|
||||
if self.generation != reference_generation:
|
||||
if self.gen.get_overall() != reference_generation:
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
return
|
||||
self.sockets.appendleft(sock_info)
|
||||
@ -1450,7 +1499,8 @@ class Pool:
|
||||
with self.lock:
|
||||
# Hold the lock to ensure this section does not race with
|
||||
# Pool.reset().
|
||||
if sock_info.generation != self.generation:
|
||||
if self.stale_generation(sock_info.generation,
|
||||
sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
else:
|
||||
sock_info.update_last_checkin_time()
|
||||
@ -1493,7 +1543,7 @@ class Pool:
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
return True
|
||||
|
||||
if sock_info.generation != self.generation:
|
||||
if self.stale_generation(sock_info.generation, sock_info.service_id):
|
||||
sock_info.close_socket(ConnectionClosedReason.STALE)
|
||||
return True
|
||||
|
||||
|
||||
@ -447,7 +447,8 @@ class Topology(object):
|
||||
# Only update pools for data-bearing servers.
|
||||
for sd in self.data_bearing_servers():
|
||||
server = self._servers[sd.address]
|
||||
servers.append((server, server.pool.generation))
|
||||
servers.append((server,
|
||||
server.pool.gen.get_overall()))
|
||||
|
||||
for server, generation in servers:
|
||||
try:
|
||||
@ -577,7 +578,8 @@ class Topology(object):
|
||||
# Another thread removed this server from the topology.
|
||||
return True
|
||||
|
||||
if err_ctx.sock_generation != server._pool.generation:
|
||||
if server._pool.stale_generation(
|
||||
err_ctx.sock_generation, err_ctx.service_id):
|
||||
# This is an outdated error from a previous pool version.
|
||||
return True
|
||||
|
||||
|
||||
@ -1468,7 +1468,7 @@ class TestClient(IntegrationTest):
|
||||
self.addCleanup(client.close)
|
||||
client.admin.command('ping')
|
||||
pool = get_pool(client)
|
||||
generation = pool.generation
|
||||
generation = pool.gen.get_overall()
|
||||
|
||||
# Continuously reset the pool.
|
||||
class ResetPoolThread(threading.Thread):
|
||||
@ -1483,7 +1483,8 @@ class TestClient(IntegrationTest):
|
||||
def run(self):
|
||||
while self.running:
|
||||
exc = AutoReconnect('mock pool error')
|
||||
ctx = _ErrorContext(exc, 0, pool.generation, False, None)
|
||||
ctx = _ErrorContext(
|
||||
exc, 0, pool.gen.get_overall(), False, None)
|
||||
client._topology.handle_error(pool.address, ctx)
|
||||
time.sleep(0.001)
|
||||
|
||||
@ -1497,7 +1498,7 @@ class TestClient(IntegrationTest):
|
||||
for _ in range(10):
|
||||
client._topology.update_pool(
|
||||
client._MongoClient__all_credentials)
|
||||
if generation != pool.generation:
|
||||
if generation != pool.gen.get_overall():
|
||||
break
|
||||
finally:
|
||||
t.stop()
|
||||
|
||||
@ -87,7 +87,8 @@ def got_app_error(topology, app_error):
|
||||
server_address = common.partition_node(app_error['address'])
|
||||
server = topology.get_server_by_address(server_address)
|
||||
error_type = app_error['type']
|
||||
generation = app_error.get('generation', server.pool.generation)
|
||||
generation = app_error.get(
|
||||
'generation', server.pool.gen.get_overall())
|
||||
when = app_error['when']
|
||||
max_wire_version = app_error['maxWireVersion']
|
||||
# XXX: We could get better test coverage by mocking the errors on the
|
||||
@ -181,7 +182,7 @@ def check_outcome(self, topology, outcome):
|
||||
if expected_pool:
|
||||
self.assertEqual(
|
||||
expected_pool.get('generation'),
|
||||
actual_server.pool.generation)
|
||||
actual_server.pool.gen.get_overall())
|
||||
|
||||
self.assertEqual(outcome['setName'], topology.description.replica_set_name)
|
||||
self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'),
|
||||
@ -270,7 +271,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
# Wait for initial discovery.
|
||||
client.admin.command('ping')
|
||||
pool = get_pool(client)
|
||||
starting_generation = pool.generation
|
||||
starting_generation = pool.gen.get_overall()
|
||||
wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets')
|
||||
|
||||
def mock_command(*args, **kwargs):
|
||||
@ -296,7 +297,8 @@ class TestIgnoreStaleErrors(IntegrationTest):
|
||||
t.join()
|
||||
|
||||
# Expect a single pool reset for the network error
|
||||
self.assertEqual(starting_generation+1, pool.generation)
|
||||
self.assertEqual(
|
||||
starting_generation+1, pool.gen.get_overall())
|
||||
|
||||
# Server should be selectable.
|
||||
client.admin.command('ping')
|
||||
|
||||
@ -659,11 +659,11 @@ class TestTopologyErrors(TopologyTest):
|
||||
self.addCleanup(t.close)
|
||||
server = wait_for_master(t)
|
||||
self.assertEqual(1, ismaster_count[0])
|
||||
generation = server.pool.generation
|
||||
generation = server.pool.gen.get_overall()
|
||||
|
||||
# Pool is reset by ismaster failure.
|
||||
t.request_check_all()
|
||||
self.assertNotEqual(generation, server.pool.generation)
|
||||
self.assertNotEqual(generation, server.pool.gen.get_overall())
|
||||
|
||||
def test_ismaster_retry(self):
|
||||
# ismaster succeeds at first, then raises socket error, then succeeds.
|
||||
|
||||
@ -124,8 +124,16 @@ def is_run_on_requirement_satisfied(requirement):
|
||||
elif client_context.server_parameters[param] != val:
|
||||
params_satisfied = False
|
||||
|
||||
auth_satisfied = True
|
||||
req_auth = requirement.get('auth')
|
||||
if req_auth is not None:
|
||||
if req_auth:
|
||||
auth_satisfied = client_context.auth_enabled
|
||||
else:
|
||||
auth_satisfied = not client_context.auth_enabled
|
||||
|
||||
return (topology_satisfied and min_version_satisfied and
|
||||
max_version_satisfied and params_satisfied)
|
||||
max_version_satisfied and params_satisfied and auth_satisfied)
|
||||
|
||||
|
||||
def parse_collection_or_database_options(options):
|
||||
|
||||
@ -39,7 +39,7 @@ from pymongo import (MongoClient,
|
||||
from pymongo.collection import ReturnDocument
|
||||
from pymongo.errors import ConfigurationError, OperationFailure
|
||||
from pymongo.monitoring import _SENSITIVE_COMMANDS
|
||||
from pymongo.pool import _CancellationContext
|
||||
from pymongo.pool import _CancellationContext, _PoolGeneration
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.server_selectors import (any_server_selector,
|
||||
@ -259,20 +259,23 @@ class MockSocketInfo(object):
|
||||
|
||||
class MockPool(object):
|
||||
def __init__(self, address, options, handshake=True):
|
||||
self.generation = 0
|
||||
self.gen = _PoolGeneration()
|
||||
self._lock = threading.Lock()
|
||||
self.opts = options
|
||||
self.operation_count = 0
|
||||
|
||||
def stale_generation(self, gen, service_id):
|
||||
return self.gen.stale(gen, service_id)
|
||||
|
||||
def get_socket(self, all_credentials, checkout=False):
|
||||
return MockSocketInfo()
|
||||
|
||||
def return_socket(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _reset(self):
|
||||
def _reset(self, service_id=None):
|
||||
with self._lock:
|
||||
self.generation += 1
|
||||
self.gen.inc(service_id)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user