Use rfc3986 for URL handling

This commit is contained in:
Tom Christie 2019-04-30 17:39:40 +01:00
parent fd9823e59c
commit 302fe93df9
10 changed files with 94 additions and 103 deletions

View File

@ -8,6 +8,7 @@ from .dispatch.http11 import HTTP11Connection
from .exceptions import (
ConnectTimeout,
DecodingError,
InvalidURL,
PoolTimeout,
ProtocolError,
ReadTimeout,

View File

@ -1,12 +1,10 @@
import typing
from urllib.parse import urljoin, urlparse
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
from ..models import URL, Headers, Request, Response
from ..status_codes import codes
from ..utils import requote_uri
class RedirectAdapter(Adapter):
@ -98,25 +96,19 @@ class RedirectAdapter(Adapter):
"""
location = response.headers["Location"]
# Handle redirection without scheme (see: RFC 1808 Section 4)
if location.startswith("//"):
location = f"{request.url.scheme}:{location}"
# Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
parsed = urlparse(location)
if parsed.fragment == "" and request.url.fragment:
parsed = parsed._replace(fragment=request.url.fragment)
url = parsed.geturl()
url = URL(location, allow_relative=True)
# Facilitate relative 'location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
# Compliant with RFC3986, we percent encode the url.
if not parsed.netloc:
url = urljoin(str(request.url), requote_uri(url))
else:
url = requote_uri(url)
if not url.is_absolute:
url = url.resolve_with(request.url.copy_with(fragment=None))
return URL(url)
# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)
return url
def redirect_headers(self, request: Request, url: URL) -> Headers:
"""

View File

@ -57,7 +57,7 @@ class HTTPConnection(Adapter):
assert isinstance(ssl, SSLConfig)
assert isinstance(timeout, TimeoutConfig)
hostname = self.origin.hostname
host = self.origin.host
port = self.origin.port
ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
@ -66,7 +66,7 @@ class HTTPConnection(Adapter):
else:
on_release = functools.partial(self.release_func, self)
reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
reader, writer, protocol = await connect(host, port, ssl_context, timeout)
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
else:

View File

@ -99,7 +99,7 @@ class HTTP2Connection(Adapter):
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
(b":method", request.method.encode()),
(b":authority", request.url.hostname.encode()),
(b":authority", request.url.host.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
] + request.headers.raw

View File

@ -1,8 +1,8 @@
import cgi
import typing
from urllib.parse import urlsplit
import chardet
import rfc3986
from .config import SSLConfig, TimeoutConfig
from .decoders import (
@ -12,7 +12,7 @@ from .decoders import (
IdentityDecoder,
MultiDecoder,
)
from .exceptions import ResponseClosed, ResponseNotRead, StreamConsumed
from .exceptions import InvalidURL, ResponseClosed, ResponseNotRead, StreamConsumed
from .status_codes import codes
from .utils import (
get_reason_phrase,
@ -33,42 +33,45 @@ ByteOrByteStream = typing.Union[bytes, typing.AsyncIterator[bytes]]
class URL:
def __init__(self, url: URLTypes) -> None:
def __init__(self, url: URLTypes, allow_relative: bool = False) -> None:
if isinstance(url, str):
self.components = urlsplit(url)
self.components = rfc3986.api.uri_reference(url).normalize()
elif isinstance(url, rfc3986.uri.URIReference):
self.components = url
else:
self.components = url.components
if not self.components.scheme:
raise ValueError("No scheme included in URL.")
if self.components.scheme not in ("http", "https"):
raise ValueError('URL scheme must be "http" or "https".')
if not self.components.hostname:
raise ValueError("No hostname included in URL.")
if not allow_relative:
if not self.scheme:
raise InvalidURL("No scheme included in URL.")
if self.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
if not self.host:
raise InvalidURL("No hostname included in URL.")
@property
def scheme(self) -> str:
return self.components.scheme
return self.components.scheme or ""
@property
def netloc(self) -> str:
return self.components.netloc
def authority(self) -> str:
return self.components.authority or ""
@property
def path(self) -> str:
return self.components.path
return self.components.path or "/"
@property
def query(self) -> str:
return self.components.query
return self.components.query or ""
@property
def fragment(self) -> str:
return self.components.fragment
return self.components.fragment or ""
@property
def hostname(self) -> str:
return self.components.hostname
def host(self) -> str:
return self.components.host or ""
@property
def port(self) -> int:
@ -89,10 +92,22 @@ class URL:
def is_secure(self) -> bool:
return self.components.scheme == "https"
@property
def is_absolute(self) -> bool:
return self.components.is_absolute()
@property
def origin(self) -> "Origin":
return Origin(self)
def copy_with(self, **kwargs: typing.Any) -> "URL":
return URL(self.components.copy_with(**kwargs))
def resolve_with(self, base_url: URLTypes) -> "URL":
if isinstance(base_url, URL):
base_url = base_url.components
return URL(self.components.resolve_with(base_url))
def __hash__(self) -> int:
return hash(str(self))
@ -100,7 +115,7 @@ class URL:
return isinstance(other, URL) and str(self) == str(other)
def __str__(self) -> str:
return self.components.geturl()
return self.components.unsplit()
def __repr__(self) -> str:
class_name = self.__class__.__name__
@ -109,23 +124,23 @@ class URL:
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
if isinstance(url, str):
def __init__(self, url: URLTypes) -> None:
if not isinstance(url, URL):
url = URL(url)
self.is_ssl = url.scheme == "https"
self.hostname = url.hostname.lower()
self.host = url.host
self.port = url.port
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, self.__class__)
and self.is_ssl == other.is_ssl
and self.hostname == other.hostname
and self.host == other.host
and self.port == other.port
)
def __hash__(self) -> int:
return hash((self.is_ssl, self.hostname, self.port))
return hash((self.is_ssl, self.host, self.port))
class Headers(typing.MutableMapping[str, str]):
@ -365,8 +380,8 @@ class Request:
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_host:
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_host and self.url.authority:
auto_headers.append((b"host", self.url.authority.encode("ascii")))
if not has_content_length:
if self.is_streaming:
auto_headers.append((b"transfer-encoding", b"chunked"))

View File

@ -1,58 +1,6 @@
import codecs
import http
import typing
from urllib.parse import quote
from .exceptions import InvalidURL
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)
def unquote_unreserved(uri: str) -> str:
"""
Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
"""
parts = uri.split("%")
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError:
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = "%" + parts[i]
else:
parts[i] = "%" + parts[i]
return "".join(parts)
def requote_uri(uri: str) -> str:
"""
Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)
def normalize_header_key(value: typing.AnyStr, encoding: str = None) -> bytes:

View File

@ -2,6 +2,7 @@ certifi
chardet
h11
h2
rfc3986
# Optional
brotlipy

View File

@ -47,7 +47,7 @@ setup(
author_email="tom@tomchristie.com",
packages=get_packages("httpcore"),
data_files=[("", ["LICENSE.md"])],
install_requires=["h11", "h2", "certifi", "chardet"],
install_requires=["h11", "h2", "certifi", "chardet", "rfc3986"],
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",

View File

@ -93,11 +93,11 @@ def test_url():
def test_invalid_urls():
with pytest.raises(ValueError):
with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "example.org")
with pytest.raises(ValueError):
with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "invalid://example.org")
with pytest.raises(ValueError):
with pytest.raises(httpcore.InvalidURL):
httpcore.Request("GET", "http:///foo")

View File

@ -37,11 +37,13 @@ def test_response_default_encoding():
def test_response_set_explicit_encoding():
headers = {"Content-Type": "text-plain; charset=utf-8"} # Deliberately incorrect charset
headers = {
"Content-Type": "text-plain; charset=utf-8"
} # Deliberately incorrect charset
response = httpcore.Response(
200, content="Latin 1: ÿ".encode("latin-1"), headers=headers
)
response.encoding = 'latin-1'
response.encoding = "latin-1"
assert response.text == "Latin 1: ÿ"
assert response.encoding == "latin-1"
@ -71,6 +73,38 @@ async def test_read_response():
assert response.is_closed
@pytest.mark.asyncio
async def test_raw_interface():
response = httpcore.Response(200, content=b"Hello, world!")
raw = b""
async for part in response.raw():
raw += part
assert raw == b"Hello, world!"
@pytest.mark.asyncio
async def test_stream_interface():
response = httpcore.Response(200, content=b"Hello, world!")
content = b""
async for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_stream_interface_after_read():
response = httpcore.Response(200, content=b"Hello, world!")
await response.read()
content = b""
async for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
async def test_streaming_response():
response = httpcore.Response(200, content=streaming_body())