Tighten up response.next() (#297)

This commit is contained in:
Florimond Manca 2019-08-31 16:09:35 +02:00 committed by Seth Michael Larson
parent 994403bec9
commit a46462764e
5 changed files with 35 additions and 4 deletions

View File

@ -27,6 +27,7 @@ from .exceptions import (
CookieConflict,
DecodingError,
InvalidURL,
NotRedirectResponse,
PoolTimeout,
ProtocolError,
ReadTimeout,
@ -92,6 +93,7 @@ __all__ = [
"CookieConflict",
"DecodingError",
"InvalidURL",
"NotRedirectResponse",
"PoolTimeout",
"ProtocolError",
"ReadTimeout",

View File

@ -252,7 +252,7 @@ class BaseClient:
if response.is_redirect:
async def send_next() -> AsyncResponse:
async def call_next() -> AsyncResponse:
nonlocal request, response, verify, cert
nonlocal allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
@ -266,7 +266,7 @@ class BaseClient:
)
return response
response.next = send_next # type: ignore
response.call_next = call_next # type: ignore
return response

View File

@ -96,6 +96,12 @@ class RedirectLoop(RedirectError):
"""
class NotRedirectResponse(RedirectError):
"""
Response was not a redirect response.
"""
# Stream exceptions...

View File

@ -23,6 +23,7 @@ from .exceptions import (
CookieConflict,
HTTPError,
InvalidURL,
NotRedirectResponse,
ResponseClosed,
ResponseNotRead,
StreamConsumed,
@ -705,7 +706,7 @@ class BaseResponse:
self.request = request
self.on_close = on_close
self.next: typing.Optional[typing.Callable] = None
self.call_next: typing.Optional[typing.Callable] = None
@property
def reason_phrase(self) -> str:
@ -950,6 +951,15 @@ class AsyncResponse(BaseResponse):
yield part
await self.close()
async def next(self) -> "AsyncResponse":
"""
Get the next response from a redirect response.
"""
if not self.is_redirect:
raise NotRedirectResponse()
assert self.call_next is not None
return await self.call_next()
async def close(self) -> None:
"""
Close the response and release the connection.

View File

@ -10,6 +10,7 @@ from httpx import (
AsyncRequest,
AsyncResponse,
CertTypes,
NotRedirectResponse,
RedirectBodyUnavailable,
RedirectLoop,
TimeoutTypes,
@ -27,7 +28,10 @@ class MockDispatch(AsyncDispatcher):
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
if request.url.path == "/redirect_301":
if request.url.path == "/no_redirect":
return AsyncResponse(codes.OK, request=request)
elif request.url.path == "/redirect_301":
status_code = codes.MOVED_PERMANENTLY
headers = {"location": "https://example.org/"}
return AsyncResponse(status_code, headers=headers, request=request)
@ -100,6 +104,15 @@ class MockDispatch(AsyncDispatcher):
return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
async def test_no_redirect(backend):
client = AsyncClient(dispatch=MockDispatch(), backend=backend)
url = "https://example.com/no_redirect"
response = await client.get(url)
assert response.status_code == 200
with pytest.raises(NotRedirectResponse):
await response.next()
async def test_redirect_301(backend):
client = AsyncClient(dispatch=MockDispatch(), backend=backend)
response = await client.post("https://example.org/redirect_301")