PYTHON-2744 Run LB tests against non-LB clusters (#638)
Fix serviceId fallback to make spec test pass. Fix socket leak when SocketInfo connection handshake fails.
This commit is contained in:
parent
c8f32a7a37
commit
bf78a9b2ef
@ -226,6 +226,9 @@ else:
|
||||
# main thread, to avoid the deadlock. See PYTHON-607.
|
||||
'foo'.encode('idna')
|
||||
|
||||
# Remove after PYTHON-2712
|
||||
_MOCK_SERVICE_ID = False
|
||||
|
||||
|
||||
def _raise_connection_failure(address, error, msg_prefix=None):
|
||||
"""Convert a socket.error to ConnectionFailure and raise it."""
|
||||
@ -600,7 +603,7 @@ class SocketInfo(object):
|
||||
doc = self.command('admin', cmd, publish_events=False,
|
||||
exhaust_allowed=awaitable)
|
||||
# PYTHON-2712 will remove this topologyVersion fallback logic.
|
||||
if self.opts.load_balanced:
|
||||
if self.opts.load_balanced and _MOCK_SERVICE_ID:
|
||||
process_id = doc.get('topologyVersion', {}).get('processId')
|
||||
doc.setdefault('serviceId', process_id)
|
||||
ismaster = IsMaster(doc, awaitable=awaitable)
|
||||
@ -627,7 +630,7 @@ class SocketInfo(object):
|
||||
if self.opts.load_balanced:
|
||||
if not ismaster.service_id:
|
||||
raise ConfigurationError(
|
||||
'Driver attempted to initialize in load balancing mode'
|
||||
'Driver attempted to initialize in load balancing mode,'
|
||||
' but the server does not support this mode')
|
||||
self.service_id = ismaster.service_id
|
||||
self.generation = self.pool_gen.get(self.service_id)
|
||||
@ -1338,11 +1341,11 @@ class Pool:
|
||||
raise
|
||||
|
||||
sock_info = SocketInfo(sock, self, self.address, conn_id)
|
||||
if self.handshake:
|
||||
sock_info.ismaster(all_credentials)
|
||||
self.is_writable = sock_info.is_writable
|
||||
|
||||
try:
|
||||
if self.handshake:
|
||||
sock_info.ismaster(all_credentials)
|
||||
self.is_writable = sock_info.is_writable
|
||||
|
||||
sock_info.check_auth(all_credentials)
|
||||
except BaseException:
|
||||
sock_info.close_socket(ConnectionClosedReason.ERROR)
|
||||
|
||||
@ -97,6 +97,9 @@ TEST_LOADBALANCER = bool(os.environ.get("TEST_LOADBALANCER"))
|
||||
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
|
||||
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
|
||||
if TEST_LOADBALANCER:
|
||||
# Remove after PYTHON-2712
|
||||
from pymongo import pool
|
||||
pool._MOCK_SERVICE_ID = True
|
||||
res = parse_uri(SINGLE_MONGOS_LB_URI)
|
||||
host, port = res['nodelist'][0]
|
||||
db_user = res['username'] or db_user
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_crud_unified import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_dns import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -14,38 +14,12 @@
|
||||
|
||||
"""Test the Load Balancer unified spec tests."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest, IntegrationTest, client_context
|
||||
from test.utils import get_pool
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), 'unified')
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
|
||||
class TestLB(IntegrationTest):
|
||||
@client_context.require_load_balancer
|
||||
def test_unpin_committed_transaction(self):
|
||||
pool = get_pool(self.client)
|
||||
with self.client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
self.assertEqual(pool.active_sockets, 0)
|
||||
self.db.test.insert_one({}, session=session)
|
||||
self.assertEqual(pool.active_sockets, 1) # Pinned.
|
||||
self.assertEqual(pool.active_sockets, 1) # Still pinned.
|
||||
self.assertEqual(pool.active_sockets, 0) # Unpinned.
|
||||
|
||||
def test_client_can_be_reopened(self):
|
||||
self.client.close()
|
||||
self.db.test.find_one({})
|
||||
from test import unittest
|
||||
from test.test_load_balancer import *
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_change_stream import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_retryable_reads import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_retryable_writes import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_transactions_unified import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_uri_spec import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.test_versioned_api import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
52
test/test_load_balancer.py
Normal file
52
test/test_load_balancer.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the Load Balancer unified spec tests."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest, IntegrationTest, client_context
|
||||
from test.utils import get_pool
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), 'load_balancer', 'unified')
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__))
|
||||
|
||||
|
||||
class TestLB(IntegrationTest):
|
||||
@client_context.require_load_balancer
|
||||
def test_unpin_committed_transaction(self):
|
||||
pool = get_pool(self.client)
|
||||
with self.client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
self.assertEqual(pool.active_sockets, 0)
|
||||
self.db.test.insert_one({}, session=session)
|
||||
self.assertEqual(pool.active_sockets, 1) # Pinned.
|
||||
self.assertEqual(pool.active_sockets, 1) # Still pinned.
|
||||
self.assertEqual(pool.active_sockets, 0) # Unpinned.
|
||||
|
||||
def test_client_can_be_reopened(self):
|
||||
self.client.close()
|
||||
self.db.test.find_one({})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -891,6 +891,10 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
if expect_error:
|
||||
return self.process_error(exc, expect_error)
|
||||
raise
|
||||
else:
|
||||
if expect_error:
|
||||
self.fail('Excepted error %s but "%s" succeeded: %s' % (
|
||||
expect_error, opname, result))
|
||||
|
||||
if expect_result:
|
||||
actual = coerce_result(opname, result)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user