Move utility functions from _utils.py to _client.py (#3389)

This commit is contained in:
RafaelWO 2024-11-15 12:42:52 +01:00 committed by GitHub
parent b47d94c904
commit 7b19cd5f4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 92 additions and 100 deletions

View File

@ -45,12 +45,7 @@ from ._types import (
TimeoutTypes,
)
from ._urls import URL, QueryParams
from ._utils import (
URLPattern,
get_environment_proxies,
is_https_redirect,
same_origin,
)
from ._utils import URLPattern, get_environment_proxies
if typing.TYPE_CHECKING:
import ssl # pragma: no cover
@ -63,6 +58,38 @@ T = typing.TypeVar("T", bound="Client")
U = typing.TypeVar("U", bound="AsyncClient")
def _is_https_redirect(url: URL, location: URL) -> bool:
"""
Return 'True' if 'location' is a HTTPS upgrade of 'url'
"""
if url.host != location.host:
return False
return (
url.scheme == "http"
and _port_or_default(url) == 80
and location.scheme == "https"
and _port_or_default(location) == 443
)
def _port_or_default(url: URL) -> int | None:
if url.port is not None:
return url.port
return {"http": 80, "https": 443}.get(url.scheme)
def _same_origin(url: URL, other: URL) -> bool:
"""
Return 'True' if the given URLs share the same origin.
"""
return (
url.scheme == other.scheme
and url.host == other.host
and _port_or_default(url) == _port_or_default(other)
)
class UseClientDefault:
"""
For some parameters such as `auth=...` and `timeout=...` we need to be able
@ -521,8 +548,8 @@ class BaseClient:
"""
headers = Headers(request.headers)
if not same_origin(url, request.url):
if not is_https_redirect(request.url, url):
if not _same_origin(url, request.url):
if not _is_https_redirect(request.url, url):
# Strip Authorization headers when responses are redirected
# away from the origin. (Except for direct HTTP to HTTPS redirects.)
headers.pop("Authorization", None)

View File

@ -27,38 +27,6 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return str(value)
def port_or_default(url: URL) -> int | None:
if url.port is not None:
return url.port
return {"http": 80, "https": 443}.get(url.scheme)
def same_origin(url: URL, other: URL) -> bool:
"""
Return 'True' if the given URLs share the same origin.
"""
return (
url.scheme == other.scheme
and url.host == other.host
and port_or_default(url) == port_or_default(other)
)
def is_https_redirect(url: URL, location: URL) -> bool:
"""
Return 'True' if 'location' is a HTTPS upgrade of 'url'
"""
if url.host != location.host:
return False
return (
url.scheme == "http"
and port_or_default(url) == 80
and location.scheme == "https"
and port_or_default(location) == 443
)
def get_environment_proxies() -> dict[str, str | None]:
"""Gets proxy information from the environment"""

View File

@ -235,3 +235,59 @@ def test_host_with_non_default_port_in_url():
def test_request_auto_headers():
request = httpx.Request("GET", "https://www.example.org/")
assert "host" in request.headers
def test_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")
client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")
assert headers["Host"] == request.url.netloc.decode("ascii")
def test_not_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")
client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")
assert headers["Host"] == origin.netloc.decode("ascii")
def test_is_https_redirect():
url = httpx.URL("https://example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" in headers
def test_is_not_https_redirect():
url = httpx.URL("https://www.example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" not in headers
def test_is_not_https_redirect_if_not_default_ports():
url = httpx.URL("https://example.com:1337")
request = httpx.Request(
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" not in headers

View File

@ -6,10 +6,7 @@ import random
import pytest
import httpx
from httpx._utils import (
URLPattern,
get_environment_proxies,
)
from httpx._utils import URLPattern, get_environment_proxies
@pytest.mark.parametrize(
@ -115,62 +112,6 @@ def test_get_environment_proxies(environment, proxies):
assert get_environment_proxies() == proxies
def test_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")
client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")
assert headers["Host"] == request.url.netloc.decode("ascii")
def test_not_same_origin():
origin = httpx.URL("https://example.com")
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")
client = httpx.Client()
headers = client._redirect_headers(request, origin, "GET")
assert headers["Host"] == origin.netloc.decode("ascii")
def test_is_https_redirect():
url = httpx.URL("https://example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" in headers
def test_is_not_https_redirect():
url = httpx.URL("https://www.example.com")
request = httpx.Request(
"GET", "http://example.com", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" not in headers
def test_is_not_https_redirect_if_not_default_ports():
url = httpx.URL("https://example.com:1337")
request = httpx.Request(
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
)
client = httpx.Client()
headers = client._redirect_headers(request, url, "GET")
assert "Authorization" not in headers
@pytest.mark.parametrize(
["pattern", "url", "expected"],
[