Tighten up response.next() (#297)
This commit is contained in:
parent
994403bec9
commit
a46462764e
@ -27,6 +27,7 @@ from .exceptions import (
|
||||
CookieConflict,
|
||||
DecodingError,
|
||||
InvalidURL,
|
||||
NotRedirectResponse,
|
||||
PoolTimeout,
|
||||
ProtocolError,
|
||||
ReadTimeout,
|
||||
@ -92,6 +93,7 @@ __all__ = [
|
||||
"CookieConflict",
|
||||
"DecodingError",
|
||||
"InvalidURL",
|
||||
"NotRedirectResponse",
|
||||
"PoolTimeout",
|
||||
"ProtocolError",
|
||||
"ReadTimeout",
|
||||
|
||||
@ -252,7 +252,7 @@ class BaseClient:
|
||||
|
||||
if response.is_redirect:
|
||||
|
||||
async def send_next() -> AsyncResponse:
|
||||
async def call_next() -> AsyncResponse:
|
||||
nonlocal request, response, verify, cert
|
||||
nonlocal allow_redirects, timeout, history
|
||||
request = self.build_redirect_request(request, response)
|
||||
@ -266,7 +266,7 @@ class BaseClient:
|
||||
)
|
||||
return response
|
||||
|
||||
response.next = send_next # type: ignore
|
||||
response.call_next = call_next # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -96,6 +96,12 @@ class RedirectLoop(RedirectError):
|
||||
"""
|
||||
|
||||
|
||||
class NotRedirectResponse(RedirectError):
|
||||
"""
|
||||
Response was not a redirect response.
|
||||
"""
|
||||
|
||||
|
||||
# Stream exceptions...
|
||||
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from .exceptions import (
|
||||
CookieConflict,
|
||||
HTTPError,
|
||||
InvalidURL,
|
||||
NotRedirectResponse,
|
||||
ResponseClosed,
|
||||
ResponseNotRead,
|
||||
StreamConsumed,
|
||||
@ -705,7 +706,7 @@ class BaseResponse:
|
||||
|
||||
self.request = request
|
||||
self.on_close = on_close
|
||||
self.next: typing.Optional[typing.Callable] = None
|
||||
self.call_next: typing.Optional[typing.Callable] = None
|
||||
|
||||
@property
|
||||
def reason_phrase(self) -> str:
|
||||
@ -950,6 +951,15 @@ class AsyncResponse(BaseResponse):
|
||||
yield part
|
||||
await self.close()
|
||||
|
||||
async def next(self) -> "AsyncResponse":
|
||||
"""
|
||||
Get the next response from a redirect response.
|
||||
"""
|
||||
if not self.is_redirect:
|
||||
raise NotRedirectResponse()
|
||||
assert self.call_next is not None
|
||||
return await self.call_next()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
@ -10,6 +10,7 @@ from httpx import (
|
||||
AsyncRequest,
|
||||
AsyncResponse,
|
||||
CertTypes,
|
||||
NotRedirectResponse,
|
||||
RedirectBodyUnavailable,
|
||||
RedirectLoop,
|
||||
TimeoutTypes,
|
||||
@ -27,7 +28,10 @@ class MockDispatch(AsyncDispatcher):
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> AsyncResponse:
|
||||
if request.url.path == "/redirect_301":
|
||||
if request.url.path == "/no_redirect":
|
||||
return AsyncResponse(codes.OK, request=request)
|
||||
|
||||
elif request.url.path == "/redirect_301":
|
||||
status_code = codes.MOVED_PERMANENTLY
|
||||
headers = {"location": "https://example.org/"}
|
||||
return AsyncResponse(status_code, headers=headers, request=request)
|
||||
@ -100,6 +104,15 @@ class MockDispatch(AsyncDispatcher):
|
||||
return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
|
||||
|
||||
|
||||
async def test_no_redirect(backend):
|
||||
client = AsyncClient(dispatch=MockDispatch(), backend=backend)
|
||||
url = "https://example.com/no_redirect"
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
with pytest.raises(NotRedirectResponse):
|
||||
await response.next()
|
||||
|
||||
|
||||
async def test_redirect_301(backend):
|
||||
client = AsyncClient(dispatch=MockDispatch(), backend=backend)
|
||||
response = await client.post("https://example.org/redirect_301")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user