Move HSTS preload checking to client (#184)
This commit is contained in:
parent
919e8d3f9b
commit
9142a893ff
@ -2,6 +2,8 @@ import inspect
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
import hstspreload
|
||||
|
||||
from .auth import HTTPBasicAuth
|
||||
from .concurrency import AsyncioBackend
|
||||
from .config import (
|
||||
@ -105,6 +107,12 @@ class BaseClient:
|
||||
self.concurrency_backend = backend
|
||||
self.trust_env = trust_env
|
||||
|
||||
def merge_url(self, url: URLTypes) -> URL:
|
||||
url = self.base_url.join(relative_url=url)
|
||||
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
|
||||
url = url.copy_with(scheme="https")
|
||||
return url
|
||||
|
||||
def merge_cookies(
|
||||
self, cookies: CookieTypes = None
|
||||
) -> typing.Optional[CookieTypes]:
|
||||
@ -564,7 +572,7 @@ class AsyncClient(BaseClient):
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = True,
|
||||
) -> AsyncResponse:
|
||||
url = self.base_url.join(url)
|
||||
url = self.merge_url(url)
|
||||
headers = self.merge_headers(headers)
|
||||
cookies = self.merge_cookies(cookies)
|
||||
request = AsyncRequest(
|
||||
@ -648,7 +656,7 @@ class Client(BaseClient):
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = True,
|
||||
) -> Response:
|
||||
url = self.base_url.join(url)
|
||||
url = self.merge_url(url)
|
||||
headers = self.merge_headers(headers)
|
||||
cookies = self.merge_cookies(cookies)
|
||||
request = AsyncRequest(
|
||||
|
||||
@ -8,7 +8,6 @@ from http.cookiejar import Cookie, CookieJar
|
||||
from urllib.parse import parse_qsl, urlencode
|
||||
|
||||
import chardet
|
||||
import hstspreload
|
||||
import rfc3986
|
||||
|
||||
from .config import USER_AGENT
|
||||
@ -109,14 +108,6 @@ class URL:
|
||||
if not self.host:
|
||||
raise InvalidURL("No host included in URL.")
|
||||
|
||||
# If the URL is HTTP but the host is on the HSTS preload list switch to HTTPS.
|
||||
if (
|
||||
self.scheme == "http"
|
||||
and self.host
|
||||
and hstspreload.in_hsts_preload(self.host)
|
||||
):
|
||||
self._uri_reference = self._uri_reference.copy_with(scheme="https")
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self._uri_reference.scheme or ""
|
||||
|
||||
@ -150,3 +150,11 @@ def test_base_url(server):
|
||||
response = http.get("/")
|
||||
assert response.status_code == 200
|
||||
assert str(response.url) == base_url
|
||||
|
||||
|
||||
def test_merge_url():
|
||||
client = httpx.Client(base_url="https://www.paypal.com/")
|
||||
url = client.merge_url("http://www.paypal.com")
|
||||
|
||||
assert url.scheme == "https"
|
||||
assert url.is_ssl
|
||||
|
||||
@ -176,11 +176,3 @@ def test_url_set():
|
||||
url_set = set(urls)
|
||||
|
||||
assert all(url in urls for url in url_set)
|
||||
|
||||
|
||||
def test_hsts_preload_converted_to_https():
|
||||
url = URL("http://www.paypal.com")
|
||||
|
||||
assert url.is_ssl
|
||||
assert url.scheme == "https"
|
||||
assert url == "https://www.paypal.com"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user