Add await response.next() interface
This commit is contained in:
parent
b392b7ea59
commit
450ea25b5a
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import typing
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
@ -19,18 +20,25 @@ class RedirectAdapter(Adapter):
|
||||
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
allow_redirects = options.pop("allow_redirects", True)
|
||||
history = [] # type: typing.List[Response]
|
||||
seen_urls = set((request.url,))
|
||||
history = options.pop("history", []) # type: typing.List[Response]
|
||||
seen_urls = options.pop("seen_urls", set()) # type: typing.Set[URL]
|
||||
seen_urls.add(request.url)
|
||||
|
||||
while True:
|
||||
response = await self.dispatch.send(request, **options)
|
||||
response.history = list(history)
|
||||
if not allow_redirects or not response.is_redirect:
|
||||
if not response.is_redirect:
|
||||
break
|
||||
history.append(response)
|
||||
request = self.build_redirect_request(request, response)
|
||||
if not allow_redirects:
|
||||
next_options = dict(options)
|
||||
next_options["seen_urls"] = seen_urls
|
||||
next_options["history"] = history
|
||||
response.next = functools.partial(self.send, request=request, **next_options)
|
||||
break
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
request = self.build_redirect_request(request, response)
|
||||
if request.url in seen_urls:
|
||||
raise RedirectLoop()
|
||||
seen_urls.add(request.url)
|
||||
@ -71,6 +79,9 @@ class RedirectAdapter(Adapter):
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
# Handle redirection without scheme (see: RFC 1808 Section 4)
|
||||
@ -94,12 +105,19 @@ class RedirectAdapter(Adapter):
|
||||
return URL(url)
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL) -> Headers:
|
||||
"""
|
||||
Strip Authorization headers when responses are redirected away from
|
||||
the origin.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
if url.origin != request.url.origin:
|
||||
del headers["Authorization"]
|
||||
return headers
|
||||
|
||||
def redirect_body(self, request: Request, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
|
||||
@ -116,6 +116,11 @@ async def test_disallow_redirects():
|
||||
assert response.url == URL("https://example.org/redirect_303")
|
||||
assert len(response.history) == 0
|
||||
|
||||
response = await response.next()
|
||||
assert response.status_code == codes.ok
|
||||
assert response.url == URL("https://example.org/")
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relative_redirect():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user