PYTHON-1677 Connections survive primary stepdown

This commit is contained in:
Prashant Mital 2019-07-29 16:45:05 -07:00
parent 57302846b7
commit 611c3f86b3
No known key found for this signature in database
GPG Key ID: 3D2DAA9E483ABE51
13 changed files with 357 additions and 114 deletions

View File

@ -65,6 +65,8 @@ Version 3.9 adds support for MongoDB 4.2. Highlights include:
the buffer protocol.
- Resume tokens can now be accessed from a ``ChangeStream`` cursor using the
:attr:`~pymongo.change_stream.ChangeStream.resume_token` attribute.
- Connections now survive primary step-down. Applications should expect less
socket connection turnover during replica set elections.
.. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst

View File

@ -29,17 +29,21 @@ from pymongo.errors import (CursorNotFound,
WriteConcernError,
WTimeoutError)
# From the Server Discovery and Monitoring spec, the "not master" error codes
# are combined with the "node is recovering" error codes.
# From the SDAM spec, the "node is shutting down" codes.
_SHUTDOWN_CODES = frozenset([
11600, # InterruptedAtShutdown
91, # ShutdownInProgress
])
# From the SDAM spec, the "not master" error codes are combined with the
# "node is recovering" error codes (of which the "node is shutting down"
# errors are a subset).
_NOT_MASTER_CODES = frozenset([
10107, # NotMaster
13435, # NotMasterNoSlaveOk
11600, # InterruptedAtShutdown
11602, # InterruptedDueToReplStateChange
13436, # NotMasterOrSecondary
189, # PrimarySteppedDown
91, # ShutdownInProgress
])
]) | _SHUTDOWN_CODES
# From the retryable writes spec.
_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset([
7, # HostNotFound

View File

@ -1195,9 +1195,11 @@ class MongoClient(common.BaseObject):
@contextlib.contextmanager
def _get_socket(self, server, session, exhaust=False):
with self._reset_on_error(server.description.address, session):
with server.get_socket(self.__all_credentials,
checkout=exhaust) as sock_info:
with _MongoClientErrorHandler(
self, server.description.address, session) as err_handler:
with server.get_socket(
self.__all_credentials, checkout=exhaust) as sock_info:
err_handler.contribute_socket(sock_info)
yield sock_info
def _select_server(self, server_selector, session, address=None):
@ -1289,8 +1291,10 @@ class MongoClient(common.BaseObject):
server = self._select_server(
operation.read_preference, operation.session, address=address)
with self._reset_on_error(server.description.address,
operation.session):
with _MongoClientErrorHandler(
self, server.description.address,
operation.session) as err_handler:
err_handler.contribute_socket(operation.exhaust_mgr.sock)
return server.run_operation_with_response(
operation.exhaust_mgr.sock,
operation,
@ -1314,49 +1318,6 @@ class MongoClient(common.BaseObject):
retryable=isinstance(operation, message._Query),
exhaust=exhaust)
@contextlib.contextmanager
def _reset_on_error(self, server_address, session):
"""On "not master" or "node is recovering" errors reset the server
according to the SDAM spec.
Unpin the session on transient transaction errors.
"""
try:
try:
yield
except PyMongoError as exc:
if session and exc.has_error_label(
"TransientTransactionError"):
session._unpin_mongos()
raise
except NetworkTimeout:
# The socket has been closed. Don't reset the server.
# Server Discovery And Monitoring Spec: "When an application
# operation fails because of any network error besides a socket
# timeout...."
if session:
session._server_session.mark_dirty()
raise
except NotMasterError:
# "When the client sees a "not master" error it MUST replace the
# server's description with type Unknown. It MUST request an
# immediate check of the server."
self._reset_server_and_request_check(server_address)
raise
except ConnectionFailure:
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."
self.__reset_server(server_address)
if session:
session._server_session.mark_dirty()
raise
except OperationFailure as exc:
if exc.code in helpers._RETRYABLE_ERROR_CODES:
# Do not request an immediate check since the server is likely
# shutting down.
self.__reset_server(server_address)
raise
def _retry_with_session(self, retryable, func, session, bulk):
"""Execute an operation with at most one consecutive retries
@ -1494,7 +1455,7 @@ class MongoClient(common.BaseObject):
with self._tmp_session(session) as s:
return self._retry_with_session(retryable, func, s, None)
def __reset_server(self, address):
def _reset_server(self, address):
"""Clear our connection pool for a server and mark it Unknown."""
self._topology.reset_server(address)
@ -2158,3 +2119,69 @@ class MongoClient(common.BaseObject):
raise TypeError("'MongoClient' object is not iterable")
next = __next__
class _MongoClientErrorHandler(object):
"""Error handler for MongoClient."""
__slots__ = ('_client', '_server_address', '_session', '_max_wire_version')
def __init__(self, client, server_address, session):
self._client = client
self._server_address = server_address
self._session = session
self._max_wire_version = None
def contribute_socket(self, sock_info):
"""Provide socket information to the error handler."""
# Currently, we only extract the max_wire_version information.
self._max_wire_version = sock_info.max_wire_version
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
return
if issubclass(exc_type, PyMongoError):
if self._session and exc_val.has_error_label(
"TransientTransactionError"):
self._session._unpin_mongos()
if issubclass(exc_type, NetworkTimeout):
# The socket has been closed. Don't reset the server.
# Server Discovery And Monitoring Spec: "When an application
# operation fails because of any network error besides a socket
# timeout...."
if self._session:
self._session._server_session.mark_dirty()
elif issubclass(exc_type, NotMasterError):
# As per the SDAM spec if:
# - the server sees a "not master" error, and
# - the server is not shutting down, and
# - the server version is >= 4.2, then
# we keep the existing connection pool, but mark the server type
# as Unknown and request an immediate check of the server.
# Otherwise, we clear the connection pool, mark the server as
# Unknown and request an immediate check of the server.
err_code = exc_val.details.get('code', -1)
is_shutting_down = err_code in helpers._SHUTDOWN_CODES
if (is_shutting_down or (self._max_wire_version is None) or
(self._max_wire_version <= 7)):
# Clear the pool, mark server Unknown and request check.
self._client._reset_server_and_request_check(
self._server_address)
else:
self._client._topology.mark_server_unknown_and_request_check(
self._server_address)
elif issubclass(exc_type, ConnectionFailure):
# "Client MUST replace the server's description with type Unknown
# ... MUST NOT request an immediate check of the server."
self._client._reset_server(self._server_address)
if self._session:
self._session._server_session.mark_dirty()
elif issubclass(exc_type, OperationFailure):
# Do not request an immediate check since the server is likely
# shutting down.
if exc_val.code in helpers._RETRYABLE_ERROR_CODES:
self._client._reset_server(self._server_address)

View File

@ -651,7 +651,8 @@ class SocketInfo(object):
"""
if unacknowledged and not self.is_writable:
# Write won't succeed, bail as if we'd received a not master error.
raise NotMasterError("not master")
raise NotMasterError("not master", {
"ok": 0, "errmsg": "not master", "code": 10107})
def legacy_write(self, request_id, msg, max_doc_size, with_last_error):
"""Send OP_INSERT, etc., optionally returning response as a dict.
@ -768,6 +769,9 @@ class SocketInfo(object):
def update_last_checkin_time(self):
self.last_checkin_time = _time()
def update_is_writable(self, is_writable):
self.is_writable = is_writable
def idle_time_seconds(self):
"""Seconds since this socket was last checked into its pool."""
return _time() - self.last_checkin_time
@ -958,6 +962,8 @@ class Pool:
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
self.closed = False
# Track whether the sockets in this pool are writeable or not.
self.is_writable = None
# Keep track of resets, so we notice sockets created before the most
# recent reset and close them.
@ -1012,6 +1018,15 @@ class Pool:
for sock_info in sockets:
sock_info.close_socket(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable):
"""Updates the is_writable attribute on all sockets currently in the
Pool.
"""
self.is_writable = is_writable
with self.lock:
for socket in self.sockets:
socket.update_is_writable(self.is_writable)
def reset(self):
self._reset(close=False)
@ -1075,6 +1090,7 @@ class Pool:
sock_info = SocketInfo(sock, self, self.address, conn_id)
if self.handshake:
sock_info.ismaster(self.opts.metadata, None)
self.is_writable = sock_info.is_writable
return sock_info
@ -1194,6 +1210,7 @@ class Pool:
sock_info.close_socket(ConnectionClosedReason.STALE)
elif not sock_info.closed:
sock_info.update_last_checkin_time()
sock_info.update_is_writable(self.is_writable)
with self.lock:
self.sockets.appendleft(sock_info)

View File

@ -409,12 +409,18 @@ class Topology(object):
Do *not* request an immediate check.
"""
with self._lock:
self._reset_server(address)
self._reset_server(address, reset_pool=True)
def reset_server_and_request_check(self, address):
"""Clear our pool for a server, mark it Unknown, and check it soon."""
with self._lock:
self._reset_server(address)
self._reset_server(address, reset_pool=True)
self._request_check(address)
def mark_server_unknown_and_request_check(self, address):
"""Mark a server Unknown, and check it soon."""
with self._lock:
self._reset_server(address, reset_pool=False)
self._request_check(address)
def update_pool(self):
@ -523,8 +529,8 @@ class Topology(object):
for server in itervalues(self._servers):
server.open()
def _reset_server(self, address):
"""Clear our pool for a server and mark it Unknown.
def _reset_server(self, address, reset_pool):
"""Mark a server Unknown and optionally reset it's pool.
Hold the lock when calling this. Does *not* request an immediate check.
"""
@ -532,7 +538,8 @@ class Topology(object):
# "server" is None if another thread removed it from the topology.
if server:
server.reset()
if reset_pool:
server.reset()
# Mark this server Unknown.
self._description = self._description.reset_server(address)
@ -578,7 +585,14 @@ class Topology(object):
self._servers[address] = server
server.open()
else:
# Cache old is_writable value.
was_writable = self._servers[address].description.is_writable
# Update server description.
self._servers[address].description = sd
# Update is_writable value of the pool, if it changed.
if was_writable != sd.is_writable:
self._servers[address].pool.update_is_writable(
sd.is_writable)
for address, server in list(self._servers.items()):
if not self._description.has_server(address):

View File

@ -159,7 +159,6 @@ class ClientContext(object):
"""Create a client and grab essential information from the server."""
self.connection_attempts = []
self.connected = False
self.ismaster = {}
self.w = None
self.nodes = set()
self.replica_set_name = None
@ -184,6 +183,10 @@ class ClientContext(object):
if COMPRESSORS:
self.default_client_options["compressors"] = COMPRESSORS
@property
def ismaster(self):
return self.client.admin.command('isMaster')
def _connect(self, host, port, **kwargs):
# Jython takes a long time to connect.
if sys.platform.startswith('java'):
@ -253,7 +256,7 @@ class ClientContext(object):
self.cmd_line = self.client.admin.command('getCmdLineOpts')
self.server_status = self.client.admin.command('serverStatus')
self.ismaster = ismaster = self.client.admin.command('isMaster')
ismaster = self.ismaster
self.sessions_enabled = 'logicalSessionTimeoutMinutes' in ismaster
if 'setName' in ismaster:
@ -276,18 +279,17 @@ class ClientContext(object):
**self.default_client_options)
# Get the authoritative ismaster result from the primary.
self.ismaster = self.client.admin.command('ismaster')
ismaster = self.ismaster
nodes = [partition_node(node.lower())
for node in self.ismaster.get('hosts', [])]
for node in ismaster.get('hosts', [])]
nodes.extend([partition_node(node.lower())
for node in self.ismaster.get('passives', [])])
for node in ismaster.get('passives', [])])
nodes.extend([partition_node(node.lower())
for node in self.ismaster.get('arbiters', [])])
for node in ismaster.get('arbiters', [])])
self.nodes = set(nodes)
else:
self.ismaster = ismaster
self.nodes = set([(host, port)])
self.w = len(self.ismaster.get("hosts", [])) or 1
self.w = len(ismaster.get("hosts", [])) or 1
self.version = Version.from_client(self.client)
if 'enableTestCommands=1' in self.cmd_line['argv']:

View File

@ -24,8 +24,7 @@ sys.path[0:0] = [""]
from pymongo.errors import (ConnectionFailure,
PyMongoError)
from pymongo.monitoring import (ConnectionPoolListener,
ConnectionCheckedInEvent,
from pymongo.monitoring import (ConnectionCheckedInEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutFailedReason,
@ -44,6 +43,7 @@ from test import (IntegrationTest,
unittest)
from test.utils import (camel_to_snake,
client_context,
CMAPListener,
get_pool,
get_pools,
rs_or_single_client,
@ -70,48 +70,6 @@ OBJECT_TYPES = {
}
class CMAPListener(ConnectionPoolListener):
def __init__(self):
self.events = []
def add_event(self, event):
self.events.append(event)
def event_count(self, event_type):
return len([event for event in self.events[:]
if isinstance(event, event_type)])
def connection_created(self, event):
self.add_event(event)
def connection_ready(self, event):
self.add_event(event)
def connection_closed(self, event):
self.add_event(event)
def connection_check_out_started(self, event):
self.add_event(event)
def connection_check_out_failed(self, event):
self.add_event(event)
def connection_checked_out(self, event):
self.add_event(event)
def connection_checked_in(self, event):
self.add_event(event)
def pool_created(self, event):
self.add_event(event)
def pool_cleared(self, event):
self.add_event(event)
def pool_closed(self, event):
self.add_event(event)
class CMAPThread(threading.Thread):
def __init__(self, name):
super(CMAPThread, self).__init__()

View File

@ -0,0 +1,138 @@
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test compliance with the connections survive primary step down spec."""
import sys
sys.path[0:0] = [""]
from bson import SON
from pymongo import monitoring
from pymongo.errors import NotMasterError
from pymongo.write_concern import WriteConcern
from test import (client_context,
unittest,
IntegrationTest)
from test.utils import (CMAPListener,
ensure_all_connected,
rs_or_single_client)
class TestConnectionsSurvivePrimaryStepDown(IntegrationTest):
@classmethod
@client_context.require_replica_set
def setUpClass(cls):
super(TestConnectionsSurvivePrimaryStepDown, cls).setUpClass()
cls.listener = CMAPListener()
cls.client = rs_or_single_client(event_listeners=[cls.listener],
retryWrites=False)
# Ensure connections to all servers in replica set. This is to test
# that the is_writable flag is properly updated for sockets that
# survive a replica set election.
ensure_all_connected(cls.client)
cls.listener.reset()
cls.db = cls.client.get_database(
"step-down", write_concern=WriteConcern("majority"))
cls.coll = cls.db.get_collection(
"step-down", write_concern=WriteConcern("majority"))
def setUp(self):
# Note that all ops use same write-concern as self.db (majority).
self.db.drop_collection("step-down")
self.db.create_collection("step-down")
self.listener.reset()
def set_fail_point(self, command_args):
cmd = SON([("configureFailPoint", "failCommand")])
cmd.update(command_args)
self.client.admin.command(cmd)
def verify_pool_cleared(self):
self.assertEqual(
self.listener.event_count(monitoring.PoolClearedEvent), 1)
def verify_pool_not_cleared(self):
self.assertEqual(
self.listener.event_count(monitoring.PoolClearedEvent), 0)
@client_context.require_version_min(4, 2, -1)
def test_get_more_iteration(self):
# Insert 5 documents with WC majority.
self.coll.insert_many([{'data': k} for k in range(5)])
# Start a find operation and retrieve first batch of results.
batch_size = 2
cursor = self.coll.find(batch_size=batch_size)
for _ in range(batch_size):
cursor.next()
# Force step-down the primary.
res = self.client.admin.command(
SON([("replSetStepDown", 5), ("force", True)]))
self.assertEqual(res["ok"], 1.0)
# Get next batch of results.
for _ in range(batch_size):
cursor.next()
# Verify pool not cleared.
self.verify_pool_not_cleared()
# Attempt insertion to mark server description as stale and prevent a
# notMaster error on the subsequent operation.
try:
self.coll.insert_one({})
except NotMasterError:
pass
# Next insert should succeed on the new primary without clearing pool.
self.coll.insert_one({})
self.verify_pool_not_cleared()
def run_scenario(self, error_code, retry, pool_status_checker):
# Set fail point.
self.set_fail_point({"mode": {"times": 1},
"data": {"failCommands": ["insert"],
"errorCode": error_code}})
self.addCleanup(self.set_fail_point, {"mode": "off"})
# Insert record and verify failure.
with self.assertRaises(NotMasterError) as exc:
self.coll.insert_one({"test": 1})
self.assertEqual(exc.exception.details['code'], error_code)
# Retry before CMAPListener assertion if retry_before=True.
if retry:
self.coll.insert_one({"test": 1})
# Verify pool cleared/not cleared.
pool_status_checker()
# Always retry here to ensure discovery of new primary.
self.coll.insert_one({"test": 1})
@client_context.require_version_min(4, 2, -1)
def test_not_master_keep_connection_pool(self):
self.run_scenario(10107, True, self.verify_pool_not_cleared)
@client_context.require_version_min(4, 0, 0)
@client_context.require_version_max(4, 1, 0, -1)
def test_not_master_reset_connection_pool(self):
self.run_scenario(10107, False, self.verify_pool_cleared)
@client_context.require_version_min(4, 0, 0)
def test_shutdown_in_progress(self):
self.run_scenario(91, False, self.verify_pool_cleared)
@client_context.require_version_min(4, 0, 0)
def test_interrupted_at_shutdown(self):
self.run_scenario(11600, False, self.verify_pool_cleared)
if __name__ == "__main__":
unittest.main()

View File

@ -63,6 +63,9 @@ class MockPool(object):
def close(self):
self._reset()
def update_is_writable(self, is_writable):
pass
class MockMonitor(object):
def __init__(self, server_description, topology, pool, topology_settings):

View File

@ -60,6 +60,9 @@ class MockPool(object):
def close(self):
self._reset()
def update_is_writable(self, is_writable):
pass
def remove_stale_sockets(self):
pass

View File

@ -73,6 +73,9 @@ class MockPool(object):
def close(self):
self._reset()
def update_is_writable(self, is_writable):
pass
def remove_stale_sockets(self):
pass

View File

@ -33,9 +33,10 @@ from bson.objectid import ObjectId
from pymongo import (MongoClient,
monitoring)
from pymongo.errors import OperationFailure
from pymongo.monitoring import _SENSITIVE_COMMANDS
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.monitoring import _SENSITIVE_COMMANDS, ConnectionPoolListener
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.server_selectors import (any_server_selector,
writable_server_selector)
from pymongo.write_concern import WriteConcern
@ -68,6 +69,51 @@ class WhiteListEventListener(monitoring.CommandListener):
self.results['failed'].append(event)
class CMAPListener(ConnectionPoolListener):
def __init__(self):
self.events = []
def reset(self):
self.events = []
def add_event(self, event):
self.events.append(event)
def event_count(self, event_type):
return len([event for event in self.events[:]
if isinstance(event, event_type)])
def connection_created(self, event):
self.add_event(event)
def connection_ready(self, event):
self.add_event(event)
def connection_closed(self, event):
self.add_event(event)
def connection_check_out_started(self, event):
self.add_event(event)
def connection_check_out_failed(self, event):
self.add_event(event)
def connection_checked_out(self, event):
self.add_event(event)
def connection_checked_in(self, event):
self.add_event(event)
def pool_created(self, event):
self.add_event(event)
def pool_cleared(self, event):
self.add_event(event)
def pool_closed(self, event):
self.add_event(event)
class EventListener(monitoring.CommandListener):
def __init__(self):
@ -359,6 +405,29 @@ def rs_or_single_client(h=None, p=None, **kwargs):
return _mongo_client(h, p, **kwargs)
def ensure_all_connected(client):
"""Ensure that the client's connection pool has socket connections to all
members of a replica set. Raises ConfigurationError when called with a
non-replica set client.
Depending on the use-case, the caller may need to clear any event listeners
that are configured on the client.
"""
ismaster = client.admin.command("isMaster")
if 'setName' not in ismaster:
raise ConfigurationError("cluster is not a replica set")
target_host_list = set(ismaster['hosts'])
connected_host_list = set([ismaster['me']])
admindb = client.get_database('admin')
# Run isMaster until we have connected to each host at least once.
while connected_host_list != target_host_list:
ismaster = admindb.command("isMaster",
read_preference=ReadPreference.SECONDARY)
connected_host_list.update([ismaster["me"]])
def one(s):
"""Get one element of a set"""
return next(iter(s))

View File

@ -53,6 +53,9 @@ class MockPool(object):
def close(self):
pass
def update_is_writable(self, is_writable):
pass
def remove_stale_sockets(self):
pass