Record history of requests made during authentication (#718)

* Record history of requests made during authentication

* Add asserts on digest auth history

Co-Authored-By: Gaurav Dhameeja <gdhameeja@gmail.com>
This commit is contained in:
Florimond Manca 2020-01-03 22:59:16 +01:00 committed by GitHub
parent 910aa9094c
commit bc6163c55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 4 deletions

View File

@ -451,7 +451,7 @@ class AsyncClient:
raise RedirectLoop()
response = await self.send_handling_auth(
request, auth=auth, timeout=timeout,
request, history, auth=auth, timeout=timeout,
)
response.history = list(history)
@ -571,7 +571,11 @@ class AsyncClient:
return request.stream
async def send_handling_auth(
self, request: Request, auth: Auth, timeout: Timeout,
self,
request: Request,
history: typing.List[Response],
auth: Auth,
timeout: Timeout,
) -> Response:
auth_flow = auth(request)
request = next(auth_flow)
@ -585,8 +589,10 @@ class AsyncClient:
await response.aclose()
raise exc from None
else:
response.history = list(history)
await response.aread()
request = next_request
await response.aclose()
history.append(response)
async def send_single_request(
self, request: Request, timeout: Timeout,

View File

@ -6,6 +6,7 @@ import typing
import pytest
from httpx import URL, AsyncClient, DigestAuth, ProtocolError, Request, Response
from httpx.auth import Auth, AuthFlow
from httpx.config import CertTypes, TimeoutTypes, VerifyTypes
from httpx.dispatch.base import Dispatcher
@ -218,6 +219,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
assert response.status_code == 200
assert response.json() == {"auth": None}
assert len(response.history) == 0
@pytest.mark.asyncio
@ -233,6 +235,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
assert response.status_code == 200
assert response.json() == {"auth": None}
assert len(response.history) == 0
@pytest.mark.asyncio
@ -245,6 +248,7 @@ async def test_digest_auth_401_response_without_digest_auth_header() -> None:
assert response.status_code == 401
assert response.json() == {"auth": None}
assert len(response.history) == 0
@pytest.mark.parametrize(
@ -271,6 +275,8 @@ async def test_digest_auth(
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert len(response.history) == 1
authorization = typing.cast(dict, response.json())["auth"]
scheme, _, fields = authorization.partition(" ")
assert scheme == "Digest"
@ -299,6 +305,8 @@ async def test_digest_auth_no_specified_qop() -> None:
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert len(response.history) == 1
authorization = typing.cast(dict, response.json())["auth"]
scheme, _, fields = authorization.partition(" ")
assert scheme == "Digest"
@ -325,7 +333,10 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
auth = DigestAuth(username="tomchristie", password="password123")
client = AsyncClient(dispatch=MockDigestAuthDispatch(qop=qop))
await client.get(url, auth=auth)
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert len(response.history) == 1
@pytest.mark.asyncio
@ -357,6 +368,7 @@ async def test_digest_auth_incorrect_credentials() -> None:
response = await client.get(url, auth=auth)
assert response.status_code == 401
assert len(response.history) == 1
@pytest.mark.parametrize(
@ -381,3 +393,52 @@ async def test_digest_auth_raises_protocol_error_on_malformed_header(
with pytest.raises(ProtocolError):
await client.get(url, auth=auth)
@pytest.mark.asyncio
async def test_auth_history() -> None:
"""
Test that intermediate requests sent as part of an authentication flow
are recorded in the response history.
"""
class RepeatAuth(Auth):
"""
A mock authentication scheme that requires clients to send
the request a fixed number of times, and then send a last request containing
an aggregation of nonces that the server sent in 'WWW-Authenticate' headers
of intermediate responses.
"""
def __init__(self, repeat: int):
self.repeat = repeat
def __call__(self, request: Request) -> AuthFlow:
nonces = []
for index in range(self.repeat):
request.headers["Authorization"] = f"Repeat {index}"
response = yield request
nonces.append(response.headers["www-authenticate"])
key = ".".join(nonces)
request.headers["Authorization"] = f"Repeat {key}"
yield request
url = "https://example.org/"
auth = RepeatAuth(repeat=2)
client = AsyncClient(dispatch=MockDispatch(auth_header="abc"))
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "Repeat abc.abc"}
assert len(response.history) == 2
resp1, resp2 = response.history
assert resp1.json() == {"auth": "Repeat 0"}
assert resp2.json() == {"auth": "Repeat 1"}
assert len(resp2.history) == 1
assert resp2.history == [resp1]
assert len(resp1.history) == 0