Move utility functions from _utils.py to _client.py (#3389)
This commit is contained in:
parent
b47d94c904
commit
7b19cd5f4b
@ -45,12 +45,7 @@ from ._types import (
|
||||
TimeoutTypes,
|
||||
)
|
||||
from ._urls import URL, QueryParams
|
||||
from ._utils import (
|
||||
URLPattern,
|
||||
get_environment_proxies,
|
||||
is_https_redirect,
|
||||
same_origin,
|
||||
)
|
||||
from ._utils import URLPattern, get_environment_proxies
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
@ -63,6 +58,38 @@ T = typing.TypeVar("T", bound="Client")
|
||||
U = typing.TypeVar("U", bound="AsyncClient")
|
||||
|
||||
|
||||
def _is_https_redirect(url: URL, location: URL) -> bool:
|
||||
"""
|
||||
Return 'True' if 'location' is a HTTPS upgrade of 'url'
|
||||
"""
|
||||
if url.host != location.host:
|
||||
return False
|
||||
|
||||
return (
|
||||
url.scheme == "http"
|
||||
and _port_or_default(url) == 80
|
||||
and location.scheme == "https"
|
||||
and _port_or_default(location) == 443
|
||||
)
|
||||
|
||||
|
||||
def _port_or_default(url: URL) -> int | None:
|
||||
if url.port is not None:
|
||||
return url.port
|
||||
return {"http": 80, "https": 443}.get(url.scheme)
|
||||
|
||||
|
||||
def _same_origin(url: URL, other: URL) -> bool:
|
||||
"""
|
||||
Return 'True' if the given URLs share the same origin.
|
||||
"""
|
||||
return (
|
||||
url.scheme == other.scheme
|
||||
and url.host == other.host
|
||||
and _port_or_default(url) == _port_or_default(other)
|
||||
)
|
||||
|
||||
|
||||
class UseClientDefault:
|
||||
"""
|
||||
For some parameters such as `auth=...` and `timeout=...` we need to be able
|
||||
@ -521,8 +548,8 @@ class BaseClient:
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
|
||||
if not same_origin(url, request.url):
|
||||
if not is_https_redirect(request.url, url):
|
||||
if not _same_origin(url, request.url):
|
||||
if not _is_https_redirect(request.url, url):
|
||||
# Strip Authorization headers when responses are redirected
|
||||
# away from the origin. (Except for direct HTTP to HTTPS redirects.)
|
||||
headers.pop("Authorization", None)
|
||||
|
||||
@ -27,38 +27,6 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
def port_or_default(url: URL) -> int | None:
|
||||
if url.port is not None:
|
||||
return url.port
|
||||
return {"http": 80, "https": 443}.get(url.scheme)
|
||||
|
||||
|
||||
def same_origin(url: URL, other: URL) -> bool:
|
||||
"""
|
||||
Return 'True' if the given URLs share the same origin.
|
||||
"""
|
||||
return (
|
||||
url.scheme == other.scheme
|
||||
and url.host == other.host
|
||||
and port_or_default(url) == port_or_default(other)
|
||||
)
|
||||
|
||||
|
||||
def is_https_redirect(url: URL, location: URL) -> bool:
|
||||
"""
|
||||
Return 'True' if 'location' is a HTTPS upgrade of 'url'
|
||||
"""
|
||||
if url.host != location.host:
|
||||
return False
|
||||
|
||||
return (
|
||||
url.scheme == "http"
|
||||
and port_or_default(url) == 80
|
||||
and location.scheme == "https"
|
||||
and port_or_default(location) == 443
|
||||
)
|
||||
|
||||
|
||||
def get_environment_proxies() -> dict[str, str | None]:
|
||||
"""Gets proxy information from the environment"""
|
||||
|
||||
|
||||
@ -235,3 +235,59 @@ def test_host_with_non_default_port_in_url():
|
||||
def test_request_auto_headers():
|
||||
request = httpx.Request("GET", "https://www.example.org/")
|
||||
assert "host" in request.headers
|
||||
|
||||
|
||||
def test_same_origin():
|
||||
origin = httpx.URL("https://example.com")
|
||||
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, origin, "GET")
|
||||
|
||||
assert headers["Host"] == request.url.netloc.decode("ascii")
|
||||
|
||||
|
||||
def test_not_same_origin():
|
||||
origin = httpx.URL("https://example.com")
|
||||
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, origin, "GET")
|
||||
|
||||
assert headers["Host"] == origin.netloc.decode("ascii")
|
||||
|
||||
|
||||
def test_is_https_redirect():
|
||||
url = httpx.URL("https://example.com")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" in headers
|
||||
|
||||
|
||||
def test_is_not_https_redirect():
|
||||
url = httpx.URL("https://www.example.com")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" not in headers
|
||||
|
||||
|
||||
def test_is_not_https_redirect_if_not_default_ports():
|
||||
url = httpx.URL("https://example.com:1337")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" not in headers
|
||||
|
||||
@ -6,10 +6,7 @@ import random
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
from httpx._utils import (
|
||||
URLPattern,
|
||||
get_environment_proxies,
|
||||
)
|
||||
from httpx._utils import URLPattern, get_environment_proxies
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -115,62 +112,6 @@ def test_get_environment_proxies(environment, proxies):
|
||||
assert get_environment_proxies() == proxies
|
||||
|
||||
|
||||
def test_same_origin():
|
||||
origin = httpx.URL("https://example.com")
|
||||
request = httpx.Request("GET", "HTTPS://EXAMPLE.COM:443")
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, origin, "GET")
|
||||
|
||||
assert headers["Host"] == request.url.netloc.decode("ascii")
|
||||
|
||||
|
||||
def test_not_same_origin():
|
||||
origin = httpx.URL("https://example.com")
|
||||
request = httpx.Request("GET", "HTTP://EXAMPLE.COM:80")
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, origin, "GET")
|
||||
|
||||
assert headers["Host"] == origin.netloc.decode("ascii")
|
||||
|
||||
|
||||
def test_is_https_redirect():
|
||||
url = httpx.URL("https://example.com")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" in headers
|
||||
|
||||
|
||||
def test_is_not_https_redirect():
|
||||
url = httpx.URL("https://www.example.com")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" not in headers
|
||||
|
||||
|
||||
def test_is_not_https_redirect_if_not_default_ports():
|
||||
url = httpx.URL("https://example.com:1337")
|
||||
request = httpx.Request(
|
||||
"GET", "http://example.com:9999", headers={"Authorization": "empty"}
|
||||
)
|
||||
|
||||
client = httpx.Client()
|
||||
headers = client._redirect_headers(request, url, "GET")
|
||||
|
||||
assert "Authorization" not in headers
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["pattern", "url", "expected"],
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user