diff --git a/pymongo/pool.py b/pymongo/pool.py index 5d62e4bee..15e5b4873 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -226,6 +226,9 @@ else: # main thread, to avoid the deadlock. See PYTHON-607. 'foo'.encode('idna') +# Remove after PYTHON-2712 +_MOCK_SERVICE_ID = False + def _raise_connection_failure(address, error, msg_prefix=None): """Convert a socket.error to ConnectionFailure and raise it.""" @@ -600,7 +603,7 @@ class SocketInfo(object): doc = self.command('admin', cmd, publish_events=False, exhaust_allowed=awaitable) # PYTHON-2712 will remove this topologyVersion fallback logic. - if self.opts.load_balanced: + if self.opts.load_balanced and _MOCK_SERVICE_ID: process_id = doc.get('topologyVersion', {}).get('processId') doc.setdefault('serviceId', process_id) ismaster = IsMaster(doc, awaitable=awaitable) @@ -627,7 +630,7 @@ class SocketInfo(object): if self.opts.load_balanced: if not ismaster.service_id: raise ConfigurationError( - 'Driver attempted to initialize in load balancing mode' + 'Driver attempted to initialize in load balancing mode,' ' but the server does not support this mode') self.service_id = ismaster.service_id self.generation = self.pool_gen.get(self.service_id) @@ -1338,11 +1341,11 @@ class Pool: raise sock_info = SocketInfo(sock, self, self.address, conn_id) - if self.handshake: - sock_info.ismaster(all_credentials) - self.is_writable = sock_info.is_writable - try: + if self.handshake: + sock_info.ismaster(all_credentials) + self.is_writable = sock_info.is_writable + sock_info.check_auth(all_credentials) except BaseException: sock_info.close_socket(ConnectionClosedReason.ERROR) diff --git a/test/__init__.py b/test/__init__.py index 9e76b28f5..1e76e58b8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -97,6 +97,9 @@ TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER")) SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI") MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI") if TEST_LOADBALANCER: + # Remove after PYTHON-2712 + from pymongo import pool + pool._MOCK_SERVICE_ID = True res = parse_uri(SINGLE_MONGOS_LB_URI) host, port = res['nodelist'][0] db_user = res['username'] or db_user diff --git a/test/load_balancer/test_crud_unified.py b/test/load_balancer/test_crud_unified.py index dfe0935bb..4363f293f 100644 --- a/test/load_balancer/test_crud_unified.py +++ b/test/load_balancer/test_crud_unified.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_crud_unified import * if __name__ == '__main__': diff --git a/test/load_balancer/test_dns.py b/test/load_balancer/test_dns.py index 047b98b12..34e2329c8 100644 --- a/test/load_balancer/test_dns.py +++ b/test/load_balancer/test_dns.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_dns import * if __name__ == '__main__': diff --git a/test/load_balancer/test_load_balancer.py b/test/load_balancer/test_load_balancer.py index 99f8855ca..77e824e59 100644 --- a/test/load_balancer/test_load_balancer.py +++ b/test/load_balancer/test_load_balancer.py @@ -14,38 +14,12 @@ """Test the Load Balancer unified spec tests.""" -import os import sys sys.path[0:0] = [""] -from test import unittest, IntegrationTest, client_context -from test.utils import get_pool -from test.unified_format import generate_test_classes - -# Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'unified') - -# Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - - -class TestLB(IntegrationTest): - @client_context.require_load_balancer - def test_unpin_committed_transaction(self): - pool = get_pool(self.client) - with self.client.start_session() as session: - with session.start_transaction(): - self.assertEqual(pool.active_sockets, 0) - self.db.test.insert_one({}, session=session) - self.assertEqual(pool.active_sockets, 1) # Pinned. - self.assertEqual(pool.active_sockets, 1) # Still pinned. - self.assertEqual(pool.active_sockets, 0) # Unpinned. - - def test_client_can_be_reopened(self): - self.client.close() - self.db.test.find_one({}) +from test import unittest +from test.test_load_balancer import * if __name__ == "__main__": diff --git a/test/load_balancer/test_retryable_change_stream.py b/test/load_balancer/test_retryable_change_stream.py index b7c902dd3..f08e27e9d 100644 --- a/test/load_balancer/test_retryable_change_stream.py +++ b/test/load_balancer/test_retryable_change_stream.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_change_stream import * if __name__ == '__main__': diff --git a/test/load_balancer/test_retryable_reads.py b/test/load_balancer/test_retryable_reads.py index c5de3c907..73510fab7 100644 --- a/test/load_balancer/test_retryable_reads.py +++ b/test/load_balancer/test_retryable_reads.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_retryable_reads import * if __name__ == '__main__': diff --git a/test/load_balancer/test_retryable_writes.py b/test/load_balancer/test_retryable_writes.py index 3800641b0..c920acb81 100644 --- a/test/load_balancer/test_retryable_writes.py +++ b/test/load_balancer/test_retryable_writes.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_retryable_writes import * if __name__ == '__main__': diff --git a/test/load_balancer/test_transactions_unified.py b/test/load_balancer/test_transactions_unified.py index 257202804..d2f7eac94 100644 --- a/test/load_balancer/test_transactions_unified.py +++ b/test/load_balancer/test_transactions_unified.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_transactions_unified import * if __name__ == '__main__': diff --git a/test/load_balancer/test_uri_options.py b/test/load_balancer/test_uri_options.py index b644d7d33..c7151d330 100644 --- a/test/load_balancer/test_uri_options.py +++ b/test/load_balancer/test_uri_options.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_uri_spec import * if __name__ == '__main__': diff --git a/test/load_balancer/test_versioned_api.py b/test/load_balancer/test_versioned_api.py index 7e801968c..2b188a6b1 100644 --- a/test/load_balancer/test_versioned_api.py +++ b/test/load_balancer/test_versioned_api.py @@ -13,10 +13,10 @@ # limitations under the License. import sys -import unittest sys.path[0:0] = [""] +from test import unittest from test.test_versioned_api import * if __name__ == '__main__': diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py new file mode 100644 index 000000000..90bde87e5 --- /dev/null +++ b/test/test_load_balancer.py @@ -0,0 +1,52 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Load Balancer unified spec tests.""" + +import os +import sys + +sys.path[0:0] = [""] + +from test import unittest, IntegrationTest, client_context +from test.utils import get_pool +from test.unified_format import generate_test_classes + +# Location of JSON test specifications. +TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'load_balancer', 'unified') + +# Generate unified tests. +globals().update(generate_test_classes(TEST_PATH, module=__name__)) + + +class TestLB(IntegrationTest): + @client_context.require_load_balancer + def test_unpin_committed_transaction(self): + pool = get_pool(self.client) + with self.client.start_session() as session: + with session.start_transaction(): + self.assertEqual(pool.active_sockets, 0) + self.db.test.insert_one({}, session=session) + self.assertEqual(pool.active_sockets, 1) # Pinned. + self.assertEqual(pool.active_sockets, 1) # Still pinned. + self.assertEqual(pool.active_sockets, 0) # Unpinned. + + def test_client_can_be_reopened(self): + self.client.close() + self.db.test.find_one({}) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index d2433e5b4..284fee1ed 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -891,6 +891,10 @@ class UnifiedSpecTestMixinV1(IntegrationTest): if expect_error: return self.process_error(exc, expect_error) raise + else: + if expect_error: + self.fail('Excepted error %s but "%s" succeeded: %s' % ( + expect_error, opname, result)) if expect_result: actual = coerce_result(opname, result)