PYTHON-2674 Pool.reset only clears connections to the given serviceId (#628)

This commit is contained in:
Shane Harvey 2021-06-15 09:52:30 -07:00 committed by GitHub
parent 9c1ff6ad9d
commit 112ee69de8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 90 additions and 24 deletions

View File

@ -1977,7 +1977,7 @@ class _MongoClientErrorHandler(object):
# "Note that when a network error occurs before the handshake
# completes then the error's generation number is the generation
# of the pool at the time the connection attempt was started."
self.sock_generation = server.pool.generation
self.sock_generation = server.pool.gen.get_overall()
self.completed_handshake = False
self.service_id = None

View File

@ -530,7 +530,8 @@ class SocketInfo(object):
# The pool's generation changes with each reset() so we can close
# sockets created before the last reset.
self.generation = pool.generation
self.pool_gen = pool.gen
self.generation = self.pool_gen.get_overall()
self.ready = False
self.cancel_context = None
if not pool.handshake:
@ -616,6 +617,7 @@ class SocketInfo(object):
'Driver attempted to initialize in load balancing mode'
' but the server does not support this mode')
self.service_id = ismaster.service_id
self.generation = self.pool_gen.get(self.service_id)
return ismaster
def _next_reply(self):
@ -1044,6 +1046,37 @@ class _PoolClosedError(PyMongoError):
pass
class _PoolGeneration(object):
def __init__(self):
# Maps service_id to generation.
self._generations = collections.defaultdict(int)
# Overall pool generation.
self._generation = 0
def get(self, service_id):
"""Get the generation for the given service_id."""
if service_id is None:
return self._generation
return self._generations[service_id]
def get_overall(self):
"""Get the Pool's overall generation."""
return self._generation
def inc(self, service_id):
"""Increment the generation for the given service_id."""
self._generation += 1
if service_id is None:
for service_id in self._generations:
self._generations[service_id] += 1
else:
self._generations[service_id] += 1
def stale(self, gen, service_id):
"""Return if the given generation for a given service_id is stale."""
return gen != self.get(service_id)
class PoolState(object):
PAUSED = 1
READY = 2
@ -1080,7 +1113,8 @@ class Pool:
# Keep track of resets, so we notice sockets created before the most
# recent reset and close them.
self.generation = 0
# self.generation = 0
self.gen = _PoolGeneration()
self.pid = os.getpid()
self.address = address
self.opts = options
@ -1137,13 +1171,25 @@ class Pool:
if (self.opts.pause_enabled and pause and
not self.opts.load_balanced):
old_state, self.state = self.state, PoolState.PAUSED
self.generation += 1
self.gen.inc(service_id)
newpid = os.getpid()
if self.pid != newpid:
self.pid = newpid
self.active_sockets = 0
self.operation_count = 0
sockets, self.sockets = self.sockets, collections.deque()
if service_id is None:
sockets, self.sockets = self.sockets, collections.deque()
else:
discard = collections.deque()
keep = collections.deque()
for sock_info in self.sockets:
if sock_info.service_id == service_id:
discard.append(sock_info)
else:
keep.append(sock_info)
sockets = discard
self.sockets = keep
if close:
self.state = PoolState.CLOSED
# Clear the wait queue
@ -1184,6 +1230,9 @@ class Pool:
def close(self):
self._reset(close=True)
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def remove_stale_sockets(self, reference_generation, all_credentials):
"""Removes stale sockets then adds new ones if pool is too small and
has not been reset. The `reference_generation` argument specifies the
@ -1222,7 +1271,7 @@ class Pool:
with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.generation != reference_generation:
if self.gen.get_overall() != reference_generation:
sock_info.close_socket(ConnectionClosedReason.STALE)
return
self.sockets.appendleft(sock_info)
@ -1450,7 +1499,8 @@ class Pool:
with self.lock:
# Hold the lock to ensure this section does not race with
# Pool.reset().
if sock_info.generation != self.generation:
if self.stale_generation(sock_info.generation,
sock_info.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE)
else:
sock_info.update_last_checkin_time()
@ -1493,7 +1543,7 @@ class Pool:
sock_info.close_socket(ConnectionClosedReason.ERROR)
return True
if sock_info.generation != self.generation:
if self.stale_generation(sock_info.generation, sock_info.service_id):
sock_info.close_socket(ConnectionClosedReason.STALE)
return True

View File

@ -447,7 +447,8 @@ class Topology(object):
# Only update pools for data-bearing servers.
for sd in self.data_bearing_servers():
server = self._servers[sd.address]
servers.append((server, server.pool.generation))
servers.append((server,
server.pool.gen.get_overall()))
for server, generation in servers:
try:
@ -577,7 +578,8 @@ class Topology(object):
# Another thread removed this server from the topology.
return True
if err_ctx.sock_generation != server._pool.generation:
if server._pool.stale_generation(
err_ctx.sock_generation, err_ctx.service_id):
# This is an outdated error from a previous pool version.
return True

View File

@ -1468,7 +1468,7 @@ class TestClient(IntegrationTest):
self.addCleanup(client.close)
client.admin.command('ping')
pool = get_pool(client)
generation = pool.generation
generation = pool.gen.get_overall()
# Continuously reset the pool.
class ResetPoolThread(threading.Thread):
@ -1483,7 +1483,8 @@ class TestClient(IntegrationTest):
def run(self):
while self.running:
exc = AutoReconnect('mock pool error')
ctx = _ErrorContext(exc, 0, pool.generation, False, None)
ctx = _ErrorContext(
exc, 0, pool.gen.get_overall(), False, None)
client._topology.handle_error(pool.address, ctx)
time.sleep(0.001)
@ -1497,7 +1498,7 @@ class TestClient(IntegrationTest):
for _ in range(10):
client._topology.update_pool(
client._MongoClient__all_credentials)
if generation != pool.generation:
if generation != pool.gen.get_overall():
break
finally:
t.stop()

View File

@ -87,7 +87,8 @@ def got_app_error(topology, app_error):
server_address = common.partition_node(app_error['address'])
server = topology.get_server_by_address(server_address)
error_type = app_error['type']
generation = app_error.get('generation', server.pool.generation)
generation = app_error.get(
'generation', server.pool.gen.get_overall())
when = app_error['when']
max_wire_version = app_error['maxWireVersion']
# XXX: We could get better test coverage by mocking the errors on the
@ -181,7 +182,7 @@ def check_outcome(self, topology, outcome):
if expected_pool:
self.assertEqual(
expected_pool.get('generation'),
actual_server.pool.generation)
actual_server.pool.gen.get_overall())
self.assertEqual(outcome['setName'], topology.description.replica_set_name)
self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'),
@ -270,7 +271,7 @@ class TestIgnoreStaleErrors(IntegrationTest):
# Wait for initial discovery.
client.admin.command('ping')
pool = get_pool(client)
starting_generation = pool.generation
starting_generation = pool.gen.get_overall()
wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets')
def mock_command(*args, **kwargs):
@ -296,7 +297,8 @@ class TestIgnoreStaleErrors(IntegrationTest):
t.join()
# Expect a single pool reset for the network error
self.assertEqual(starting_generation+1, pool.generation)
self.assertEqual(
starting_generation+1, pool.gen.get_overall())
# Server should be selectable.
client.admin.command('ping')

View File

@ -659,11 +659,11 @@ class TestTopologyErrors(TopologyTest):
self.addCleanup(t.close)
server = wait_for_master(t)
self.assertEqual(1, ismaster_count[0])
generation = server.pool.generation
generation = server.pool.gen.get_overall()
# Pool is reset by ismaster failure.
t.request_check_all()
self.assertNotEqual(generation, server.pool.generation)
self.assertNotEqual(generation, server.pool.gen.get_overall())
def test_ismaster_retry(self):
# ismaster succeeds at first, then raises socket error, then succeeds.

View File

@ -124,8 +124,16 @@ def is_run_on_requirement_satisfied(requirement):
elif client_context.server_parameters[param] != val:
params_satisfied = False
auth_satisfied = True
req_auth = requirement.get('auth')
if req_auth is not None:
if req_auth:
auth_satisfied = client_context.auth_enabled
else:
auth_satisfied = not client_context.auth_enabled
return (topology_satisfied and min_version_satisfied and
max_version_satisfied and params_satisfied)
max_version_satisfied and params_satisfied and auth_satisfied)
def parse_collection_or_database_options(options):

View File

@ -39,7 +39,7 @@ from pymongo import (MongoClient,
from pymongo.collection import ReturnDocument
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.monitoring import _SENSITIVE_COMMANDS
from pymongo.pool import _CancellationContext
from pymongo.pool import _CancellationContext, _PoolGeneration
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.server_selectors import (any_server_selector,
@ -259,20 +259,23 @@ class MockSocketInfo(object):
class MockPool(object):
def __init__(self, address, options, handshake=True):
self.generation = 0
self.gen = _PoolGeneration()
self._lock = threading.Lock()
self.opts = options
self.operation_count = 0
def stale_generation(self, gen, service_id):
return self.gen.stale(gen, service_id)
def get_socket(self, all_credentials, checkout=False):
return MockSocketInfo()
def return_socket(self, *args, **kwargs):
pass
def _reset(self):
def _reset(self, service_id=None):
with self._lock:
self.generation += 1
self.gen.inc(service_id)
def ready(self):
pass