Redirections

This commit is contained in:
Tom Christie 2019-04-29 12:35:58 +01:00
parent 8b4ef87a89
commit 041813bc4a
5 changed files with 92 additions and 10 deletions

View File

@ -10,10 +10,11 @@ from .exceptions import (
PoolTimeout,
ProtocolError,
ReadTimeout,
RedirectLoop,
ResponseClosed,
StreamConsumed,
Timeout,
TooManyRedirects
TooManyRedirects,
)
from .interfaces import Adapter
from .models import URL, Headers, Origin, Request, Response

View File

@ -2,7 +2,7 @@ import typing
from urllib.parse import urljoin, urlparse
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import TooManyRedirects
from ..exceptions import RedirectLoop, TooManyRedirects
from ..interfaces import Adapter
from ..models import URL, Request, Response
from ..status_codes import codes
@ -20,6 +20,7 @@ class RedirectAdapter(Adapter):
async def send(self, request: Request, **options: typing.Any) -> Response:
allow_redirects = options.pop("allow_redirects", True)
history = []
seen_urls = set((request.url,))
while True:
response = await self.dispatch.send(request, **options)
@ -29,6 +30,9 @@ class RedirectAdapter(Adapter):
if len(history) > self.max_redirects:
raise TooManyRedirects()
request = self.build_redirect_request(request, response)
if request.url in seen_urls:
raise RedirectLoop()
seen_urls.add(request.url)
return response

View File

@ -28,12 +28,24 @@ class PoolTimeout(Timeout):
"""
class TooManyRedirects(Exception):
class RedirectError(Exception):
"""
Base class for HTTP redirect errors.
"""
class TooManyRedirects(RedirectError):
"""
Too many redirects.
"""
class RedirectLoop(RedirectError):
"""
Infinite redirect loop.
"""
class ProtocolError(Exception):
"""
Malformed HTTP.

View File

@ -70,6 +70,12 @@ class URL:
def origin(self) -> "Origin":
return Origin(self)
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, URL) and str(self) == str(other)
def __str__(self) -> str:
return self.components.geturl()

View File

@ -1,7 +1,16 @@
import pytest
from urllib.parse import parse_qs
from httpcore import Adapter, RedirectAdapter, Request, Response, TooManyRedirects, codes
import pytest
from httpcore import (
Adapter,
RedirectAdapter,
RedirectLoop,
Request,
Response,
TooManyRedirects,
codes,
)
class MockDispatch(Adapter):
@ -9,12 +18,21 @@ class MockDispatch(Adapter):
pass
async def send(self, request: Request, **options) -> Response:
if request.url.path == "/redirect_303":
return Response(
codes.see_other, headers=[(b"location", b"https://example.org/")]
)
if request.url.path == "/redirect_301": # "Moved Permanently"
return Response(301, headers=[(b"location", b"https://example.org/")])
elif request.url.path == "/redirect_302": # "Found"
return Response(302, headers=[(b"location", b"https://example.org/")])
elif request.url.path == "/redirect_303": # "See Other"
return Response(303, headers=[(b"location", b"https://example.org/")])
elif request.url.path == "/relative_redirect":
return Response(codes.see_other, headers=[(b"location", b"/")])
elif request.url.path == "/no_scheme_redirect":
return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
count = int(params.get("count", "0")[0])
@ -22,9 +40,27 @@ class MockDispatch(Adapter):
location = "/multiple_redirects?count=" + str(count - 1)
headers = [(b"location", location.encode())] if count else []
return Response(code, headers=headers)
if request.url.path == "/redirect_loop":
return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
return Response(codes.ok, body=b"Hello, world!")
@pytest.mark.asyncio
async def test_redirect_301():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_301")
assert response.status_code == codes.ok
@pytest.mark.asyncio
async def test_redirect_302():
client = RedirectAdapter(MockDispatch())
response = await client.request("POST", "https://example.org/redirect_302")
assert response.status_code == codes.ok
@pytest.mark.asyncio
async def test_redirect_303():
client = RedirectAdapter(MockDispatch())
@ -39,10 +75,26 @@ async def test_relative_redirect():
assert response.status_code == codes.ok
@pytest.mark.asyncio
async def test_no_scheme_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/no_scheme_redirect")
assert response.status_code == codes.ok
@pytest.mark.asyncio
async def test_fragment_redirect():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/relative_redirect#fragment")
assert response.status_code == codes.ok
@pytest.mark.asyncio
async def test_multiple_redirects():
client = RedirectAdapter(MockDispatch())
response = await client.request("GET", "https://example.org/multiple_redirects?count=20")
response = await client.request(
"GET", "https://example.org/multiple_redirects?count=20"
)
assert response.status_code == codes.ok
@ -51,3 +103,10 @@ async def test_too_many_redirects():
client = RedirectAdapter(MockDispatch())
with pytest.raises(TooManyRedirects):
await client.request("GET", "https://example.org/multiple_redirects?count=21")
@pytest.mark.asyncio
async def test_redirect_loop():
client = RedirectAdapter(MockDispatch())
with pytest.raises(RedirectLoop):
await client.request("GET", "https://example.org/redirect_loop")