diff --git a/test/__init__.py b/test/__init__.py index 3e54523ec..d72232f06 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -18,6 +18,7 @@ import os import socket import sys +import threading import time import warnings @@ -88,22 +89,6 @@ def is_server_resolvable(): socket.setdefaulttimeout(socket_timeout) -def _connect(host, port, **kwargs): - client = pymongo.MongoClient(host, port, **kwargs) - start = time.time() - # Jython takes a long time to connect. - if sys.platform.startswith('java'): - time_limit = 10 - else: - time_limit = .5 - while not client.nodes: - time.sleep(0.05) - if time.time() - start > time_limit: - return None - - return client - - def _create_user(authdb, user, pwd=None, roles=None, **kwargs): cmd = SON([('createUser', user)]) # X509 doesn't use a password @@ -189,11 +174,42 @@ class ClientContext(object): self.server_is_resolvable = is_server_resolvable() self.default_client_options = {} self.sessions_enabled = False - self.client = self._connect(host, port) + self.client = None + self.conn_lock = threading.Lock() if COMPRESSORS: self.default_client_options["compressors"] = COMPRESSORS + def _connect(self, host, port, **kwargs): + # Jython takes a long time to connect. + if sys.platform.startswith('java'): + timeout_ms = 10000 + else: + timeout_ms = 500 + if COMPRESSORS: + kwargs["compressors"] = COMPRESSORS + client = pymongo.MongoClient( + host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) + try: + try: + client.admin.command('isMaster') # Can we connect? + except pymongo.errors.OperationFailure as exc: + # SERVER-32063 + self.connection_attempts.append( + 'connected client %r, but isMaster failed: %s' % ( + client, exc)) + else: + self.connection_attempts.append( + 'successfully connected client %r' % (client,)) + # If connected, then return client with default timeout + return pymongo.MongoClient(host, port, **kwargs) + except pymongo.errors.ConnectionFailure as exc: + self.connection_attempts.append( + 'failed to connect client %r: %s' % (client, exc)) + return None + + def _init_client(self): + self.client = self._connect(host, port) if HAVE_SSL and not self.client: # Is MongoDB configured for SSL? self.client = self._connect(host, port, **_SSL_OPTIONS) @@ -283,33 +299,10 @@ class ClientContext(object): self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid') self.has_ipv6 = self._server_started_with_ipv6() - def _connect(self, host, port, **kwargs): - # Jython takes a long time to connect. - if sys.platform.startswith('java'): - timeout_ms = 10000 - else: - timeout_ms = 500 - if COMPRESSORS: - kwargs["compressors"] = COMPRESSORS - client = pymongo.MongoClient( - host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) - try: - try: - client.admin.command('isMaster') # Can we connect? - except pymongo.errors.OperationFailure as exc: - # SERVER-32063 - self.connection_attempts.append( - 'connected client %r, but isMaster failed: %s' % ( - client, exc)) - else: - self.connection_attempts.append( - 'successfully connected client %r' % (client,)) - # If connected, then return client with default timeout - return pymongo.MongoClient(host, port, **kwargs) - except pymongo.errors.ConnectionFailure as exc: - self.connection_attempts.append( - 'failed to connect client %r: %s' % (client, exc)) - return None + def init(self): + with self.conn_lock: + if not self.client and not self.connection_attempts: + self._init_client() def connection_attempt_info(self): return '\n'.join(self.connection_attempts) @@ -400,11 +393,12 @@ class ClientContext(object): def make_wrapper(f): @wraps(f) def wrap(*args, **kwargs): + self.init() # Always raise SkipTest if we can't connect to MongoDB if not self.connected: raise SkipTest( "Cannot connect to MongoDB on %s" % (self.pair,)) - if condition: + if condition(): return f(*args, **kwargs) raise SkipTest(msg) return wrap @@ -426,40 +420,40 @@ class ClientContext(object): def require_connection(self, func): """Run a test only if we can connect to MongoDB.""" return self._require( - self.connected, + lambda: True, # _require checks if we're connected "Cannot connect to MongoDB on %s" % (self.pair,), func=func) def require_version_min(self, *ver): """Run a test only if the server version is at least ``version``.""" other_version = Version(*ver) - return self._require(self.version >= other_version, + return self._require(lambda: self.version >= other_version, "Server version must be at least %s" % str(other_version)) def require_version_max(self, *ver): """Run a test only if the server version is at most ``version``.""" other_version = Version(*ver) - return self._require(self.version <= other_version, + return self._require(lambda: self.version <= other_version, "Server version must be at most %s" % str(other_version)) def require_auth(self, func): """Run a test only if the server is running with auth enabled.""" return self.check_auth_with_sharding( - self._require(self.auth_enabled, + self._require(lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func)) def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" - return self._require(not self.auth_enabled, + return self._require(lambda: not self.auth_enabled, "Authentication must not be enabled on the server", func=func) def require_replica_set(self, func): """Run a test only if the client is connected to a replica set.""" - return self._require(self.is_rs, + return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) @@ -467,51 +461,51 @@ class ClientContext(object): """Run a test only if the client is connected to a replica set that has `count` secondaries. """ - sec_count = 0 if not self.client else len(self.client.secondaries) - return self._require(sec_count >= count, - "Need %d secondaries, %d available" - % (count, sec_count)) + def sec_count(): + return 0 if not self.client else len(self.client.secondaries) + return self._require(lambda: sec_count() >= count, + "Not enough secondaries available") def require_no_replica_set(self, func): """Run a test if the client is *not* connected to a replica set.""" return self._require( - not self.is_rs, + lambda: not self.is_rs, "Connected to a replica set, not a standalone mongod", func=func) def require_ipv6(self, func): """Run a test only if the client can connect to a server via IPv6.""" - return self._require(self.has_ipv6, + return self._require(lambda: self.has_ipv6, "No IPv6", func=func) def require_no_mongos(self, func): """Run a test only if the client is not connected to a mongos.""" - return self._require(not self.is_mongos, + return self._require(lambda: not self.is_mongos, "Must be connected to a mongod, not a mongos", func=func) def require_mongos(self, func): """Run a test only if the client is connected to a mongos.""" - return self._require(self.is_mongos, + return self._require(lambda: self.is_mongos, "Must be connected to a mongos", func=func) def require_standalone(self, func): """Run a test only if the client is connected to a standalone.""" - return self._require(not (self.is_mongos or self.is_rs), + return self._require(lambda: not (self.is_mongos or self.is_rs), "Must be connected to a standalone", func=func) def require_no_standalone(self, func): """Run a test only if the client is not connected to a standalone.""" - return self._require(self.is_mongos or self.is_rs, + return self._require(lambda: self.is_mongos or self.is_rs, "Must be connected to a replica set or mongos", func=func) def check_auth_with_sharding(self, func): """Skip a test when connected to mongos < 2.0 and running with auth.""" - condition = not (self.auth_enabled and + condition = lambda: not (self.auth_enabled and self.is_mongos and self.version < (2,)) return self._require(condition, "Auth with sharding requires MongoDB >= 2.0.0", @@ -519,44 +513,44 @@ class ClientContext(object): def require_test_commands(self, func): """Run a test only if the server has test commands enabled.""" - return self._require(self.test_commands_enabled, + return self._require(lambda: self.test_commands_enabled, "Test commands must be enabled", func=func) def require_ssl(self, func): """Run a test only if the client can connect over SSL.""" - return self._require(self.ssl, + return self._require(lambda: self.ssl, "Must be able to connect via SSL", func=func) def require_no_ssl(self, func): """Run a test only if the client can connect over SSL.""" - return self._require(not self.ssl, + return self._require(lambda: not self.ssl, "Must be able to connect without SSL", func=func) def require_ssl_cert_none(self, func): """Run a test only if the client can connect with ssl.CERT_NONE.""" - return self._require(self.ssl_cert_none, + return self._require(lambda: self.ssl_cert_none, "Must be able to connect with ssl.CERT_NONE", func=func) def require_ssl_certfile(self, func): """Run a test only if the client can connect with ssl_certfile.""" - return self._require(self.ssl_certfile, + return self._require(lambda: self.ssl_certfile, "Must be able to connect with ssl_certfile", func=func) def require_server_resolvable(self, func): """Run a test only if the hostname 'server' is resolvable.""" - return self._require(self.server_is_resolvable, + return self._require(lambda: self.server_is_resolvable, "No hosts entry for 'server'. Cannot validate " "hostname in the certificate", func=func) def require_sessions(self, func): """Run a test only if the deployment supports sessions.""" - return self._require(self.sessions_enabled, + return self._require(lambda: self.sessions_enabled, "Sessions not supported", func=func) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 63ee695c3..7ed2e8a30 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -28,8 +28,6 @@ from pymongo.server_description import ServerDescription from test import client_context -default_host, default_port = client_context.host, client_context.port - class MockPool(Pool): def __init__(self, client, pair, *args, **kwargs): @@ -39,7 +37,7 @@ class MockPool(Pool): self.mock_host, self.mock_port = pair # Actually connect to the default server. - Pool.__init__(self, (default_host, default_port), *args, **kwargs) + Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): diff --git a/test/test_dns.py b/test/test_dns.py index 2fec78d0f..dc659be0a 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -32,11 +32,6 @@ from test.utils import wait_until _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'dns') -_SSL_OPTS = client_context.default_client_options.copy() -if client_context.ssl is True: - # Our test certs don't support the SRV hosts used in these tests. - _SSL_OPTS['ssl_match_hostname'] = False - class TestDNS(unittest.TestCase): pass @@ -78,7 +73,12 @@ def create_test(test_case): hostname = next(iter(client_context.client.nodes))[0] # The replica set members must be configured as 'localhost'. if hostname == 'localhost': - client = MongoClient(uri, **_SSL_OPTS) + copts = client_context.default_client_options.copy() + if client_context.ssl is True: + # Our test certs don't support the SRV hosts used in these tests. + copts['ssl_match_hostname'] = False + + client = MongoClient(uri, **copts) # Force server selection client.admin.command('ismaster') wait_until( diff --git a/test/test_transactions.py b/test/test_transactions.py index cb23f8370..e0e3a6148 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -455,13 +455,13 @@ def create_tests(): new_test = create_test(scenario_def, test) new_test = client_context.require_transactions(new_test) new_test = client_context._require( - not test.get('skipReason'), + lambda: not test.get('skipReason'), test.get('skipReason'), new_test) if 'secondary' in test_name: new_test = client_context._require( - client_context.has_secondaries, + lambda: client_context.has_secondaries, 'No secondaries', new_test) diff --git a/test/utils.py b/test/utils.py index 14e8a0186..b50b1e509 100644 --- a/test/utils.py +++ b/test/utils.py @@ -132,6 +132,8 @@ def _connection_string(h, p, authenticate): def _mongo_client(host, port, authenticate=True, direct=False, **kwargs): """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port client_options = client_context.default_client_options.copy() if client_context.replica_set_name and not direct: client_options['replicaSet'] = client_context.replica_set_name @@ -143,32 +145,27 @@ def _mongo_client(host, port, authenticate=True, direct=False, **kwargs): return client -def single_client_noauth( - h=client_context.host, p=client_context.port, **kwargs): +def single_client_noauth(h=None, p=None, **kwargs): """Make a direct connection. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, direct=True, **kwargs) -def single_client( - h=client_context.host, p=client_context.port, **kwargs): +def single_client(h=None, p=None, **kwargs): """Make a direct connection, and authenticate if necessary.""" return _mongo_client(h, p, direct=True, **kwargs) -def rs_client_noauth( - h=client_context.host, p=client_context.port, **kwargs): +def rs_client_noauth(h=None, p=None, **kwargs): """Connect to the replica set. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, **kwargs) -def rs_client( - h=client_context.host, p=client_context.port, **kwargs): +def rs_client(h=None, p=None, **kwargs): """Connect to the replica set and authenticate if necessary.""" return _mongo_client(h, p, **kwargs) -def rs_or_single_client_noauth( - h=client_context.host, p=client_context.port, **kwargs): +def rs_or_single_client_noauth(h=None, p=None, **kwargs): """Connect to the replica set if there is one, otherwise the standalone. Like rs_or_single_client, but does not authenticate. @@ -176,8 +173,7 @@ def rs_or_single_client_noauth( return _mongo_client(h, p, authenticate=False, **kwargs) -def rs_or_single_client( - h=client_context.host, p=client_context.port, **kwargs): +def rs_or_single_client(h=None, p=None, **kwargs): """Connect to the replica set if there is one, otherwise the standalone. Authenticates if necessary.