Use rfc3986 for URL handling
This commit is contained in:
parent
fd9823e59c
commit
302fe93df9
@ -8,6 +8,7 @@ from .dispatch.http11 import HTTP11Connection
|
||||
from .exceptions import (
|
||||
ConnectTimeout,
|
||||
DecodingError,
|
||||
InvalidURL,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ReadTimeout,
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import typing
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from ..interfaces import Adapter
|
||||
from ..models import URL, Headers, Request, Response
|
||||
from ..status_codes import codes
|
||||
from ..utils import requote_uri
|
||||
|
||||
|
||||
class RedirectAdapter(Adapter):
|
||||
@ -98,25 +96,19 @@ class RedirectAdapter(Adapter):
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
# Handle redirection without scheme (see: RFC 1808 Section 4)
|
||||
if location.startswith("//"):
|
||||
location = f"{request.url.scheme}:{location}"
|
||||
|
||||
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
parsed = urlparse(location)
|
||||
if parsed.fragment == "" and request.url.fragment:
|
||||
parsed = parsed._replace(fragment=request.url.fragment)
|
||||
url = parsed.geturl()
|
||||
url = URL(location, allow_relative=True)
|
||||
|
||||
# Facilitate relative 'location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
# Compliant with RFC3986, we percent encode the url.
|
||||
if not parsed.netloc:
|
||||
url = urljoin(str(request.url), requote_uri(url))
|
||||
else:
|
||||
url = requote_uri(url)
|
||||
if not url.is_absolute:
|
||||
url = url.resolve_with(request.url.copy_with(fragment=None))
|
||||
|
||||
return URL(url)
|
||||
# Attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
if request.url.fragment and not url.fragment:
|
||||
url = url.copy_with(fragment=request.url.fragment)
|
||||
|
||||
return url
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL) -> Headers:
|
||||
"""
|
||||
|
||||
@ -57,7 +57,7 @@ class HTTPConnection(Adapter):
|
||||
assert isinstance(ssl, SSLConfig)
|
||||
assert isinstance(timeout, TimeoutConfig)
|
||||
|
||||
hostname = self.origin.hostname
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
|
||||
|
||||
@ -66,7 +66,7 @@ class HTTPConnection(Adapter):
|
||||
else:
|
||||
on_release = functools.partial(self.release_func, self)
|
||||
|
||||
reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
|
||||
reader, writer, protocol = await connect(host, port, ssl_context, timeout)
|
||||
if protocol == Protocol.HTTP_2:
|
||||
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
|
||||
else:
|
||||
|
||||
@ -99,7 +99,7 @@ class HTTP2Connection(Adapter):
|
||||
stream_id = self.h2_state.get_next_available_stream_id()
|
||||
headers = [
|
||||
(b":method", request.method.encode()),
|
||||
(b":authority", request.url.hostname.encode()),
|
||||
(b":authority", request.url.host.encode()),
|
||||
(b":scheme", request.url.scheme.encode()),
|
||||
(b":path", request.url.full_path.encode()),
|
||||
] + request.headers.raw
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import cgi
|
||||
import typing
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import chardet
|
||||
import rfc3986
|
||||
|
||||
from .config import SSLConfig, TimeoutConfig
|
||||
from .decoders import (
|
||||
@ -12,7 +12,7 @@ from .decoders import (
|
||||
IdentityDecoder,
|
||||
MultiDecoder,
|
||||
)
|
||||
from .exceptions import ResponseClosed, ResponseNotRead, StreamConsumed
|
||||
from .exceptions import InvalidURL, ResponseClosed, ResponseNotRead, StreamConsumed
|
||||
from .status_codes import codes
|
||||
from .utils import (
|
||||
get_reason_phrase,
|
||||
@ -33,42 +33,45 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]]
|
||||
|
||||
|
||||
class URL:
|
||||
def __init__(self, url: URLTypes) -> None:
|
||||
def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
|
||||
if isinstance(url, str):
|
||||
self.components = urlsplit(url)
|
||||
self.components = rfc3986.api.uri_reference(url).normalize()
|
||||
elif isinstance(url, rfc3986.uri.URIReference):
|
||||
self.components = url
|
||||
else:
|
||||
self.components = url.components
|
||||
|
||||
if not self.components.scheme:
|
||||
raise ValueError("No scheme included in URL.")
|
||||
if self.components.scheme not in ("http", "https"):
|
||||
raise ValueError('URL scheme must be "http" or "https".')
|
||||
if not self.components.hostname:
|
||||
raise ValueError("No hostname included in URL.")
|
||||
if not allow_relative:
|
||||
if not self.scheme:
|
||||
raise InvalidURL("No scheme included in URL.")
|
||||
if self.scheme not in ("http", "https"):
|
||||
raise InvalidURL('URL scheme must be "http" or "https".')
|
||||
if not self.host:
|
||||
raise InvalidURL("No hostname included in URL.")
|
||||
|
||||
@property
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
return self.components.scheme or ""
|
||||
|
||||
@property
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
def authority(self) -> str:
|
||||
return self.components.authority or ""
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
return self.components.path or "/"
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
return self.components.query or ""
|
||||
|
||||
@property
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
return self.components.fragment or ""
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self.components.hostname
|
||||
def host(self) -> str:
|
||||
return self.components.host or ""
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
@ -89,10 +92,22 @@ class URL:
|
||||
def is_secure(self) -> bool:
|
||||
return self.components.scheme == "https"
|
||||
|
||||
@property
|
||||
def is_absolute(self) -> bool:
|
||||
return self.components.is_absolute()
|
||||
|
||||
@property
|
||||
def origin(self) -> "Origin":
|
||||
return Origin(self)
|
||||
|
||||
def copy_with(self, **kwargs: typing.Any) -> "URL":
|
||||
return URL(self.components.copy_with(**kwargs))
|
||||
|
||||
def resolve_with(self, base_url: URLTypes) -> "URL":
|
||||
if isinstance(base_url, URL):
|
||||
base_url = base_url.components
|
||||
return URL(self.components.resolve_with(base_url))
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
@ -100,7 +115,7 @@ class URL:
|
||||
return isinstance(other, URL) and str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.components.geturl()
|
||||
return self.components.unsplit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
@ -109,23 +124,23 @@ class URL:
|
||||
|
||||
|
||||
class Origin:
|
||||
def __init__(self, url: typing.Union[str, URL]) -> None:
|
||||
if isinstance(url, str):
|
||||
def __init__(self, url: URLTypes) -> None:
|
||||
if not isinstance(url, URL):
|
||||
url = URL(url)
|
||||
self.is_ssl = url.scheme == "https"
|
||||
self.hostname = url.hostname.lower()
|
||||
self.host = url.host
|
||||
self.port = url.port
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self.is_ssl == other.is_ssl
|
||||
and self.hostname == other.hostname
|
||||
and self.host == other.host
|
||||
and self.port == other.port
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.is_ssl, self.hostname, self.port))
|
||||
return hash((self.is_ssl, self.host, self.port))
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
@ -365,8 +380,8 @@ class Request:
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_host:
|
||||
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
if not has_host and self.url.authority:
|
||||
auto_headers.append((b"host", self.url.authority.encode("ascii")))
|
||||
if not has_content_length:
|
||||
if self.is_streaming:
|
||||
auto_headers.append((b"transfer-encoding", b"chunked"))
|
||||
|
||||
@ -1,58 +1,6 @@
|
||||
import codecs
|
||||
import http
|
||||
import typing
|
||||
from urllib.parse import quote
|
||||
|
||||
from .exceptions import InvalidURL
|
||||
|
||||
# The unreserved URI characters (RFC 3986)
|
||||
UNRESERVED_SET = frozenset(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
|
||||
)
|
||||
|
||||
|
||||
def unquote_unreserved(uri: str) -> str:
|
||||
"""
|
||||
Un-escape any percent-escape sequences in a URI that are unreserved
|
||||
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
|
||||
"""
|
||||
parts = uri.split("%")
|
||||
for i in range(1, len(parts)):
|
||||
h = parts[i][0:2]
|
||||
if len(h) == 2 and h.isalnum():
|
||||
try:
|
||||
c = chr(int(h, 16))
|
||||
except ValueError:
|
||||
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
|
||||
|
||||
if c in UNRESERVED_SET:
|
||||
parts[i] = c + parts[i][2:]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
else:
|
||||
parts[i] = "%" + parts[i]
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def requote_uri(uri: str) -> str:
|
||||
"""
|
||||
Re-quote the given URI.
|
||||
|
||||
This function passes the given URI through an unquote/quote cycle to
|
||||
ensure that it is fully and consistently quoted.
|
||||
"""
|
||||
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
|
||||
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
|
||||
try:
|
||||
# Unquote only the unreserved characters
|
||||
# Then quote only illegal characters (do not quote reserved,
|
||||
# unreserved, or '%')
|
||||
return quote(unquote_unreserved(uri), safe=safe_with_percent)
|
||||
except InvalidURL:
|
||||
# We couldn't unquote the given URI, so let's try quoting it, but
|
||||
# there may be unquoted '%'s in the URI. We need to make sure they're
|
||||
# properly quoted so they do not cause issues elsewhere.
|
||||
return quote(uri, safe=safe_without_percent)
|
||||
|
||||
|
||||
def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:
|
||||
|
||||
@ -2,6 +2,7 @@ certifi
|
||||
chardet
|
||||
h11
|
||||
h2
|
||||
rfc3986
|
||||
|
||||
# Optional
|
||||
brotlipy
|
||||
|
||||
2
setup.py
2
setup.py
@ -47,7 +47,7 @@ setup(
|
||||
author_email="tom@tomchristie.com",
|
||||
packages=get_packages("httpcore"),
|
||||
data_files=[("", ["LICENSE.md"])],
|
||||
install_requires=["h11", "h2", "certifi", "chardet"],
|
||||
install_requires=["h11", "h2", "certifi", "chardet", "rfc3986"],
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Web Environment",
|
||||
|
||||
@ -93,11 +93,11 @@ def test_url():
|
||||
|
||||
|
||||
def test_invalid_urls():
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(httpcore.InvalidURL):
|
||||
httpcore.Request("GET", "example.org")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(httpcore.InvalidURL):
|
||||
httpcore.Request("GET", "invalid://example.org")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(httpcore.InvalidURL):
|
||||
httpcore.Request("GET", "http:///foo")
|
||||
|
||||
@ -37,11 +37,13 @@ def test_response_default_encoding():
|
||||
|
||||
|
||||
def test_response_set_explicit_encoding():
|
||||
headers = {"Content-Type": "text-plain; charset=utf-8"} # Deliberately incorrect charset
|
||||
headers = {
|
||||
"Content-Type": "text-plain; charset=utf-8"
|
||||
} # Deliberately incorrect charset
|
||||
response = httpcore.Response(
|
||||
200, content="Latin 1: ÿ".encode("latin-1"), headers=headers
|
||||
)
|
||||
response.encoding = 'latin-1'
|
||||
response.encoding = "latin-1"
|
||||
assert response.text == "Latin 1: ÿ"
|
||||
assert response.encoding == "latin-1"
|
||||
|
||||
@ -71,6 +73,38 @@ async def test_read_response():
|
||||
assert response.is_closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_interface():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
raw = b""
|
||||
async for part in response.raw():
|
||||
raw += part
|
||||
assert raw == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_interface():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
content = b""
|
||||
async for part in response.stream():
|
||||
content += part
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_interface_after_read():
|
||||
response = httpcore.Response(200, content=b"Hello, world!")
|
||||
|
||||
await response.read()
|
||||
|
||||
content = b""
|
||||
async for part in response.stream():
|
||||
content += part
|
||||
assert content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response():
|
||||
response = httpcore.Response(200, content=streaming_body())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user