Rollin'
This commit is contained in:
parent
fab6fcd397
commit
8a29a0a1ad
@ -14,7 +14,7 @@ from .exceptions import (
|
||||
)
|
||||
from .http2 import HTTP2Connection
|
||||
from .http11 import HTTP11Connection
|
||||
from .models import URL, Origin, Request, Response
|
||||
from .models import URL, Headers, Origin, Request, Response
|
||||
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
|
||||
from .sync import SyncClient, SyncConnectionPool
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ class Adapter:
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
**options: typing.Any,
|
||||
) -> Response:
|
||||
|
||||
@ -15,4 +15,4 @@ class AuthAdapter(Adapter):
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
await self.dispatch.close()
|
||||
|
||||
@ -40,7 +40,7 @@ class Client:
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
allow_redirects: bool = True,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
@ -61,7 +61,7 @@ class Client:
|
||||
self,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
@ -75,7 +75,7 @@ class Client:
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
stream: bool = False,
|
||||
ssl: typing.Optional[SSLConfig] = None,
|
||||
timeout: typing.Optional[TimeoutConfig] = None,
|
||||
|
||||
@ -15,4 +15,4 @@ class CookieAdapter(Adapter):
|
||||
return await self.dispatch.send(request, **options)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
await self.dispatch.close()
|
||||
|
||||
@ -110,17 +110,17 @@ class MultiDecoder(Decoder):
|
||||
|
||||
|
||||
SUPPORTED_DECODERS = {
|
||||
b"identity": IdentityDecoder,
|
||||
b"deflate": DeflateDecoder,
|
||||
b"gzip": GZipDecoder,
|
||||
b"br": BrotliDecoder,
|
||||
"identity": IdentityDecoder,
|
||||
"deflate": DeflateDecoder,
|
||||
"gzip": GZipDecoder,
|
||||
"br": BrotliDecoder,
|
||||
}
|
||||
|
||||
|
||||
if brotli is None:
|
||||
SUPPORTED_DECODERS.pop(b"br") # pragma: nocover
|
||||
SUPPORTED_DECODERS.pop("br") # pragma: nocover
|
||||
|
||||
|
||||
ACCEPT_ENCODING = b", ".join(
|
||||
[key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
|
||||
ACCEPT_ENCODING = ", ".join(
|
||||
[key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
|
||||
)
|
||||
|
||||
@ -52,3 +52,8 @@ class ResponseClosed(Exception):
|
||||
Attempted to read or stream response content, but the request has been
|
||||
closed without loading the body.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURL(Exception):
|
||||
"""
|
||||
"""
|
||||
|
||||
@ -51,7 +51,7 @@ class HTTP11Connection(Adapter):
|
||||
# Start sending the request.
|
||||
method = request.method.encode()
|
||||
target = request.url.full_path
|
||||
headers = request.headers
|
||||
headers = request.headers.raw
|
||||
event = h11.Request(method=method, target=target, headers=headers)
|
||||
await self._send_event(event, timeout)
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class HTTP2Connection(Adapter):
|
||||
(b":authority", request.url.hostname.encode()),
|
||||
(b":scheme", request.url.scheme.encode()),
|
||||
(b":path", request.url.full_path.encode()),
|
||||
] + request.headers
|
||||
] + request.headers.raw
|
||||
self.h2_state.send_headers(stream_id, headers)
|
||||
data_to_send = self.h2_state.data_to_send()
|
||||
await self.writer.write(data_to_send, timeout)
|
||||
|
||||
@ -39,6 +39,10 @@ class URL:
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self.components.hostname
|
||||
@ -87,32 +91,140 @@ class Origin:
|
||||
return hash((self.is_ssl, self.hostname, self.port))
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
A case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
|
||||
self._list = [(k.lower(), v) for k, v in headers]
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
return self._list
|
||||
|
||||
def keys(self) -> typing.List[str]: # type: ignore
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self) -> typing.List[str]: # type: ignore
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
for item_key, item_value in self._list
|
||||
if item_key == get_header_key
|
||||
]
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
found_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
found_indexes.append(idx)
|
||||
|
||||
for idx in reversed(found_indexes[1:]):
|
||||
del self._list[idx]
|
||||
|
||||
if found_indexes:
|
||||
idx = found_indexes[0]
|
||||
self._list[idx] = (set_key, set_value)
|
||||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
as_dict = dict(self.items())
|
||||
if len(as_dict) == len(self):
|
||||
return f"{class_name}({as_dict!r})"
|
||||
return f"{class_name}(raw={self.raw!r})"
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
):
|
||||
self.method = method.upper()
|
||||
self.url = URL(url) if isinstance(url, str) else url
|
||||
self.headers = list(headers)
|
||||
if isinstance(body, bytes):
|
||||
self.is_streaming = False
|
||||
self.body = body
|
||||
else:
|
||||
self.is_streaming = True
|
||||
self.body_aiter = body
|
||||
self.headers = self._auto_headers() + self.headers
|
||||
self.headers = self.build_headers(headers)
|
||||
|
||||
def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def build_headers(
|
||||
self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
|
||||
) -> Headers:
|
||||
has_host = False
|
||||
has_content_length = False
|
||||
has_accept_encoding = False
|
||||
|
||||
for header, value in self.headers:
|
||||
for header, value in init_headers:
|
||||
header = header.strip().lower()
|
||||
if header == b"host":
|
||||
has_host = True
|
||||
@ -121,20 +233,20 @@ class Request:
|
||||
elif header == b"accept-encoding":
|
||||
has_accept_encoding = True
|
||||
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
if not has_host:
|
||||
headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
if not has_content_length:
|
||||
if self.is_streaming:
|
||||
headers.append((b"transfer-encoding", b"chunked"))
|
||||
auto_headers.append((b"transfer-encoding", b"chunked"))
|
||||
elif self.body:
|
||||
content_length = str(len(self.body)).encode()
|
||||
headers.append((b"content-length", content_length))
|
||||
auto_headers.append((b"content-length", content_length))
|
||||
if not has_accept_encoding:
|
||||
headers.append((b"accept-encoding", ACCEPT_ENCODING))
|
||||
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
|
||||
|
||||
return headers
|
||||
return Headers(auto_headers + init_headers)
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_streaming:
|
||||
@ -151,7 +263,7 @@ class Response:
|
||||
*,
|
||||
reason: typing.Optional[str] = None,
|
||||
protocol: typing.Optional[str] = None,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
on_close: typing.Callable = None,
|
||||
):
|
||||
@ -164,18 +276,17 @@ class Response:
|
||||
else:
|
||||
self.reason = reason
|
||||
self.protocol = protocol
|
||||
self.headers = list(headers)
|
||||
self.headers = Headers(headers)
|
||||
self.on_close = on_close
|
||||
self.is_closed = False
|
||||
self.is_streamed = False
|
||||
|
||||
decoders = [] # type: typing.List[Decoder]
|
||||
for header, value in self.headers:
|
||||
if header.strip().lower() == b"content-encoding":
|
||||
for part in value.split(b","):
|
||||
part = part.strip().lower()
|
||||
decoder_cls = SUPPORTED_DECODERS[part]
|
||||
decoders.append(decoder_cls())
|
||||
value = self.headers.get("content-encoding", "identity")
|
||||
for part in value.split(","):
|
||||
part = part.strip().lower()
|
||||
decoder_cls = SUPPORTED_DECODERS[part]
|
||||
decoders.append(decoder_cls())
|
||||
|
||||
if len(decoders) == 0:
|
||||
self.decoder = IdentityDecoder() # type: Decoder
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import typing
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from .adapters import Adapter
|
||||
from .exceptions import TooManyRedirects
|
||||
from .models import Request, Response
|
||||
from .models import URL, Request, Response
|
||||
from .status_codes import codes
|
||||
from .utils import requote_uri
|
||||
|
||||
|
||||
class RedirectAdapter(Adapter):
|
||||
@ -29,7 +32,55 @@ class RedirectAdapter(Adapter):
|
||||
return response
|
||||
|
||||
async def close(self) -> None:
|
||||
self.dispatch.close()
|
||||
await self.dispatch.close()
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
raise NotImplementedError()
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes["see_other"] and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# First, turn 302s into GETs.
|
||||
if response.status_code == codes["found"] and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Second, if a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in Issue 1704.
|
||||
if response.status_code == codes["moved"] and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
location = response.headers["Location"]
|
||||
|
||||
# Handle redirection without scheme (see: RFC 1808 Section 4)
|
||||
if location.startswith("//"):
|
||||
location = f"{request.url.scheme}:{location}"
|
||||
|
||||
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
parsed = urlparse(location)
|
||||
if parsed.fragment == "" and request.url.fragment:
|
||||
parsed = parsed._replace(fragment=request.url.fragment)
|
||||
url = parsed.geturl()
|
||||
|
||||
# Facilitate relative 'location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
# Compliant with RFC3986, we percent encode the url.
|
||||
if not parsed.netloc:
|
||||
url = urljoin(str(request.url), requote_uri(url))
|
||||
else:
|
||||
url = requote_uri(url)
|
||||
|
||||
return URL(url)
|
||||
|
||||
123
httpcore/status_codes.py
Normal file
123
httpcore/status_codes.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
The ``codes`` object defines a mapping from common names for HTTP statuses
|
||||
to their numerical codes, accessible either as attributes or as dictionary
|
||||
items.
|
||||
>>> requests.codes['temporary_redirect']
|
||||
307
|
||||
>>> requests.codes.teapot
|
||||
418
|
||||
Some codes have multiple names, and both upper- and lower-case versions of
|
||||
the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
|
||||
``codes.okay`` all correspond to the HTTP status code 200.
|
||||
"""
|
||||
|
||||
import typing
|
||||
|
||||
from .structures import LookupDict
|
||||
|
||||
_codes = {
|
||||
# Informational.
|
||||
100: ("continue",),
|
||||
101: ("switching_protocols",),
|
||||
102: ("processing",),
|
||||
103: ("checkpoint",),
|
||||
122: ("uri_too_long", "request_uri_too_long"),
|
||||
200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", "✓"),
|
||||
201: ("created",),
|
||||
202: ("accepted",),
|
||||
203: ("non_authoritative_info", "non_authoritative_information"),
|
||||
204: ("no_content",),
|
||||
205: ("reset_content", "reset"),
|
||||
206: ("partial_content", "partial"),
|
||||
207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
|
||||
208: ("already_reported",),
|
||||
226: ("im_used",),
|
||||
# Redirection.
|
||||
300: ("multiple_choices",),
|
||||
301: ("moved_permanently", "moved", "\\o-"),
|
||||
302: ("found",),
|
||||
303: ("see_other", "other"),
|
||||
304: ("not_modified",),
|
||||
305: ("use_proxy",),
|
||||
306: ("switch_proxy",),
|
||||
307: ("temporary_redirect", "temporary_moved", "temporary"),
|
||||
308: (
|
||||
"permanent_redirect",
|
||||
"resume_incomplete",
|
||||
"resume",
|
||||
), # These 2 to be removed in 3.0
|
||||
# Client Error.
|
||||
400: ("bad_request", "bad"),
|
||||
401: ("unauthorized",),
|
||||
402: ("payment_required", "payment"),
|
||||
403: ("forbidden",),
|
||||
404: ("not_found", "-o-"),
|
||||
405: ("method_not_allowed", "not_allowed"),
|
||||
406: ("not_acceptable",),
|
||||
407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
|
||||
408: ("request_timeout", "timeout"),
|
||||
409: ("conflict",),
|
||||
410: ("gone",),
|
||||
411: ("length_required",),
|
||||
412: ("precondition_failed", "precondition"),
|
||||
413: ("request_entity_too_large",),
|
||||
414: ("request_uri_too_large",),
|
||||
415: ("unsupported_media_type", "unsupported_media", "media_type"),
|
||||
416: (
|
||||
"requested_range_not_satisfiable",
|
||||
"requested_range",
|
||||
"range_not_satisfiable",
|
||||
),
|
||||
417: ("expectation_failed",),
|
||||
418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
|
||||
421: ("misdirected_request",),
|
||||
422: ("unprocessable_entity", "unprocessable"),
|
||||
423: ("locked",),
|
||||
424: ("failed_dependency", "dependency"),
|
||||
425: ("unordered_collection", "unordered"),
|
||||
426: ("upgrade_required", "upgrade"),
|
||||
428: ("precondition_required", "precondition"),
|
||||
429: ("too_many_requests", "too_many"),
|
||||
431: ("header_fields_too_large", "fields_too_large"),
|
||||
444: ("no_response", "none"),
|
||||
449: ("retry_with", "retry"),
|
||||
450: ("blocked_by_windows_parental_controls", "parental_controls"),
|
||||
451: ("unavailable_for_legal_reasons", "legal_reasons"),
|
||||
499: ("client_closed_request",),
|
||||
# Server Error.
|
||||
500: ("internal_server_error", "server_error", "/o\\", "✗"),
|
||||
501: ("not_implemented",),
|
||||
502: ("bad_gateway",),
|
||||
503: ("service_unavailable", "unavailable"),
|
||||
504: ("gateway_timeout",),
|
||||
505: ("http_version_not_supported", "http_version"),
|
||||
506: ("variant_also_negotiates",),
|
||||
507: ("insufficient_storage",),
|
||||
509: ("bandwidth_limit_exceeded", "bandwidth"),
|
||||
510: ("not_extended",),
|
||||
511: ("network_authentication_required", "network_auth", "network_authentication"),
|
||||
} # type: typing.Dict[int, typing.Sequence[str]]
|
||||
|
||||
codes = LookupDict(name="status_codes")
|
||||
|
||||
|
||||
def _init() -> None:
|
||||
for code, titles in _codes.items():
|
||||
for title in titles:
|
||||
setattr(codes, title, code)
|
||||
if not title.startswith(("\\", "/")):
|
||||
setattr(codes, title.upper(), code)
|
||||
|
||||
def doc(code: int) -> str:
|
||||
names = ", ".join("``%s``" % n for n in _codes[code])
|
||||
return "* %d: %s" % (code, names)
|
||||
|
||||
global __doc__
|
||||
__doc__ = (
|
||||
__doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes))
|
||||
if __doc__ is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
_init()
|
||||
20
httpcore/structures.py
Normal file
20
httpcore/structures.py
Normal file
@ -0,0 +1,20 @@
|
||||
import typing
|
||||
|
||||
|
||||
class LookupDict(dict):
|
||||
"""Dictionary lookup object."""
|
||||
|
||||
def __init__(self, name: str = None) -> None:
|
||||
self.name = name
|
||||
super(LookupDict, self).__init__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<lookup '%s'>" % (self.name)
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
# We allow fall-through here, so values default to None
|
||||
|
||||
return self.__dict__.get(key, None)
|
||||
|
||||
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
return self.__dict__.get(key, default)
|
||||
@ -5,7 +5,7 @@ from types import TracebackType
|
||||
from .adapters import Adapter
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .connection_pool import ConnectionPool
|
||||
from .models import URL, Response
|
||||
from .models import URL, Headers, Response
|
||||
|
||||
|
||||
class SyncResponse:
|
||||
@ -22,7 +22,7 @@ class SyncResponse:
|
||||
return self._response.reason
|
||||
|
||||
@property
|
||||
def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
def headers(self) -> Headers:
|
||||
return self._response.headers
|
||||
|
||||
@property
|
||||
@ -54,7 +54,7 @@ class SyncClient:
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
**options: typing.Any
|
||||
) -> SyncResponse:
|
||||
|
||||
52
httpcore/utils.py
Normal file
52
httpcore/utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
from .exceptions import InvalidURL
|
||||
|
||||
# The unreserved URI characters (RFC 3986)
|
||||
UNRESERVED_SET = frozenset(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
|
||||
)
|
||||
|
||||
|
||||
def unquote_unreserved(uri: str) -> str:
|
||||
"""Un-escape any percent-escape sequences in a URI that are unreserved
|
||||
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
|
||||
:rtype: str
|
||||
"""
|
||||
parts = uri.split("%")
|
||||
for i in range(1, len(parts)):
|
||||
h = parts[i][0:2]
|
||||
if len(h) == 2 and h.isalnum():
|
||||
try:
|
||||
c = chr(int(h, 16))
|
||||
except ValueError:
|
||||
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
|
||||
|
||||
if c in UNRESERVED_SET:
|
||||
parts[i] = c + parts[i][2:]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def requote_uri(uri: str) -> str:
|
||||
"""
|
||||
Re-quote the given URI.
|
||||
|
||||
This function passes the given URI through an unquote/quote cycle to
|
||||
ensure that it is fully and consistently quoted.
|
||||
"""
|
||||
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
|
||||
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
|
||||
try:
|
||||
# Unquote only the unreserved characters
|
||||
# Then quote only illegal characters (do not quote reserved,
|
||||
# unreserved, or '%')
|
||||
return quote(unquote_unreserved(uri), safe=safe_with_percent)
|
||||
except InvalidURL:
|
||||
# We couldn't unquote the given URI, so let's try quoting it, but
|
||||
# there may be unquoted '%'s in the URI. We need to make sure they're
|
||||
# properly quoted so they do not cause issues elsewhere.
|
||||
return quote(uri, safe=safe_without_percent)
|
||||
@ -5,19 +5,20 @@ import httpcore
|
||||
|
||||
def test_host_header():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
|
||||
)
|
||||
|
||||
|
||||
def test_content_length_header():
|
||||
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"content-length", b"8"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_transfer_encoding_header():
|
||||
@ -27,31 +28,31 @@ def test_transfer_encoding_header():
|
||||
body = streaming_body(b"test 123")
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"transfer-encoding", b"chunked"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_override_host_header():
|
||||
headers = [(b"host", b"1.2.3.4:80")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
assert request.headers == [
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"host", b"1.2.3.4:80"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
|
||||
)
|
||||
|
||||
|
||||
def test_override_accept_encoding_header():
|
||||
headers = [(b"accept-encoding", b"identity")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"identity"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
|
||||
)
|
||||
|
||||
|
||||
def test_override_content_length_header():
|
||||
@ -62,11 +63,13 @@ def test_override_content_length_header():
|
||||
headers = [(b"content-length", b"8")]
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
|
||||
assert request.headers == [
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-length", b"8"),
|
||||
]
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
(b"accept-encoding", b"deflate, gzip, br"),
|
||||
(b"content-length", b"8"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_url():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user