diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index 34b407b94..223c39228 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -207,7 +207,7 @@ static PyObject* _test_long_long_to_str(PyObject* self, PyObject* args) { * * Returns a new ref */ static PyObject* _error(char* name) { - PyObject* error; + PyObject* error = NULL; PyObject* errors = PyImport_ImportModule("bson.errors"); if (!errors) { return NULL; @@ -279,7 +279,7 @@ static PyObject* datetime_from_millis(long long millis) { * micros = diff * 1000 111000 * Resulting in datetime(1, 1, 1, 1, 1, 1, 111000) -- the expected result */ - PyObject* datetime; + PyObject* datetime = NULL; int diff = (int)(((millis % 1000) + 1000) % 1000); int microseconds = diff * 1000; Time64_T seconds = (millis - diff) / 1000; @@ -294,7 +294,7 @@ static PyObject* datetime_from_millis(long long millis) { timeinfo.tm_sec, microseconds); if(!datetime) { - PyObject *etype, *evalue, *etrace; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; /* * Calling _error clears the error state, so fetch it first. @@ -350,8 +350,8 @@ static PyObject* datetime_ms_from_millis(PyObject* self, long long millis){ return NULL; } - PyObject* dt; - PyObject* ll_millis; + PyObject* dt = NULL; + PyObject* ll_millis = NULL; if (!(ll_millis = PyLong_FromLongLong(millis))){ return NULL; @@ -1790,7 +1790,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) { PyObject* result; unsigned char check_keys; unsigned char top_level = 1; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; PyObject* raw_bson_document_bytes_obj; @@ -2512,8 +2512,8 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, * Wrap any non-InvalidBSON errors in InvalidBSON. */ if (PyErr_Occurred()) { - PyObject *etype, *evalue, *etrace; - PyObject *InvalidBSON; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject *InvalidBSON = NULL; /* * Calling _error clears the error state, so fetch it first. @@ -2585,8 +2585,8 @@ static int _element_to_dict(PyObject* self, const char* string, if (!*name) { /* If NULL is returned then wrap the UnicodeDecodeError in an InvalidBSON error */ - PyObject *etype, *evalue, *etrace; - PyObject *InvalidBSON; + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject *InvalidBSON = NULL; PyErr_Fetch(&etype, &evalue, &etrace); if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { @@ -2620,7 +2620,7 @@ static PyObject* _cbson_element_to_dict(PyObject* self, PyObject* args) { /* TODO: Support buffer protocol */ char* string; PyObject* bson; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; unsigned position; unsigned max; @@ -2732,7 +2732,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { int32_t size; Py_ssize_t total_size; const char* string; - PyObject* bson; + PyObject* bson = NULL; codec_options_t options; PyObject* result = NULL; PyObject* options_obj; diff --git a/pymongo/_cmessagemodule.c b/pymongo/_cmessagemodule.c index f95b94938..b5adbeec3 100644 --- a/pymongo/_cmessagemodule.c +++ b/pymongo/_cmessagemodule.c @@ -45,7 +45,7 @@ struct module_state { * * Returns a new ref */ static PyObject* _error(char* name) { - PyObject* error; + PyObject* error = NULL; PyObject* errors = PyImport_ImportModule("pymongo.errors"); if (!errors) { return NULL; @@ -75,9 +75,9 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) { int begin, cur_size, max_size = 0; int num_to_skip; int num_to_return; - PyObject* query; - PyObject* field_selector; - PyObject* options_obj; + PyObject* query = NULL; + PyObject* field_selector = NULL; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; @@ -221,12 +221,12 @@ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ int request_id = rand(); unsigned int flags; - PyObject* command; + PyObject* command = NULL; char* identifier = NULL; Py_ssize_t identifier_length = 0; - PyObject* docs; - PyObject* doc; - PyObject* options_obj; + PyObject* docs = NULL; + PyObject* doc = NULL; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; @@ -535,12 +535,12 @@ static PyObject* _cbson_encode_batched_op_msg(PyObject* self, PyObject* args) { unsigned char op; unsigned char ack; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); @@ -592,12 +592,12 @@ _cbson_batched_op_msg(PyObject* self, PyObject* args) { unsigned char ack; int request_id; int position; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); @@ -868,12 +868,12 @@ _cbson_encode_batched_write_command(PyObject* self, PyObject* args) { char *ns = NULL; unsigned char op; Py_ssize_t ns_len; - PyObject* command; - PyObject* docs; + PyObject* command = NULL; + PyObject* docs = NULL; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; - PyObject* options_obj; + PyObject* options_obj = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 46f66af62..b5fc5d8ac 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -42,6 +42,7 @@ from unittest import SkipTest from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +async def async_repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + await client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + await client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py new file mode 100644 index 000000000..289cf4975 --- /dev/null +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -0,0 +1,148 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test compliance with the connections survive primary step down spec.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import async_repl_set_step_down +from test.utils import ( + CMAPListener, + async_ensure_all_connected, +) + +from bson import SON +from pymongo import monitoring +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.errors import NotPrimaryError +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): + listener: CMAPListener + coll: AsyncCollection + + @classmethod + @async_client_context.require_replica_set + async def _setup_class(cls): + await super()._setup_class() + cls.listener = CMAPListener() + cls.client = await cls.unmanaged_async_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + ) + + # Ensure connections to all servers in replica set. This is to test + # that the is_writable flag is properly updated for connections that + # survive a replica set election. + await async_ensure_all_connected(cls.client) + cls.listener.reset() + + cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) + cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() + + async def asyncSetUp(self): + # Note that all ops use same write-concern as self.db (majority). + await self.db.drop_collection("step-down") + await self.db.create_collection("step-down") + self.listener.reset() + + async def set_fail_point(self, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await self.client.admin.command(cmd) + + def verify_pool_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 1) + + def verify_pool_not_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 0) + + @async_client_context.require_version_min(4, 2, -1) + async def test_get_more_iteration(self): + # Insert 5 documents with WC majority. + await self.coll.insert_many([{"data": k} for k in range(5)]) + # Start a find operation and retrieve first batch of results. + batch_size = 2 + cursor = self.coll.find(batch_size=batch_size) + for _ in range(batch_size): + await cursor.next() + # Force step-down the primary. + await async_repl_set_step_down(self.client, replSetStepDown=5, force=True) + # Get await anext batch of results. + for _ in range(batch_size): + await cursor.next() + # Verify pool not cleared. + self.verify_pool_not_cleared() + # Attempt insertion to mark server description as stale and prevent a + # NotPrimaryError on the subsequent operation. + try: + await self.coll.insert_one({}) + except NotPrimaryError: + pass + # Next insert should succeed on the new primary without clearing pool. + await self.coll.insert_one({}) + self.verify_pool_not_cleared() + + async def run_scenario(self, error_code, retry, pool_status_checker): + # Set fail point. + await self.set_fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": error_code}} + ) + self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) + # Insert record and verify failure. + with self.assertRaises(NotPrimaryError) as exc: + await self.coll.insert_one({"test": 1}) + self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload] + # Retry before CMAPListener assertion if retry_before=True. + if retry: + await self.coll.insert_one({"test": 1}) + # Verify pool cleared/not cleared. + pool_status_checker() + # Always retry here to ensure discovery of new primary. + await self.coll.insert_one({"test": 1}) + + @async_client_context.require_version_min(4, 2, -1) + @async_client_context.require_test_commands + async def test_not_primary_keep_connection_pool(self): + await self.run_scenario(10107, True, self.verify_pool_not_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_version_max(4, 1, 0, -1) + @async_client_context.require_test_commands + async def test_not_primary_reset_connection_pool(self): + await self.run_scenario(10107, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_shutdown_in_progress(self): + await self.run_scenario(91, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_interrupted_at_shutdown(self): + await self.run_scenario(11600, False, self.verify_pool_cleared) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/helpers.py b/test/helpers.py index bf6186d1a..11d5ab037 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -42,6 +42,7 @@ from unittest import SkipTest from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +def repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index fba767574..54cc4e048 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -20,10 +20,10 @@ import sys sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest +from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, ensure_all_connected, - repl_set_step_down, ) from bson import SON @@ -32,6 +32,8 @@ from pymongo.errors import NotPrimaryError from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener @@ -39,8 +41,8 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @client_context.require_replica_set - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.listener = CMAPListener() cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 @@ -56,7 +58,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.client.close() def setUp(self): diff --git a/test/utils.py b/test/utils.py index 961503489..9c78cff3a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -599,6 +599,44 @@ def ensure_all_connected(client: MongoClient) -> None: ) +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + def one(s): """Get one element of a set""" return next(iter(s)) @@ -761,16 +799,6 @@ async def async_wait_until(predicate, success_description, timeout=10): await asyncio.sleep(interval) -def repl_set_step_down(client, **kwargs): - """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" - cmd = SON([("replSetStepDown", 1)]) - cmd.update(kwargs) - - # Unfreeze a secondary to ensure a speedy election. - client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) - client.admin.command(cmd) - - def is_mongos(client): res = client.admin.command(HelloCompat.LEGACY_CMD) return res.get("msg", "") == "isdbgrid" diff --git a/tools/synchro.py b/tools/synchro.py index 3333b0de2..d8ec9ae46 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ replacements = { "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", "async_set_fail_point": "set_fail_point", + "async_ensure_all_connected": "ensure_all_connected", + "async_repl_set_step_down": "repl_set_step_down", } docstring_replacements: dict[tuple[str, str], str] = { @@ -186,6 +188,7 @@ converted_tests = [ "test_client_bulk_write.py", "test_client_context.py", "test_collection.py", + "test_connections_survive_primary_stepdown_spec.py", "test_cursor.py", "test_database.py", "test_encryption.py",