Included create_ssl_context function to create the same context with SSLConfig and serve as API (#996)
* Included create_ssl_context function to create the same context with SSLConfig and serve as API. * Changed create_ssl_context with SSLConfig into the client implementation and tests. * Dropped the __repr__ and __eq__ methods from SSLConfig and removed SSLConfig using from tests * Fixed test issue regarding cert_authority trust of ssl context Co-authored-by: Tom Christie <tom@tomchristie.com> Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
This commit is contained in:
parent
35f09d1394
commit
926a55a84f
@ -600,6 +600,23 @@ Alternatively, you can pass a standard library `ssl.SSLContext`.
|
||||
<Response [200 OK]>
|
||||
```
|
||||
|
||||
We also include a helper function for creating properly configured `SSLContext` instances.
|
||||
|
||||
```python
|
||||
>>> context = httpx.create_ssl_context()
|
||||
```
|
||||
|
||||
The `create_ssl_context` function accepts the same set of SSL configuration arguments
|
||||
(`trust_env`, `verify`, `cert` and `http2` arguments)
|
||||
as `httpx.Client` or `httpx.AsyncClient`
|
||||
|
||||
```python
|
||||
>>> import httpx
|
||||
>>> context = httpx.create_ssl_context(verify="/tmp/client.pem")
|
||||
>>> httpx.get('https://example.org', verify=context)
|
||||
<Response [200 OK]>
|
||||
```
|
||||
|
||||
Or you can also disable the SSL verification entirely, which is _not_ recommended.
|
||||
|
||||
```python
|
||||
|
||||
@ -2,7 +2,7 @@ from .__version__ import __description__, __title__, __version__
|
||||
from ._api import delete, get, head, options, patch, post, put, request, stream
|
||||
from ._auth import Auth, BasicAuth, DigestAuth
|
||||
from ._client import AsyncClient, Client
|
||||
from ._config import PoolLimits, Proxy, Timeout
|
||||
from ._config import PoolLimits, Proxy, Timeout, create_ssl_context
|
||||
from ._exceptions import (
|
||||
CloseError,
|
||||
ConnectError,
|
||||
@ -61,6 +61,7 @@ __all__ = [
|
||||
"PoolLimits",
|
||||
"Proxy",
|
||||
"Timeout",
|
||||
"create_ssl_context",
|
||||
"CloseError",
|
||||
"ConnectError",
|
||||
"ConnectTimeout",
|
||||
|
||||
@ -13,9 +13,9 @@ from ._config import (
|
||||
UNSET,
|
||||
PoolLimits,
|
||||
Proxy,
|
||||
SSLConfig,
|
||||
Timeout,
|
||||
UnsetType,
|
||||
create_ssl_context,
|
||||
)
|
||||
from ._content_streams import ContentStream
|
||||
from ._exceptions import (
|
||||
@ -499,9 +499,7 @@ class Client(BaseClient):
|
||||
if app is not None:
|
||||
return WSGITransport(app=app)
|
||||
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
return httpcore.SyncConnectionPool(
|
||||
ssl_context=ssl_context,
|
||||
@ -520,9 +518,7 @@ class Client(BaseClient):
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
trust_env: bool = True,
|
||||
) -> httpcore.SyncHTTPTransport:
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
return httpcore.SyncHTTPProxy(
|
||||
proxy_url=proxy.url.raw,
|
||||
@ -1032,9 +1028,7 @@ class AsyncClient(BaseClient):
|
||||
if app is not None:
|
||||
return ASGITransport(app=app)
|
||||
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
return httpcore.AsyncConnectionPool(
|
||||
ssl_context=ssl_context,
|
||||
@ -1053,9 +1047,7 @@ class AsyncClient(BaseClient):
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
trust_env: bool = True,
|
||||
) -> httpcore.AsyncHTTPTransport:
|
||||
ssl_context = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env
|
||||
).ssl_context
|
||||
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
|
||||
|
||||
return httpcore.AsyncHTTPProxy(
|
||||
proxy_url=proxy.url.raw,
|
||||
|
||||
@ -40,6 +40,17 @@ class UnsetType:
|
||||
UNSET = UnsetType()
|
||||
|
||||
|
||||
def create_ssl_context(
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
trust_env: bool = None,
|
||||
http2: bool = False,
|
||||
) -> ssl.SSLContext:
|
||||
return SSLConfig(
|
||||
cert=cert, verify=verify, trust_env=trust_env, http2=http2
|
||||
).ssl_context
|
||||
|
||||
|
||||
class SSLConfig:
|
||||
"""
|
||||
SSL Configuration.
|
||||
@ -61,17 +72,6 @@ class SSLConfig:
|
||||
self.http2 = http2
|
||||
self.ssl_context = self.load_ssl_context()
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.cert == other.cert
|
||||
and self.verify == other.verify
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(cert={self.cert}, verify={self.verify})"
|
||||
|
||||
def load_ssl_context(self) -> ssl.SSLContext:
|
||||
logger.trace(
|
||||
f"load_ssl_context "
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import httpcore
|
||||
|
||||
from .._config import SSLConfig
|
||||
from .._config import create_ssl_context
|
||||
from .._content_streams import ByteStream, IteratorStream
|
||||
from .._exceptions import NetworkError, map_exceptions
|
||||
from .._types import CertTypes, VerifyTypes
|
||||
@ -30,12 +30,10 @@ class URLLib3Transport(httpcore.SyncHTTPTransport):
|
||||
urllib3 is not None
|
||||
), "urllib3 must be installed in order to use URLLib3Transport"
|
||||
|
||||
ssl_config = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
)
|
||||
|
||||
self.pool = urllib3.PoolManager(
|
||||
ssl_context=ssl_config.ssl_context,
|
||||
ssl_context=create_ssl_context(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
),
|
||||
num_pools=pool_connections,
|
||||
maxsize=pool_maxsize,
|
||||
block=pool_block,
|
||||
@ -139,14 +137,12 @@ class URLLib3ProxyTransport(URLLib3Transport):
|
||||
urllib3 is not None
|
||||
), "urllib3 must be installed in order to use URLLib3ProxyTransport"
|
||||
|
||||
ssl_config = SSLConfig(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
)
|
||||
|
||||
self.pool = urllib3.ProxyManager(
|
||||
proxy_url=proxy_url,
|
||||
proxy_headers=proxy_headers,
|
||||
ssl_context=ssl_config.ssl_context,
|
||||
ssl_context=create_ssl_context(
|
||||
verify=verify, cert=cert, trust_env=trust_env, http2=False
|
||||
),
|
||||
num_pools=pool_connections,
|
||||
maxsize=pool_maxsize,
|
||||
block=pool_block,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@ -8,63 +7,51 @@ import certifi
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._config import SSLConfig
|
||||
|
||||
|
||||
def test_load_ssl_config():
|
||||
ssl_config = SSLConfig()
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context()
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
def test_load_ssl_config_verify_non_existing_path():
|
||||
with pytest.raises(IOError):
|
||||
SSLConfig(verify="/path/to/nowhere")
|
||||
httpx.create_ssl_context(verify="/path/to/nowhere")
|
||||
|
||||
|
||||
def test_load_ssl_config_verify_existing_file():
|
||||
ssl_config = SSLConfig(verify=certifi.where())
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context(verify=certifi.where())
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", ("SSL_CERT_FILE", "SSL_CERT_DIR"))
|
||||
def test_load_ssl_config_verify_env_file(https_server, ca_cert_pem_file, config):
|
||||
def test_load_ssl_config_verify_env_file(
|
||||
https_server, ca_cert_pem_file, config, cert_authority
|
||||
):
|
||||
os.environ[config] = (
|
||||
ca_cert_pem_file
|
||||
if config.endswith("_FILE")
|
||||
else str(Path(ca_cert_pem_file).parent)
|
||||
)
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context(trust_env=True)
|
||||
cert_authority.configure_trust(context)
|
||||
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
assert ssl_config.verify == os.environ[config]
|
||||
|
||||
# Skipping 'SSL_CERT_DIR' functional test for now because
|
||||
# we're unable to get the certificate within the directory to
|
||||
# load into the SSLContext. :(
|
||||
if config == "SSL_CERT_FILE":
|
||||
host = https_server.url.host
|
||||
port = https_server.url.port
|
||||
conn = socket.create_connection((host, port))
|
||||
context.wrap_socket(conn, server_hostname=host)
|
||||
assert len(context.get_ca_certs()) == 1
|
||||
assert len(context.get_ca_certs()) == 1
|
||||
|
||||
|
||||
def test_load_ssl_config_verify_directory():
|
||||
path = Path(certifi.where()).parent
|
||||
ssl_config = SSLConfig(verify=str(path))
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context(verify=str(path))
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
|
||||
def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
|
||||
ssl_config = SSLConfig(cert=(cert_pem_file, cert_private_key_file))
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context(cert=(cert_pem_file, cert_private_key_file))
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
@ -73,10 +60,9 @@ def test_load_ssl_config_cert_and_key(cert_pem_file, cert_private_key_file):
|
||||
def test_load_ssl_config_cert_and_encrypted_key(
|
||||
cert_pem_file, cert_encrypted_private_key_file, password
|
||||
):
|
||||
ssl_config = SSLConfig(
|
||||
context = httpx.create_ssl_context(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, password)
|
||||
)
|
||||
context = ssl_config.ssl_context
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
|
||||
assert context.check_hostname is True
|
||||
|
||||
@ -85,36 +71,33 @@ def test_load_ssl_config_cert_and_key_invalid_password(
|
||||
cert_pem_file, cert_encrypted_private_key_file
|
||||
):
|
||||
with pytest.raises(ssl.SSLError):
|
||||
SSLConfig(cert=(cert_pem_file, cert_encrypted_private_key_file, "password1"))
|
||||
httpx.create_ssl_context(
|
||||
cert=(cert_pem_file, cert_encrypted_private_key_file, "password1")
|
||||
)
|
||||
|
||||
|
||||
def test_load_ssl_config_cert_without_key_raises(cert_pem_file):
|
||||
with pytest.raises(ssl.SSLError):
|
||||
SSLConfig(cert=cert_pem_file)
|
||||
httpx.create_ssl_context(cert=cert_pem_file)
|
||||
|
||||
|
||||
def test_load_ssl_config_no_verify():
|
||||
ssl_config = SSLConfig(verify=False)
|
||||
context = ssl_config.ssl_context
|
||||
context = httpx.create_ssl_context(verify=False)
|
||||
assert context.verify_mode == ssl.VerifyMode.CERT_NONE
|
||||
assert context.check_hostname is False
|
||||
|
||||
|
||||
def test_load_ssl_context():
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_config = SSLConfig(verify=ssl_context)
|
||||
context = httpx.create_ssl_context(verify=ssl_context)
|
||||
|
||||
assert ssl_config.ssl_context is ssl_context
|
||||
assert context is ssl_context
|
||||
|
||||
|
||||
def test_ssl_repr():
|
||||
ssl = SSLConfig(verify=False)
|
||||
assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
|
||||
|
||||
|
||||
def test_ssl_eq():
|
||||
ssl = SSLConfig(verify=False)
|
||||
assert ssl == SSLConfig(verify=False)
|
||||
def test_create_ssl_context_with_get_request(server, cert_pem_file):
|
||||
context = httpx.create_ssl_context(verify=cert_pem_file)
|
||||
response = httpx.get(server.url, verify=context)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_limits_repr():
|
||||
@ -190,22 +173,22 @@ def test_ssl_config_support_for_keylog_file(tmpdir, monkeypatch): # pragma: noc
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("SSLKEYLOGFILE", raising=False)
|
||||
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
context = httpx.create_ssl_context(trust_env=True)
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename is None # type: ignore
|
||||
assert context.keylog_filename is None # type: ignore
|
||||
|
||||
filename = str(tmpdir.join("test.log"))
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("SSLKEYLOGFILE", filename)
|
||||
|
||||
ssl_config = SSLConfig(trust_env=True)
|
||||
context = httpx.create_ssl_context(trust_env=True)
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename == filename # type: ignore
|
||||
assert context.keylog_filename == filename # type: ignore
|
||||
|
||||
ssl_config = SSLConfig(trust_env=False)
|
||||
context = httpx.create_ssl_context(trust_env=False)
|
||||
|
||||
assert ssl_config.ssl_context.keylog_filename is None # type: ignore
|
||||
assert context.keylog_filename is None # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user