Record history of requests made during authentication (#718)
* Record history of requests made during authentication * Add asserts on digest auth history Co-Authored-By: Gaurav Dhameeja <gdhameeja@gmail.com>
This commit is contained in:
parent
910aa9094c
commit
bc6163c55a
@ -451,7 +451,7 @@ class AsyncClient:
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await self.send_handling_auth(
|
||||
request, auth=auth, timeout=timeout,
|
||||
request, history, auth=auth, timeout=timeout,
|
||||
)
|
||||
response.history = list(history)
|
||||
|
||||
@ -571,7 +571,11 @@ class AsyncClient:
|
||||
return request.stream
|
||||
|
||||
async def send_handling_auth(
|
||||
self, request: Request, auth: Auth, timeout: Timeout,
|
||||
self,
|
||||
request: Request,
|
||||
history: typing.List[Response],
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
) -> Response:
|
||||
auth_flow = auth(request)
|
||||
request = next(auth_flow)
|
||||
@ -585,8 +589,10 @@ class AsyncClient:
|
||||
await response.aclose()
|
||||
raise exc from None
|
||||
else:
|
||||
response.history = list(history)
|
||||
await response.aread()
|
||||
request = next_request
|
||||
await response.aclose()
|
||||
history.append(response)
|
||||
|
||||
async def send_single_request(
|
||||
self, request: Request, timeout: Timeout,
|
||||
|
||||
@ -6,6 +6,7 @@ import typing
|
||||
import pytest
|
||||
|
||||
from httpx import URL, AsyncClient, DigestAuth, ProtocolError, Request, Response
|
||||
from httpx.auth import Auth, AuthFlow
|
||||
from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from httpx.dispatch.base import Dispatcher
|
||||
|
||||
@ -218,6 +219,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": None}
|
||||
assert len(response.history) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -233,6 +235,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": None}
|
||||
assert len(response.history) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -245,6 +248,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"auth": None}
|
||||
assert len(response.history) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -271,6 +275,8 @@ async def test_digest_auth(
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.history) == 1
|
||||
|
||||
authorization = typing.cast(dict, response.json())["auth"]
|
||||
scheme, _, fields = authorization.partition(" ")
|
||||
assert scheme == "Digest"
|
||||
@ -299,6 +305,8 @@ async def test_digest_auth_no_specified_qop() -> None:
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.history) == 1
|
||||
|
||||
authorization = typing.cast(dict, response.json())["auth"]
|
||||
scheme, _, fields = authorization.partition(" ")
|
||||
assert scheme == "Digest"
|
||||
@ -325,7 +333,10 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
|
||||
client = AsyncClient(dispatch=MockDigestAuthDispatch(qop=qop))
|
||||
await client.get(url, auth=auth)
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -357,6 +368,7 @@ async def test_digest_auth_incorrect_credentials() -> None:
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -381,3 +393,52 @@ async def test_digest_auth_raises_protocol_error_on_malformed_header(
|
||||
|
||||
with pytest.raises(ProtocolError):
|
||||
await client.get(url, auth=auth)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_history() -> None:
|
||||
"""
|
||||
Test that intermediate requests sent as part of an authentication flow
|
||||
are recorded in the response history.
|
||||
"""
|
||||
|
||||
class RepeatAuth(Auth):
|
||||
"""
|
||||
A mock authentication scheme that requires clients to send
|
||||
the request a fixed number of times, and then send a last request containing
|
||||
an aggregation of nonces that the server sent in 'WWW-Authenticate' headers
|
||||
of intermediate responses.
|
||||
"""
|
||||
|
||||
def __init__(self, repeat: int):
|
||||
self.repeat = repeat
|
||||
|
||||
def __call__(self, request: Request) -> AuthFlow:
|
||||
nonces = []
|
||||
|
||||
for index in range(self.repeat):
|
||||
request.headers["Authorization"] = f"Repeat {index}"
|
||||
response = yield request
|
||||
nonces.append(response.headers["www-authenticate"])
|
||||
|
||||
key = ".".join(nonces)
|
||||
request.headers["Authorization"] = f"Repeat {key}"
|
||||
yield request
|
||||
|
||||
url = "https://example.org/"
|
||||
auth = RepeatAuth(repeat=2)
|
||||
client = AsyncClient(dispatch=MockDispatch(auth_header="abc"))
|
||||
|
||||
response = await client.get(url, auth=auth)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Repeat abc.abc"}
|
||||
|
||||
assert len(response.history) == 2
|
||||
resp1, resp2 = response.history
|
||||
assert resp1.json() == {"auth": "Repeat 0"}
|
||||
assert resp2.json() == {"auth": "Repeat 1"}
|
||||
|
||||
assert len(resp2.history) == 1
|
||||
assert resp2.history == [resp1]
|
||||
|
||||
assert len(resp1.history) == 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user