Add await response.next() interface

This commit is contained in:
Tom Christie 2019-04-29 16:18:41 +01:00
parent b392b7ea59
commit 450ea25b5a
2 changed files with 27 additions and 4 deletions

View File

@ -1,3 +1,4 @@
import functools
import typing
from urllib.parse import urljoin, urlparse
@ -19,18 +20,25 @@ class RedirectAdapter(Adapter):
async def send(self, request: Request, **options: typing.Any) -> Response:
allow_redirects = options.pop("allow_redirects", True)
history = [] # type: typing.List[Response]
seen_urls = set((request.url,))
history = options.pop("history", []) # type: typing.List[Response]
seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL]
seen_urls.add(request.url)
while True:
response = await self.dispatch.send(request, **options)
response.history = list(history)
if not allow_redirects or not response.is_redirect:
if not response.is_redirect:
break
history.append(response)
request = self.build_redirect_request(request, response)
if not allow_redirects:
next_options = dict(options)
next_options["seen_urls"] = seen_urls
next_options["history"] = history
response.next = functools.partial(self.send, request=request, **next_options)
break
if len(history) > self.max_redirects:
raise TooManyRedirects()
request = self.build_redirect_request(request, response)
if request.url in seen_urls:
raise RedirectLoop()
seen_urls.add(request.url)
@ -71,6 +79,9 @@ class RedirectAdapter(Adapter):
return method
def redirect_url(self, request: Request, response: Response) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
# Handle redirection without scheme (see: RFC 1808 Section 4)
@ -94,12 +105,19 @@ class RedirectAdapter(Adapter):
return URL(url)
def redirect_headers(self, request: Request, url: URL) -> Headers:
"""
Strip Authorization headers when responses are redirected away from
the origin.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
return headers
def redirect_body(self, request: Request, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:

View File

@ -116,6 +116,11 @@ async def test_disallow_redirects():
assert response.url == URL("https://example.org/redirect_303")
assert len(response.history) == 0
response = await response.next()
assert response.status_code == codes.ok
assert response.url == URL("https://example.org/")
assert len(response.history) == 1
@pytest.mark.asyncio
async def test_relative_redirect():