Included create_ssl_context function to create the same context with SSLConfig and serve as API (#996)

* Included create_ssl_context function to create the same context with SSLConfig and serve as API.

* Changed create_ssl_context with SSLConfig into the client implementation and tests.

* Dropped the __repr__ and __eq__ methods from SSLConfig and removed SSLConfig using from tests

* Fixed test issue regarding cert_authority trust of ssl context

Co-authored-by: Tom Christie <tom@tomchristie.com>
Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
Can Sarıgöl 2020-07-27 20:46:46 +02:00 committed by GitHub
parent 35f09d1394
commit 926a55a84f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 83 deletions

View File

@ -600,6 +600,23 @@ Alternatively, you can pass a standard library `ssl.SSLContext`.
<Response [200 OK]>
```
We also include a helper function for creating properly configured `SSLContext` instances.
```python
>>> context = httpx.create_ssl_context()
```
The `create_ssl_context` function accepts the same set of SSL configuration arguments
(`trust_env`, `verify`, `cert` and `http2` arguments)
as `httpx.Client` or `httpx.AsyncClient`
```python
>>> import httpx
>>> context = httpx.create_ssl_context(verify="/tmp/client.pem")
>>> httpx.get('https://example.org', verify=context)
<Response [200 OK]>
```
Or you can also disable the SSL verification entirely, which is _not_ recommended.
```python

View File

@ -2,7 +2,7 @@ from .__version__ import __description__, __title__, __version__
from ._api import delete, get, head, options, patch, post, put, request, stream
from ._auth import Auth, BasicAuth, DigestAuth
from ._client import AsyncClient, Client
from ._config import PoolLimits, Proxy, Timeout
from ._config import PoolLimits, Proxy, Timeout, create_ssl_context
from ._exceptions import (
CloseError,
ConnectError,
@ -61,6 +61,7 @@ __all__ = [
"PoolLimits",
"Proxy",
"Timeout",
"create_ssl_context",
"CloseError",
"ConnectError",
"ConnectTimeout",

View File

@ -13,9 +13,9 @@ from ._config import (
UNSET,
PoolLimits,
Proxy,
SSLConfig,
Timeout,
UnsetType,
create_ssl_context,
)
from ._content_streams import ContentStream
from ._exceptions import (
@ -499,9 +499,7 @@ class Client(BaseClient):
if app is not None:
return WSGITransport(app=app)
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
return httpcore.SyncConnectionPool(
ssl_context=ssl_context,
@ -520,9 +518,7 @@ class Client(BaseClient):
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
trust_env: bool = True,
) -> httpcore.SyncHTTPTransport:
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
return httpcore.SyncHTTPProxy(
proxy_url=proxy.url.raw,
@ -1032,9 +1028,7 @@ class AsyncClient(BaseClient):
if app is not None:
return ASGITransport(app=app)
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
return httpcore.AsyncConnectionPool(
ssl_context=ssl_context,
@ -1053,9 +1047,7 @@ class AsyncClient(BaseClient):
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
trust_env: bool = True,
) -> httpcore.AsyncHTTPTransport:
ssl_context = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env
).ssl_context
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
return httpcore.AsyncHTTPProxy(
proxy_url=proxy.url.raw,

View File

@ -40,6 +40,17 @@ class UnsetType:
UNSET = UnsetType()
def create_ssl_context(
cert: CertTypes = None,
verify: VerifyTypes = True,
trust_env: bool = None,
http2: bool = False,
) -> ssl.SSLContext:
return SSLConfig(
cert=cert, verify=verify, trust_env=trust_env, http2=http2
).ssl_context
class SSLConfig:
"""
SSL Configuration.
@ -61,17 +72,6 @@ class SSLConfig:
self.http2 = http2
self.ssl_context = self.load_ssl_context()
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.cert == other.cert
and self.verify == other.verify
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(cert={self.cert}, verify={self.verify})"
def load_ssl_context(self) -> ssl.SSLContext:
logger.trace(
f"load_ssl_context "

View File

@ -3,7 +3,7 @@ from typing import Dict, Iterator, List, Optional, Tuple
import httpcore
from .._config import SSLConfig
from .._config import create_ssl_context
from .._content_streams import ByteStream, IteratorStream
from .._exceptions import NetworkError, map_exceptions
from .._types import CertTypes, VerifyTypes
@ -30,12 +30,10 @@ class URLLib3Transport(httpcore.SyncHTTPTransport):
urllib3 is not None
), "urllib3 must be installed in order to use URLLib3Transport"
ssl_config = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env, http2=False
)
self.pool = urllib3.PoolManager(
ssl_context=ssl_config.ssl_context,
ssl_context=create_ssl_context(
verify=verify, cert=cert, trust_env=trust_env, http2=False
),
num_pools=pool_connections,
maxsize=pool_maxsize,
block=pool_block,
@ -139,14 +137,12 @@ class URLLib3ProxyTransport(URLLib3Transport):
urllib3 is not None
), "urllib3 must be installed in order to use URLLib3ProxyTransport"
ssl_config = SSLConfig(
verify=verify, cert=cert, trust_env=trust_env, http2=False
)
self.pool = urllib3.ProxyManager(
proxy_url=proxy_url,
proxy_headers=proxy_headers,
ssl_context=ssl_config.ssl_context,
ssl_context=create_ssl_context(
verify=verify, cert=cert, trust_env=trust_env, http2=False
),
num_pools=pool_connections,
maxsize=pool_maxsize,
block=pool_block,

View File

@ -1,5 +1,4 @@
import os
import socket
import ssl
import sys
from pathlib import Path
@ -8,63 +7,51 @@ import certifi
import pytest
import httpx
from httpx._config import SSLConfig
def test_load_ssl_config():
ssl_config = SSLConfig()
context = ssl_config.ssl_context
context = httpx.create_ssl_context()
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
def test_load_ssl_config_verify_non_existing_path():
with pytest.raises(IOError):
SSLConfig(verify="/path/to/nowhere")
httpx.create_ssl_context(verify="/path/to/nowhere")
def test_load_ssl_config_verify_existing_file():
ssl_config = SSLConfig(verify=certifi.where())
context = ssl_config.ssl_context
context = httpx.create_ssl_context(verify=certifi.where())
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@pytest.mark.parametrize("config", ("SSL_CERT_FILE", "SSL_CERT_DIR"))
def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config):
def test_load_ssl_config_verify_env_file(
https_server, ca_cert_pem_file, config, cert_authority
):
os.environ[config] = (
ca_cert_pem_file
if config.endswith("_FILE")
else str(Path(ca_cert_pem_file).parent)
)
ssl_config = SSLConfig(trust_env=True)
context = ssl_config.ssl_context
context = httpx.create_ssl_context(trust_env=True)
cert_authority.configure_trust(context)
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
assert ssl_config.verify == os.environ[config]
# Skipping 'SSL_CERT_DIR' functional test for now because
# we're unable to get the certificate within the directory to
# load into the SSLContext. :(
if config == "SSL_CERT_FILE":
host = https_server.url.host
port = https_server.url.port
conn = socket.create_connection((host, port))
context.wrap_socket(conn, server_hostname=host)
assert len(context.get_ca_certs()) == 1
assert len(context.get_ca_certs()) == 1
def test_load_ssl_config_verify_directory():
path = Path(certifi.where()).parent
ssl_config = SSLConfig(verify=str(path))
context = ssl_config.ssl_context
context = httpx.create_ssl_context(verify=str(path))
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
context = ssl_config.ssl_context
context = httpx.create_ssl_context(cert=(cert_pem_file, cert_private_key_file))
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@ -73,10 +60,9 @@ def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
def test_load_ssl_config_cert_and_encrypted_key(
cert_pem_file, cert_encrypted_private_key_file, password
):
ssl_config = SSLConfig(
context = httpx.create_ssl_context(
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
)
context = ssl_config.ssl_context
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
assert context.check_hostname is True
@ -85,36 +71,33 @@ def test_load_ssl_config_cert_and_key_invalid_password(
cert_pem_file, cert_encrypted_private_key_file
):
with pytest.raises(ssl.SSLError):
SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
httpx.create_ssl_context(
cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
)
def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
with pytest.raises(ssl.SSLError):
SSLConfig(cert=cert_pem_file)
httpx.create_ssl_context(cert=cert_pem_file)
def test_load_ssl_config_no_verify():
ssl_config = SSLConfig(verify=False)
context = ssl_config.ssl_context
context = httpx.create_ssl_context(verify=False)
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
assert context.check_hostname is False
def test_load_ssl_context():
ssl_context = ssl.create_default_context()
ssl_config = SSLConfig(verify=ssl_context)
context = httpx.create_ssl_context(verify=ssl_context)
assert ssl_config.ssl_context is ssl_context
assert context is ssl_context
def test_ssl_repr():
ssl = SSLConfig(verify=False)
assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
def test_ssl_eq():
ssl = SSLConfig(verify=False)
assert ssl == SSLConfig(verify=False)
def test_create_ssl_context_with_get_request(server, cert_pem_file):
context = httpx.create_ssl_context(verify=cert_pem_file)
response = httpx.get(server.url, verify=context)
assert response.status_code == 200
def test_limits_repr():
@ -190,22 +173,22 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch): # pragma: noc
with monkeypatch.context() as m:
m.delenv("SSLKEYLOGFILE", raising=False)
ssl_config = SSLConfig(trust_env=True)
context = httpx.create_ssl_context(trust_env=True)
assert ssl_config.ssl_context.keylog_filename is None # type: ignore
assert context.keylog_filename is None # type: ignore
filename = str(tmpdir.join("test.log"))
with monkeypatch.context() as m:
m.setenv("SSLKEYLOGFILE", filename)
ssl_config = SSLConfig(trust_env=True)
context = httpx.create_ssl_context(trust_env=True)
assert ssl_config.ssl_context.keylog_filename == filename # type: ignore
assert context.keylog_filename == filename # type: ignore
ssl_config = SSLConfig(trust_env=False)
context = httpx.create_ssl_context(trust_env=False)
assert ssl_config.ssl_context.keylog_filename is None # type: ignore
assert context.keylog_filename is None # type: ignore
@pytest.mark.parametrize(