Rejig request preparing
This commit is contained in:
parent
041813bc4a
commit
4f23ab4f0d
@ -4,7 +4,7 @@ from urllib.parse import urljoin, urlparse
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import RedirectLoop, TooManyRedirects
|
||||
from ..interfaces import Adapter
|
||||
from ..models import URL, Request, Response
|
||||
from ..models import URL, Headers, Request, Response
|
||||
from ..status_codes import codes
|
||||
from ..utils import requote_uri
|
||||
|
||||
@ -19,11 +19,12 @@ class RedirectAdapter(Adapter):
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
allow_redirects = options.pop("allow_redirects", True)
|
||||
history = []
|
||||
history = [] # type: typing.List[Response]
|
||||
seen_urls = set((request.url,))
|
||||
|
||||
while True:
|
||||
response = await self.dispatch.send(request, **options)
|
||||
response.history = list(history)
|
||||
if not allow_redirects or not response.is_redirect:
|
||||
break
|
||||
history.append(response)
|
||||
@ -42,7 +43,8 @@ class RedirectAdapter(Adapter):
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
return Request(method=method, url=url)
|
||||
headers = self.redirect_headers(request, url)
|
||||
return Request(method=method, url=url, headers=headers)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
@ -89,3 +91,9 @@ class RedirectAdapter(Adapter):
|
||||
url = requote_uri(url)
|
||||
|
||||
return URL(url)
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL) -> Headers:
|
||||
headers = Headers(request.headers)
|
||||
if url.origin != request.url.origin:
|
||||
del headers["Authorization"]
|
||||
return headers
|
||||
|
||||
@ -37,7 +37,7 @@ class HTTPConnection(Adapter):
|
||||
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
if self.h11_connection is None and self.h2_connection is None:
|
||||
|
||||
@ -10,6 +10,7 @@ from ..config import (
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
from ..decoders import ACCEPT_ENCODING
|
||||
from ..exceptions import PoolTimeout
|
||||
from ..interfaces import Adapter
|
||||
from ..models import Origin, Request, Response
|
||||
@ -104,7 +105,7 @@ class ConnectionPool(Adapter):
|
||||
return len(self.keepalive_connections) + len(self.active_connections)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
connection = await self.acquire_connection(request.url.origin)
|
||||
|
||||
@ -46,7 +46,7 @@ class HTTP11Connection(Adapter):
|
||||
self.h11_state = h11.Connection(our_role=h11.CLIENT)
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
@ -87,6 +87,7 @@ class HTTP11Connection(Adapter):
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=self.response_closed,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
|
||||
@ -32,7 +32,7 @@ class HTTP2Connection(Adapter):
|
||||
self.initialized = False
|
||||
|
||||
def prepare_request(self, request: Request) -> None:
|
||||
pass
|
||||
request.prepare()
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
timeout = options.get("timeout")
|
||||
@ -75,6 +75,7 @@ class HTTP2Connection(Adapter):
|
||||
headers=headers,
|
||||
body=body,
|
||||
on_close=on_close,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
|
||||
@ -79,6 +79,11 @@ class URL:
|
||||
def __str__(self) -> str:
|
||||
return self.components.geturl()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
url_str = str(self)
|
||||
return f"{class_name}({url_str!r})"
|
||||
|
||||
|
||||
class Origin:
|
||||
def __init__(self, url: typing.Union[str, URL]) -> None:
|
||||
@ -100,13 +105,21 @@ class Origin:
|
||||
return hash((self.is_ssl, self.hostname, self.port))
|
||||
|
||||
|
||||
HeaderTypes = typing.Union["Headers", typing.List[typing.Tuple[bytes, bytes]]]
|
||||
|
||||
|
||||
class Headers(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
A case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
|
||||
self._list = [(k.lower(), v) for k, v in headers]
|
||||
def __init__(self, headers: HeaderTypes = None) -> None:
|
||||
if headers is None:
|
||||
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif isinstance(headers, Headers):
|
||||
self._list = list(headers.raw)
|
||||
else:
|
||||
self._list = [(k.lower(), v) for k, v in headers]
|
||||
|
||||
@property
|
||||
def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
|
||||
@ -213,7 +226,7 @@ class Request:
|
||||
method: str,
|
||||
url: typing.Union[str, URL],
|
||||
*,
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
headers: HeaderTypes = None,
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
):
|
||||
self.method = method.upper()
|
||||
@ -224,26 +237,24 @@ class Request:
|
||||
else:
|
||||
self.is_streaming = True
|
||||
self.body_aiter = body
|
||||
self.headers = self.build_headers(headers)
|
||||
self.headers = Headers(headers)
|
||||
|
||||
def build_headers(
|
||||
self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
|
||||
) -> Headers:
|
||||
has_host = False
|
||||
has_content_length = False
|
||||
has_accept_encoding = False
|
||||
|
||||
for header, value in init_headers:
|
||||
header = header.strip().lower()
|
||||
if header == b"host":
|
||||
has_host = True
|
||||
elif header in (b"content-length", b"transfer-encoding"):
|
||||
has_content_length = True
|
||||
elif header == b"accept-encoding":
|
||||
has_accept_encoding = True
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_streaming:
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
elif self.body:
|
||||
yield self.body
|
||||
|
||||
def prepare(self) -> None:
|
||||
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
|
||||
has_host = "host" in self.headers
|
||||
has_content_length = (
|
||||
"content-length" in self.headers or "transfer-encoding" in self.headers
|
||||
)
|
||||
has_accept_encoding = "accept-encoding" in self.headers
|
||||
|
||||
if not has_host:
|
||||
auto_headers.append((b"host", self.url.netloc.encode("ascii")))
|
||||
if not has_content_length:
|
||||
@ -255,14 +266,8 @@ class Request:
|
||||
if not has_accept_encoding:
|
||||
auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
|
||||
|
||||
return Headers(auto_headers + init_headers)
|
||||
|
||||
async def stream(self) -> typing.AsyncIterator[bytes]:
|
||||
if self.is_streaming:
|
||||
async for part in self.body_aiter:
|
||||
yield part
|
||||
elif self.body:
|
||||
yield self.body
|
||||
for item in reversed(auto_headers):
|
||||
self.headers.raw.insert(0, item)
|
||||
|
||||
|
||||
class Response:
|
||||
@ -275,6 +280,8 @@ class Response:
|
||||
headers: typing.List[typing.Tuple[bytes, bytes]] = [],
|
||||
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
|
||||
on_close: typing.Callable = None,
|
||||
request: Request = None,
|
||||
history: typing.List["Response"] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
if not reason:
|
||||
@ -310,6 +317,13 @@ class Response:
|
||||
else:
|
||||
self.body_aiter = body
|
||||
|
||||
self.request = request
|
||||
self.history = [] if history is None else list(history)
|
||||
|
||||
@property
|
||||
def url(self) -> typing.Optional[URL]:
|
||||
return None if self.request is None else self.request.url
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""
|
||||
Read and return the response content.
|
||||
@ -358,4 +372,6 @@ class Response:
|
||||
|
||||
@property
|
||||
def is_redirect(self) -> bool:
|
||||
return self.status_code in (301, 302, 303, 307, 308)
|
||||
return (
|
||||
self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
|
||||
)
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
from httpcore import (
|
||||
URL,
|
||||
Adapter,
|
||||
RedirectAdapter,
|
||||
RedirectLoop,
|
||||
@ -19,32 +21,60 @@ class MockDispatch(Adapter):
|
||||
|
||||
async def send(self, request: Request, **options) -> Response:
|
||||
if request.url.path == "/redirect_301": # "Moved Permanently"
|
||||
return Response(301, headers=[(b"location", b"https://example.org/")])
|
||||
return Response(
|
||||
301, headers=[(b"location", b"https://example.org/")], request=request
|
||||
)
|
||||
|
||||
elif request.url.path == "/redirect_302": # "Found"
|
||||
return Response(302, headers=[(b"location", b"https://example.org/")])
|
||||
return Response(
|
||||
302, headers=[(b"location", b"https://example.org/")], request=request
|
||||
)
|
||||
|
||||
elif request.url.path == "/redirect_303": # "See Other"
|
||||
return Response(303, headers=[(b"location", b"https://example.org/")])
|
||||
return Response(
|
||||
303, headers=[(b"location", b"https://example.org/")], request=request
|
||||
)
|
||||
|
||||
elif request.url.path == "/relative_redirect":
|
||||
return Response(codes.see_other, headers=[(b"location", b"/")])
|
||||
return Response(
|
||||
codes.see_other, headers=[(b"location", b"/")], request=request
|
||||
)
|
||||
|
||||
elif request.url.path == "/no_scheme_redirect":
|
||||
return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
|
||||
return Response(
|
||||
codes.see_other,
|
||||
headers=[(b"location", b"//example.org/")],
|
||||
request=request,
|
||||
)
|
||||
|
||||
elif request.url.path == "/multiple_redirects":
|
||||
params = parse_qs(request.url.query)
|
||||
count = int(params.get("count", "0")[0])
|
||||
redirect_count = count - 1
|
||||
code = codes.see_other if count else codes.ok
|
||||
location = "/multiple_redirects?count=" + str(count - 1)
|
||||
location = "/multiple_redirects"
|
||||
if redirect_count:
|
||||
location += "?count=" + str(redirect_count)
|
||||
headers = [(b"location", location.encode())] if count else []
|
||||
return Response(code, headers=headers)
|
||||
return Response(code, headers=headers, request=request)
|
||||
|
||||
if request.url.path == "/redirect_loop":
|
||||
return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
|
||||
return Response(
|
||||
codes.see_other,
|
||||
headers=[(b"location", b"/redirect_loop")],
|
||||
request=request,
|
||||
)
|
||||
|
||||
return Response(codes.ok, body=b"Hello, world!")
|
||||
elif request.url.path == "/cross_domain":
|
||||
location = b"https://example.org/cross_domain_target"
|
||||
return Response(301, headers=[(b"location", location)], request=request)
|
||||
|
||||
elif request.url.path == "/cross_domain_target":
|
||||
headers = {k.decode(): v.decode() for k, v in request.headers.raw}
|
||||
body = json.dumps({"headers": headers}).encode()
|
||||
return Response(codes.ok, body=body, request=request)
|
||||
|
||||
return Response(codes.ok, body=b"Hello, world!", request=request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -52,6 +82,8 @@ async def test_redirect_301():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_301")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -59,6 +91,8 @@ async def test_redirect_302():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_302")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -66,6 +100,8 @@ async def test_redirect_303():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/redirect_303")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -73,6 +109,8 @@ async def test_relative_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/relative_redirect")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -80,13 +118,19 @@ async def test_no_scheme_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/no_scheme_redirect")
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fragment_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/relative_redirect#fragment")
|
||||
response = await client.request(
|
||||
"GET", "https://example.org/relative_redirect#fragment"
|
||||
)
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/#fragment")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -96,6 +140,8 @@ async def test_multiple_redirects():
|
||||
"GET", "https://example.org/multiple_redirects?count=20"
|
||||
)
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/multiple_redirects")
|
||||
assert len(response.history) == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -110,3 +156,25 @@ async def test_redirect_loop():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
with pytest.raises(RedirectLoop):
|
||||
await client.request("GET", "https://example.org/redirect_loop")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_domain_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
headers = [(b"Authorization", b"abc")]
|
||||
url = "https://example.com/cross_domain"
|
||||
response = await client.request("GET", url, headers=headers)
|
||||
data = json.loads(response.body.decode())
|
||||
assert response.url == URL("https://example.org/cross_domain_target")
|
||||
assert data == {"headers": {}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_domain_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
headers = [(b"Authorization", b"abc")]
|
||||
url = "https://example.org/cross_domain"
|
||||
response = await client.request("GET", url, headers=headers)
|
||||
data = json.loads(response.body.decode())
|
||||
assert response.url == URL("https://example.org/cross_domain_target")
|
||||
assert data == {"headers": {"authorization": "abc"}}
|
||||
|
||||
@ -5,6 +5,7 @@ import httpcore
|
||||
|
||||
def test_host_header():
|
||||
request = httpcore.Request("GET", "http://example.org")
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
|
||||
)
|
||||
@ -12,6 +13,7 @@ def test_host_header():
|
||||
|
||||
def test_content_length_header():
|
||||
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
@ -28,6 +30,7 @@ def test_transfer_encoding_header():
|
||||
body = streaming_body(b"test 123")
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
@ -41,6 +44,7 @@ def test_override_host_header():
|
||||
headers = [(b"host", b"1.2.3.4:80")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
|
||||
)
|
||||
@ -50,6 +54,7 @@ def test_override_accept_encoding_header():
|
||||
headers = [(b"accept-encoding", b"identity")]
|
||||
|
||||
request = httpcore.Request("GET", "http://example.org", headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[(b"host", b"example.org"), (b"accept-encoding", b"identity")]
|
||||
)
|
||||
@ -63,6 +68,7 @@ def test_override_content_length_header():
|
||||
headers = [(b"content-length", b"8")]
|
||||
|
||||
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
|
||||
request.prepare()
|
||||
assert request.headers == httpcore.Headers(
|
||||
[
|
||||
(b"host", b"example.org"),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user