This commit is contained in:
Tom Christie 2019-04-26 17:00:47 +01:00
parent fab6fcd397
commit 8a29a0a1ad
16 changed files with 432 additions and 67 deletions

View File

@ -14,7 +14,7 @@ from .exceptions import (
)
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
from .models import URL, Origin, Request, Response
from .models import URL, Headers, Origin, Request, Response
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
from .sync import SyncClient, SyncConnectionPool

View File

@ -10,7 +10,7 @@ class Adapter:
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
**options: typing.Any,
) -> Response:

View File

@ -15,4 +15,4 @@ class AuthAdapter(Adapter):
return await self.dispatch.send(request, **options)
async def close(self) -> None:
self.dispatch.close()
await self.dispatch.close()

View File

@ -40,7 +40,7 @@ class Client:
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
allow_redirects: bool = True,
ssl: typing.Optional[SSLConfig] = None,
@ -61,7 +61,7 @@ class Client:
self,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
@ -75,7 +75,7 @@ class Client:
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,

View File

@ -15,4 +15,4 @@ class CookieAdapter(Adapter):
return await self.dispatch.send(request, **options)
async def close(self) -> None:
self.dispatch.close()
await self.dispatch.close()

View File

@ -110,17 +110,17 @@ class MultiDecoder(Decoder):
SUPPORTED_DECODERS = {
b"identity": IdentityDecoder,
b"deflate": DeflateDecoder,
b"gzip": GZipDecoder,
b"br": BrotliDecoder,
"identity": IdentityDecoder,
"deflate": DeflateDecoder,
"gzip": GZipDecoder,
"br": BrotliDecoder,
}
if brotli is None:
SUPPORTED_DECODERS.pop(b"br") # pragma: nocover
SUPPORTED_DECODERS.pop("br") # pragma: nocover
ACCEPT_ENCODING = b", ".join(
[key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
ACCEPT_ENCODING = ", ".join(
[key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
)

View File

@ -52,3 +52,8 @@ class ResponseClosed(Exception):
Attempted to read or stream response content, but the request has been
closed without loading the body.
"""
class InvalidURL(Exception):
"""
"""

View File

@ -51,7 +51,7 @@ class HTTP11Connection(Adapter):
#  Start sending the request.
method = request.method.encode()
target = request.url.full_path
headers = request.headers
headers = request.headers.raw
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)

View File

@ -96,7 +96,7 @@ class HTTP2Connection(Adapter):
(b":authority", request.url.hostname.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers
] + request.headers.raw
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)

View File

@ -39,6 +39,10 @@ class URL:
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def hostname(self) -> str:
return self.components.hostname
@ -87,32 +91,140 @@ class Origin:
return hash((self.is_ssl, self.hostname, self.port))
class Headers(typing.MutableMapping[str, str]):
"""
A case-insensitive multidict.
"""
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
self._list = [(k.lower(), v) for k, v in headers]
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._list
def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
except KeyError:
return default
def getlist(self, key: str) -> typing.List[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
class Request:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
self.url = URL(url) if isinstance(url, str) else url
self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
self.headers = self._auto_headers() + self.headers
self.headers = self.build_headers(headers)
def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def build_headers(
self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
) -> Headers:
has_host = False
has_content_length = False
has_accept_encoding = False
for header, value in self.headers:
for header, value in init_headers:
header = header.strip().lower()
if header == b"host":
has_host = True
@ -121,20 +233,20 @@ class Request:
elif header == b"accept-encoding":
has_accept_encoding = True
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
if not has_host:
headers.append((b"host", self.url.netloc.encode("ascii")))
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if self.is_streaming:
headers.append((b"transfer-encoding", b"chunked"))
auto_headers.append((b"transfer-encoding", b"chunked"))
elif self.body:
content_length = str(len(self.body)).encode()
headers.append((b"content-length", content_length))
auto_headers.append((b"content-length", content_length))
if not has_accept_encoding:
headers.append((b"accept-encoding", ACCEPT_ENCODING))
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
return headers
return Headers(auto_headers + init_headers)
async def stream(self) -> typing.AsyncIterator[bytes]:
if self.is_streaming:
@ -151,7 +263,7 @@ class Response:
*,
reason: typing.Optional[str] = None,
protocol: typing.Optional[str] = None,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
):
@ -164,18 +276,17 @@ class Response:
else:
self.reason = reason
self.protocol = protocol
self.headers = list(headers)
self.headers = Headers(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
decoders = [] # type: typing.List[Decoder]
for header, value in self.headers:
if header.strip().lower() == b"content-encoding":
for part in value.split(b","):
part = part.strip().lower()
decoder_cls = SUPPORTED_DECODERS[part]
decoders.append(decoder_cls())
value = self.headers.get("content-encoding", "identity")
for part in value.split(","):
part = part.strip().lower()
decoder_cls = SUPPORTED_DECODERS[part]
decoders.append(decoder_cls())
if len(decoders) == 0:
self.decoder = IdentityDecoder() # type: Decoder

View File

@ -1,8 +1,11 @@
import typing
from urllib.parse import urljoin, urlparse
from .adapters import Adapter
from .exceptions import TooManyRedirects
from .models import Request, Response
from .models import URL, Request, Response
from .status_codes import codes
from .utils import requote_uri
class RedirectAdapter(Adapter):
@ -29,7 +32,55 @@ class RedirectAdapter(Adapter):
return response
async def close(self) -> None:
self.dispatch.close()
await self.dispatch.close()
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
raise NotImplementedError()
def redirect_method(self, request: Request, response: Response) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes["see_other"] and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# First, turn 302s into GETs.
if response.status_code == codes["found"] and method != "HEAD":
method = "GET"
# Second, if a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in Issue 1704.
if response.status_code == codes["moved"] and method == "POST":
method = "GET"
return method
def redirect_url(self, request: Request, response: Response) -> URL:
location = response.headers["Location"]
# Handle redirection without scheme (see: RFC 1808 Section 4)
if location.startswith("//"):
location = f"{request.url.scheme}:{location}"
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
parsed = urlparse(location)
if parsed.fragment == "" and request.url.fragment:
parsed = parsed._replace(fragment=request.url.fragment)
url = parsed.geturl()
# Facilitate relative 'location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
# Compliant with RFC3986, we percent encode the url.
if not parsed.netloc:
url = urljoin(str(request.url), requote_uri(url))
else:
url = requote_uri(url)
return URL(url)

123
httpcore/status_codes.py Normal file
View File

@ -0,0 +1,123 @@
"""
The ``codes`` object defines a mapping from common names for HTTP statuses
to their numerical codes, accessible either as attributes or as dictionary
items.
>>> requests.codes['temporary_redirect']
307
>>> requests.codes.teapot
418
Some codes have multiple names, and both upper- and lower-case versions of
the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
``codes.okay`` all correspond to the HTTP status code 200.
"""
import typing
from .structures import LookupDict
_codes = {
# Informational.
100: ("continue",),
101: ("switching_protocols",),
102: ("processing",),
103: ("checkpoint",),
122: ("uri_too_long", "request_uri_too_long"),
200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", ""),
201: ("created",),
202: ("accepted",),
203: ("non_authoritative_info", "non_authoritative_information"),
204: ("no_content",),
205: ("reset_content", "reset"),
206: ("partial_content", "partial"),
207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
208: ("already_reported",),
226: ("im_used",),
# Redirection.
300: ("multiple_choices",),
301: ("moved_permanently", "moved", "\\o-"),
302: ("found",),
303: ("see_other", "other"),
304: ("not_modified",),
305: ("use_proxy",),
306: ("switch_proxy",),
307: ("temporary_redirect", "temporary_moved", "temporary"),
308: (
"permanent_redirect",
"resume_incomplete",
"resume",
), # These 2 to be removed in 3.0
# Client Error.
400: ("bad_request", "bad"),
401: ("unauthorized",),
402: ("payment_required", "payment"),
403: ("forbidden",),
404: ("not_found", "-o-"),
405: ("method_not_allowed", "not_allowed"),
406: ("not_acceptable",),
407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
408: ("request_timeout", "timeout"),
409: ("conflict",),
410: ("gone",),
411: ("length_required",),
412: ("precondition_failed", "precondition"),
413: ("request_entity_too_large",),
414: ("request_uri_too_large",),
415: ("unsupported_media_type", "unsupported_media", "media_type"),
416: (
"requested_range_not_satisfiable",
"requested_range",
"range_not_satisfiable",
),
417: ("expectation_failed",),
418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
421: ("misdirected_request",),
422: ("unprocessable_entity", "unprocessable"),
423: ("locked",),
424: ("failed_dependency", "dependency"),
425: ("unordered_collection", "unordered"),
426: ("upgrade_required", "upgrade"),
428: ("precondition_required", "precondition"),
429: ("too_many_requests", "too_many"),
431: ("header_fields_too_large", "fields_too_large"),
444: ("no_response", "none"),
449: ("retry_with", "retry"),
450: ("blocked_by_windows_parental_controls", "parental_controls"),
451: ("unavailable_for_legal_reasons", "legal_reasons"),
499: ("client_closed_request",),
# Server Error.
500: ("internal_server_error", "server_error", "/o\\", ""),
501: ("not_implemented",),
502: ("bad_gateway",),
503: ("service_unavailable", "unavailable"),
504: ("gateway_timeout",),
505: ("http_version_not_supported", "http_version"),
506: ("variant_also_negotiates",),
507: ("insufficient_storage",),
509: ("bandwidth_limit_exceeded", "bandwidth"),
510: ("not_extended",),
511: ("network_authentication_required", "network_auth", "network_authentication"),
} # type: typing.Dict[int, typing.Sequence[str]]
codes = LookupDict(name="status_codes")
def _init() -> None:
for code, titles in _codes.items():
for title in titles:
setattr(codes, title, code)
if not title.startswith(("\\", "/")):
setattr(codes, title.upper(), code)
def doc(code: int) -> str:
names = ", ".join("``%s``" % n for n in _codes[code])
return "* %d: %s" % (code, names)
global __doc__
__doc__ = (
__doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes))
if __doc__ is not None
else None
)
_init()

20
httpcore/structures.py Normal file
View File

@ -0,0 +1,20 @@
import typing
class LookupDict(dict):
"""Dictionary lookup object."""
def __init__(self, name: str = None) -> None:
self.name = name
super(LookupDict, self).__init__()
def __repr__(self) -> str:
return "<lookup '%s'>" % (self.name)
def __getitem__(self, key: typing.Any) -> typing.Any:
# We allow fall-through here, so values default to None
return self.__dict__.get(key, None)
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
return self.__dict__.get(key, default)

View File

@ -5,7 +5,7 @@ from types import TracebackType
from .adapters import Adapter
from .config import SSLConfig, TimeoutConfig
from .connection_pool import ConnectionPool
from .models import URL, Response
from .models import URL, Headers, Response
class SyncResponse:
@ -22,7 +22,7 @@ class SyncResponse:
return self._response.reason
@property
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
def headers(self) -> Headers:
return self._response.headers
@property
@ -54,7 +54,7 @@ class SyncClient:
method: str,
url: typing.Union[str, URL],
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
**options: typing.Any
) -> SyncResponse:

52
httpcore/utils.py Normal file
View File

@ -0,0 +1,52 @@
from urllib.parse import quote
from .exceptions import InvalidURL
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)
def unquote_unreserved(uri: str) -> str:
"""Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
:rtype: str
"""
parts = uri.split("%")
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError:
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = "%" + parts[i]
else:
parts[i] = "%" + parts[i]
return "".join(parts)
def requote_uri(uri: str) -> str:
"""
Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)

View File

@ -5,19 +5,20 @@ import httpcore
def test_host_header():
request = httpcore.Request("GET", "http://example.org")
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
]
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
)
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
assert request.headers == [
(b"host", b"example.org"),
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"content-length", b"8"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
def test_transfer_encoding_header():
@ -27,31 +28,31 @@ def test_transfer_encoding_header():
body = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", body=body)
assert request.headers == [
(b"host", b"example.org"),
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"transfer-encoding", b"chunked"),
(b"accept-encoding", b"deflate, gzip, br"),
]
)
def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"accept-encoding", b"deflate, gzip, br"),
(b"host", b"1.2.3.4:80"),
]
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
)
def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"identity"),
]
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
)
def test_override_content_length_header():
@ -62,11 +63,13 @@ def test_override_content_length_header():
headers = [(b"content-length", b"8")]
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
assert request.headers == [
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
(b"accept-encoding", b"deflate, gzip, br"),
(b"content-length", b"8"),
]
)
def test_url():