Rejig request preparing

This commit is contained in:
Tom Christie 2019-04-29 14:25:24 +01:00
parent 041813bc4a
commit 4f23ab4f0d
8 changed files with 146 additions and 45 deletions

View File

@ -4,7 +4,7 @@ from urllib.parse import urljoin, urlparse
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
from ..models import URL, Request, Response
from ..models import URL, Headers, Request, Response
from ..status_codes import codes
from ..utils import requote_uri
@ -19,11 +19,12 @@ class RedirectAdapter(Adapter):
async def send(self, request: Request, **options: typing.Any) -> Response:
allow_redirects = options.pop("allow_redirects", True)
history = []
history = [] # type: typing.List[Response]
seen_urls = set((request.url,))
while True:
response = await self.dispatch.send(request, **options)
response.history = list(history)
if not allow_redirects or not response.is_redirect:
break
history.append(response)
@ -42,7 +43,8 @@ class RedirectAdapter(Adapter):
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
return Request(method=method, url=url)
headers = self.redirect_headers(request, url)
return Request(method=method, url=url, headers=headers)
def redirect_method(self, request: Request, response: Response) -> str:
"""
@ -89,3 +91,9 @@ class RedirectAdapter(Adapter):
url = requote_uri(url)
return URL(url)
def redirect_headers(self, request: Request, url: URL) -> Headers:
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
return headers

View File

@ -37,7 +37,7 @@ class HTTPConnection(Adapter):
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
def prepare_request(self, request: Request) -> None:
pass
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
if self.h11_connection is None and self.h2_connection is None:

View File

@ -10,6 +10,7 @@ from ..config import (
SSLConfig,
TimeoutConfig,
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
from ..interfaces import Adapter
from ..models import Origin, Request, Response
@ -104,7 +105,7 @@ class ConnectionPool(Adapter):
return len(self.keepalive_connections) + len(self.active_connections)
def prepare_request(self, request: Request) -> None:
pass
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
connection = await self.acquire_connection(request.url.origin)

View File

@ -46,7 +46,7 @@ class HTTP11Connection(Adapter):
self.h11_state = h11.Connection(our_role=h11.CLIENT)
def prepare_request(self, request: Request) -> None:
pass
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
@ -87,6 +87,7 @@ class HTTP11Connection(Adapter):
headers=headers,
body=body,
on_close=self.response_closed,
request=request,
)
if not stream:

View File

@ -32,7 +32,7 @@ class HTTP2Connection(Adapter):
self.initialized = False
def prepare_request(self, request: Request) -> None:
pass
request.prepare()
async def send(self, request: Request, **options: typing.Any) -> Response:
timeout = options.get("timeout")
@ -75,6 +75,7 @@ class HTTP2Connection(Adapter):
headers=headers,
body=body,
on_close=on_close,
request=request,
)
if not stream:

View File

@ -79,6 +79,11 @@ class URL:
def __str__(self) -> str:
return self.components.geturl()
def __repr__(self) -> str:
class_name = self.__class__.__name__
url_str = str(self)
return f"{class_name}({url_str!r})"
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
@ -100,13 +105,21 @@ class Origin:
return hash((self.is_ssl, self.hostname, self.port))
HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]]
class Headers(typing.MutableMapping[str, str]):
"""
A case-insensitive multidict.
"""
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
self._list = [(k.lower(), v) for k, v in headers]
def __init__(self, headers: HeaderTypes = None) -> None:
if headers is None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
elif isinstance(headers, Headers):
self._list = list(headers.raw)
else:
self._list = [(k.lower(), v) for k, v in headers]
@property
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
@ -213,7 +226,7 @@ class Request:
method: str,
url: typing.Union[str, URL],
*,
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
headers: HeaderTypes = None,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
@ -224,26 +237,24 @@ class Request:
else:
self.is_streaming = True
self.body_aiter = body
self.headers = self.build_headers(headers)
self.headers = Headers(headers)
def build_headers(
self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
) -> Headers:
has_host = False
has_content_length = False
has_accept_encoding = False
for header, value in init_headers:
header = header.strip().lower()
if header == b"host":
has_host = True
elif header in (b"content-length", b"transfer-encoding"):
has_content_length = True
elif header == b"accept-encoding":
has_accept_encoding = True
async def stream(self) -> typing.AsyncIterator[bytes]:
if self.is_streaming:
async for part in self.body_aiter:
yield part
elif self.body:
yield self.body
def prepare(self) -> None:
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
has_host = "host" in self.headers
has_content_length = (
"content-length" in self.headers or "transfer-encoding" in self.headers
)
has_accept_encoding = "accept-encoding" in self.headers
if not has_host:
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
@ -255,14 +266,8 @@ class Request:
if not has_accept_encoding:
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
return Headers(auto_headers + init_headers)
async def stream(self) -> typing.AsyncIterator[bytes]:
if self.is_streaming:
async for part in self.body_aiter:
yield part
elif self.body:
yield self.body
for item in reversed(auto_headers):
self.headers.raw.insert(0, item)
class Response:
@ -275,6 +280,8 @@ class Response:
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
request: Request = None,
history: typing.List["Response"] = None,
):
self.status_code = status_code
if not reason:
@ -310,6 +317,13 @@ class Response:
else:
self.body_aiter = body
self.request = request
self.history = [] if history is None else list(history)
@property
def url(self) -> typing.Optional[URL]:
return None if self.request is None else self.request.url
async def read(self) -> bytes:
"""
Read and return the response content.
@ -358,4 +372,6 @@ class Response:
@property
def is_redirect(self) -> bool:
return self.status_code in (301, 302, 303, 307, 308)
return (
self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
)

View File

@ -1,8 +1,10 @@
import json
from urllib.parse import parse_qs
import pytest
from httpcore import (
URL,
Adapter,
RedirectAdapter,
RedirectLoop,
@ -19,32 +21,60 @@ class MockDispatch(Adapter):
async def send(self, request: Request, **options) -> Response:
if request.url.path == "/redirect_301": # "Moved Permanently"
return Response(301, headers=[(b"location", b"https://example.org/")])
return Response(
301, headers=[(b"location", b"https://example.org/")], request=request
)
elif request.url.path == "/redirect_302": # "Found"
return Response(302, headers=[(b"location", b"https://example.org/")])
return Response(
302, headers=[(b"location", b"https://example.org/")], request=request
)
elif request.url.path == "/redirect_303": # "See Other"
return Response(303, headers=[(b"location", b"https://example.org/")])
return Response(
303, headers=[(b"location", b"https://example.org/")], request=request
)
elif request.url.path == "/relative_redirect":
return Response(codes.see_other, headers=[(b"location", b"/")])
return Response(
codes.see_other, headers=[(b"location", b"/")], request=request
)
elif request.url.path == "/no_scheme_redirect":
return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
return Response(
codes.see_other,
headers=[(b"location", b"//example.org/")],
request=request,
)
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
count = int(params.get("count", "0")[0])
redirect_count = count - 1
code = codes.see_other if count else codes.ok
location = "/multiple_redirects?count=" + str(count - 1)
location = "/multiple_redirects"
if redirect_count:
location += "?count=" + str(redirect_count)
headers = [(b"location", location.encode())] if count else []
return Response(code, headers=headers)
return Response(code, headers=headers, request=request)
if request.url.path == "/redirect_loop":
return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
return Response(
codes.see_other,
headers=[(b"location", b"/redirect_loop")],
request=request,
)
return Response(codes.ok, body=b"Hello, world!")
elif request.url.path == "/cross_domain":
location = b"https://example.org/cross_domain_target"
return Response(301, headers=[(b"location", location)], request=request)
elif request.url.path == "/cross_domain_target":
headers = {k.decode(): v.decode() for k, v in request.headers.raw}
body = json.dumps({"headers": headers}).encode()
return Response(codes.ok, body=body, request=request)
return Response(codes.ok, body=b"Hello, world!", request=request)
@pytest.mark.asyncio
@ -52,6 +82,8 @@ async def test_redirect_301():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_301")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
@ -59,6 +91,8 @@ async def test_redirect_302():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_302")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
@ -66,6 +100,8 @@ async def test_redirect_303():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/redirect_303")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
@ -73,6 +109,8 @@ async def test_relative_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
@ -80,13 +118,19 @@ async def test_no_scheme_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/no_scheme_redirect")
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_fragment_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect#fragment")
response = await client.request(
"GET", "https://example.org/relative_redirect#fragment"
)
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/#fragment")
assert len(response.history) == 1
@pytest.mark.asyncio
@ -96,6 +140,8 @@ async def test_multiple_redirects():
"GET", "https://example.org/multiple_redirects?count=20"
)
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/multiple_redirects")
assert len(response.history) == 20
@pytest.mark.asyncio
@ -110,3 +156,25 @@ async def test_redirect_loop():
client = RedirectAdapter(MockDispatch())
with pytest.raises(RedirectLoop):
await client.request("GET", "https://example.org/redirect_loop")
@pytest.mark.asyncio
async def test_cross_domain_redirect():
client = RedirectAdapter(MockDispatch())
headers = [(b"Authorization", b"abc")]
url = "https://example.com/cross_domain"
response = await client.request("GET", url, headers=headers)
data = json.loads(response.body.decode())
assert response.url == URL("https://example.org/cross_domain_target")
assert data == {"headers": {}}
@pytest.mark.asyncio
async def test_same_domain_redirect():
client = RedirectAdapter(MockDispatch())
headers = [(b"Authorization", b"abc")]
url = "https://example.org/cross_domain"
response = await client.request("GET", url, headers=headers)
data = json.loads(response.body.decode())
assert response.url == URL("https://example.org/cross_domain_target")
assert data == {"headers": {"authorization": "abc"}}

View File

@ -5,6 +5,7 @@ import httpcore
def test_host_header():
request = httpcore.Request("GET", "http://example.org")
request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
)
@ -12,6 +13,7 @@ def test_host_header():
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
@ -28,6 +30,7 @@ def test_transfer_encoding_header():
body = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", body=body)
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),
@ -41,6 +44,7 @@ def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
)
@ -50,6 +54,7 @@ def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
)
@ -63,6 +68,7 @@ def test_override_content_length_header():
headers = [(b"content-length", b"8")]
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
request.prepare()
assert request.headers == httpcore.Headers(
[
(b"host", b"example.org"),