Preserve header casing (#1338)
This commit is contained in:
parent
0eed6a3734
commit
c725387b2d
105
httpx/_models.py
105
httpx/_models.py
@ -525,27 +525,28 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
|
||||
def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
|
||||
if headers is None:
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
|
||||
elif isinstance(headers, Headers):
|
||||
self._list = list(headers.raw)
|
||||
self._list = list(headers._list)
|
||||
elif isinstance(headers, dict):
|
||||
self._list = [
|
||||
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
|
||||
(
|
||||
normalize_header_key(k, lower=False, encoding=encoding),
|
||||
normalize_header_key(k, lower=True, encoding=encoding),
|
||||
normalize_header_value(v, encoding),
|
||||
)
|
||||
for k, v in headers.items()
|
||||
]
|
||||
else:
|
||||
self._list = [
|
||||
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
|
||||
(
|
||||
normalize_header_key(k, lower=False, encoding=encoding),
|
||||
normalize_header_key(k, lower=True, encoding=encoding),
|
||||
normalize_header_value(v, encoding),
|
||||
)
|
||||
for k, v in headers
|
||||
]
|
||||
|
||||
self._dict = {} # type: typing.Dict[bytes, bytes]
|
||||
for key, value in self._list:
|
||||
if key in self._dict:
|
||||
self._dict[key] = self._dict[key] + b", " + value
|
||||
else:
|
||||
self._dict[key] = value
|
||||
|
||||
self._encoding = encoding
|
||||
|
||||
@property
|
||||
@ -582,25 +583,36 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
Returns a list of the raw header items, as byte pairs.
|
||||
"""
|
||||
return list(self._list)
|
||||
return [(raw_key, value) for raw_key, _, value in self._list]
|
||||
|
||||
def keys(self) -> typing.KeysView[str]:
|
||||
return {key.decode(self.encoding): None for key in self._dict.keys()}.keys()
|
||||
return {key.decode(self.encoding): None for _, key, value in self._list}.keys()
|
||||
|
||||
def values(self) -> typing.ValuesView[str]:
|
||||
return {
|
||||
key: value.decode(self.encoding) for key, value in self._dict.items()
|
||||
}.values()
|
||||
values_dict: typing.Dict[str, str] = {}
|
||||
for _, key, value in self._list:
|
||||
str_key = key.decode(self.encoding)
|
||||
str_value = value.decode(self.encoding)
|
||||
if str_key in values_dict:
|
||||
values_dict[str_key] += f", {str_value}"
|
||||
else:
|
||||
values_dict[str_key] = str_value
|
||||
return values_dict.values()
|
||||
|
||||
def items(self) -> typing.ItemsView[str, str]:
|
||||
"""
|
||||
Return `(key, value)` items of headers. Concatenate headers
|
||||
into a single comma seperated value when a key occurs multiple times.
|
||||
"""
|
||||
return {
|
||||
key.decode(self.encoding): value.decode(self.encoding)
|
||||
for key, value in self._dict.items()
|
||||
}.items()
|
||||
values_dict: typing.Dict[str, str] = {}
|
||||
for _, key, value in self._list:
|
||||
str_key = key.decode(self.encoding)
|
||||
str_value = value.decode(self.encoding)
|
||||
if str_key in values_dict:
|
||||
values_dict[str_key] += f", {str_value}"
|
||||
else:
|
||||
values_dict[str_key] = str_value
|
||||
return values_dict.items()
|
||||
|
||||
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
|
||||
"""
|
||||
@ -610,7 +622,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
return [
|
||||
(key.decode(self.encoding), value.decode(self.encoding))
|
||||
for key, value in self._list
|
||||
for _, key, value in self._list
|
||||
]
|
||||
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
@ -633,8 +645,8 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
|
||||
values = [
|
||||
item_value.decode(self.encoding)
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
for _, item_key, item_value in self._list
|
||||
if item_key.lower() == get_header_key
|
||||
]
|
||||
|
||||
if not split_commas:
|
||||
@ -647,11 +659,11 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
|
||||
def update(self, headers: HeaderTypes = None) -> None: # type: ignore
|
||||
headers = Headers(headers)
|
||||
for header in headers:
|
||||
self[header] = headers[header]
|
||||
for key, value in headers.raw:
|
||||
self[key.decode(headers.encoding)] = value.decode(headers.encoding)
|
||||
|
||||
def copy(self) -> "Headers":
|
||||
return Headers(dict(self.items()), encoding=self.encoding)
|
||||
return Headers(self, encoding=self.encoding)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
"""
|
||||
@ -663,7 +675,7 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
normalized_key = key.lower().encode(self.encoding)
|
||||
|
||||
items = []
|
||||
for header_key, header_value in self._list:
|
||||
for _, header_key, header_value in self._list:
|
||||
if header_key == normalized_key:
|
||||
items.append(header_value.decode(self.encoding))
|
||||
|
||||
@ -677,14 +689,13 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
"""
|
||||
set_key = key.lower().encode(self._encoding or "utf-8")
|
||||
set_key = key.encode(self._encoding or "utf-8")
|
||||
set_value = value.encode(self._encoding or "utf-8")
|
||||
|
||||
self._dict[set_key] = set_value
|
||||
lookup_key = set_key.lower()
|
||||
|
||||
found_indexes = []
|
||||
for idx, (item_key, _) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
for idx, (_, item_key, _) in enumerate(self._list):
|
||||
if item_key == lookup_key:
|
||||
found_indexes.append(idx)
|
||||
|
||||
for idx in reversed(found_indexes[1:]):
|
||||
@ -692,9 +703,9 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
|
||||
if found_indexes:
|
||||
idx = found_indexes[0]
|
||||
self._list[idx] = (set_key, set_value)
|
||||
self._list[idx] = (set_key, lookup_key, set_value)
|
||||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
self._list.append((set_key, lookup_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
@ -702,19 +713,20 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
del_key = key.lower().encode(self.encoding)
|
||||
|
||||
del self._dict[del_key]
|
||||
|
||||
pop_indexes = []
|
||||
for idx, (item_key, _) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
for idx, (_, item_key, _) in enumerate(self._list):
|
||||
if item_key.lower() == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
if not pop_indexes:
|
||||
raise KeyError(key)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
header_key = key.lower().encode(self.encoding)
|
||||
return header_key in self._dict
|
||||
return header_key in [key for _, key, _ in self._list]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
@ -727,7 +739,10 @@ class Headers(typing.MutableMapping[str, str]):
|
||||
other_headers = Headers(other)
|
||||
except ValueError:
|
||||
return False
|
||||
return sorted(self._list) == sorted(other_headers._list)
|
||||
|
||||
self_list = [(key, value) for _, key, value in self._list]
|
||||
other_list = [(key, value) for _, key, value in other_headers._list]
|
||||
return sorted(self_list) == sorted(other_list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@ -793,15 +808,15 @@ class Request:
|
||||
def _prepare(self, default_headers: typing.Dict[str, str]) -> None:
|
||||
for key, value in default_headers.items():
|
||||
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
|
||||
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
|
||||
if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
|
||||
continue
|
||||
self.headers.setdefault(key, value)
|
||||
|
||||
auto_headers: typing.List[typing.Tuple[bytes, bytes]] = []
|
||||
|
||||
has_host = "host" in self.headers
|
||||
has_host = "Host" in self.headers
|
||||
has_content_length = (
|
||||
"content-length" in self.headers or "transfer-encoding" in self.headers
|
||||
"Content-Length" in self.headers or "Transfer-Encoding" in self.headers
|
||||
)
|
||||
|
||||
if not has_host and self.url.host:
|
||||
@ -810,9 +825,9 @@ class Request:
|
||||
host_header = self.url.host.encode("ascii")
|
||||
else:
|
||||
host_header = self.url.netloc.encode("ascii")
|
||||
auto_headers.append((b"host", host_header))
|
||||
auto_headers.append((b"Host", host_header))
|
||||
if not has_content_length and self.method in ("POST", "PUT", "PATCH"):
|
||||
auto_headers.append((b"content-length", b"0"))
|
||||
auto_headers.append((b"Content-Length", b"0"))
|
||||
|
||||
self.headers = Headers(auto_headers + self.headers.raw)
|
||||
|
||||
|
||||
@ -30,14 +30,19 @@ _HTML5_FORM_ENCODING_RE = re.compile(
|
||||
|
||||
|
||||
def normalize_header_key(
|
||||
value: typing.Union[str, bytes], encoding: str = None
|
||||
value: typing.Union[str, bytes],
|
||||
lower: bool,
|
||||
encoding: str = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Coerce str/bytes into a strictly byte-wise HTTP header key.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
return value.lower()
|
||||
return value.encode(encoding or "ascii").lower()
|
||||
bytes_value = value
|
||||
else:
|
||||
bytes_value = value.encode(encoding or "ascii")
|
||||
|
||||
return bytes_value.lower() if lower else bytes_value
|
||||
|
||||
|
||||
def normalize_header_value(
|
||||
|
||||
@ -271,3 +271,32 @@ def test_client_closed_state_using_with_block():
|
||||
assert client.is_closed
|
||||
with pytest.raises(RuntimeError):
|
||||
client.get("http://example.com")
|
||||
|
||||
|
||||
def echo_raw_headers(request: httpx.Request) -> httpx.Response:
|
||||
data = [
|
||||
(name.decode("ascii"), value.decode("ascii"))
|
||||
for name, value in request.headers.raw
|
||||
]
|
||||
return httpx.Response(200, json=data)
|
||||
|
||||
|
||||
def test_raw_client_header():
|
||||
"""
|
||||
Set a header in the Client.
|
||||
"""
|
||||
url = "http://example.org/echo_headers"
|
||||
headers = {"Example-Header": "example-value"}
|
||||
|
||||
client = httpx.Client(transport=MockTransport(echo_raw_headers), headers=headers)
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
["Host", "example.org"],
|
||||
["Accept", "*/*"],
|
||||
["Accept-Encoding", "gzip, deflate, br"],
|
||||
["Connection", "keep-alive"],
|
||||
["User-Agent", f"python-httpx/{httpx.__version__}"],
|
||||
["Example-Header", "example-value"],
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user