PYTHON-1075 Support running the entire test suite with SSL/TLS

SSL connections are configurable via the environment variables
"CLIENT_PEM", "CA_PEM", and "CERT_REQS".
This commit is contained in:
Shane Harvey 2016-08-30 15:48:05 -07:00
parent bb6cd59525
commit 5905a86785
22 changed files with 280 additions and 286 deletions

View File

@ -35,7 +35,7 @@ import pymongo.errors
from bson.py3compat import _unicode
from pymongo import common
from pymongo.ssl_support import HAVE_SSL
from pymongo.ssl_support import HAVE_SSL, validate_cert_reqs
from test.version import Version
if HAVE_SSL:
@ -53,7 +53,18 @@ db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password"))
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'certificates')
CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem')
CLIENT_PEM = os.environ.get('CLIENT_PEM',
os.path.join(CERT_PATH, 'client.pem'))
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
CERT_REQS = validate_cert_reqs('CERT_REQS', os.environ.get('CERT_REQS'))
_SSL_OPTIONS = dict(ssl=True)
if CLIENT_PEM:
_SSL_OPTIONS['ssl_certfile'] = CLIENT_PEM
if CA_PEM:
_SSL_OPTIONS['ssl_ca_certs'] = CA_PEM
if CERT_REQS is not None:
_SSL_OPTIONS['ssl_cert_reqs'] = CERT_REQS
def is_server_resolvable():
@ -70,6 +81,17 @@ def is_server_resolvable():
socket.setdefaulttimeout(socket_timeout)
def _connect(host, port, **kwargs):
try:
client = pymongo.MongoClient(
host, port, serverSelectionTimeoutMS=100, **kwargs)
client.admin.command('ismaster') # Can we connect?
# If connected, then return client with default timeout
return pymongo.MongoClient(host, port, **kwargs)
except pymongo.errors.ConnectionFailure:
return None
class client_knobs(object):
def __init__(
self,
@ -129,7 +151,6 @@ class ClientContext(object):
self.w = None
self.nodes = set()
self.replica_set_name = None
self.rs_client = None
self.cmd_line = None
self.version = Version(-1) # Needs to be comparable with Version
self.auth_enabled = False
@ -137,40 +158,22 @@ class ClientContext(object):
self.is_mongos = False
self.is_rs = False
self.has_ipv6 = False
self.ssl = False
self.ssl_cert_none = False
self.ssl_certfile = False
self.server_is_resolvable = is_server_resolvable()
self.client = self.rs_or_standalone_client = None
def connect(**kwargs):
try:
client = pymongo.MongoClient(
self.host,
self.port,
serverSelectionTimeoutMS=100,
**kwargs)
client.admin.command('ismaster') # Can we connect?
# If connected, then return client with default timeout
return pymongo.MongoClient(self.host, self.port, **kwargs)
except pymongo.errors.ConnectionFailure:
return None
self.client = connect()
self.ssl_client_options = {}
self.client = _connect(self.host, self.port)
if HAVE_SSL and not self.client:
# Is MongoDB configured for SSL?
self.client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE)
self.client = _connect(self.host, self.port, **_SSL_OPTIONS)
if self.client:
self.ssl_cert_none = True
# Can client connect with certfile?
client = connect(ssl=True,
ssl_cert_reqs=ssl.CERT_NONE,
ssl_certfile=CLIENT_PEM,)
if client:
self.ssl = True
self.ssl_client_options = _SSL_OPTIONS
self.ssl_certfile = True
self.client = client
if _SSL_OPTIONS.get('ssl_cert_reqs') == ssl.CERT_NONE:
self.ssl_cert_none = True
if self.client:
self.connected = True
@ -178,16 +181,16 @@ class ClientContext(object):
self.w = len(self.ismaster.get("hosts", [])) or 1
self.nodes = set([(self.host, self.port)])
self.replica_set_name = self.ismaster.get('setName', '')
self.rs_client = None
self.version = Version.from_client(self.client)
if self.replica_set_name:
self.is_rs = True
self.rs_client = pymongo.MongoClient(
self.ismaster['primary'], replicaSet=self.replica_set_name)
self.client = pymongo.MongoClient(
self.ismaster['primary'],
replicaSet=self.replica_set_name,
**self.ssl_client_options)
# Force connection
self.rs_client.admin.command('ismaster')
self.host, self.port = self.rs_client.primary
self.client = connect()
self.client.admin.command('ismaster')
self.host, self.port = self.client.primary
nodes = [partition_node(node.lower())
for node in self.ismaster.get('hosts', [])]
@ -197,8 +200,6 @@ class ClientContext(object):
for node in self.ismaster.get('arbiters', [])])
self.nodes = set(nodes)
self.rs_or_standalone_client = self.rs_client or self.client
try:
self.cmd_line = self.client.admin.command('getCmdLineOpts')
except pymongo.errors.OperationFailure as e:
@ -221,9 +222,6 @@ class ClientContext(object):
self.client.admin.add_user(db_user, db_pwd, **roles)
self.client.admin.authenticate(db_user, db_pwd)
if self.rs_client:
self.rs_client.admin.authenticate(db_user, db_pwd)
# May not have this if OperationFailure was raised earlier.
self.cmd_line = self.client.admin.command('getCmdLineOpts')
@ -233,12 +231,17 @@ class ClientContext(object):
params = self.cmd_line['parsed'].get('setParameter', [])
if 'enableTestCommands=1' in params:
self.test_commands_enabled = True
else:
params = self.cmd_line['parsed'].get('setParameter', {})
if params.get('enableTestCommands') == '1':
self.test_commands_enabled = True
self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid')
self.has_ipv6 = self._server_started_with_ipv6()
# Do this after we connect so we know who the primary is.
self.pair = "%s:%d" % (self.host, self.port)
@property
def pair(self):
return "%s:%d" % (self.host, self.port)
def _check_user_provided(self):
try:
@ -391,13 +394,13 @@ class ClientContext(object):
def require_ssl(self, func):
"""Run a test only if the client can connect over SSL."""
return self._require(self.ssl_cert_none or self.ssl_certfile,
return self._require(self.ssl,
"Must be able to connect via SSL",
func=func)
def require_no_ssl(self, func):
"""Run a test only if the client can connect over SSL."""
return self._require(not (self.ssl_cert_none or self.ssl_certfile),
return self._require(not self.ssl,
"Must be able to connect without SSL",
func=func)
@ -430,7 +433,7 @@ class IntegrationTest(unittest.TestCase):
@classmethod
@client_context.require_connection
def setUpClass(cls):
cls.client = client_context.rs_or_standalone_client
cls.client = client_context.client
cls.db = cls.client.pymongo_test

View File

@ -23,7 +23,7 @@ from pymongo import MongoClient
from pymongo.errors import AutoReconnect, NetworkTimeout
from pymongo.ismaster import IsMaster
from pymongo.monitor import Monitor
from pymongo.pool import Pool, PoolOptions
from pymongo.pool import Pool
from pymongo.server_description import ServerDescription
from test import client_context
@ -39,9 +39,7 @@ class MockPool(Pool):
self.mock_host, self.mock_port = pair
# Actually connect to the default server.
Pool.__init__(self,
(default_host, default_port),
PoolOptions(connect_timeout=20))
Pool.__init__(self, (default_host, default_port), *args, **kwargs)
@contextlib.contextmanager
def get_socket(self, all_credentials, checkout=False):
@ -125,7 +123,10 @@ class MockClient(MongoClient):
kwargs['_pool_class'] = partial(MockPool, self)
kwargs['_monitor_class'] = partial(MockMonitor, self)
super(MockClient, self).__init__(*args, **kwargs)
client_options = client_context.ssl_client_options.copy()
client_options.update(kwargs)
super(MockClient, self).__init__(*args, **client_options)
def kill_host(self, host):
"""Host is like 'a:1'."""

View File

@ -31,7 +31,7 @@ from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple
from pymongo.errors import OperationFailure
from pymongo.read_preferences import ReadPreference
from test import client_context, SkipTest, unittest, Version
from test.utils import delay
from test.utils import delay, rs_or_single_client_noauth, single_client_noauth
# YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX.
GSSAPI_HOST = os.environ.get('GSSAPI_HOST')
@ -305,14 +305,11 @@ class TestSASLPlain(unittest.TestCase):
self.assertRaises(OperationFailure, bad_pwd.admin.command, 'ismaster')
class TestSCRAMSHA1(unittest.TestCase):
@client_context.require_auth
@client_context.require_version_min(2, 7, 2)
def setUp(self):
self.replica_set_name = client_context.replica_set_name
# Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line.
if client_context.version < Version(2, 7, 7):
cmd_line = client_context.cmd_line
@ -321,7 +318,7 @@ class TestSCRAMSHA1(unittest.TestCase):
{}).get('authenticationMechanisms', ''):
raise SkipTest('SCRAM-SHA-1 mechanism not enabled')
client = client_context.rs_or_standalone_client
client = client_context.client
client.pymongo_test.add_user(
'user', 'pass',
roles=['userAdmin', 'readWrite'],
@ -329,42 +326,36 @@ class TestSCRAMSHA1(unittest.TestCase):
def test_scram_sha1(self):
host, port = client_context.host, client_context.port
client = MongoClient(host, port)
client = rs_or_single_client_noauth()
self.assertTrue(client.pymongo_test.authenticate(
'user', 'pass', mechanism='SCRAM-SHA-1'))
client.pymongo_test.command('dbstats')
client = MongoClient('mongodb://user:pass@%s:%d/pymongo_test'
'?authMechanism=SCRAM-SHA-1' % (host, port))
client = rs_or_single_client_noauth(
'mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1'
% (host, port))
client.pymongo_test.command('dbstats')
if self.replica_set_name:
client = MongoClient(host, port,
replicaSet='%s' % (self.replica_set_name,))
self.assertTrue(client.pymongo_test.authenticate(
'user', 'pass', mechanism='SCRAM-SHA-1'))
client.pymongo_test.command('dbstats')
if client_context.is_rs:
uri = ('mongodb://user:pass'
'@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1'
'&replicaSet=%s' % (host, port, self.replica_set_name))
client = MongoClient(uri)
'&replicaSet=%s' % (host, port,
client_context.replica_set_name))
client = single_client_noauth(uri)
client.pymongo_test.command('dbstats')
db = client.get_database(
'pymongo_test', read_preference=ReadPreference.SECONDARY)
db.command('dbstats')
def tearDown(self):
client_context.rs_or_standalone_client.pymongo_test.remove_user('user')
client_context.client.pymongo_test.remove_user('user')
class TestAuthURIOptions(unittest.TestCase):
@client_context.require_auth
def setUp(self):
client = MongoClient(client_context.host, client_context.port)
response = client.admin.command('ismaster')
self.replica_set_name = str(response.get('setName', ''))
client = rs_or_single_client_noauth()
client_context.client.admin.add_user('admin', 'pass',
roles=['userAdminAnyDatabase',
'dbAdminAnyDatabase',
@ -374,12 +365,10 @@ class TestAuthURIOptions(unittest.TestCase):
client.pymongo_test.add_user('user', 'pass',
roles=['userAdmin', 'readWrite'])
if self.replica_set_name:
# GLE requires authentication.
client.admin.authenticate('admin', 'pass')
if client_context.is_rs:
# Make sure the admin user is replicated after calling add_user
# above. This avoids a race in the replica set tests below.
client.admin.command('getLastError', w=len(response['hosts']))
client.admin.command('getLastError', w=client_context.w)
self.client = client
def tearDown(self):
@ -393,14 +382,14 @@ class TestAuthURIOptions(unittest.TestCase):
def test_uri_options(self):
# Test default to admin
host, port = client_context.host, client_context.port
client = MongoClient(
client = rs_or_single_client_noauth(
'mongodb://admin:pass@%s:%d' % (host, port))
self.assertTrue(client.admin.command('dbstats'))
if self.replica_set_name:
uri = ('mongodb://admin:pass'
'@%s:%d/?replicaSet=%s' % (host, port, self.replica_set_name))
client = MongoClient(uri)
if client_context.is_rs:
uri = ('mongodb://admin:pass@%s:%d/?replicaSet=%s' % (
host, port, client_context.replica_set_name))
client = single_client_noauth(uri)
self.assertTrue(client.admin.command('dbstats'))
db = client.get_database(
'admin', read_preference=ReadPreference.SECONDARY)
@ -408,14 +397,14 @@ class TestAuthURIOptions(unittest.TestCase):
# Test explicit database
uri = 'mongodb://user:pass@%s:%d/pymongo_test' % (host, port)
client = MongoClient(uri)
client = rs_or_single_client_noauth(uri)
self.assertRaises(OperationFailure, client.admin.command, 'dbstats')
self.assertTrue(client.pymongo_test.command('dbstats'))
if self.replica_set_name:
uri = ('mongodb://user:pass@%s:%d'
'/pymongo_test?replicaSet=%s' % (host, port, self.replica_set_name))
client = MongoClient(uri)
if client_context.is_rs:
uri = ('mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s' % (
host, port, client_context.replica_set_name))
client = single_client_noauth(uri)
self.assertRaises(OperationFailure,
client.admin.command, 'dbstats')
self.assertTrue(client.pymongo_test.command('dbstats'))
@ -426,15 +415,16 @@ class TestAuthURIOptions(unittest.TestCase):
# Test authSource
uri = ('mongodb://user:pass@%s:%d'
'/pymongo_test2?authSource=pymongo_test' % (host, port))
client = MongoClient(uri)
client = rs_or_single_client_noauth(uri)
self.assertRaises(OperationFailure,
client.pymongo_test2.command, 'dbstats')
self.assertTrue(client.pymongo_test.command('dbstats'))
if self.replica_set_name:
if client_context.is_rs:
uri = ('mongodb://user:pass@%s:%d/pymongo_test2?replicaSet='
'%s;authSource=pymongo_test' % (host, port, self.replica_set_name))
client = MongoClient(uri)
'%s;authSource=pymongo_test' % (
host, port, client_context.replica_set_name))
client = single_client_noauth(uri)
self.assertRaises(OperationFailure,
client.pymongo_test2.command, 'dbstats')
self.assertTrue(client.pymongo_test.command('dbstats'))
@ -449,7 +439,7 @@ class TestDelegatedAuth(unittest.TestCase):
@client_context.require_version_max(2, 5, 3)
@client_context.require_version_min(2, 4, 0)
def setUp(self):
self.client = client_context.rs_or_standalone_client
self.client = client_context.client
def tearDown(self):
self.client.pymongo_test.remove_user('user')
@ -465,7 +455,7 @@ class TestDelegatedAuth(unittest.TestCase):
self.client.pymongo_test2.add_user('user',
userSource='pymongo_test',
roles=['read'])
auth_c = MongoClient(client_context.host, client_context.port)
auth_c = rs_or_single_client_noauth()
self.assertRaises(OperationFailure,
auth_c.pymongo_test2.foo.find_one)
# Auth must occur on the db where the user is defined.

View File

@ -21,7 +21,6 @@ sys.path[0:0] = [""]
from bson import InvalidDocument, SON
from bson.objectid import ObjectId
from bson.py3compat import string_type
from pymongo import MongoClient
from pymongo.operations import *
from pymongo.common import partition_node
from pymongo.errors import (BulkWriteError,
@ -33,7 +32,11 @@ from test import (client_context,
unittest,
IntegrationTest,
SkipTest)
from test.utils import oid_generated_on_client, remove_all_users, wait_until
from test.utils import (oid_generated_on_client,
remove_all_users,
rs_or_single_client_noauth,
single_client,
wait_until)
class BulkTestBase(IntegrationTest):
@ -137,7 +140,6 @@ class TestBulk(BulkTestBase):
bulk.find({})
@client_context.require_version_min(3, 1, 9, -1)
@client_context.require_no_auth
def test_bypass_document_validation_bulk_op(self):
# Test insert
@ -971,7 +973,7 @@ class TestBulkWriteConcern(BulkTestBase):
if cls.w > 1:
for member in client_context.ismaster['hosts']:
if member != client_context.ismaster['primary']:
cls.secondary = MongoClient(*partition_node(member))
cls.secondary = single_client(*partition_node(member))
break
# We tested wtimeout errors by specifying a write concern greater than
@ -1243,7 +1245,7 @@ class TestBulkAuthorization(BulkTestBase):
def test_readonly(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = MongoClient(client_context.host, client_context.port)
cli = rs_or_single_client_noauth()
db = cli.pymongo_test
coll = db.test
db.authenticate('readonly', 'pw')
@ -1254,7 +1256,7 @@ class TestBulkAuthorization(BulkTestBase):
def test_no_remove(self):
# We test that an authorization failure aborts the batch and is raised
# as OperationFailure.
cli = MongoClient(client_context.host, client_context.port)
cli = rs_or_single_client_noauth()
db = cli.pymongo_test
coll = db.test
db.authenticate('noremove', 'pw')

View File

@ -38,9 +38,9 @@ from pymongo.errors import (AutoReconnect,
ConfigurationError,
ConnectionFailure,
InvalidName,
OperationFailure,
NetworkTimeout,
InvalidURI,
NetworkTimeout,
OperationFailure,
WriteConcernError)
from pymongo.monitoring import (ServerHeartbeatListener,
ServerHeartbeatStartedEvent)
@ -68,6 +68,7 @@ from test.utils import (assertRaisesExactly,
one,
connected,
wait_until,
rs_client,
rs_or_single_client,
rs_or_single_client_noauth,
single_client,
@ -81,11 +82,8 @@ class ClientUnitTest(unittest.TestCase):
@classmethod
@client_context.require_connection
def setUpClass(cls):
cls.client = MongoClient(
client_context.host,
client_context.port,
connect=False,
serverSelectionTimeoutMS=100)
cls.client = rs_or_single_client(connect=False,
serverSelectionTimeoutMS=100)
def test_keyword_arg_defaults(self):
client = MongoClient(socketTimeoutMS=None,
@ -173,23 +171,23 @@ class ClientUnitTest(unittest.TestCase):
self.assertRaises(TypeError, iterate)
def test_get_default_database(self):
c = MongoClient(
"mongodb://%s:%d/foo" % (client_context.host, client_context.port),
connect=False)
c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host,
client_context.port),
connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_get_default_database_error(self):
# URI with no database.
c = MongoClient(
"mongodb://%s:%d/" % (client_context.host, client_context.port),
connect=False)
c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host,
client_context.port),
connect=False)
self.assertRaises(ConfigurationError, c.get_default_database)
def test_get_default_database_with_authsource(self):
# Ensure we distinguish database name from authSource.
uri = "mongodb://%s:%d/foo?authSource=src" % (
client_context.host, client_context.port)
c = MongoClient(uri, connect=False)
c = rs_or_single_client(uri, connect=False)
self.assertEqual(Database(c, 'foo'), c.get_default_database())
def test_primary_read_pref_with_tags(self):
@ -265,10 +263,9 @@ class ClientUnitTest(unittest.TestCase):
class TestClient(IntegrationTest):
def test_max_idle_time_reaper(self):
host, port = client_context.host, client_context.port
with client_knobs(kill_cursor_frequency=0.1):
# Assert reaper doesn't remove sockets when maxIdleTimeMS not set
client = MongoClient(host, port)
client = rs_or_single_client()
server = client._get_topology().select_server(any_server_selector)
with server._pool.get_socket({}) as sock_info:
pass
@ -276,7 +273,8 @@ class TestClient(IntegrationTest):
self.assertTrue(sock_info in server._pool.sockets)
# Assert reaper removes idle socket and replaces it with a new one
client = MongoClient(host, port, maxIdleTimeMS=.5, minPoolSize=1)
client = rs_or_single_client(maxIdleTimeMS=.5,
minPoolSize=1)
server = client._get_topology().select_server(any_server_selector)
with server._pool.get_socket({}) as sock_info:
pass
@ -287,7 +285,7 @@ class TestClient(IntegrationTest):
"reaper replaces stale socket with new one")
# Assert reaper has removed idle socket and NOT replaced it
client = MongoClient(host, port, maxIdleTimeMS=.5)
client = rs_or_single_client(maxIdleTimeMS=.5)
server = client._get_topology().select_server(any_server_selector)
with server._pool.get_socket({}):
pass
@ -296,14 +294,13 @@ class TestClient(IntegrationTest):
"stale socket reaped and new one NOT added to the pool")
def test_min_pool_size(self):
host, port = client_context.host, client_context.port
with client_knobs(kill_cursor_frequency=.1):
client = MongoClient(host, port)
client = rs_or_single_client()
server = client._get_topology().select_server(any_server_selector)
self.assertEqual(0, len(server._pool.sockets))
# Assert that pool started up at minPoolSize
client = MongoClient(host, port, minPoolSize=10)
client = rs_or_single_client(minPoolSize=10)
server = client._get_topology().select_server(any_server_selector)
wait_until(lambda: 10 == len(server._pool.sockets),
"pool initialized with 10 sockets")
@ -316,10 +313,9 @@ class TestClient(IntegrationTest):
self.assertFalse(sock_info in server._pool.sockets)
def test_max_idle_time_checkout(self):
host, port = client_context.host, client_context.port
# Use high frequency to test _get_socket_no_auth.
with client_knobs(kill_cursor_frequency=99999999):
client = MongoClient(host, port, maxIdleTimeMS=.5)
client = rs_or_single_client(maxIdleTimeMS=.5)
server = client._get_topology().select_server(any_server_selector)
with server._pool.get_socket({}) as sock_info:
pass
@ -333,7 +329,7 @@ class TestClient(IntegrationTest):
self.assertTrue(new_sock_info in server._pool.sockets)
# Test that sockets are reused if maxIdleTimeMS is not set.
client = MongoClient(host, port)
client = rs_or_single_client()
server = client._get_topology().select_server(any_server_selector)
with server._pool.get_socket({}) as sock_info:
pass
@ -344,22 +340,27 @@ class TestClient(IntegrationTest):
self.assertEqual(1, len(server._pool.sockets))
def test_constants(self):
"""This test uses MongoClient explicitly to make sure that host and
port are not overloaded.
"""
host, port = client_context.host, client_context.port
# Set bad defaults.
MongoClient.HOST = "somedomainthatdoesntexist.org"
MongoClient.PORT = 123456789
with self.assertRaises(AutoReconnect):
connected(MongoClient(serverSelectionTimeoutMS=10))
connected(MongoClient(serverSelectionTimeoutMS=10,
**client_context.ssl_client_options))
# Override the defaults. No error.
connected(MongoClient(host, port))
connected(MongoClient(host, port,
**client_context.ssl_client_options))
# Set good defaults.
MongoClient.HOST = host
MongoClient.PORT = port
# No error.
connected(MongoClient())
connected(MongoClient(**client_context.ssl_client_options))
def test_init_disconnected(self):
host, port = client_context.host, client_context.port
@ -401,10 +402,10 @@ class TestClient(IntegrationTest):
def test_equality(self):
c = connected(rs_or_single_client())
self.assertEqual(client_context.rs_or_standalone_client, c)
self.assertEqual(client_context.client, c)
# Explicitly test inequality
self.assertFalse(client_context.rs_or_standalone_client != c)
self.assertFalse(client_context.client != c)
def test_host_w_port(self):
with self.assertRaises(ValueError):
@ -543,12 +544,12 @@ class TestClient(IntegrationTest):
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port)))
# Auth with lazy connection.
rs_or_single_client(
rs_or_single_client_noauth(
"mongodb://user:pass@%s:%d/pymongo_test" % (host, port),
connect=False).pymongo_test.test.find_one()
# Wrong password.
bad_client = rs_or_single_client(
bad_client = rs_or_single_client_noauth(
"mongodb://user:wrong@%s:%d/pymongo_test" % (host, port),
connect=False)
@ -585,7 +586,7 @@ class TestClient(IntegrationTest):
@client_context.require_auth
def test_lazy_auth_raises_operation_failure(self):
lazy_client = rs_or_single_client(
lazy_client = rs_or_single_client_noauth(
"mongodb://user:wrong@%s/pymongo_test" % (client_context.host,),
connect=False)
@ -608,7 +609,7 @@ class TestClient(IntegrationTest):
uri = "mongodb://%s" % encoded_socket
# Confirm we can do operations via the socket.
client = MongoClient(uri)
client = rs_or_single_client(uri)
client.pymongo_test.test.insert_one({"dummy": "object"})
dbs = client.database_names()
self.assertTrue("pymongo_test" in dbs)
@ -973,11 +974,8 @@ class TestClient(IntegrationTest):
@client_context.require_no_replica_set
def test_connect_to_standalone_using_replica_set_name(self):
client = MongoClient(
client_context.host,
client_context.port,
replicaSet='anything',
serverSelectionTimeoutMS=100)
client = single_client(replicaSet='anything',
serverSelectionTimeoutMS=100)
with self.assertRaises(AutoReconnect):
client.test.test.find_one()
@ -988,12 +986,8 @@ class TestClient(IntegrationTest):
# the topology before the getMore message is sent. Test that
# MongoClient._send_message_with_response handles the error.
with self.assertRaises(AutoReconnect):
client = MongoClient(
client_context.host,
client_context.port,
connect=False,
serverSelectionTimeoutMS=100,
replicaSet=client_context.replica_set_name)
client = rs_client(connect=False,
serverSelectionTimeoutMS=100)
client._send_message_with_response(
operation=message._GetMore('pymongo_test', 'collection',
101, 1234, client.codec_options),

View File

@ -91,6 +91,7 @@ def raisesConfigurationErrorForOldMongoDB(func):
class TestCollation(unittest.TestCase):
@classmethod
@client_context.require_connection
def setUpClass(cls):
cls.listener = EventListener()
cls.saved_listeners = monitoring._LISTENERS

View File

@ -36,7 +36,7 @@ from bson.py3compat import itervalues
from bson.son import SON
from pymongo import (ASCENDING, DESCENDING, GEO2D,
GEOHAYSTACK, GEOSPHERE, HASHED, TEXT)
from pymongo import MongoClient, monitoring
from pymongo import monitoring
from pymongo.bulk import BulkWriteError
from pymongo.collection import Collection, ReturnDocument
from pymongo.command_cursor import CommandCursor
@ -50,6 +50,7 @@ from pymongo.errors import (DocumentTooLarge,
OperationFailure,
WriteConcernError)
from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command
from pymongo.mongo_client import MongoClient
from pymongo.operations import *
from pymongo.read_preferences import ReadPreference
from pymongo.results import (InsertOneResult,
@ -66,13 +67,12 @@ from test import client_context, unittest
class TestCollectionNoConnect(unittest.TestCase):
"""Test Collection features on a client that does not connect.
"""
@classmethod
@client_context.require_connection
def setUpClass(cls):
client = MongoClient(
client_context.host, client_context.port, connect=False)
cls.db = client.pymongo_test
cls.db = MongoClient(connect=False).pymongo_test
def test_collection(self):
self.assertRaises(TypeError, Collection, self.db, 5)
@ -92,12 +92,6 @@ class TestCollectionNoConnect(unittest.TestCase):
self.assertRaises(InvalidName, make_col, self.db.test, "tes..t")
self.assertRaises(InvalidName, make_col, self.db.test, "tes\x00t")
self.assertTrue(isinstance(self.db.test, Collection))
self.assertEqual(self.db.test, self.db["test"])
self.assertEqual(self.db.test, Collection(self.db, "test"))
self.assertEqual(self.db.test.mike, self.db["test.mike"])
self.assertEqual(self.db.test["mike"], self.db["test.mike"])
def test_getattr(self):
coll = self.db.test
self.assertTrue(isinstance(coll['_does_not_exist'], Collection))
@ -138,10 +132,17 @@ class TestCollection(IntegrationTest):
else:
yield self.db.test
def test_equality(self):
self.assertTrue(isinstance(self.db.test, Collection))
self.assertEqual(self.db.test, self.db["test"])
self.assertEqual(self.db.test, Collection(self.db, "test"))
self.assertEqual(self.db.test.mike, self.db["test.mike"])
self.assertEqual(self.db.test["mike"], self.db["test.mike"])
@client_context.require_version_min(3, 3, 9)
def test_create(self):
# No Exception.
db = client_context.rs_or_standalone_client.pymongo_test
db = client_context.client.pymongo_test
db.create_test_no_wc.drop()
Collection(db, name='create_test_no_wc', create=True)
with self.assertRaises(OperationFailure):
@ -810,7 +811,6 @@ class TestCollection(IntegrationTest):
DocumentTooLarge, coll.delete_one, {'data': large})
@client_context.require_version_min(3, 1, 9, -1)
@client_context.require_no_auth
def test_insert_bypass_document_validation(self):
db = self.db
db.test.drop()
@ -860,7 +860,6 @@ class TestCollection(IntegrationTest):
bypass_document_validation=True)
@client_context.require_version_min(3, 1, 9, -1)
@client_context.require_no_auth
def test_replace_bypass_document_validation(self):
db = self.db
db.test.drop()
@ -901,7 +900,6 @@ class TestCollection(IntegrationTest):
{"x": 1}, bypass_document_validation=True)
@client_context.require_version_min(3, 1, 9, -1)
@client_context.require_no_auth
def test_update_bypass_document_validation(self):
db = self.db
db.test.drop()
@ -980,7 +978,6 @@ class TestCollection(IntegrationTest):
{"$inc": {"x": 1}}, bypass_document_validation=True)
@client_context.require_version_min(3, 1, 9, -1)
@client_context.require_no_auth
def test_bypass_document_validation_bulk_write(self):
db = self.db
db.test.drop()

View File

@ -23,7 +23,6 @@ from bson.binary import UUIDLegacy, PYTHON_LEGACY, STANDARD
from bson.code import Code
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from pymongo.mongo_client import MongoClient
from pymongo.errors import OperationFailure
from pymongo.write_concern import WriteConcern
from test import client_context, unittest, IntegrationTest
@ -165,10 +164,10 @@ class TestCommon(IntegrationTest):
{"count": 0}, reduce))
def test_write_concern(self):
c = MongoClient(connect=False)
c = rs_or_single_client(connect=False)
self.assertEqual(WriteConcern(), c.write_concern)
c = MongoClient(connect=False, w=2, wtimeout=1000)
c = rs_or_single_client(connect=False, w=2, wtimeout=1000)
wc = WriteConcern(w=2, wtimeout=1000)
self.assertEqual(wc, c.write_concern)
@ -199,24 +198,22 @@ class TestCommon(IntegrationTest):
self.assertTrue(new_coll.insert_one(doc))
self.assertRaises(OperationFailure, coll.insert_one, doc)
m = MongoClient("mongodb://%s/" % (pair,),
replicaSet=client_context.replica_set_name)
m = rs_or_single_client("mongodb://%s/" % (pair,),
replicaSet=client_context.replica_set_name)
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert_one, doc)
m = MongoClient("mongodb://%s/?w=0" % (pair,),
replicaSet=client_context.replica_set_name)
m = rs_or_single_client("mongodb://%s/?w=0" % (pair,),
replicaSet=client_context.replica_set_name)
coll = m.pymongo_test.write_concern_test
coll.insert_one(doc)
# Equality tests
direct = connected(single_client(w=0))
self.assertEqual(direct,
connected(MongoClient("mongodb://%s/?w=0" % (pair,))))
self.assertFalse(direct !=
connected(MongoClient("mongodb://%s/?w=0" % (pair,))))
direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,)))
self.assertEqual(direct, direct2)
self.assertFalse(direct != direct2)
if __name__ == "__main__":

View File

@ -24,8 +24,8 @@ sys.path[0:0] = [""]
from bson.code import Code
from bson.py3compat import PY3
from bson.son import SON
from pymongo import (MongoClient,
monitoring,
from pymongo import (monitoring,
MongoClient,
ASCENDING,
DESCENDING,
ALL,
@ -38,7 +38,8 @@ from test import (client_context,
SkipTest,
unittest,
IntegrationTest)
from test.utils import server_started_with_auth, single_client, EventListener
from test.utils import (rs_or_single_client,
EventListener)
if PY3:
long = int
@ -48,9 +49,7 @@ class TestCursorNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(
client_context.host, client_context.port, connect=False)
cls.db = client.test
cls.db = MongoClient(connect=False).test
def test_deepcopy_cursor_littered_with_regexes(self):
cursor = self.db.test.find({
@ -128,8 +127,15 @@ class TestCursorNoConnect(unittest.TestCase):
cursor.remove_option(128)
self.assertEqual(0, cursor._Cursor__query_flags)
class TestCursor(IntegrationTest):
def test_add_remove_option_exhaust(self):
# Exhaust - which mongos doesn't support
if client_context.client is not None and not self.db.client.is_mongos:
if client_context.is_mongos:
with self.assertRaises(InvalidOperation):
self.db.test.find(cursor_type=CursorType.EXHAUST)
else:
cursor = self.db.test.find(cursor_type=CursorType.EXHAUST)
self.assertEqual(64, cursor._Cursor__query_flags)
cursor2 = self.db.test.find().add_option(64)
@ -140,9 +146,6 @@ class TestCursorNoConnect(unittest.TestCase):
self.assertEqual(0, cursor._Cursor__query_flags)
self.assertFalse(cursor._Cursor__exhaust)
class TestCursor(IntegrationTest):
@client_context.require_version_min(2, 5, 3, -1)
def test_max_time_ms(self):
db = self.db
@ -226,7 +229,7 @@ class TestCursor(IntegrationTest):
listener.add_command_filter('killCursors')
saved_listeners = monitoring._LISTENERS
monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
coll = single_client(
coll = rs_or_single_client(
event_listeners=[listener])[self.db.name].pymongo_test
results = listener.results
@ -312,6 +315,7 @@ class TestCursor(IntegrationTest):
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_test_commands
@client_context.require_no_mongos
def test_max_time_ms_getmore(self):
# Test that Cursor handles server timeout error in response to getmore.
coll = self.db.pymongo_test
@ -1172,7 +1176,7 @@ class TestCursor(IntegrationTest):
@client_context.require_no_mongos
def test_comment(self):
if server_started_with_auth(self.db.client):
if client_context.auth_enabled:
raise SkipTest("SERVER-4754 - This test uses profiling.")
# MongoDB 3.1.5 changed the ns for commands.

View File

@ -29,8 +29,7 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from bson.py3compat import string_type, text_type, PY3
from bson.son import SON
from pymongo import (MongoClient,
ALL,
from pymongo import (ALL,
auth,
OFF,
SLOW_ONLY,
@ -43,6 +42,7 @@ from pymongo.errors import (CollectionInvalid,
InvalidName,
OperationFailure,
WriteConcernError)
from pymongo.mongo_client import MongoClient
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
@ -63,12 +63,12 @@ if PY3:
class TestDatabaseNoConnect(unittest.TestCase):
"""Test Database features on a client that does not connect.
"""
@classmethod
@client_context.require_connection
def setUpClass(cls):
cls.client = MongoClient(
client_context.host, client_context.port, connect=False)
cls.client = MongoClient(connect=False)
def test_name(self):
self.assertRaises(TypeError, Database, self.client, 4)
@ -79,23 +79,6 @@ class TestDatabaseNoConnect(unittest.TestCase):
self.client, u"my\u0000db")
self.assertEqual("name", Database(self.client, "name").name)
def test_equality(self):
self.assertNotEqual(Database(self.client, "test"),
Database(self.client, "mike"))
self.assertEqual(Database(self.client, "test"),
Database(self.client, "test"))
# Explicitly test inequality
self.assertFalse(Database(self.client, "test") !=
Database(self.client, "test"))
def test_get_coll(self):
db = Database(self.client, "pymongo_test")
self.assertEqual(db.test, db["test"])
self.assertEqual(db.test, Collection(db, "test"))
self.assertNotEqual(db.test, Collection(db, "mike"))
self.assertEqual(db.test.mike, db["test.mike"])
def test_get_collection(self):
codec_options = CodecOptions(tz_aware=True)
write_concern = WriteConcern(w=2, j=True)
@ -128,6 +111,23 @@ class TestDatabaseNoConnect(unittest.TestCase):
class TestDatabase(IntegrationTest):
def test_equality(self):
self.assertNotEqual(Database(self.client, "test"),
Database(self.client, "mike"))
self.assertEqual(Database(self.client, "test"),
Database(self.client, "test"))
# Explicitly test inequality
self.assertFalse(Database(self.client, "test") !=
Database(self.client, "test"))
def test_get_coll(self):
db = Database(self.client, "pymongo_test")
self.assertEqual(db.test, db["test"])
self.assertEqual(db.test, Collection(db, "test"))
self.assertNotEqual(db.test, Collection(db, "mike"))
self.assertEqual(db.test.mike, db["test.mike"])
def test_repr(self):
self.assertEqual(repr(Database(self.client, "pymongo_test")),
"Database(%r, %s)" % (self.client,
@ -818,6 +818,7 @@ class TestDatabase(IntegrationTest):
@client_context.require_version_min(2, 5, 3, -1)
@client_context.require_test_commands
@client_context.require_no_mongos
def test_command_max_time_ms(self):
self.client.admin.command("configureFailPoint",
"maxTimeAlwaysTimeOut",

View File

@ -34,20 +34,18 @@ from gridfs.errors import NoFile
from pymongo import MongoClient
from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError
from test import (IntegrationTest,
client_context,
unittest,
qcheck)
from test.utils import rs_or_single_client
class TestGridFileNoConnect(unittest.TestCase):
"""Test GridFile features on a client that does not connect.
"""
@classmethod
@client_context.require_connection
def setUpClass(cls):
client = MongoClient(
client_context.host, client_context.port, connect=False)
cls.db = client.pymongo_test
cls.db = MongoClient(connect=False).pymongo_test
def test_grid_in_custom_opts(self):
self.assertRaises(TypeError, GridIn, "foo")
@ -73,17 +71,6 @@ class TestGridFileNoConnect(unittest.TestCase):
self.assertEqual(1000, b.chunk_size)
self.assertEqual(100, b.baz)
def test_grid_out_cursor_options(self):
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
projection={"filename": 1})
cursor = GridOutCursor(self.db.fs, {})
cursor_clone = cursor.clone()
self.assertEqual(cursor_clone.__dict__, cursor.__dict__)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
class TestGridFile(IntegrationTest):
@ -239,6 +226,17 @@ class TestGridFile(IntegrationTest):
"upload_date", "aliases", "metadata", "md5"]:
self.assertRaises(AttributeError, setattr, b, attr, 5)
def test_grid_out_cursor_options(self):
self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {},
projection={"filename": 1})
cursor = GridOutCursor(self.db.fs, {})
cursor_clone = cursor.clone()
self.assertEqual(cursor_clone.__dict__, cursor.__dict__)
self.assertRaises(NotImplementedError, cursor.add_option, 0)
self.assertRaises(NotImplementedError, cursor.remove_option, 0)
def test_grid_out_custom_opts(self):
one = GridIn(self.db.fs, _id=5, filename="my_file",
contentType="text/html", chunkSize=1000, aliases=["foo"],
@ -608,7 +606,6 @@ Bye"""))
self.assertRaises(ServerSelectionTimeoutError, infile.write, b'data')
self.assertRaises(ServerSelectionTimeoutError, infile.close)
def test_unacknowledged(self):
# w=0 is prohibited.
with self.assertRaises(ConfigurationError):

View File

@ -79,9 +79,7 @@ class TestGridfsNoConnect(unittest.TestCase):
@classmethod
def setUpClass(cls):
client = MongoClient(
client_context.host, client_context.port, connect=False)
cls.db = client.pymongo_test
cls.db = MongoClient(connect=False).pymongo_test
def test_gridfs(self):
self.assertRaises(TypeError, gridfs.GridFS, "foo")
@ -512,9 +510,8 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase):
self.assertRaises(ConnectionFailure, fs.put, 'data')
def tearDown(self):
rsc = client_context.rs_client
rsc.pymongo_test.drop_collection('fs.files')
rsc.pymongo_test.drop_collection('fs.chunks')
self.client.pymongo_test.drop_collection('fs.files')
self.client.pymongo_test.drop_collection('fs.chunks')
if __name__ == "__main__":

View File

@ -28,20 +28,18 @@ from bson.py3compat import StringIO, string_type
from gridfs.errors import NoFile, CorruptGridFile
from pymongo.errors import (ConfigurationError,
ConnectionFailure,
ServerSelectionTimeoutError,
OperationFailure)
ServerSelectionTimeoutError)
from pymongo.mongo_client import MongoClient
from pymongo.read_preferences import ReadPreference
from test import (client_context,
unittest,
IntegrationTest)
from test.test_replica_set_client import TestReplicaSetClientBase
from test.utils import (joinall,
single_client,
one,
rs_client,
rs_or_single_client,
rs_or_single_client_noauth,
remove_all_users)
rs_or_single_client)
class JustWrite(threading.Thread):
@ -496,9 +494,8 @@ class TestGridfsBucketReplicaSet(TestReplicaSetClientBase):
"test_filename", b'data')
def tearDown(self):
rsc = client_context.rs_client
rsc.pymongo_test.drop_collection('fs.files')
rsc.pymongo_test.drop_collection('fs.chunks')
self.client.pymongo_test.drop_collection('fs.files')
self.client.pymongo_test.drop_collection('fs.chunks')
if __name__ == "__main__":

View File

@ -28,7 +28,10 @@ from pymongo.errors import NotMasterError, OperationFailure
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
from test import unittest, client_context, client_knobs
from test.utils import single_client, wait_until, EventListener
from test.utils import (EventListener,
rs_or_single_client,
single_client,
wait_until)
class TestCommandMonitoring(unittest.TestCase):
@ -40,7 +43,7 @@ class TestCommandMonitoring(unittest.TestCase):
cls.saved_listeners = monitoring._LISTENERS
# Don't use any global subscribers.
monitoring._LISTENERS = monitoring._Listeners([], [], [], [])
cls.client = single_client(event_listeners=[cls.listener])
cls.client = rs_or_single_client(event_listeners=[cls.listener])
@classmethod
def tearDownClass(cls):
@ -384,7 +387,7 @@ class TestCommandMonitoring(unittest.TestCase):
@client_context.require_replica_set
def test_not_master_error(self):
address = next(iter(client_context.rs_client.secondaries))
address = next(iter(client_context.client.secondaries))
client = single_client(*address, event_listeners=[self.listener])
# Clear authentication command results from the listener.
client.admin.command('ismaster')

View File

@ -166,6 +166,10 @@ class _TestPoolingBase(unittest.TestCase):
pair=(client_context.host, client_context.port),
*args,
**kwargs):
# Start the pool with the correct ssl options.
pool_options = client_context.client._topology_settings.pool_options
kwargs['ssl_context'] = pool_options.ssl_context
kwargs['ssl_match_hostname'] = pool_options.ssl_match_hostname
return Pool(pair, PoolOptions(*args, **kwargs))
@ -173,14 +177,12 @@ class TestPooling(_TestPoolingBase):
def test_max_pool_size_validation(self):
host, port = client_context.host, client_context.port
self.assertRaises(
ValueError, MongoClient, host=host, port=port,
maxPoolSize=-1)
ValueError, MongoClient, host=host, port=port, maxPoolSize=-1)
self.assertRaises(
ValueError, MongoClient, host=host, port=port,
maxPoolSize='foo')
ValueError, MongoClient, host=host, port=port, maxPoolSize='foo')
c = MongoClient(host=host, port=port, maxPoolSize=100)
c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False)
self.assertEqual(c.max_pool_size, 100)
def test_no_disconnect(self):

View File

@ -22,7 +22,7 @@ class TestRawBSONDocument(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.client = client_context.rs_or_standalone_client
cls.client = client_context.client
def tearDown(self):
if client_context.connected:

View File

@ -22,7 +22,7 @@ from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.read_concern import ReadConcern
from test import client_context, unittest
from test.utils import single_client, EventListener
from test.utils import single_client, rs_or_single_client, EventListener
class TestReadConcern(unittest.TestCase):
@ -61,8 +61,9 @@ class TestReadConcern(unittest.TestCase):
self.assertRaises(TypeError, ReadConcern, 42)
def test_read_concern_uri(self):
uri = 'mongodb://%s/?readConcernLevel=majority' % (client_context.pair,)
client = pymongo.MongoClient(uri)
uri = 'mongodb://%s/?readConcernLevel=majority' % (
client_context.pair,)
client = rs_or_single_client(uri, connect=False)
self.assertEqual(ReadConcern('majority'), client.read_concern)
@client_context.require_version_max(3, 1)

View File

@ -47,6 +47,8 @@ from test.version import Version
class TestSelections(unittest.TestCase):
@client_context.require_connection
def test_bool(self):
client = single_client()
@ -293,7 +295,9 @@ class TestReadPreferences(TestReadPreferencesBase):
class ReadPrefTester(MongoClient):
def __init__(self, *args, **kwargs):
self.has_read_from = set()
super(ReadPrefTester, self).__init__(*args, **kwargs)
client_options = client_context.ssl_client_options.copy()
client_options.update(kwargs)
super(ReadPrefTester, self).__init__(*args, **client_options)
@contextlib.contextmanager
def _socket_for_reads(self, read_preference):

View File

@ -38,6 +38,7 @@ from test import (client_context,
client_knobs,
IntegrationTest,
unittest,
SkipTest,
db_pwd,
db_user,
MockClientTest)
@ -106,7 +107,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase):
self.assertIn(client_context.pair, repr(client))
def test_properties(self):
c = client_context.rs_client
c = client_context.client
c.admin.command('ping')
wait_until(lambda: c.primary == self.primary, "discover primary")
@ -132,9 +133,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase):
tag_sets = [{'dc': 'la', 'rack': '2'}, {'foo': 'bar'}]
secondary = Secondary(tag_sets=tag_sets)
c = MongoClient(
client_context.pair,
replicaSet=self.name,
c = rs_client(
maxPoolSize=25,
document_class=SON,
tz_aware=True,
@ -213,19 +212,17 @@ class TestReplicaSetClient(TestReplicaSetClientBase):
# No error.
coll.find_one()
@client_context.require_replica_set
@client_context.require_ipv6
def test_ipv6(self):
port = client_context.port
c = MongoClient("mongodb://[::1]:%d" % (port,), replicaSet=self.name)
c = rs_client("mongodb://[::1]:%d" % (port,))
# Client switches to IPv4 once it has first ismaster response.
msg = 'discovered primary with IPv4 address "%r"' % (self.primary,)
wait_until(lambda: c.primary == self.primary, msg)
# Same outcome with both IPv4 and IPv6 seeds.
c = MongoClient("[::1]:%d,localhost:%d" % (port, port),
replicaSet=self.name)
c = rs_client("[::1]:%d,localhost:%d" % (port, port))
wait_until(lambda: c.primary == self.primary, msg)
@ -235,7 +232,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase):
auth_str = ""
uri = "mongodb://%slocalhost:%d,[::1]:%d" % (auth_str, port, port)
client = MongoClient(uri, replicaSet=self.name)
client = rs_client(uri)
client.pymongo_test.test.insert_one({"dummy": u"object"})
client.pymongo_test_bernie.test.insert_one({"dummy": u"object"})

View File

@ -157,9 +157,18 @@ class TestSSL(IntegrationTest):
self.assertTrue(coll.find_one()['ssl'])
coll.drop()
@classmethod
@unittest.skipUnless(HAVE_SSL, "The ssl module is not available.")
def setUp(self):
super(TestSSL, self).setUp()
def setUpClass(cls):
super(TestSSL, cls).setUpClass()
# MongoClient should connect to the primary by default.
cls.saved_port = MongoClient.PORT
MongoClient.PORT = client_context.port
@classmethod
def tearDownClass(cls):
MongoClient.PORT = cls.saved_port
super(TestSSL, cls).tearDownClass()
@client_context.require_ssl
def test_simple_ssl(self):
@ -259,7 +268,6 @@ class TestSSL(IntegrationTest):
ssl_ca_certs=CA_PEM)
self.assertClientWorks(client)
@client_context.require_ssl_certfile
@client_context.require_server_resolvable
@client_context.require_no_auth

View File

@ -131,7 +131,7 @@ class Disconnect(threading.Thread):
class TestThreads(IntegrationTest):
def setUp(self):
self.db = client_context.rs_or_standalone_client.pymongo_test
self.db = self.client.pymongo_test
def test_threading(self):
self.db.drop_collection("test")

View File

@ -98,45 +98,51 @@ class HeartbeatEventListener(monitoring.ServerHeartbeatListener):
self.results.append(event)
def _connection_string_noauth(h, p):
def _connection_string(h, p, authenticate):
if h.startswith("mongodb://"):
return h
return "mongodb://%s:%d" % (h, p)
def _connection_string(h, p):
if h.startswith("mongodb://"):
return h
elif client_context.auth_enabled:
elif client_context.auth_enabled and authenticate:
return "mongodb://%s:%s@%s:%d" % (db_user, db_pwd, h, p)
else:
return _connection_string_noauth(h, p)
return "mongodb://%s:%d" % (h, p)
def _mongo_client(host, port, authenticate=True, direct=False, **kwargs):
"""Create a new client over SSL/TLS if necessary."""
client_options = client_context.ssl_client_options.copy()
if client_context.replica_set_name and not direct:
client_options['replicaSet'] = client_context.replica_set_name
client_options.update(kwargs)
client = MongoClient(_connection_string(host, port, authenticate), port,
**client_options)
return client
def single_client_noauth(
h=client_context.host, p=client_context.port, **kwargs):
"""Make a direct connection. Don't authenticate."""
return MongoClient(_connection_string_noauth(h, p), p, **kwargs)
return _mongo_client(h, p, authenticate=False, direct=True, **kwargs)
def single_client(
h=client_context.host, p=client_context.port, **kwargs):
"""Make a direct connection, and authenticate if necessary."""
return MongoClient(_connection_string(h, p), p, **kwargs)
return _mongo_client(h, p, direct=True, **kwargs)
def rs_client_noauth(
h=client_context.host, p=client_context.port, **kwargs):
"""Connect to the replica set. Don't authenticate."""
return MongoClient(_connection_string_noauth(h, p), p,
replicaSet=client_context.replica_set_name, **kwargs)
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_client(
h=client_context.host, p=client_context.port, **kwargs):
"""Connect to the replica set and authenticate if necessary."""
return MongoClient(_connection_string(h, p), p,
replicaSet=client_context.replica_set_name, **kwargs)
return _mongo_client(h, p, **kwargs)
def rs_or_single_client_noauth(
@ -145,10 +151,7 @@ def rs_or_single_client_noauth(
Like rs_or_single_client, but does not authenticate.
"""
if client_context.replica_set_name:
return rs_client_noauth(h, p, **kwargs)
else:
return single_client_noauth(h, p, **kwargs)
return _mongo_client(h, p, authenticate=False, **kwargs)
def rs_or_single_client(
@ -157,10 +160,7 @@ def rs_or_single_client(
Authenticates if necessary.
"""
if client_context.replica_set_name:
return rs_client(h, p, **kwargs)
else:
return single_client(h, p, **kwargs)
return _mongo_client(h, p, **kwargs)
def one(s):
@ -319,9 +319,7 @@ def enable_text_search(client):
'setParameter', textSearchEnabled=True)
for host, port in client.secondaries:
client = MongoClient(host, port)
if client_context.auth_enabled:
client.admin.authenticate(db_user, db_pwd)
client = single_client(host, port)
client.admin.command('setParameter', textSearchEnabled=True)