Preserve forwarded client ports in proxy headers middleware (#2903)
Co-authored-by: takeda <411978+takeda@users.noreply.github.com>
This commit is contained in:
parent
77843e06dc
commit
18edfa7012
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
import httpx._transports.asgi
|
||||
import pytest
|
||||
import websockets.client
|
||||
|
||||
@ -30,6 +31,9 @@ async def default_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend
|
||||
client_addr = "NONE" # pragma: no cover
|
||||
else:
|
||||
host, port = client
|
||||
with contextlib.suppress(ValueError):
|
||||
if ipaddress.ip_address(host).version == 6:
|
||||
host = f"[{host}]"
|
||||
client_addr = f"{host}:{port}"
|
||||
|
||||
response = Response(f"{scheme}://{client_addr}", media_type="text/plain")
|
||||
@ -426,6 +430,31 @@ async def test_proxy_headers_multiple_proxies(trusted_hosts: str | list[str], ex
|
||||
assert response.text == expected
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
("trusted_hosts", "expected"),
|
||||
[
|
||||
# always trust
|
||||
("*", "https://1.2.3.4:1234"),
|
||||
# all proxies are trusted
|
||||
(["127.0.0.1", "2001:db8::1", "192.168.0.2"], "https://1.2.3.4:1234"),
|
||||
# should set first untrusted as remote address
|
||||
(["192.168.0.2", "127.0.0.1"], "https://[2001:db8::1]:8080"),
|
||||
# Mixed literals and networks
|
||||
(["127.0.0.1", "2001:db8::/32", "192.168.0.2"], "https://1.2.3.4:1234"),
|
||||
],
|
||||
)
|
||||
async def test_proxy_headers_multiple_proxies_with_ports(trusted_hosts: str | list[str], expected: str) -> None:
|
||||
async with make_httpx_client(trusted_hosts) as client:
|
||||
headers = {
|
||||
X_FORWARDED_FOR: "1.2.3.4:1234, [2001:db8::1]:8080, 192.168.0.2:9000",
|
||||
X_FORWARDED_PROTO: "https",
|
||||
}
|
||||
response = await client.get("/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.text == expected
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_proxy_headers_invalid_x_forwarded_for() -> None:
|
||||
async with make_httpx_client("*") as client:
|
||||
@ -441,6 +470,38 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
|
||||
assert response.text == "https://1.2.3.4:0"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
("forwarded_for", "expected"),
|
||||
[
|
||||
# IPv4 without port
|
||||
("1.2.3.4", "https://1.2.3.4:0"),
|
||||
# IPv4 with port
|
||||
("1.2.3.4:1234", "https://1.2.3.4:1234"),
|
||||
# Bracketed IPv6 with port
|
||||
("[2001:db8::1]:443", "https://[2001:db8::1]:443"),
|
||||
# Bracketed IPv6 without port
|
||||
("[2001:db8::1]", "https://[2001:db8::1]:0"),
|
||||
# Bare IPv6 without port
|
||||
("2001:db8::1", "https://[2001:db8::1]:0"),
|
||||
# Invalid IPv4 port falls back to the original host value
|
||||
("1.2.3.4:notaport", "https://1.2.3.4:notaport:0"),
|
||||
# Invalid bracketed IPv6 port keeps the host and drops the port
|
||||
("[2001:db8::1]:notaport", "https://[2001:db8::1]:0"),
|
||||
# Trailing data after a bracketed IPv6 host is left untouched
|
||||
("[2001:db8::1]extra", "https://[2001:db8::1]extra:0"),
|
||||
# Malformed bracket is left untouched
|
||||
("[2001:db8::1", "https://[2001:db8::1:0"),
|
||||
],
|
||||
)
|
||||
async def test_proxy_headers_x_forwarded_for_port_shapes(forwarded_for: str, expected: str) -> None:
|
||||
async with make_httpx_client("*") as client:
|
||||
headers = {X_FORWARDED_FOR: forwarded_for, X_FORWARDED_PROTO: "https"}
|
||||
response = await client.get("/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.text == expected
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"forwarded_proto,expected",
|
||||
|
||||
@ -45,16 +45,12 @@ class ProxyHeadersMiddleware:
|
||||
|
||||
if b"x-forwarded-for" in headers:
|
||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
|
||||
host, port = self.trusted_hosts.get_trusted_client_address(x_forwarded_for)
|
||||
|
||||
if host:
|
||||
# If the x-forwarded-for header is empty then host is an empty string.
|
||||
# Only set the client if we actually got something usable.
|
||||
# See: https://github.com/Kludex/uvicorn/issues/1068
|
||||
|
||||
# We've lost the connecting client's port information by now,
|
||||
# so only include the host.
|
||||
port = 0
|
||||
scope["client"] = (host, port)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
@ -64,6 +60,41 @@ def _parse_raw_hosts(value: str) -> list[str]:
|
||||
return [item.strip() for item in value.split(",")]
|
||||
|
||||
|
||||
def _parse_host_port(value: str) -> tuple[str, int]:
|
||||
"""Parse a forwarded host value into host and optional port.
|
||||
|
||||
Accepts bare IPs, IPv4 `host:port`, and bracketed IPv6 `[host]:port`.
|
||||
Any unrecognized or malformed value is treated conservatively and returned
|
||||
without a port so trust checks do not silently normalize arbitrary input.
|
||||
"""
|
||||
|
||||
if value.startswith("["):
|
||||
bracket_end = value.find("]")
|
||||
if bracket_end == -1:
|
||||
return value, 0
|
||||
|
||||
host = value[1:bracket_end]
|
||||
remainder = value[bracket_end + 1 :]
|
||||
if not remainder:
|
||||
return host, 0
|
||||
if not remainder.startswith(":"):
|
||||
return value, 0
|
||||
|
||||
try:
|
||||
return host, int(remainder[1:])
|
||||
except ValueError:
|
||||
return host, 0
|
||||
|
||||
if value.count(":") == 1:
|
||||
host, port = value.rsplit(":", 1)
|
||||
try:
|
||||
return host, int(port)
|
||||
except ValueError:
|
||||
return value, 0
|
||||
|
||||
return value, 0
|
||||
|
||||
|
||||
class _TrustedHosts:
|
||||
"""Container for trusted hosts and networks"""
|
||||
|
||||
@ -122,21 +153,22 @@ class _TrustedHosts:
|
||||
except ValueError:
|
||||
return host in self.trusted_literals
|
||||
|
||||
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
|
||||
"""Extract the client host from x_forwarded_for header
|
||||
def get_trusted_client_address(self, x_forwarded_for: str) -> tuple[str, int]:
|
||||
"""Extract the client address from x_forwarded_for header.
|
||||
|
||||
In general this is the first "untrusted" host in the forwarded for list.
|
||||
"""
|
||||
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
|
||||
|
||||
if self.always_trust:
|
||||
return x_forwarded_for_hosts[0]
|
||||
return _parse_host_port(x_forwarded_for_hosts[0])
|
||||
|
||||
# Note: each proxy appends to the header list so check it in reverse order
|
||||
for host in reversed(x_forwarded_for_hosts):
|
||||
for host_port in reversed(x_forwarded_for_hosts):
|
||||
host, port = _parse_host_port(host_port)
|
||||
if host not in self:
|
||||
return host
|
||||
return host, port
|
||||
|
||||
# All hosts are trusted meaning that the client was also a trusted proxy
|
||||
# See https://github.com/Kludex/uvicorn/issues/1068#issuecomment-855371576
|
||||
return x_forwarded_for_hosts[0]
|
||||
return _parse_host_port(x_forwarded_for_hosts[0])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user