Redirections
This commit is contained in:
parent
8b4ef87a89
commit
041813bc4a
@ -10,10 +10,11 @@ from .exceptions import (
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ReadTimeout,
|
||||
RedirectLoop,
|
||||
ResponseClosed,
|
||||
StreamConsumed,
|
||||
Timeout,
|
||||
TooManyRedirects
|
||||
TooManyRedirects,
|
||||
)
|
||||
from .interfaces import Adapter
|
||||
from .models import URL, Headers, Origin, Request, Response
|
||||
|
||||
@ -2,7 +2,7 @@ import typing
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import TooManyRedirects
|
||||
from ..exceptions import RedirectLoop, TooManyRedirects
|
||||
from ..interfaces import Adapter
|
||||
from ..models import URL, Request, Response
|
||||
from ..status_codes import codes
|
||||
@ -20,6 +20,7 @@ class RedirectAdapter(Adapter):
|
||||
async def send(self, request: Request, **options: typing.Any) -> Response:
|
||||
allow_redirects = options.pop("allow_redirects", True)
|
||||
history = []
|
||||
seen_urls = set((request.url,))
|
||||
|
||||
while True:
|
||||
response = await self.dispatch.send(request, **options)
|
||||
@ -29,6 +30,9 @@ class RedirectAdapter(Adapter):
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
request = self.build_redirect_request(request, response)
|
||||
if request.url in seen_urls:
|
||||
raise RedirectLoop()
|
||||
seen_urls.add(request.url)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -28,12 +28,24 @@ class PoolTimeout(Timeout):
|
||||
"""
|
||||
|
||||
|
||||
class TooManyRedirects(Exception):
|
||||
class RedirectError(Exception):
|
||||
"""
|
||||
Base class for HTTP redirect errors.
|
||||
"""
|
||||
|
||||
|
||||
class TooManyRedirects(RedirectError):
|
||||
"""
|
||||
Too many redirects.
|
||||
"""
|
||||
|
||||
|
||||
class RedirectLoop(RedirectError):
|
||||
"""
|
||||
Infinite redirect loop.
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
"""
|
||||
Malformed HTTP.
|
||||
|
||||
@ -70,6 +70,12 @@ class URL:
|
||||
def origin(self) -> "Origin":
|
||||
return Origin(self)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, URL) and str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.components.geturl()
|
||||
|
||||
|
||||
@ -1,7 +1,16 @@
|
||||
import pytest
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from httpcore import Adapter, RedirectAdapter, Request, Response, TooManyRedirects, codes
|
||||
import pytest
|
||||
|
||||
from httpcore import (
|
||||
Adapter,
|
||||
RedirectAdapter,
|
||||
RedirectLoop,
|
||||
Request,
|
||||
Response,
|
||||
TooManyRedirects,
|
||||
codes,
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Adapter):
|
||||
@ -9,12 +18,21 @@ class MockDispatch(Adapter):
|
||||
pass
|
||||
|
||||
async def send(self, request: Request, **options) -> Response:
|
||||
if request.url.path == "/redirect_303":
|
||||
return Response(
|
||||
codes.see_other, headers=[(b"location", b"https://example.org/")]
|
||||
)
|
||||
if request.url.path == "/redirect_301": # "Moved Permanently"
|
||||
return Response(301, headers=[(b"location", b"https://example.org/")])
|
||||
|
||||
elif request.url.path == "/redirect_302": # "Found"
|
||||
return Response(302, headers=[(b"location", b"https://example.org/")])
|
||||
|
||||
elif request.url.path == "/redirect_303": # "See Other"
|
||||
return Response(303, headers=[(b"location", b"https://example.org/")])
|
||||
|
||||
elif request.url.path == "/relative_redirect":
|
||||
return Response(codes.see_other, headers=[(b"location", b"/")])
|
||||
|
||||
elif request.url.path == "/no_scheme_redirect":
|
||||
return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
|
||||
|
||||
elif request.url.path == "/multiple_redirects":
|
||||
params = parse_qs(request.url.query)
|
||||
count = int(params.get("count", "0")[0])
|
||||
@ -22,9 +40,27 @@ class MockDispatch(Adapter):
|
||||
location = "/multiple_redirects?count=" + str(count - 1)
|
||||
headers = [(b"location", location.encode())] if count else []
|
||||
return Response(code, headers=headers)
|
||||
|
||||
if request.url.path == "/redirect_loop":
|
||||
return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
|
||||
|
||||
return Response(codes.ok, body=b"Hello, world!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_301():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_301")
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_302():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("POST", "https://example.org/redirect_302")
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_303():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
@ -39,10 +75,26 @@ async def test_relative_redirect():
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_scheme_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/no_scheme_redirect")
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fragment_redirect():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/relative_redirect#fragment")
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_redirects():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
response = await client.request("GET", "https://example.org/multiple_redirects?count=20")
|
||||
response = await client.request(
|
||||
"GET", "https://example.org/multiple_redirects?count=20"
|
||||
)
|
||||
assert response.status_code == codes.ok
|
||||
|
||||
|
||||
@ -51,3 +103,10 @@ async def test_too_many_redirects():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
with pytest.raises(TooManyRedirects):
|
||||
await client.request("GET", "https://example.org/multiple_redirects?count=21")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redirect_loop():
|
||||
client = RedirectAdapter(MockDispatch())
|
||||
with pytest.raises(RedirectLoop):
|
||||
await client.request("GET", "https://example.org/redirect_loop")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user