PYTHON-1563 - Avoid import deadlocks in the test suite

This commit is contained in:
Bernie Hackett 2018-06-02 09:49:46 -07:00
parent c63c068611
commit cb85eb02a2
5 changed files with 80 additions and 92 deletions

View File

@ -18,6 +18,7 @@
import os
import socket
import sys
import threading
import time
import warnings
@ -88,22 +89,6 @@ def is_server_resolvable():
socket.setdefaulttimeout(socket_timeout)
def _connect(host, port, **kwargs):
client = pymongo.MongoClient(host, port, **kwargs)
start = time.time()
# Jython takes a long time to connect.
if sys.platform.startswith('java'):
time_limit = 10
else:
time_limit = .5
while not client.nodes:
time.sleep(0.05)
if time.time() - start > time_limit:
return None
return client
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
cmd = SON([('createUser', user)])
# X509 doesn't use a password
@ -189,11 +174,42 @@ class ClientContext(object):
self.server_is_resolvable = is_server_resolvable()
self.default_client_options = {}
self.sessions_enabled = False
self.client = self._connect(host, port)
self.client = None
self.conn_lock = threading.Lock()
if COMPRESSORS:
self.default_client_options["compressors"] = COMPRESSORS
def _connect(self, host, port, **kwargs):
# Jython takes a long time to connect.
if sys.platform.startswith('java'):
timeout_ms = 10000
else:
timeout_ms = 500
if COMPRESSORS:
kwargs["compressors"] = COMPRESSORS
client = pymongo.MongoClient(
host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs)
try:
try:
client.admin.command('isMaster') # Can we connect?
except pymongo.errors.OperationFailure as exc:
# SERVER-32063
self.connection_attempts.append(
'connected client %r, but isMaster failed: %s' % (
client, exc))
else:
self.connection_attempts.append(
'successfully connected client %r' % (client,))
# If connected, then return client with default timeout
return pymongo.MongoClient(host, port, **kwargs)
except pymongo.errors.ConnectionFailure as exc:
self.connection_attempts.append(
'failed to connect client %r: %s' % (client, exc))
return None
def _init_client(self):
self.client = self._connect(host, port)
if HAVE_SSL and not self.client:
# Is MongoDB configured for SSL?
self.client = self._connect(host, port, **_SSL_OPTIONS)
@ -283,33 +299,10 @@ class ClientContext(object):
self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
self.has_ipv6 = self._server_started_with_ipv6()
def _connect(self, host, port, **kwargs):
# Jython takes a long time to connect.
if sys.platform.startswith('java'):
timeout_ms = 10000
else:
timeout_ms = 500
if COMPRESSORS:
kwargs["compressors"] = COMPRESSORS
client = pymongo.MongoClient(
host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs)
try:
try:
client.admin.command('isMaster') # Can we connect?
except pymongo.errors.OperationFailure as exc:
# SERVER-32063
self.connection_attempts.append(
'connected client %r, but isMaster failed: %s' % (
client, exc))
else:
self.connection_attempts.append(
'successfully connected client %r' % (client,))
# If connected, then return client with default timeout
return pymongo.MongoClient(host, port, **kwargs)
except pymongo.errors.ConnectionFailure as exc:
self.connection_attempts.append(
'failed to connect client %r: %s' % (client, exc))
return None
def init(self):
with self.conn_lock:
if not self.client and not self.connection_attempts:
self._init_client()
def connection_attempt_info(self):
return '\n'.join(self.connection_attempts)
@ -400,11 +393,12 @@ class ClientContext(object):
def make_wrapper(f):
@wraps(f)
def wrap(*args, **kwargs):
self.init()
# Always raise SkipTest if we can't connect to MongoDB
if not self.connected:
raise SkipTest(
"Cannot connect to MongoDB on %s" % (self.pair,))
if condition:
if condition():
return f(*args, **kwargs)
raise SkipTest(msg)
return wrap
@ -426,40 +420,40 @@ class ClientContext(object):
def require_connection(self, func):
"""Run a test only if we can connect to MongoDB."""
return self._require(
self.connected,
lambda: True, # _require checks if we're connected
"Cannot connect to MongoDB on %s" % (self.pair,),
func=func)
def require_version_min(self, *ver):
"""Run a test only if the server version is at least ``version``."""
other_version = Version(*ver)
return self._require(self.version >= other_version,
return self._require(lambda: self.version >= other_version,
"Server version must be at least %s"
% str(other_version))
def require_version_max(self, *ver):
"""Run a test only if the server version is at most ``version``."""
other_version = Version(*ver)
return self._require(self.version <= other_version,
return self._require(lambda: self.version <= other_version,
"Server version must be at most %s"
% str(other_version))
def require_auth(self, func):
"""Run a test only if the server is running with auth enabled."""
return self.check_auth_with_sharding(
self._require(self.auth_enabled,
self._require(lambda: self.auth_enabled,
"Authentication is not enabled on the server",
func=func))
def require_no_auth(self, func):
"""Run a test only if the server is running without auth enabled."""
return self._require(not self.auth_enabled,
return self._require(lambda: not self.auth_enabled,
"Authentication must not be enabled on the server",
func=func)
def require_replica_set(self, func):
"""Run a test only if the client is connected to a replica set."""
return self._require(self.is_rs,
return self._require(lambda: self.is_rs,
"Not connected to a replica set",
func=func)
@ -467,51 +461,51 @@ class ClientContext(object):
"""Run a test only if the client is connected to a replica set that has
`count` secondaries.
"""
sec_count = 0 if not self.client else len(self.client.secondaries)
return self._require(sec_count >= count,
"Need %d secondaries, %d available"
% (count, sec_count))
def sec_count():
return 0 if not self.client else len(self.client.secondaries)
return self._require(lambda: sec_count() >= count,
"Not enough secondaries available")
def require_no_replica_set(self, func):
"""Run a test if the client is *not* connected to a replica set."""
return self._require(
not self.is_rs,
lambda: not self.is_rs,
"Connected to a replica set, not a standalone mongod",
func=func)
def require_ipv6(self, func):
"""Run a test only if the client can connect to a server via IPv6."""
return self._require(self.has_ipv6,
return self._require(lambda: self.has_ipv6,
"No IPv6",
func=func)
def require_no_mongos(self, func):
"""Run a test only if the client is not connected to a mongos."""
return self._require(not self.is_mongos,
return self._require(lambda: not self.is_mongos,
"Must be connected to a mongod, not a mongos",
func=func)
def require_mongos(self, func):
"""Run a test only if the client is connected to a mongos."""
return self._require(self.is_mongos,
return self._require(lambda: self.is_mongos,
"Must be connected to a mongos",
func=func)
def require_standalone(self, func):
"""Run a test only if the client is connected to a standalone."""
return self._require(not (self.is_mongos or self.is_rs),
return self._require(lambda: not (self.is_mongos or self.is_rs),
"Must be connected to a standalone",
func=func)
def require_no_standalone(self, func):
"""Run a test only if the client is not connected to a standalone."""
return self._require(self.is_mongos or self.is_rs,
return self._require(lambda: self.is_mongos or self.is_rs,
"Must be connected to a replica set or mongos",
func=func)
def check_auth_with_sharding(self, func):
"""Skip a test when connected to mongos < 2.0 and running with auth."""
condition = not (self.auth_enabled and
condition = lambda: not (self.auth_enabled and
self.is_mongos and self.version < (2,))
return self._require(condition,
"Auth with sharding requires MongoDB >= 2.0.0",
@ -519,44 +513,44 @@ class ClientContext(object):
def require_test_commands(self, func):
"""Run a test only if the server has test commands enabled."""
return self._require(self.test_commands_enabled,
return self._require(lambda: self.test_commands_enabled,
"Test commands must be enabled",
func=func)
def require_ssl(self, func):
"""Run a test only if the client can connect over SSL."""
return self._require(self.ssl,
return self._require(lambda: self.ssl,
"Must be able to connect via SSL",
func=func)
def require_no_ssl(self, func):
"""Run a test only if the client can connect over SSL."""
return self._require(not self.ssl,
return self._require(lambda: not self.ssl,
"Must be able to connect without SSL",
func=func)
def require_ssl_cert_none(self, func):
"""Run a test only if the client can connect with ssl.CERT_NONE."""
return self._require(self.ssl_cert_none,
return self._require(lambda: self.ssl_cert_none,
"Must be able to connect with ssl.CERT_NONE",
func=func)
def require_ssl_certfile(self, func):
"""Run a test only if the client can connect with ssl_certfile."""
return self._require(self.ssl_certfile,
return self._require(lambda: self.ssl_certfile,
"Must be able to connect with ssl_certfile",
func=func)
def require_server_resolvable(self, func):
"""Run a test only if the hostname 'server' is resolvable."""
return self._require(self.server_is_resolvable,
return self._require(lambda: self.server_is_resolvable,
"No hosts entry for 'server'. Cannot validate "
"hostname in the certificate",
func=func)
def require_sessions(self, func):
"""Run a test only if the deployment supports sessions."""
return self._require(self.sessions_enabled,
return self._require(lambda: self.sessions_enabled,
"Sessions not supported",
func=func)

View File

@ -28,8 +28,6 @@ from pymongo.server_description import ServerDescription
from test import client_context
default_host, default_port = client_context.host, client_context.port
class MockPool(Pool):
def __init__(self, client, pair, *args, **kwargs):
@ -39,7 +37,7 @@ class MockPool(Pool):
self.mock_host, self.mock_port = pair
# Actually connect to the default server.
Pool.__init__(self, (default_host, default_port), *args, **kwargs)
Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs)
@contextlib.contextmanager
def get_socket(self, all_credentials, checkout=False):

View File

@ -32,11 +32,6 @@ from test.utils import wait_until
_TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dns')
_SSL_OPTS = client_context.default_client_options.copy()
if client_context.ssl is True:
# Our test certs don't support the SRV hosts used in these tests.
_SSL_OPTS['ssl_match_hostname'] = False
class TestDNS(unittest.TestCase):
pass
@ -78,7 +73,12 @@ def create_test(test_case):
hostname = next(iter(client_context.client.nodes))[0]
# The replica set members must be configured as 'localhost'.
if hostname == 'localhost':
client = MongoClient(uri, **_SSL_OPTS)
copts = client_context.default_client_options.copy()
if client_context.ssl is True:
# Our test certs don't support the SRV hosts used in these tests.
copts['ssl_match_hostname'] = False
client = MongoClient(uri, **copts)
# Force server selection
client.admin.command('ismaster')
wait_until(

View File

@ -455,13 +455,13 @@ def create_tests():
new_test = create_test(scenario_def, test)
new_test = client_context.require_transactions(new_test)
new_test = client_context._require(
not test.get('skipReason'),
lambda: not test.get('skipReason'),
test.get('skipReason'),
new_test)
if 'secondary' in test_name:
new_test = client_context._require(
client_context.has_secondaries,
lambda: client_context.has_secondaries,
'No secondaries',
new_test)

View File

@ -132,6 +132,8 @@ def _connection_string(h, p, authenticate):
def _mongo_client(host, port, authenticate=True, direct=False, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
host = host or client_context.host
port = port or client_context.port
client_options = client_context.default_client_options.copy()
if client_context.replica_set_name and not direct:
client_options['replicaSet'] = client_context.replica_set_name
@ -143,32 +145,27 @@ def _mongo_client(host, port, authenticate=True, direct=False, **kwargs):
return client
def single_client_noauth(
h=client_context.host, p=client_context.port, **kwargs):
def single_client_noauth(h=None, p=None, **kwargs):
"""Make a direct connection. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, direct=True, **kwargs)
def single_client(
h=client_context.host, p=client_context.port, **kwargs):
def single_client(h=None, p=None, **kwargs):
"""Make a direct connection, and authenticate if necessary."""
return _mongo_client(h, p, direct=True, **kwargs)
def rs_client_noauth(
h=client_context.host, p=client_context.port, **kwargs):
def rs_client_noauth(h=None, p=None, **kwargs):
"""Connect to the replica set. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_client(
h=client_context.host, p=client_context.port, **kwargs):
def rs_client(h=None, p=None, **kwargs):
"""Connect to the replica set and authenticate if necessary."""
return _mongo_client(h, p, **kwargs)
def rs_or_single_client_noauth(
h=client_context.host, p=client_context.port, **kwargs):
def rs_or_single_client_noauth(h=None, p=None, **kwargs):
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
@ -176,8 +173,7 @@ def rs_or_single_client_noauth(
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_or_single_client(
h=client_context.host, p=client_context.port, **kwargs):
def rs_or_single_client(h=None, p=None, **kwargs):
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.