From a80169d1fa49734290ae73f113edf0c7bb589877 Mon Sep 17 00:00:00 2001 From: Julius Park Date: Fri, 24 Sep 2021 15:37:24 -0700 Subject: [PATCH] PYTHON-2463 Do not allow a MongoClient to be reused after it is closed (#737) --- doc/changelog.rst | 3 + doc/migrate-to-pymongo4.rst | 6 ++ pymongo/mongo_client.py | 9 ++- pymongo/topology.py | 14 ++++- test/test_client.py | 98 +++++++++++++----------------- test/test_load_balancer.py | 4 -- test/test_mongos_load_balancing.py | 15 ----- test/test_pooling.py | 9 --- test/test_raw_bson.py | 2 + test/test_replica_set_reconfig.py | 21 +++---- test/test_server_selection.py | 2 +- test/test_threads.py | 34 ----------- test/unified_format.py | 2 +- 13 files changed, 79 insertions(+), 140 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index d570d8377..b20314012 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -143,6 +143,9 @@ Breaking Changes in 4.0 opposed to the previous syntax which was simply ``if collection:`` or ``if database:``. You must now explicitly compare with None. +- :class:`~pymongo.mongo_client.MongoClient` cannot execute any operations + after being closed. The previous behavior would simply reconnect. However, + now you must create a new instance. - Classes :class:`~bson.int64.Int64`, :class:`~bson.min_key.MinKey`, :class:`~bson.max_key.MaxKey`, :class:`~bson.timestamp.Timestamp`, :class:`~bson.regex.Regex`, and :class:`~bson.dbref.DBRef` all implement diff --git a/doc/migrate-to-pymongo4.rst b/doc/migrate-to-pymongo4.rst index 0630f5816..eea17e8bc 100644 --- a/doc/migrate-to-pymongo4.rst +++ b/doc/migrate-to-pymongo4.rst @@ -180,6 +180,12 @@ can be changed to this:: now defaults to ``False`` instead of ``True``. ``json_util.loads`` now decodes datetime as naive by default. +MongoClient cannot execute operations after ``close()`` +....................................................... + +:class:`~pymongo.mongo_client.MongoClient` cannot execute any operations +after being closed. The previous behavior would simply reconnect. However, +now you must create a new instance. Database -------- diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index db56187b7..f105b1023 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -704,7 +704,6 @@ class MongoClient(common.BaseObject): self.__kill_cursors_queue = [] self._event_listeners = options.pool_options.event_listeners - super(MongoClient, self).__init__(options.codec_options, options.read_preference, options.write_concern, @@ -1127,10 +1126,10 @@ class MongoClient(common.BaseObject): sending one or more endSessions commands. Close all sockets in the connection pools and stop the monitor threads. - If this instance is used again it will be automatically re-opened and - the threads restarted unless auto encryption is enabled. A client - enabled with auto encryption cannot be used again after being closed; - any attempt will raise :exc:`~.errors.InvalidOperation`. + + .. versionchanged:: 4.0 + Once closed, the client cannot be used again and any attempt will + raise :exc:`~pymongo.errors.InvalidOperation`. .. versionchanged:: 3.6 End all server sessions created by this client. diff --git a/pymongo/topology.py b/pymongo/topology.py index 8a2818b5c..9ca3029bf 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -33,7 +33,8 @@ from pymongo.errors import (ConnectionFailure, OperationFailure, PyMongoError, ServerSelectionTimeoutError, - WriteError) + WriteError, + InvalidOperation) from pymongo.hello import Hello from pymongo.monitor import SrvMonitor from pymongo.pool import PoolOptions @@ -112,6 +113,7 @@ class Topology(object): # Store the seed list to help diagnose errors in _error_message(). self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False + self._closed = False self._lock = threading.Lock() self._condition = self._settings.condition_class(self._lock) self._servers = {} @@ -461,7 +463,9 @@ class Topology(object): raise def close(self): - """Clear pools and terminate monitors. Topology reopens on demand.""" + """Clear pools and terminate monitors. Topology does not reopen on + demand. Any further operations will raise + :exc:`~.errors.InvalidOperation`. """ with self._lock: for server in self._servers.values(): server.close() @@ -477,6 +481,7 @@ class Topology(object): self._srv_monitor.close() self._opened = False + self._closed = True # Publish only after releasing the lock. if self._publish_tp: @@ -550,6 +555,11 @@ class Topology(object): Hold the lock when calling this. """ + if self._closed: + raise InvalidOperation("Once a MongoClient is closed, " + "all operations will fail. Please create " + "a new client object if you wish to " + "reconnect.") if not self._opened: self._opened = True self._update_servers() diff --git a/test/test_client.py b/test/test_client.py index cce028670..993c58267 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -50,7 +50,8 @@ from pymongo.errors import (AutoReconnect, NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, - WriteConcernError) + WriteConcernError, + InvalidOperation) from pymongo.hello import HelloCompat from pymongo.mongo_client import MongoClient from pymongo.monitoring import (ServerHeartbeatListener, @@ -772,28 +773,22 @@ class TestClient(IntegrationTest): self.assertNotIn("pymongo_test2", dbs) def test_close(self): - coll = self.client.pymongo_test.bar - - self.client.close() - self.client.close() - - coll.count_documents({}) - - self.client.close() - self.client.close() - - coll.count_documents({}) + test_client = rs_or_single_client() + coll = test_client.pymongo_test.bar + test_client.close() + self.assertRaises(InvalidOperation, coll.count_documents, {}) def test_close_kills_cursors(self): if sys.platform.startswith('java'): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") + test_client = rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() - self.client._process_periodic_tasks() + test_client._process_periodic_tasks() # Add some test data. - coll = self.client.pymongo_test.test_close_kills_cursors + coll = test_client.pymongo_test.test_close_kills_cursors docs_inserted = 1000 coll.insert_many([{"i": i} for i in range(docs_inserted)]) @@ -811,13 +806,13 @@ class TestClient(IntegrationTest): gc.collect() # Close the client and ensure the topology is closed. - self.assertTrue(self.client._topology._opened) - self.client.close() - self.assertFalse(self.client._topology._opened) - + self.assertTrue(test_client._topology._opened) + test_client.close() + self.assertFalse(test_client._topology._opened) + test_client = rs_or_single_client() # The killCursors task should not need to re-open the topology. - self.client._process_periodic_tasks() - self.assertFalse(self.client._topology._opened) + test_client._process_periodic_tasks() + self.assertTrue(test_client._topology._opened) def test_close_stops_kill_cursors_thread(self): client = rs_client() @@ -828,12 +823,9 @@ class TestClient(IntegrationTest): client.close() self.assertTrue(client._kill_cursors_executor._stopped) - # Reusing the closed client should restart the thread. - client.admin.command('ping') - self.assertFalse(client._kill_cursors_executor._stopped) - - # Again, closing the client should stop the thread. - client.close() + # Reusing the closed client should raise an InvalidOperation error. + self.assertRaises(InvalidOperation, client.admin.command, 'ping') + # Thread is still stopped. self.assertTrue(client._kill_cursors_executor._stopped) def test_uri_connect_option(self): @@ -1128,12 +1120,13 @@ class TestClient(IntegrationTest): with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - self.assertEqual(1, len(get_pool(client).sockets)) - self.assertEqual(0, len(get_pool(client).sockets)) - + with self.assertRaises(InvalidOperation): + client.pymongo_test.test.find_one() + client = rs_or_single_client() with client as client: self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - self.assertEqual(0, len(get_pool(client).sockets)) + with self.assertRaises(InvalidOperation): + client.pymongo_test.test.find_one() def test_interrupt_signal(self): if sys.platform.startswith('java'): @@ -1787,35 +1780,26 @@ class TestClientLazyConnect(IntegrationTest): class TestMongoClientFailover(MockClientTest): def test_discover_primary(self): - # Disable background refresh. - with client_knobs(heartbeat_frequency=999999): - c = MockClient( - standalones=[], - members=['a:1', 'b:2', 'c:3'], - mongoses=[], - host='b:2', # Pass a secondary. - replicaSet='rs') - self.addCleanup(c.close) + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs', + heartbeatFrequencyMS=500) + self.addCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, 'connect') - self.assertEqual(c.address, ('a', 1)) + wait_until(lambda: len(c.nodes) == 3, 'connect') - # Fail over. - c.kill_host('a:1') - c.mock_primary = 'b:2' - - c.close() - self.assertEqual(0, len(c.nodes)) - - t = c._get_topology() - t.select_servers(writable_server_selector) # Reconnect. - self.assertEqual(c.address, ('b', 2)) - - # a:1 not longer in nodes. - self.assertLess(len(c.nodes), 3) - - # c:3 is rediscovered. - t.select_server_by_address(('c', 3)) + self.assertEqual(c.address, ('a', 1)) + # Fail over. + c.kill_host('a:1') + c.mock_primary = 'b:2' + wait_until(lambda: c.address == ('b', 2), "wait for server " + "address to be " + "updated") + # a:1 not longer in nodes. + self.assertLess(len(c.nodes), 3) def test_reconnect(self): # Verify the node list isn't forgotten during a network failure. diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 62b40be50..da77734e9 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -62,10 +62,6 @@ class TestLB(IntegrationTest): self.assertEqual(pool.active_sockets, 1) # Still pinned. self.assertEqual(pool.active_sockets, 0) # Unpinned. - def test_client_can_be_reopened(self): - self.client.close() - self.db.test.find_one({}) - @client_context.require_failCommand_fail_point def test_cursor_gc(self): def create_resource(coll): diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index e67b1186e..575bc458c 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -92,21 +92,6 @@ class TestMongosLoadBalancing(MockClientTest): do_simple_op(client, nthreads) wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') - def test_reconnect(self): - nthreads = 10 - client = connected(self.mock_client()) - - # connected() ensures we've contacted at least one mongos. Wait for - # all of them. - wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') - - # Trigger reconnect. - client.close() - do_simple_op(client, nthreads) - - wait_until(lambda: len(client.nodes) == 3, - 'reconnect to all mongoses') - def test_failover(self): nthreads = 10 client = connected(self.mock_client(localThresholdMS=0.001)) diff --git a/test/test_pooling.py b/test/test_pooling.py index becbacc1e..266e080ca 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -102,12 +102,6 @@ class NonUnique(MongoThread): raise AssertionError("Should have raised DuplicateKeyError") -class Disconnect(MongoThread): - def run_mongo_thread(self): - for _ in range(N): - self.client.close() - - class SocketGetter(MongoThread): """Utility for TestPooling. @@ -198,9 +192,6 @@ class TestPooling(_TestPoolingBase): def test_no_disconnect(self): run_cases(self.c, [NonUnique, Unique, InsertOneAndFind]) - def test_disconnect(self): - run_cases(self.c, [InsertOneAndFind, Disconnect, Unique]) - def test_pool_reuses_open_socket(self): # Test Pool's _check_closed() method doesn't close a healthy socket. cx_pool = self.create_pool(max_pool_size=10) diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 7fb53c6da..7e1bf6f83 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -25,6 +25,7 @@ from bson.errors import InvalidBSON from bson.raw_bson import RawBSONDocument, DEFAULT_RAW_BSON_OPTIONS from bson.son import SON from test import client_context, unittest +from test.utils import rs_or_single_client from test.test_client import IntegrationTest @@ -43,6 +44,7 @@ class TestRawBSONDocument(IntegrationTest): @classmethod def setUpClass(cls): super(TestRawBSONDocument, cls).setUpClass() + client_context.client = rs_or_single_client() cls.client = client_context.client def tearDown(self): diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index a25d7af16..f19a32ea4 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -18,7 +18,7 @@ import sys sys.path[0:0] = [""] -from pymongo.errors import ConnectionFailure, AutoReconnect +from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError from pymongo import ReadPreference from test import unittest, client_context, client_knobs, MockClientTest from test.pymongo_mocks import MockClient @@ -42,13 +42,10 @@ class TestSecondaryBecomesStandalone(MockClientTest): mongoses=[], host='a:1,b:2,c:3', replicaSet='rs', - serverSelectionTimeoutMS=100) + serverSelectionTimeoutMS=100, + connect=False) self.addCleanup(c.close) - # MongoClient connects to primary by default. - wait_until(lambda: c.address is not None, 'connect to primary') - self.assertEqual(c.address, ('a', 1)) - # C is brought up as a standalone. c.mock_members.remove('c:3') c.mock_standalones.append('c:3') @@ -57,14 +54,15 @@ class TestSecondaryBecomesStandalone(MockClientTest): c.kill_host('a:1') c.kill_host('b:2') - # Force reconnect. - c.close() - - with self.assertRaises(AutoReconnect): + with self.assertRaises(ServerSelectionTimeoutError): c.db.command('ping') - self.assertEqual(c.address, None) + # Client can still discover the primary node + c.revive_host('a:1') + wait_until(lambda: c.address is not None, 'connect to primary') + self.assertEqual(c.address, ('a', 1)) + def test_replica_set_client(self): c = MockClient( standalones=[], @@ -158,7 +156,6 @@ class TestSecondaryAdded(MockClientTest): c.mock_members.append('c:3') c.mock_hello_hosts.append('c:3') - c.close() c.db.command('ping') self.assertEqual(c.address, ('a', 1)) diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 1e4165246..46fce3b13 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -111,8 +111,8 @@ class TestCustomServerSelectorFunction(IntegrationTest): # Client setup. mongo_client = rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection - self.addCleanup(mongo_client.drop_database, 'testdb') self.addCleanup(mongo_client.close) + self.addCleanup(mongo_client.drop_database, 'testdb') # Do N operations and test selector is called at least N times. test_collection.insert_one({'age': 20, 'name': 'John'}) diff --git a/test/test_threads.py b/test/test_threads.py index 854d3cab0..a3cde207a 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -111,21 +111,6 @@ class Update(threading.Thread): assert error -class Disconnect(threading.Thread): - - def __init__(self, client, n): - threading.Thread.__init__(self) - self.client = client - self.n = n - self.passed = False - - def run(self): - for _ in range(self.n): - self.client.close() - - self.passed = True - - class TestThreads(IntegrationTest): def setUp(self): self.db = self.client.pymongo_test @@ -180,25 +165,6 @@ class TestThreads(IntegrationTest): error.join() okay.join() - def test_client_disconnect(self): - db = rs_or_single_client(serverSelectionTimeoutMS=30000).pymongo_test - db.drop_collection("test") - db.test.insert_many([{"x": i} for i in range(1000)]) - - # Start 10 threads that execute a query, and 10 threads that call - # client.close() 10 times in a row. - threads = [SaveAndFind(db.test) for _ in range(10)] - threads.extend(Disconnect(db.client, 10) for _ in range(10)) - - for t in threads: - t.start() - - for t in threads: - t.join(300) - - for t in threads: - self.assertTrue(t.passed) - if __name__ == "__main__": unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index 69e79e3b7..0a2f0b996 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -975,9 +975,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest): "session %s" % (spec['session'],)) client = single_client('%s:%s' % session._pinned_address) + self.addCleanup(client.close) self.__set_fail_point( client=client, command_args=spec['failPoint']) - self.addCleanup(client.close) def _testOperation_assertSessionTransactionState(self, spec): session = self.entity_map[spec['session']]