diff --git a/test/__init__.py b/test/__init__.py index f75894a35..43b53e9b8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -35,7 +35,7 @@ import pymongo.errors from bson.py3compat import _unicode from pymongo import common -from pymongo.ssl_support import HAVE_SSL +from pymongo.ssl_support import HAVE_SSL, validate_cert_reqs from test.version import Version if HAVE_SSL: @@ -53,7 +53,18 @@ db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password")) CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'certificates') -CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem') +CLIENT_PEM = os.environ.get('CLIENT_PEM', + os.path.join(CERT_PATH, 'client.pem')) +CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem')) +CERT_REQS = validate_cert_reqs('CERT_REQS', os.environ.get('CERT_REQS')) + +_SSL_OPTIONS = dict(ssl=True) +if CLIENT_PEM: + _SSL_OPTIONS['ssl_certfile'] = CLIENT_PEM +if CA_PEM: + _SSL_OPTIONS['ssl_ca_certs'] = CA_PEM +if CERT_REQS is not None: + _SSL_OPTIONS['ssl_cert_reqs'] = CERT_REQS def is_server_resolvable(): @@ -70,6 +81,17 @@ def is_server_resolvable(): socket.setdefaulttimeout(socket_timeout) +def _connect(host, port, **kwargs): + try: + client = pymongo.MongoClient( + host, port, serverSelectionTimeoutMS=100, **kwargs) + client.admin.command('ismaster') # Can we connect? + # If connected, then return client with default timeout + return pymongo.MongoClient(host, port, **kwargs) + except pymongo.errors.ConnectionFailure: + return None + + class client_knobs(object): def __init__( self, @@ -129,7 +151,6 @@ class ClientContext(object): self.w = None self.nodes = set() self.replica_set_name = None - self.rs_client = None self.cmd_line = None self.version = Version(-1) # Needs to be comparable with Version self.auth_enabled = False @@ -137,40 +158,22 @@ class ClientContext(object): self.is_mongos = False self.is_rs = False self.has_ipv6 = False + self.ssl = False self.ssl_cert_none = False self.ssl_certfile = False self.server_is_resolvable = is_server_resolvable() - - self.client = self.rs_or_standalone_client = None - - def connect(**kwargs): - try: - client = pymongo.MongoClient( - self.host, - self.port, - serverSelectionTimeoutMS=100, - **kwargs) - client.admin.command('ismaster') # Can we connect? - # If connected, then return client with default timeout - return pymongo.MongoClient(self.host, self.port, **kwargs) - except pymongo.errors.ConnectionFailure: - return None - - self.client = connect() + self.ssl_client_options = {} + self.client = _connect(self.host, self.port) if HAVE_SSL and not self.client: # Is MongoDB configured for SSL? - self.client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE) + self.client = _connect(self.host, self.port, **_SSL_OPTIONS) if self.client: - self.ssl_cert_none = True - - # Can client connect with certfile? - client = connect(ssl=True, - ssl_cert_reqs=ssl.CERT_NONE, - ssl_certfile=CLIENT_PEM,) - if client: + self.ssl = True + self.ssl_client_options = _SSL_OPTIONS self.ssl_certfile = True - self.client = client + if _SSL_OPTIONS.get('ssl_cert_reqs') == ssl.CERT_NONE: + self.ssl_cert_none = True if self.client: self.connected = True @@ -178,16 +181,16 @@ class ClientContext(object): self.w = len(self.ismaster.get("hosts", [])) or 1 self.nodes = set([(self.host, self.port)]) self.replica_set_name = self.ismaster.get('setName', '') - self.rs_client = None self.version = Version.from_client(self.client) if self.replica_set_name: self.is_rs = True - self.rs_client = pymongo.MongoClient( - self.ismaster['primary'], replicaSet=self.replica_set_name) + self.client = pymongo.MongoClient( + self.ismaster['primary'], + replicaSet=self.replica_set_name, + **self.ssl_client_options) # Force connection - self.rs_client.admin.command('ismaster') - self.host, self.port = self.rs_client.primary - self.client = connect() + self.client.admin.command('ismaster') + self.host, self.port = self.client.primary nodes = [partition_node(node.lower()) for node in self.ismaster.get('hosts', [])] @@ -197,8 +200,6 @@ class ClientContext(object): for node in self.ismaster.get('arbiters', [])]) self.nodes = set(nodes) - self.rs_or_standalone_client = self.rs_client or self.client - try: self.cmd_line = self.client.admin.command('getCmdLineOpts') except pymongo.errors.OperationFailure as e: @@ -221,9 +222,6 @@ class ClientContext(object): self.client.admin.add_user(db_user, db_pwd, **roles) self.client.admin.authenticate(db_user, db_pwd) - if self.rs_client: - self.rs_client.admin.authenticate(db_user, db_pwd) - # May not have this if OperationFailure was raised earlier. self.cmd_line = self.client.admin.command('getCmdLineOpts') @@ -233,12 +231,17 @@ class ClientContext(object): params = self.cmd_line['parsed'].get('setParameter', []) if 'enableTestCommands=1' in params: self.test_commands_enabled = True + else: + params = self.cmd_line['parsed'].get('setParameter', {}) + if params.get('enableTestCommands') == '1': + self.test_commands_enabled = True self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid') self.has_ipv6 = self._server_started_with_ipv6() - # Do this after we connect so we know who the primary is. - self.pair = "%s:%d" % (self.host, self.port) + @property + def pair(self): + return "%s:%d" % (self.host, self.port) def _check_user_provided(self): try: @@ -391,13 +394,13 @@ class ClientContext(object): def require_ssl(self, func): """Run a test only if the client can connect over SSL.""" - return self._require(self.ssl_cert_none or self.ssl_certfile, + return self._require(self.ssl, "Must be able to connect via SSL", func=func) def require_no_ssl(self, func): """Run a test only if the client can connect over SSL.""" - return self._require(not (self.ssl_cert_none or self.ssl_certfile), + return self._require(not self.ssl, "Must be able to connect without SSL", func=func) @@ -430,7 +433,7 @@ class IntegrationTest(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): - cls.client = client_context.rs_or_standalone_client + cls.client = client_context.client cls.db = cls.client.pymongo_test diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index a7685d88b..e5b4af1e4 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -23,7 +23,7 @@ from pymongo import MongoClient from pymongo.errors import AutoReconnect, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor -from pymongo.pool import Pool, PoolOptions +from pymongo.pool import Pool from pymongo.server_description import ServerDescription from test import client_context @@ -39,9 +39,7 @@ class MockPool(Pool): self.mock_host, self.mock_port = pair # Actually connect to the default server. - Pool.__init__(self, - (default_host, default_port), - PoolOptions(connect_timeout=20)) + Pool.__init__(self, (default_host, default_port), *args, **kwargs) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): @@ -125,7 +123,10 @@ class MockClient(MongoClient): kwargs['_pool_class'] = partial(MockPool, self) kwargs['_monitor_class'] = partial(MockMonitor, self) - super(MockClient, self).__init__(*args, **kwargs) + client_options = client_context.ssl_client_options.copy() + client_options.update(kwargs) + + super(MockClient, self).__init__(*args, **client_options) def kill_host(self, host): """Host is like 'a:1'.""" diff --git a/test/test_auth.py b/test/test_auth.py index fdb79772d..ab5fac0f5 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -31,7 +31,7 @@ from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple from pymongo.errors import OperationFailure from pymongo.read_preferences import ReadPreference from test import client_context, SkipTest, unittest, Version -from test.utils import delay +from test.utils import delay, rs_or_single_client_noauth, single_client_noauth # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get('GSSAPI_HOST') @@ -305,14 +305,11 @@ class TestSASLPlain(unittest.TestCase): self.assertRaises(OperationFailure, bad_pwd.admin.command, 'ismaster') - class TestSCRAMSHA1(unittest.TestCase): @client_context.require_auth @client_context.require_version_min(2, 7, 2) def setUp(self): - self.replica_set_name = client_context.replica_set_name - # Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line. if client_context.version < Version(2, 7, 7): cmd_line = client_context.cmd_line @@ -321,7 +318,7 @@ class TestSCRAMSHA1(unittest.TestCase): {}).get('authenticationMechanisms', ''): raise SkipTest('SCRAM-SHA-1 mechanism not enabled') - client = client_context.rs_or_standalone_client + client = client_context.client client.pymongo_test.add_user( 'user', 'pass', roles=['userAdmin', 'readWrite'], @@ -329,42 +326,36 @@ class TestSCRAMSHA1(unittest.TestCase): def test_scram_sha1(self): host, port = client_context.host, client_context.port - client = MongoClient(host, port) + client = rs_or_single_client_noauth() self.assertTrue(client.pymongo_test.authenticate( 'user', 'pass', mechanism='SCRAM-SHA-1')) client.pymongo_test.command('dbstats') - client = MongoClient('mongodb://user:pass@%s:%d/pymongo_test' - '?authMechanism=SCRAM-SHA-1' % (host, port)) + client = rs_or_single_client_noauth( + 'mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' + % (host, port)) client.pymongo_test.command('dbstats') - if self.replica_set_name: - client = MongoClient(host, port, - replicaSet='%s' % (self.replica_set_name,)) - self.assertTrue(client.pymongo_test.authenticate( - 'user', 'pass', mechanism='SCRAM-SHA-1')) - client.pymongo_test.command('dbstats') - + if client_context.is_rs: uri = ('mongodb://user:pass' '@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' - '&replicaSet=%s' % (host, port, self.replica_set_name)) - client = MongoClient(uri) + '&replicaSet=%s' % (host, port, + client_context.replica_set_name)) + client = single_client_noauth(uri) client.pymongo_test.command('dbstats') db = client.get_database( 'pymongo_test', read_preference=ReadPreference.SECONDARY) db.command('dbstats') def tearDown(self): - client_context.rs_or_standalone_client.pymongo_test.remove_user('user') + client_context.client.pymongo_test.remove_user('user') class TestAuthURIOptions(unittest.TestCase): @client_context.require_auth def setUp(self): - client = MongoClient(client_context.host, client_context.port) - response = client.admin.command('ismaster') - self.replica_set_name = str(response.get('setName', '')) + client = rs_or_single_client_noauth() client_context.client.admin.add_user('admin', 'pass', roles=['userAdminAnyDatabase', 'dbAdminAnyDatabase', @@ -374,12 +365,10 @@ class TestAuthURIOptions(unittest.TestCase): client.pymongo_test.add_user('user', 'pass', roles=['userAdmin', 'readWrite']) - if self.replica_set_name: - # GLE requires authentication. - client.admin.authenticate('admin', 'pass') + if client_context.is_rs: # Make sure the admin user is replicated after calling add_user # above. This avoids a race in the replica set tests below. - client.admin.command('getLastError', w=len(response['hosts'])) + client.admin.command('getLastError', w=client_context.w) self.client = client def tearDown(self): @@ -393,14 +382,14 @@ class TestAuthURIOptions(unittest.TestCase): def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port - client = MongoClient( + client = rs_or_single_client_noauth( 'mongodb://admin:pass@%s:%d' % (host, port)) self.assertTrue(client.admin.command('dbstats')) - if self.replica_set_name: - uri = ('mongodb://admin:pass' - '@%s:%d/?replicaSet=%s' % (host, port, self.replica_set_name)) - client = MongoClient(uri) + if client_context.is_rs: + uri = ('mongodb://admin:pass@%s:%d/?replicaSet=%s' % ( + host, port, client_context.replica_set_name)) + client = single_client_noauth(uri) self.assertTrue(client.admin.command('dbstats')) db = client.get_database( 'admin', read_preference=ReadPreference.SECONDARY) @@ -408,14 +397,14 @@ class TestAuthURIOptions(unittest.TestCase): # Test explicit database uri = 'mongodb://user:pass@%s:%d/pymongo_test' % (host, port) - client = MongoClient(uri) + client = rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) - if self.replica_set_name: - uri = ('mongodb://user:pass@%s:%d' - '/pymongo_test?replicaSet=%s' % (host, port, self.replica_set_name)) - client = MongoClient(uri) + if client_context.is_rs: + uri = ('mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s' % ( + host, port, client_context.replica_set_name)) + client = single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) @@ -426,15 +415,16 @@ class TestAuthURIOptions(unittest.TestCase): # Test authSource uri = ('mongodb://user:pass@%s:%d' '/pymongo_test2?authSource=pymongo_test' % (host, port)) - client = MongoClient(uri) + client = rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) - if self.replica_set_name: + if client_context.is_rs: uri = ('mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=' - '%s;authSource=pymongo_test' % (host, port, self.replica_set_name)) - client = MongoClient(uri) + '%s;authSource=pymongo_test' % ( + host, port, client_context.replica_set_name)) + client = single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) @@ -449,7 +439,7 @@ class TestDelegatedAuth(unittest.TestCase): @client_context.require_version_max(2, 5, 3) @client_context.require_version_min(2, 4, 0) def setUp(self): - self.client = client_context.rs_or_standalone_client + self.client = client_context.client def tearDown(self): self.client.pymongo_test.remove_user('user') @@ -465,7 +455,7 @@ class TestDelegatedAuth(unittest.TestCase): self.client.pymongo_test2.add_user('user', userSource='pymongo_test', roles=['read']) - auth_c = MongoClient(client_context.host, client_context.port) + auth_c = rs_or_single_client_noauth() self.assertRaises(OperationFailure, auth_c.pymongo_test2.foo.find_one) # Auth must occur on the db where the user is defined. diff --git a/test/test_bulk.py b/test/test_bulk.py index a1782599f..21adc2be7 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -21,7 +21,6 @@ sys.path[0:0] = [""] from bson import InvalidDocument, SON from bson.objectid import ObjectId from bson.py3compat import string_type -from pymongo import MongoClient from pymongo.operations import * from pymongo.common import partition_node from pymongo.errors import (BulkWriteError, @@ -33,7 +32,11 @@ from test import (client_context, unittest, IntegrationTest, SkipTest) -from test.utils import oid_generated_on_client, remove_all_users, wait_until +from test.utils import (oid_generated_on_client, + remove_all_users, + rs_or_single_client_noauth, + single_client, + wait_until) class BulkTestBase(IntegrationTest): @@ -137,7 +140,6 @@ class TestBulk(BulkTestBase): bulk.find({}) @client_context.require_version_min(3, 1, 9, -1) - @client_context.require_no_auth def test_bypass_document_validation_bulk_op(self): # Test insert @@ -971,7 +973,7 @@ class TestBulkWriteConcern(BulkTestBase): if cls.w > 1: for member in client_context.ismaster['hosts']: if member != client_context.ismaster['primary']: - cls.secondary = MongoClient(*partition_node(member)) + cls.secondary = single_client(*partition_node(member)) break # We tested wtimeout errors by specifying a write concern greater than @@ -1243,7 +1245,7 @@ class TestBulkAuthorization(BulkTestBase): def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = MongoClient(client_context.host, client_context.port) + cli = rs_or_single_client_noauth() db = cli.pymongo_test coll = db.test db.authenticate('readonly', 'pw') @@ -1254,7 +1256,7 @@ class TestBulkAuthorization(BulkTestBase): def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = MongoClient(client_context.host, client_context.port) + cli = rs_or_single_client_noauth() db = cli.pymongo_test coll = db.test db.authenticate('noremove', 'pw') diff --git a/test/test_client.py b/test/test_client.py index 3f3359f13..2aca604df 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -38,9 +38,9 @@ from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure, InvalidName, - OperationFailure, - NetworkTimeout, InvalidURI, + NetworkTimeout, + OperationFailure, WriteConcernError) from pymongo.monitoring import (ServerHeartbeatListener, ServerHeartbeatStartedEvent) @@ -68,6 +68,7 @@ from test.utils import (assertRaisesExactly, one, connected, wait_until, + rs_client, rs_or_single_client, rs_or_single_client_noauth, single_client, @@ -81,11 +82,8 @@ class ClientUnitTest(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): - cls.client = MongoClient( - client_context.host, - client_context.port, - connect=False, - serverSelectionTimeoutMS=100) + cls.client = rs_or_single_client(connect=False, + serverSelectionTimeoutMS=100) def test_keyword_arg_defaults(self): client = MongoClient(socketTimeoutMS=None, @@ -173,23 +171,23 @@ class ClientUnitTest(unittest.TestCase): self.assertRaises(TypeError, iterate) def test_get_default_database(self): - c = MongoClient( - "mongodb://%s:%d/foo" % (client_context.host, client_context.port), - connect=False) + c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, + client_context.port), + connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_get_default_database_error(self): # URI with no database. - c = MongoClient( - "mongodb://%s:%d/" % (client_context.host, client_context.port), - connect=False) + c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, + client_context.port), + connect=False) self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( client_context.host, client_context.port) - c = MongoClient(uri, connect=False) + c = rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_primary_read_pref_with_tags(self): @@ -265,10 +263,9 @@ class ClientUnitTest(unittest.TestCase): class TestClient(IntegrationTest): def test_max_idle_time_reaper(self): - host, port = client_context.host, client_context.port with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove sockets when maxIdleTimeMS not set - client = MongoClient(host, port) + client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass @@ -276,7 +273,8 @@ class TestClient(IntegrationTest): self.assertTrue(sock_info in server._pool.sockets) # Assert reaper removes idle socket and replaces it with a new one - client = MongoClient(host, port, maxIdleTimeMS=.5, minPoolSize=1) + client = rs_or_single_client(maxIdleTimeMS=.5, + minPoolSize=1) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass @@ -287,7 +285,7 @@ class TestClient(IntegrationTest): "reaper replaces stale socket with new one") # Assert reaper has removed idle socket and NOT replaced it - client = MongoClient(host, port, maxIdleTimeMS=.5) + client = rs_or_single_client(maxIdleTimeMS=.5) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}): pass @@ -296,14 +294,13 @@ class TestClient(IntegrationTest): "stale socket reaped and new one NOT added to the pool") def test_min_pool_size(self): - host, port = client_context.host, client_context.port with client_knobs(kill_cursor_frequency=.1): - client = MongoClient(host, port) + client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) self.assertEqual(0, len(server._pool.sockets)) # Assert that pool started up at minPoolSize - client = MongoClient(host, port, minPoolSize=10) + client = rs_or_single_client(minPoolSize=10) server = client._get_topology().select_server(any_server_selector) wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") @@ -316,10 +313,9 @@ class TestClient(IntegrationTest): self.assertFalse(sock_info in server._pool.sockets) def test_max_idle_time_checkout(self): - host, port = client_context.host, client_context.port # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = MongoClient(host, port, maxIdleTimeMS=.5) + client = rs_or_single_client(maxIdleTimeMS=.5) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass @@ -333,7 +329,7 @@ class TestClient(IntegrationTest): self.assertTrue(new_sock_info in server._pool.sockets) # Test that sockets are reused if maxIdleTimeMS is not set. - client = MongoClient(host, port) + client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass @@ -344,22 +340,27 @@ class TestClient(IntegrationTest): self.assertEqual(1, len(server._pool.sockets)) def test_constants(self): + """This test uses MongoClient explicitly to make sure that host and + port are not overloaded. + """ host, port = client_context.host, client_context.port # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - connected(MongoClient(serverSelectionTimeoutMS=10)) + connected(MongoClient(serverSelectionTimeoutMS=10, + **client_context.ssl_client_options)) # Override the defaults. No error. - connected(MongoClient(host, port)) + connected(MongoClient(host, port, + **client_context.ssl_client_options)) # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. - connected(MongoClient()) + connected(MongoClient(**client_context.ssl_client_options)) def test_init_disconnected(self): host, port = client_context.host, client_context.port @@ -401,10 +402,10 @@ class TestClient(IntegrationTest): def test_equality(self): c = connected(rs_or_single_client()) - self.assertEqual(client_context.rs_or_standalone_client, c) + self.assertEqual(client_context.client, c) # Explicitly test inequality - self.assertFalse(client_context.rs_or_standalone_client != c) + self.assertFalse(client_context.client != c) def test_host_w_port(self): with self.assertRaises(ValueError): @@ -543,12 +544,12 @@ class TestClient(IntegrationTest): "mongodb://user:pass@%s:%d/pymongo_test" % (host, port))) # Auth with lazy connection. - rs_or_single_client( + rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False).pymongo_test.test.find_one() # Wrong password. - bad_client = rs_or_single_client( + bad_client = rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False) @@ -585,7 +586,7 @@ class TestClient(IntegrationTest): @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): - lazy_client = rs_or_single_client( + lazy_client = rs_or_single_client_noauth( "mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), connect=False) @@ -608,7 +609,7 @@ class TestClient(IntegrationTest): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = MongoClient(uri) + client = rs_or_single_client(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.database_names() self.assertTrue("pymongo_test" in dbs) @@ -973,11 +974,8 @@ class TestClient(IntegrationTest): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - client = MongoClient( - client_context.host, - client_context.port, - replicaSet='anything', - serverSelectionTimeoutMS=100) + client = single_client(replicaSet='anything', + serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): client.test.test.find_one() @@ -988,12 +986,8 @@ class TestClient(IntegrationTest): # the topology before the getMore message is sent. Test that # MongoClient._send_message_with_response handles the error. with self.assertRaises(AutoReconnect): - client = MongoClient( - client_context.host, - client_context.port, - connect=False, - serverSelectionTimeoutMS=100, - replicaSet=client_context.replica_set_name) + client = rs_client(connect=False, + serverSelectionTimeoutMS=100) client._send_message_with_response( operation=message._GetMore('pymongo_test', 'collection', 101, 1234, client.codec_options), diff --git a/test/test_collation.py b/test/test_collation.py index 4b1e20235..7ce5dbc44 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -91,6 +91,7 @@ def raisesConfigurationErrorForOldMongoDB(func): class TestCollation(unittest.TestCase): @classmethod + @client_context.require_connection def setUpClass(cls): cls.listener = EventListener() cls.saved_listeners = monitoring._LISTENERS diff --git a/test/test_collection.py b/test/test_collection.py index 0257ddeb8..2424b8ed6 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -36,7 +36,7 @@ from bson.py3compat import itervalues from bson.son import SON from pymongo import (ASCENDING, DESCENDING, GEO2D, GEOHAYSTACK, GEOSPHERE, HASHED, TEXT) -from pymongo import MongoClient, monitoring +from pymongo import monitoring from pymongo.bulk import BulkWriteError from pymongo.collection import Collection, ReturnDocument from pymongo.command_cursor import CommandCursor @@ -50,6 +50,7 @@ from pymongo.errors import (DocumentTooLarge, OperationFailure, WriteConcernError) from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command +from pymongo.mongo_client import MongoClient from pymongo.operations import * from pymongo.read_preferences import ReadPreference from pymongo.results import (InsertOneResult, @@ -66,13 +67,12 @@ from test import client_context, unittest class TestCollectionNoConnect(unittest.TestCase): + """Test Collection features on a client that does not connect. + """ @classmethod - @client_context.require_connection def setUpClass(cls): - client = MongoClient( - client_context.host, client_context.port, connect=False) - cls.db = client.pymongo_test + cls.db = MongoClient(connect=False).pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -92,12 +92,6 @@ class TestCollectionNoConnect(unittest.TestCase): self.assertRaises(InvalidName, make_col, self.db.test, "tes..t") self.assertRaises(InvalidName, make_col, self.db.test, "tes\x00t") - self.assertTrue(isinstance(self.db.test, Collection)) - self.assertEqual(self.db.test, self.db["test"]) - self.assertEqual(self.db.test, Collection(self.db, "test")) - self.assertEqual(self.db.test.mike, self.db["test.mike"]) - self.assertEqual(self.db.test["mike"], self.db["test.mike"]) - def test_getattr(self): coll = self.db.test self.assertTrue(isinstance(coll['_does_not_exist'], Collection)) @@ -138,10 +132,17 @@ class TestCollection(IntegrationTest): else: yield self.db.test + def test_equality(self): + self.assertTrue(isinstance(self.db.test, Collection)) + self.assertEqual(self.db.test, self.db["test"]) + self.assertEqual(self.db.test, Collection(self.db, "test")) + self.assertEqual(self.db.test.mike, self.db["test.mike"]) + self.assertEqual(self.db.test["mike"], self.db["test.mike"]) + @client_context.require_version_min(3, 3, 9) def test_create(self): # No Exception. - db = client_context.rs_or_standalone_client.pymongo_test + db = client_context.client.pymongo_test db.create_test_no_wc.drop() Collection(db, name='create_test_no_wc', create=True) with self.assertRaises(OperationFailure): @@ -810,7 +811,6 @@ class TestCollection(IntegrationTest): DocumentTooLarge, coll.delete_one, {'data': large}) @client_context.require_version_min(3, 1, 9, -1) - @client_context.require_no_auth def test_insert_bypass_document_validation(self): db = self.db db.test.drop() @@ -860,7 +860,6 @@ class TestCollection(IntegrationTest): bypass_document_validation=True) @client_context.require_version_min(3, 1, 9, -1) - @client_context.require_no_auth def test_replace_bypass_document_validation(self): db = self.db db.test.drop() @@ -901,7 +900,6 @@ class TestCollection(IntegrationTest): {"x": 1}, bypass_document_validation=True) @client_context.require_version_min(3, 1, 9, -1) - @client_context.require_no_auth def test_update_bypass_document_validation(self): db = self.db db.test.drop() @@ -980,7 +978,6 @@ class TestCollection(IntegrationTest): {"$inc": {"x": 1}}, bypass_document_validation=True) @client_context.require_version_min(3, 1, 9, -1) - @client_context.require_no_auth def test_bypass_document_validation_bulk_write(self): db = self.db db.test.drop() diff --git a/test/test_common.py b/test/test_common.py index 2d1811b48..6153d11d1 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -23,7 +23,6 @@ from bson.binary import UUIDLegacy, PYTHON_LEGACY, STANDARD from bson.code import Code from bson.codec_options import CodecOptions from bson.objectid import ObjectId -from pymongo.mongo_client import MongoClient from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest @@ -165,10 +164,10 @@ class TestCommon(IntegrationTest): {"count": 0}, reduce)) def test_write_concern(self): - c = MongoClient(connect=False) + c = rs_or_single_client(connect=False) self.assertEqual(WriteConcern(), c.write_concern) - c = MongoClient(connect=False, w=2, wtimeout=1000) + c = rs_or_single_client(connect=False, w=2, wtimeout=1000) wc = WriteConcern(w=2, wtimeout=1000) self.assertEqual(wc, c.write_concern) @@ -199,24 +198,22 @@ class TestCommon(IntegrationTest): self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = MongoClient("mongodb://%s/" % (pair,), - replicaSet=client_context.replica_set_name) + m = rs_or_single_client("mongodb://%s/" % (pair,), + replicaSet=client_context.replica_set_name) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) - m = MongoClient("mongodb://%s/?w=0" % (pair,), - replicaSet=client_context.replica_set_name) + m = rs_or_single_client("mongodb://%s/?w=0" % (pair,), + replicaSet=client_context.replica_set_name) coll = m.pymongo_test.write_concern_test coll.insert_one(doc) # Equality tests direct = connected(single_client(w=0)) - self.assertEqual(direct, - connected(MongoClient("mongodb://%s/?w=0" % (pair,)))) - - self.assertFalse(direct != - connected(MongoClient("mongodb://%s/?w=0" % (pair,)))) + direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,))) + self.assertEqual(direct, direct2) + self.assertFalse(direct != direct2) if __name__ == "__main__": diff --git a/test/test_cursor.py b/test/test_cursor.py index a213ee2aa..b4c631980 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -24,8 +24,8 @@ sys.path[0:0] = [""] from bson.code import Code from bson.py3compat import PY3 from bson.son import SON -from pymongo import (MongoClient, - monitoring, +from pymongo import (monitoring, + MongoClient, ASCENDING, DESCENDING, ALL, @@ -38,7 +38,8 @@ from test import (client_context, SkipTest, unittest, IntegrationTest) -from test.utils import server_started_with_auth, single_client, EventListener +from test.utils import (rs_or_single_client, + EventListener) if PY3: long = int @@ -48,9 +49,7 @@ class TestCursorNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient( - client_context.host, client_context.port, connect=False) - cls.db = client.test + cls.db = MongoClient(connect=False).test def test_deepcopy_cursor_littered_with_regexes(self): cursor = self.db.test.find({ @@ -128,8 +127,15 @@ class TestCursorNoConnect(unittest.TestCase): cursor.remove_option(128) self.assertEqual(0, cursor._Cursor__query_flags) + +class TestCursor(IntegrationTest): + + def test_add_remove_option_exhaust(self): # Exhaust - which mongos doesn't support - if client_context.client is not None and not self.db.client.is_mongos: + if client_context.is_mongos: + with self.assertRaises(InvalidOperation): + self.db.test.find(cursor_type=CursorType.EXHAUST) + else: cursor = self.db.test.find(cursor_type=CursorType.EXHAUST) self.assertEqual(64, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(64) @@ -140,9 +146,6 @@ class TestCursorNoConnect(unittest.TestCase): self.assertEqual(0, cursor._Cursor__query_flags) self.assertFalse(cursor._Cursor__exhaust) - -class TestCursor(IntegrationTest): - @client_context.require_version_min(2, 5, 3, -1) def test_max_time_ms(self): db = self.db @@ -226,7 +229,7 @@ class TestCursor(IntegrationTest): listener.add_command_filter('killCursors') saved_listeners = monitoring._LISTENERS monitoring._LISTENERS = monitoring._Listeners([], [], [], []) - coll = single_client( + coll = rs_or_single_client( event_listeners=[listener])[self.db.name].pymongo_test results = listener.results @@ -312,6 +315,7 @@ class TestCursor(IntegrationTest): @client_context.require_version_min(2, 5, 3, -1) @client_context.require_test_commands + @client_context.require_no_mongos def test_max_time_ms_getmore(self): # Test that Cursor handles server timeout error in response to getmore. coll = self.db.pymongo_test @@ -1172,7 +1176,7 @@ class TestCursor(IntegrationTest): @client_context.require_no_mongos def test_comment(self): - if server_started_with_auth(self.db.client): + if client_context.auth_enabled: raise SkipTest("SERVER-4754 - This test uses profiling.") # MongoDB 3.1.5 changed the ns for commands. diff --git a/test/test_database.py b/test/test_database.py index 2ed5ce9d5..12cccfc02 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -29,8 +29,7 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from bson.py3compat import string_type, text_type, PY3 from bson.son import SON -from pymongo import (MongoClient, - ALL, +from pymongo import (ALL, auth, OFF, SLOW_ONLY, @@ -43,6 +42,7 @@ from pymongo.errors import (CollectionInvalid, InvalidName, OperationFailure, WriteConcernError) +from pymongo.mongo_client import MongoClient from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -63,12 +63,12 @@ if PY3: class TestDatabaseNoConnect(unittest.TestCase): + """Test Database features on a client that does not connect. + """ @classmethod - @client_context.require_connection def setUpClass(cls): - cls.client = MongoClient( - client_context.host, client_context.port, connect=False) + cls.client = MongoClient(connect=False) def test_name(self): self.assertRaises(TypeError, Database, self.client, 4) @@ -79,23 +79,6 @@ class TestDatabaseNoConnect(unittest.TestCase): self.client, u"my\u0000db") self.assertEqual("name", Database(self.client, "name").name) - def test_equality(self): - self.assertNotEqual(Database(self.client, "test"), - Database(self.client, "mike")) - self.assertEqual(Database(self.client, "test"), - Database(self.client, "test")) - - # Explicitly test inequality - self.assertFalse(Database(self.client, "test") != - Database(self.client, "test")) - - def test_get_coll(self): - db = Database(self.client, "pymongo_test") - self.assertEqual(db.test, db["test"]) - self.assertEqual(db.test, Collection(db, "test")) - self.assertNotEqual(db.test, Collection(db, "mike")) - self.assertEqual(db.test.mike, db["test.mike"]) - def test_get_collection(self): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) @@ -128,6 +111,23 @@ class TestDatabaseNoConnect(unittest.TestCase): class TestDatabase(IntegrationTest): + def test_equality(self): + self.assertNotEqual(Database(self.client, "test"), + Database(self.client, "mike")) + self.assertEqual(Database(self.client, "test"), + Database(self.client, "test")) + + # Explicitly test inequality + self.assertFalse(Database(self.client, "test") != + Database(self.client, "test")) + + def test_get_coll(self): + db = Database(self.client, "pymongo_test") + self.assertEqual(db.test, db["test"]) + self.assertEqual(db.test, Collection(db, "test")) + self.assertNotEqual(db.test, Collection(db, "mike")) + self.assertEqual(db.test.mike, db["test.mike"]) + def test_repr(self): self.assertEqual(repr(Database(self.client, "pymongo_test")), "Database(%r, %s)" % (self.client, @@ -818,6 +818,7 @@ class TestDatabase(IntegrationTest): @client_context.require_version_min(2, 5, 3, -1) @client_context.require_test_commands + @client_context.require_no_mongos def test_command_max_time_ms(self): self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 5ad9fa08f..76f12e0dd 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -34,20 +34,18 @@ from gridfs.errors import NoFile from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError from test import (IntegrationTest, - client_context, unittest, qcheck) from test.utils import rs_or_single_client class TestGridFileNoConnect(unittest.TestCase): + """Test GridFile features on a client that does not connect. + """ @classmethod - @client_context.require_connection def setUpClass(cls): - client = MongoClient( - client_context.host, client_context.port, connect=False) - cls.db = client.pymongo_test + cls.db = MongoClient(connect=False).pymongo_test def test_grid_in_custom_opts(self): self.assertRaises(TypeError, GridIn, "foo") @@ -73,17 +71,6 @@ class TestGridFileNoConnect(unittest.TestCase): self.assertEqual(1000, b.chunk_size) self.assertEqual(100, b.baz) - def test_grid_out_cursor_options(self): - self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {}, - projection={"filename": 1}) - - cursor = GridOutCursor(self.db.fs, {}) - cursor_clone = cursor.clone() - self.assertEqual(cursor_clone.__dict__, cursor.__dict__) - - self.assertRaises(NotImplementedError, cursor.add_option, 0) - self.assertRaises(NotImplementedError, cursor.remove_option, 0) - class TestGridFile(IntegrationTest): @@ -239,6 +226,17 @@ class TestGridFile(IntegrationTest): "upload_date", "aliases", "metadata", "md5"]: self.assertRaises(AttributeError, setattr, b, attr, 5) + def test_grid_out_cursor_options(self): + self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {}, + projection={"filename": 1}) + + cursor = GridOutCursor(self.db.fs, {}) + cursor_clone = cursor.clone() + self.assertEqual(cursor_clone.__dict__, cursor.__dict__) + + self.assertRaises(NotImplementedError, cursor.add_option, 0) + self.assertRaises(NotImplementedError, cursor.remove_option, 0) + def test_grid_out_custom_opts(self): one = GridIn(self.db.fs, _id=5, filename="my_file", contentType="text/html", chunkSize=1000, aliases=["foo"], @@ -608,7 +606,6 @@ Bye""")) self.assertRaises(ServerSelectionTimeoutError, infile.write, b'data') self.assertRaises(ServerSelectionTimeoutError, infile.close) - def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 07cf5c3bd..8d8fc4e2f 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -79,9 +79,7 @@ class TestGridfsNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient( - client_context.host, client_context.port, connect=False) - cls.db = client.pymongo_test + cls.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -512,9 +510,8 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase): self.assertRaises(ConnectionFailure, fs.put, 'data') def tearDown(self): - rsc = client_context.rs_client - rsc.pymongo_test.drop_collection('fs.files') - rsc.pymongo_test.drop_collection('fs.chunks') + self.client.pymongo_test.drop_collection('fs.files') + self.client.pymongo_test.drop_collection('fs.chunks') if __name__ == "__main__": diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 51d37511a..a22d61ad2 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -28,20 +28,18 @@ from bson.py3compat import StringIO, string_type from gridfs.errors import NoFile, CorruptGridFile from pymongo.errors import (ConfigurationError, ConnectionFailure, - ServerSelectionTimeoutError, - OperationFailure) + ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference from test import (client_context, + unittest, IntegrationTest) from test.test_replica_set_client import TestReplicaSetClientBase from test.utils import (joinall, single_client, one, rs_client, - rs_or_single_client, - rs_or_single_client_noauth, - remove_all_users) + rs_or_single_client) class JustWrite(threading.Thread): @@ -496,9 +494,8 @@ class TestGridfsBucketReplicaSet(TestReplicaSetClientBase): "test_filename", b'data') def tearDown(self): - rsc = client_context.rs_client - rsc.pymongo_test.drop_collection('fs.files') - rsc.pymongo_test.drop_collection('fs.chunks') + self.client.pymongo_test.drop_collection('fs.files') + self.client.pymongo_test.drop_collection('fs.chunks') if __name__ == "__main__": diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 013eb284d..2568d2373 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -28,7 +28,10 @@ from pymongo.errors import NotMasterError, OperationFailure from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern from test import unittest, client_context, client_knobs -from test.utils import single_client, wait_until, EventListener +from test.utils import (EventListener, + rs_or_single_client, + single_client, + wait_until) class TestCommandMonitoring(unittest.TestCase): @@ -40,7 +43,7 @@ class TestCommandMonitoring(unittest.TestCase): cls.saved_listeners = monitoring._LISTENERS # Don't use any global subscribers. monitoring._LISTENERS = monitoring._Listeners([], [], [], []) - cls.client = single_client(event_listeners=[cls.listener]) + cls.client = rs_or_single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): @@ -384,7 +387,7 @@ class TestCommandMonitoring(unittest.TestCase): @client_context.require_replica_set def test_not_master_error(self): - address = next(iter(client_context.rs_client.secondaries)) + address = next(iter(client_context.client.secondaries)) client = single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. client.admin.command('ismaster') diff --git a/test/test_pooling.py b/test/test_pooling.py index 177d38e7e..4e2a5f775 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -166,6 +166,10 @@ class _TestPoolingBase(unittest.TestCase): pair=(client_context.host, client_context.port), *args, **kwargs): + # Start the pool with the correct ssl options. + pool_options = client_context.client._topology_settings.pool_options + kwargs['ssl_context'] = pool_options.ssl_context + kwargs['ssl_match_hostname'] = pool_options.ssl_match_hostname return Pool(pair, PoolOptions(*args, **kwargs)) @@ -173,14 +177,12 @@ class TestPooling(_TestPoolingBase): def test_max_pool_size_validation(self): host, port = client_context.host, client_context.port self.assertRaises( - ValueError, MongoClient, host=host, port=port, - maxPoolSize=-1) + ValueError, MongoClient, host=host, port=port, maxPoolSize=-1) self.assertRaises( - ValueError, MongoClient, host=host, port=port, - maxPoolSize='foo') + ValueError, MongoClient, host=host, port=port, maxPoolSize='foo') - c = MongoClient(host=host, port=port, maxPoolSize=100) + c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False) self.assertEqual(c.max_pool_size, 100) def test_no_disconnect(self): diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 25cd0b760..b81f04fb8 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -22,7 +22,7 @@ class TestRawBSONDocument(unittest.TestCase): @classmethod def setUpClass(cls): - cls.client = client_context.rs_or_standalone_client + cls.client = client_context.client def tearDown(self): if client_context.connected: diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 20ac21a04..1eaa34c4e 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -22,7 +22,7 @@ from pymongo.errors import ConfigurationError, OperationFailure from pymongo.read_concern import ReadConcern from test import client_context, unittest -from test.utils import single_client, EventListener +from test.utils import single_client, rs_or_single_client, EventListener class TestReadConcern(unittest.TestCase): @@ -61,8 +61,9 @@ class TestReadConcern(unittest.TestCase): self.assertRaises(TypeError, ReadConcern, 42) def test_read_concern_uri(self): - uri = 'mongodb://%s/?readConcernLevel=majority' % (client_context.pair,) - client = pymongo.MongoClient(uri) + uri = 'mongodb://%s/?readConcernLevel=majority' % ( + client_context.pair,) + client = rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern('majority'), client.read_concern) @client_context.require_version_max(3, 1) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 22d3cae2e..5b024c665 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -47,6 +47,8 @@ from test.version import Version class TestSelections(unittest.TestCase): + + @client_context.require_connection def test_bool(self): client = single_client() @@ -293,7 +295,9 @@ class TestReadPreferences(TestReadPreferencesBase): class ReadPrefTester(MongoClient): def __init__(self, *args, **kwargs): self.has_read_from = set() - super(ReadPrefTester, self).__init__(*args, **kwargs) + client_options = client_context.ssl_client_options.copy() + client_options.update(kwargs) + super(ReadPrefTester, self).__init__(*args, **client_options) @contextlib.contextmanager def _socket_for_reads(self, read_preference): diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index 83de8990e..fc937f0e8 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -38,6 +38,7 @@ from test import (client_context, client_knobs, IntegrationTest, unittest, + SkipTest, db_pwd, db_user, MockClientTest) @@ -106,7 +107,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase): self.assertIn(client_context.pair, repr(client)) def test_properties(self): - c = client_context.rs_client + c = client_context.client c.admin.command('ping') wait_until(lambda: c.primary == self.primary, "discover primary") @@ -132,9 +133,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase): tag_sets = [{'dc': 'la', 'rack': '2'}, {'foo': 'bar'}] secondary = Secondary(tag_sets=tag_sets) - c = MongoClient( - client_context.pair, - replicaSet=self.name, + c = rs_client( maxPoolSize=25, document_class=SON, tz_aware=True, @@ -213,19 +212,17 @@ class TestReplicaSetClient(TestReplicaSetClientBase): # No error. coll.find_one() - @client_context.require_replica_set @client_context.require_ipv6 def test_ipv6(self): port = client_context.port - c = MongoClient("mongodb://[::1]:%d" % (port,), replicaSet=self.name) + c = rs_client("mongodb://[::1]:%d" % (port,)) # Client switches to IPv4 once it has first ismaster response. msg = 'discovered primary with IPv4 address "%r"' % (self.primary,) wait_until(lambda: c.primary == self.primary, msg) # Same outcome with both IPv4 and IPv6 seeds. - c = MongoClient("[::1]:%d,localhost:%d" % (port, port), - replicaSet=self.name) + c = rs_client("[::1]:%d,localhost:%d" % (port, port)) wait_until(lambda: c.primary == self.primary, msg) @@ -235,7 +232,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase): auth_str = "" uri = "mongodb://%slocalhost:%d,[::1]:%d" % (auth_str, port, port) - client = MongoClient(uri, replicaSet=self.name) + client = rs_client(uri) client.pymongo_test.test.insert_one({"dummy": u"object"}) client.pymongo_test_bernie.test.insert_one({"dummy": u"object"}) diff --git a/test/test_ssl.py b/test/test_ssl.py index c40c88bfd..1f7917e2d 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -157,9 +157,18 @@ class TestSSL(IntegrationTest): self.assertTrue(coll.find_one()['ssl']) coll.drop() + @classmethod @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") - def setUp(self): - super(TestSSL, self).setUp() + def setUpClass(cls): + super(TestSSL, cls).setUpClass() + # MongoClient should connect to the primary by default. + cls.saved_port = MongoClient.PORT + MongoClient.PORT = client_context.port + + @classmethod + def tearDownClass(cls): + MongoClient.PORT = cls.saved_port + super(TestSSL, cls).tearDownClass() @client_context.require_ssl def test_simple_ssl(self): @@ -259,7 +268,6 @@ class TestSSL(IntegrationTest): ssl_ca_certs=CA_PEM) self.assertClientWorks(client) - @client_context.require_ssl_certfile @client_context.require_server_resolvable @client_context.require_no_auth diff --git a/test/test_threads.py b/test/test_threads.py index 2477b9829..f1c39b059 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -131,7 +131,7 @@ class Disconnect(threading.Thread): class TestThreads(IntegrationTest): def setUp(self): - self.db = client_context.rs_or_standalone_client.pymongo_test + self.db = self.client.pymongo_test def test_threading(self): self.db.drop_collection("test") diff --git a/test/utils.py b/test/utils.py index f74fe6bec..a8b08ad11 100644 --- a/test/utils.py +++ b/test/utils.py @@ -98,45 +98,51 @@ class HeartbeatEventListener(monitoring.ServerHeartbeatListener): self.results.append(event) -def _connection_string_noauth(h, p): +def _connection_string(h, p, authenticate): if h.startswith("mongodb://"): return h - return "mongodb://%s:%d" % (h, p) - - -def _connection_string(h, p): - if h.startswith("mongodb://"): - return h - elif client_context.auth_enabled: + elif client_context.auth_enabled and authenticate: return "mongodb://%s:%s@%s:%d" % (db_user, db_pwd, h, p) else: - return _connection_string_noauth(h, p) + return "mongodb://%s:%d" % (h, p) + + +def _mongo_client(host, port, authenticate=True, direct=False, **kwargs): + """Create a new client over SSL/TLS if necessary.""" + client_options = client_context.ssl_client_options.copy() + + if client_context.replica_set_name and not direct: + client_options['replicaSet'] = client_context.replica_set_name + client_options.update(kwargs) + + client = MongoClient(_connection_string(host, port, authenticate), port, + **client_options) + + return client def single_client_noauth( h=client_context.host, p=client_context.port, **kwargs): """Make a direct connection. Don't authenticate.""" - return MongoClient(_connection_string_noauth(h, p), p, **kwargs) + return _mongo_client(h, p, authenticate=False, direct=True, **kwargs) def single_client( h=client_context.host, p=client_context.port, **kwargs): """Make a direct connection, and authenticate if necessary.""" - return MongoClient(_connection_string(h, p), p, **kwargs) + return _mongo_client(h, p, direct=True, **kwargs) def rs_client_noauth( h=client_context.host, p=client_context.port, **kwargs): """Connect to the replica set. Don't authenticate.""" - return MongoClient(_connection_string_noauth(h, p), p, - replicaSet=client_context.replica_set_name, **kwargs) + return _mongo_client(h, p, authenticate=False, **kwargs) def rs_client( h=client_context.host, p=client_context.port, **kwargs): """Connect to the replica set and authenticate if necessary.""" - return MongoClient(_connection_string(h, p), p, - replicaSet=client_context.replica_set_name, **kwargs) + return _mongo_client(h, p, **kwargs) def rs_or_single_client_noauth( @@ -145,10 +151,7 @@ def rs_or_single_client_noauth( Like rs_or_single_client, but does not authenticate. """ - if client_context.replica_set_name: - return rs_client_noauth(h, p, **kwargs) - else: - return single_client_noauth(h, p, **kwargs) + return _mongo_client(h, p, authenticate=False, **kwargs) def rs_or_single_client( @@ -157,10 +160,7 @@ def rs_or_single_client( Authenticates if necessary. """ - if client_context.replica_set_name: - return rs_client(h, p, **kwargs) - else: - return single_client(h, p, **kwargs) + return _mongo_client(h, p, **kwargs) def one(s): @@ -319,9 +319,7 @@ def enable_text_search(client): 'setParameter', textSearchEnabled=True) for host, port in client.secondaries: - client = MongoClient(host, port) - if client_context.auth_enabled: - client.admin.authenticate(db_user, db_pwd) + client = single_client(host, port) client.admin.command('setParameter', textSearchEnabled=True)