Fix Headers.update to correctly handle repeated headers (#2038)

This commit is contained in:
Adrian Garcia Badaracco 2022-01-21 08:35:10 -08:00 committed by GitHub
parent 8dc9b6bd59
commit 321d4aa509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 2 deletions

View File

@ -964,8 +964,10 @@ class Headers(typing.MutableMapping[str, str]):
def update(self, headers: HeaderTypes = None) -> None: # type: ignore
headers = Headers(headers)
for key, value in headers.raw:
self[key.decode(headers.encoding)] = value.decode(headers.encoding)
for key in headers.keys():
if key in self:
self.pop(key)
self._list.extend(headers._list)
def copy(self) -> "Headers":
return Headers(self, encoding=self.encoding)

View File

@ -10,6 +10,16 @@ def echo_headers(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json=data)
def echo_repeated_headers_multi_items(request: httpx.Request) -> httpx.Response:
data = {"headers": list(request.headers.multi_items())}
return httpx.Response(200, json=data)
def echo_repeated_headers_items(request: httpx.Request) -> httpx.Response:
data = {"headers": list(request.headers.items())}
return httpx.Response(200, json=data)
def test_client_header():
"""
Set a header in the Client.
@ -110,6 +120,35 @@ def test_header_update():
}
def test_header_repeated_items():
url = "http://example.org/echo_headers"
client = httpx.Client(transport=httpx.MockTransport(echo_repeated_headers_items))
response = client.get(url, headers=[("x-header", "1"), ("x-header", "2,3")])
assert response.status_code == 200
echoed_headers = response.json()["headers"]
# as per RFC 7230, the whitespace after a comma is insignificant
# so we split and strip here so that we can do a safe comparison
assert ["x-header", ["1", "2", "3"]] in [
[k, [subv.lstrip() for subv in v.split(",")]] for k, v in echoed_headers
]
def test_header_repeated_multi_items():
url = "http://example.org/echo_headers"
client = httpx.Client(
transport=httpx.MockTransport(echo_repeated_headers_multi_items)
)
response = client.get(url, headers=[("x-header", "1"), ("x-header", "2,3")])
assert response.status_code == 200
echoed_headers = response.json()["headers"]
assert ["x-header", "1"] in echoed_headers
assert ["x-header", "2,3"] in echoed_headers
def test_remove_default_header():
"""
Remove a default header from the Client.