Move HSTS preload checking to client (#184)

This commit is contained in:
Seth Michael Larson 2019-08-01 04:26:45 -05:00 committed by Tom Christie
parent 919e8d3f9b
commit 9142a893ff
4 changed files with 18 additions and 19 deletions

View File

@ -2,6 +2,8 @@ import inspect
import typing
from types import TracebackType
import hstspreload
from .auth import HTTPBasicAuth
from .concurrency import AsyncioBackend
from .config import (
@ -105,6 +107,12 @@ class BaseClient:
self.concurrency_backend = backend
self.trust_env = trust_env
def merge_url(self, url: URLTypes) -> URL:
url = self.base_url.join(relative_url=url)
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
url = url.copy_with(scheme="https")
return url
def merge_cookies(
self, cookies: CookieTypes = None
) -> typing.Optional[CookieTypes]:
@ -564,7 +572,7 @@ class AsyncClient(BaseClient):
timeout: TimeoutTypes = None,
trust_env: bool = True,
) -> AsyncResponse:
url = self.base_url.join(url)
url = self.merge_url(url)
headers = self.merge_headers(headers)
cookies = self.merge_cookies(cookies)
request = AsyncRequest(
@ -648,7 +656,7 @@ class Client(BaseClient):
timeout: TimeoutTypes = None,
trust_env: bool = True,
) -> Response:
url = self.base_url.join(url)
url = self.merge_url(url)
headers = self.merge_headers(headers)
cookies = self.merge_cookies(cookies)
request = AsyncRequest(

View File

@ -8,7 +8,6 @@ from http.cookiejar import Cookie, CookieJar
from urllib.parse import parse_qsl, urlencode
import chardet
import hstspreload
import rfc3986
from .config import USER_AGENT
@ -109,14 +108,6 @@ class URL:
if not self.host:
raise InvalidURL("No host included in URL.")
# If the URL is HTTP but the host is on the HSTS preload list switch to HTTPS.
if (
self.scheme == "http"
and self.host
and hstspreload.in_hsts_preload(self.host)
):
self._uri_reference = self._uri_reference.copy_with(scheme="https")
@property
def scheme(self) -> str:
return self._uri_reference.scheme or ""

View File

@ -150,3 +150,11 @@ def test_base_url(server):
response = http.get("/")
assert response.status_code == 200
assert str(response.url) == base_url
def test_merge_url():
client = httpx.Client(base_url="https://www.paypal.com/")
url = client.merge_url("http://www.paypal.com")
assert url.scheme == "https"
assert url.is_ssl

View File

@ -176,11 +176,3 @@ def test_url_set():
url_set = set(urls)
assert all(url in urls for url in url_set)
def test_hsts_preload_converted_to_https():
url = URL("http://www.paypal.com")
assert url.is_ssl
assert url.scheme == "https"
assert url == "https://www.paypal.com"