PYTHON-1852 Use TLS option names in test suite ClientContext (#442)
This commit is contained in:
parent
74202455aa
commit
84fd04ec6d
@ -81,15 +81,12 @@ CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
CLIENT_PEM = os.environ.get('CLIENT_PEM',
|
||||
os.path.join(CERT_PATH, 'client.pem'))
|
||||
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
|
||||
CERT_REQS = validate_cert_reqs('CERT_REQS', os.environ.get('CERT_REQS'))
|
||||
|
||||
_SSL_OPTIONS = dict(ssl=True)
|
||||
TLS_OPTIONS = dict(tls=True)
|
||||
if CLIENT_PEM:
|
||||
_SSL_OPTIONS['ssl_certfile'] = CLIENT_PEM
|
||||
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
|
||||
if CA_PEM:
|
||||
_SSL_OPTIONS['ssl_ca_certs'] = CA_PEM
|
||||
if CERT_REQS is not None:
|
||||
_SSL_OPTIONS['ssl_cert_reqs'] = CERT_REQS
|
||||
TLS_OPTIONS['tlsCAFile'] = CA_PEM
|
||||
|
||||
COMPRESSORS = os.environ.get("COMPRESSORS")
|
||||
|
||||
@ -187,8 +184,7 @@ class ClientContext(object):
|
||||
self.mongoses = []
|
||||
self.is_rs = False
|
||||
self.has_ipv6 = False
|
||||
self.ssl = False
|
||||
self.ssl_cert_none = False
|
||||
self.tls = False
|
||||
self.ssl_certfile = False
|
||||
self.server_is_resolvable = is_server_resolvable()
|
||||
self.default_client_options = {}
|
||||
@ -235,13 +231,11 @@ class ClientContext(object):
|
||||
self.client = self._connect(host, port)
|
||||
if HAVE_SSL and not self.client:
|
||||
# Is MongoDB configured for SSL?
|
||||
self.client = self._connect(host, port, **_SSL_OPTIONS)
|
||||
self.client = self._connect(host, port, **TLS_OPTIONS)
|
||||
if self.client:
|
||||
self.ssl = True
|
||||
self.default_client_options.update(_SSL_OPTIONS)
|
||||
self.tls = True
|
||||
self.default_client_options.update(TLS_OPTIONS)
|
||||
self.ssl_certfile = True
|
||||
if _SSL_OPTIONS.get('ssl_cert_reqs') == ssl.CERT_NONE:
|
||||
self.ssl_cert_none = True
|
||||
|
||||
if self.client:
|
||||
self.connected = True
|
||||
@ -608,22 +602,16 @@ class ClientContext(object):
|
||||
"failCommand fail point must be supported",
|
||||
func=func)
|
||||
|
||||
def require_ssl(self, func):
|
||||
"""Run a test only if the client can connect over SSL."""
|
||||
return self._require(lambda: self.ssl,
|
||||
"Must be able to connect via SSL",
|
||||
def require_tls(self, func):
|
||||
"""Run a test only if the client can connect over TLS."""
|
||||
return self._require(lambda: self.tls,
|
||||
"Must be able to connect via TLS",
|
||||
func=func)
|
||||
|
||||
def require_no_ssl(self, func):
|
||||
"""Run a test only if the client can connect over SSL."""
|
||||
return self._require(lambda: not self.ssl,
|
||||
"Must be able to connect without SSL",
|
||||
func=func)
|
||||
|
||||
def require_ssl_cert_none(self, func):
|
||||
"""Run a test only if the client can connect with ssl.CERT_NONE."""
|
||||
return self._require(lambda: self.ssl_cert_none,
|
||||
"Must be able to connect with ssl.CERT_NONE",
|
||||
def require_no_tls(self, func):
|
||||
"""Run a test only if the client can connect over TLS."""
|
||||
return self._require(lambda: not self.tls,
|
||||
"Must be able to connect without TLS",
|
||||
func=func)
|
||||
|
||||
def require_ssl_certfile(self, func):
|
||||
|
||||
@ -919,7 +919,7 @@ class TestClient(IntegrationTest):
|
||||
assertRaisesExactly(
|
||||
OperationFailure, lazy_client.test.collection.find_one)
|
||||
|
||||
@client_context.require_no_ssl
|
||||
@client_context.require_no_tls
|
||||
def test_unix_socket(self):
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
raise SkipTest("UNIX-sockets are not supported on this system")
|
||||
@ -1086,7 +1086,7 @@ class TestClient(IntegrationTest):
|
||||
|
||||
@client_context.require_ipv6
|
||||
def test_ipv6(self):
|
||||
if client_context.ssl:
|
||||
if client_context.tls:
|
||||
if not HAVE_IPADDRESS:
|
||||
raise SkipTest("Need the ipaddress module to test with SSL")
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class TestDNS(unittest.TestCase):
|
||||
def create_test(test_case):
|
||||
|
||||
@client_context.require_replica_set
|
||||
@client_context.require_ssl
|
||||
@client_context.require_tls
|
||||
def run_test(self):
|
||||
if not _HAVE_DNSPYTHON:
|
||||
raise unittest.SkipTest("DNS tests require the dnspython module")
|
||||
@ -68,7 +68,7 @@ def create_test(test_case):
|
||||
# The replica set members must be configured as 'localhost'.
|
||||
if hostname == 'localhost':
|
||||
copts = client_context.default_client_options.copy()
|
||||
if client_context.ssl is True:
|
||||
if client_context.tls is True:
|
||||
# Our test certs don't support the SRV hosts used in these tests.
|
||||
copts['ssl_match_hostname'] = False
|
||||
|
||||
|
||||
@ -195,7 +195,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase):
|
||||
|
||||
@client_context.require_ipv6
|
||||
def test_ipv6(self):
|
||||
if client_context.ssl:
|
||||
if client_context.tls:
|
||||
if not HAVE_IPADDRESS:
|
||||
raise SkipTest("Need the ipaddress module to test with SSL")
|
||||
|
||||
|
||||
@ -196,7 +196,7 @@ class TestSSL(IntegrationTest):
|
||||
MongoClient.PORT = cls.saved_port
|
||||
super(TestSSL, cls).tearDownClass()
|
||||
|
||||
@client_context.require_ssl
|
||||
@client_context.require_tls
|
||||
def test_simple_ssl(self):
|
||||
# Expects the server to be running with ssl and with
|
||||
# no --sslPEMKeyFile or with --sslWeakCertificateValidation
|
||||
|
||||
Loading…
Reference in New Issue
Block a user