Preserve forwarded client ports in proxy headers middleware (#2903)

Co-authored-by: takeda <411978+takeda@users.noreply.github.com>
This commit is contained in:
Marcelo Trylesinski 2026-04-14 11:56:40 +02:00 committed by GitHub
parent 77843e06dc
commit 18edfa7012
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 12 deletions

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import contextlib
import ipaddress
from typing import TYPE_CHECKING
import httpx
import httpx._transports.asgi
import pytest
import websockets.client
@ -30,6 +31,9 @@ async def default_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend
client_addr = "NONE" # pragma: no cover
else:
host, port = client
with contextlib.suppress(ValueError):
if ipaddress.ip_address(host).version == 6:
host = f"[{host}]"
client_addr = f"{host}:{port}"
response = Response(f"{scheme}://{client_addr}", media_type="text/plain")
@ -426,6 +430,31 @@ async def test_proxy_headers_multiple_proxies(trusted_hosts: str | list[str], ex
assert response.text == expected
@pytest.mark.anyio
@pytest.mark.parametrize(
("trusted_hosts", "expected"),
[
# always trust
("*", "https://1.2.3.4:1234"),
# all proxies are trusted
(["127.0.0.1", "2001:db8::1", "192.168.0.2"], "https://1.2.3.4:1234"),
# should set first untrusted as remote address
(["192.168.0.2", "127.0.0.1"], "https://[2001:db8::1]:8080"),
# Mixed literals and networks
(["127.0.0.1", "2001:db8::/32", "192.168.0.2"], "https://1.2.3.4:1234"),
],
)
async def test_proxy_headers_multiple_proxies_with_ports(trusted_hosts: str | list[str], expected: str) -> None:
async with make_httpx_client(trusted_hosts) as client:
headers = {
X_FORWARDED_FOR: "1.2.3.4:1234, [2001:db8::1]:8080, 192.168.0.2:9000",
X_FORWARDED_PROTO: "https",
}
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == expected
@pytest.mark.anyio
async def test_proxy_headers_invalid_x_forwarded_for() -> None:
async with make_httpx_client("*") as client:
@ -441,6 +470,38 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
assert response.text == "https://1.2.3.4:0"
@pytest.mark.anyio
@pytest.mark.parametrize(
("forwarded_for", "expected"),
[
# IPv4 without port
("1.2.3.4", "https://1.2.3.4:0"),
# IPv4 with port
("1.2.3.4:1234", "https://1.2.3.4:1234"),
# Bracketed IPv6 with port
("[2001:db8::1]:443", "https://[2001:db8::1]:443"),
# Bracketed IPv6 without port
("[2001:db8::1]", "https://[2001:db8::1]:0"),
# Bare IPv6 without port
("2001:db8::1", "https://[2001:db8::1]:0"),
# Invalid IPv4 port falls back to the original host value
("1.2.3.4:notaport", "https://1.2.3.4:notaport:0"),
# Invalid bracketed IPv6 port keeps the host and drops the port
("[2001:db8::1]:notaport", "https://[2001:db8::1]:0"),
# Trailing data after a bracketed IPv6 host is left untouched
("[2001:db8::1]extra", "https://[2001:db8::1]extra:0"),
# Malformed bracket is left untouched
("[2001:db8::1", "https://[2001:db8::1:0"),
],
)
async def test_proxy_headers_x_forwarded_for_port_shapes(forwarded_for: str, expected: str) -> None:
async with make_httpx_client("*") as client:
headers = {X_FORWARDED_FOR: forwarded_for, X_FORWARDED_PROTO: "https"}
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == expected
@pytest.mark.anyio
@pytest.mark.parametrize(
"forwarded_proto,expected",

View File

@ -45,16 +45,12 @@ class ProxyHeadersMiddleware:
if b"x-forwarded-for" in headers:
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
host, port = self.trusted_hosts.get_trusted_client_address(x_forwarded_for)
if host:
# If the x-forwarded-for header is empty then host is an empty string.
# Only set the client if we actually got something usable.
# See: https://github.com/Kludex/uvicorn/issues/1068
# We've lost the connecting client's port information by now,
# so only include the host.
port = 0
scope["client"] = (host, port)
return await self.app(scope, receive, send)
@ -64,6 +60,41 @@ def _parse_raw_hosts(value: str) -> list[str]:
return [item.strip() for item in value.split(",")]
def _parse_host_port(value: str) -> tuple[str, int]:
"""Parse a forwarded host value into host and optional port.
Accepts bare IPs, IPv4 `host:port`, and bracketed IPv6 `[host]:port`.
Any unrecognized or malformed value is treated conservatively and returned
without a port so trust checks do not silently normalize arbitrary input.
"""
if value.startswith("["):
bracket_end = value.find("]")
if bracket_end == -1:
return value, 0
host = value[1:bracket_end]
remainder = value[bracket_end + 1 :]
if not remainder:
return host, 0
if not remainder.startswith(":"):
return value, 0
try:
return host, int(remainder[1:])
except ValueError:
return host, 0
if value.count(":") == 1:
host, port = value.rsplit(":", 1)
try:
return host, int(port)
except ValueError:
return value, 0
return value, 0
class _TrustedHosts:
"""Container for trusted hosts and networks"""
@ -122,21 +153,22 @@ class _TrustedHosts:
except ValueError:
return host in self.trusted_literals
def get_trusted_client_host(self, x_forwarded_for: str) -> str:
"""Extract the client host from x_forwarded_for header
def get_trusted_client_address(self, x_forwarded_for: str) -> tuple[str, int]:
"""Extract the client address from x_forwarded_for header.
In general this is the first "untrusted" host in the forwarded for list.
"""
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
if self.always_trust:
return x_forwarded_for_hosts[0]
return _parse_host_port(x_forwarded_for_hosts[0])
# Note: each proxy appends to the header list so check it in reverse order
for host in reversed(x_forwarded_for_hosts):
for host_port in reversed(x_forwarded_for_hosts):
host, port = _parse_host_port(host_port)
if host not in self:
return host
return host, port
# All hosts are trusted meaning that the client was also a trusted proxy
# See https://github.com/Kludex/uvicorn/issues/1068#issuecomment-855371576
return x_forwarded_for_hosts[0]
return _parse_host_port(x_forwarded_for_hosts[0])