reuse the digest auth state to avoid unnecessary requests (#2463)
* reuse the digest auth challenge to avoid sending twice as many requests * fix for digest testcase * ran testing/linting scripts * codereview changes, removed tomchristie username from all authentication tests Co-authored-by: Philipp Reitter <p.reitter@accessio.at> Co-authored-by: Tom Christie <tom@tomchristie.com>
This commit is contained in:
parent
69e13cbc39
commit
8327e13454
@ -158,8 +158,15 @@ class DigestAuth(Auth):
|
||||
) -> None:
|
||||
self._username = to_bytes(username)
|
||||
self._password = to_bytes(password)
|
||||
self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
|
||||
self._nonce_count = 1
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
if self._last_challenge:
|
||||
request.headers["Authorization"] = self._build_auth_header(
|
||||
request, self._last_challenge
|
||||
)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code != 401 or "www-authenticate" not in response.headers:
|
||||
@ -175,8 +182,12 @@ class DigestAuth(Auth):
|
||||
# header, then we don't need to build an authenticated request.
|
||||
return
|
||||
|
||||
challenge = self._parse_challenge(request, response, auth_header)
|
||||
request.headers["Authorization"] = self._build_auth_header(request, challenge)
|
||||
self._last_challenge = self._parse_challenge(request, response, auth_header)
|
||||
self._nonce_count = 1
|
||||
|
||||
request.headers["Authorization"] = self._build_auth_header(
|
||||
request, self._last_challenge
|
||||
)
|
||||
yield request
|
||||
|
||||
def _parse_challenge(
|
||||
@ -225,9 +236,9 @@ class DigestAuth(Auth):
|
||||
# TODO: implement auth-int
|
||||
HA2 = digest(A2)
|
||||
|
||||
nonce_count = 1 # TODO: implement nonce counting
|
||||
nc_value = b"%08x" % nonce_count
|
||||
cnonce = self._get_client_nonce(nonce_count, challenge.nonce)
|
||||
nc_value = b"%08x" % self._nonce_count
|
||||
cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
|
||||
self._nonce_count += 1
|
||||
|
||||
HA1 = digest(A1)
|
||||
if challenge.algorithm.lower().endswith("-sess"):
|
||||
|
||||
@ -8,6 +8,7 @@ import hashlib
|
||||
import os
|
||||
import threading
|
||||
import typing
|
||||
from urllib.request import parse_keqv_list
|
||||
|
||||
import pytest
|
||||
|
||||
@ -151,14 +152,14 @@ class SyncOrAsyncAuth(Auth):
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_auth() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = ("tomchristie", "password123")
|
||||
auth = ("user", "password123")
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -167,7 +168,7 @@ async def test_basic_auth_with_stream() -> None:
|
||||
See: https://github.com/encode/httpx/pull/1312
|
||||
"""
|
||||
url = "https://example.org/"
|
||||
auth = ("tomchristie", "password123")
|
||||
auth = ("user", "password123")
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
@ -177,25 +178,25 @@ async def test_basic_auth_with_stream() -> None:
|
||||
await response.aread()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_auth_in_url() -> None:
|
||||
url = "https://tomchristie:password123@example.org/"
|
||||
url = "https://user:password123@example.org/"
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
response = await client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_auth_on_session() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = ("tomchristie", "password123")
|
||||
auth = ("user", "password123")
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
@ -204,7 +205,7 @@ async def test_basic_auth_on_session() -> None:
|
||||
response = await client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -279,7 +280,7 @@ async def test_trust_env_auth() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_disable_per_request() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = ("tomchristie", "password123")
|
||||
auth = ("user", "password123")
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
@ -317,13 +318,13 @@ async def test_auth_property() -> None:
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
assert client.auth is None
|
||||
|
||||
client.auth = ("tomchristie", "password123") # type: ignore
|
||||
client.auth = ("user", "password123") # type: ignore
|
||||
assert isinstance(client.auth, BasicAuth)
|
||||
|
||||
url = "https://example.org/"
|
||||
response = await client.get(url)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
|
||||
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -347,7 +348,7 @@ async def test_auth_invalid_type() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = App()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -360,7 +361,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
|
||||
|
||||
def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
auth_header = "Token ..."
|
||||
app = App(auth_header=auth_header, status_code=401)
|
||||
|
||||
@ -375,7 +376,7 @@ def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_200_response_including_digest_auth_header() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
|
||||
app = App(auth_header=auth_header, status_code=200)
|
||||
|
||||
@ -390,7 +391,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_401_response_without_digest_auth_header() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = App(auth_header="", status_code=401)
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -419,7 +420,7 @@ async def test_digest_auth(
|
||||
algorithm: str, expected_hash_length: int, expected_response_length: int
|
||||
) -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(algorithm=algorithm)
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -435,7 +436,7 @@ async def test_digest_auth(
|
||||
response_fields = [field.strip() for field in fields.split(",")]
|
||||
digest_data = dict(field.split("=") for field in response_fields)
|
||||
|
||||
assert digest_data["username"] == '"tomchristie"'
|
||||
assert digest_data["username"] == '"user"'
|
||||
assert digest_data["realm"] == '"httpx@example.org"'
|
||||
assert "nonce" in digest_data
|
||||
assert digest_data["uri"] == '"/"'
|
||||
@ -450,7 +451,7 @@ async def test_digest_auth(
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_no_specified_qop() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(qop="")
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -469,7 +470,7 @@ async def test_digest_auth_no_specified_qop() -> None:
|
||||
assert "qop" not in digest_data
|
||||
assert "nc" not in digest_data
|
||||
assert "cnonce" not in digest_data
|
||||
assert digest_data["username"] == '"tomchristie"'
|
||||
assert digest_data["username"] == '"user"'
|
||||
assert digest_data["realm"] == '"httpx@example.org"'
|
||||
assert len(digest_data["nonce"]) == 64 + 2 # extra quotes
|
||||
assert digest_data["uri"] == '"/"'
|
||||
@ -482,7 +483,7 @@ async def test_digest_auth_no_specified_qop() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str) -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(qop=qop)
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -495,7 +496,7 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_qop_auth_int_not_implemented() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(qop="auth-int")
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -506,7 +507,7 @@ async def test_digest_auth_qop_auth_int_not_implemented() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(qop="not-auth")
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -517,7 +518,7 @@ async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_incorrect_credentials() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp(send_response_after_attempt=2)
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -527,6 +528,62 @@ async def test_digest_auth_incorrect_credentials() -> None:
|
||||
assert len(response.history) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_reuses_challenge() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
response_1 = await client.get(url, auth=auth)
|
||||
response_2 = await client.get(url, auth=auth)
|
||||
|
||||
assert response_1.status_code == 200
|
||||
assert response_2.status_code == 200
|
||||
|
||||
assert len(response_1.history) == 1
|
||||
assert len(response_2.history) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_resets_nonce_count_after_401() -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp()
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
response_1 = await client.get(url, auth=auth)
|
||||
assert response_1.status_code == 200
|
||||
assert len(response_1.history) == 1
|
||||
|
||||
first_nonce = parse_keqv_list(
|
||||
response_1.request.headers["Authorization"].split(", ")
|
||||
)["nonce"]
|
||||
first_nc = parse_keqv_list(
|
||||
response_1.request.headers["Authorization"].split(", ")
|
||||
)["nc"]
|
||||
|
||||
# with this we now force a 401 on a subsequent (but initial) request
|
||||
app.send_response_after_attempt = 2
|
||||
|
||||
# we expect the client again to try to authenticate, i.e. the history length must be 1
|
||||
response_2 = await client.get(url, auth=auth)
|
||||
assert response_2.status_code == 200
|
||||
assert len(response_2.history) == 1
|
||||
|
||||
second_nonce = parse_keqv_list(
|
||||
response_2.request.headers["Authorization"].split(", ")
|
||||
)["nonce"]
|
||||
second_nc = parse_keqv_list(
|
||||
response_2.request.headers["Authorization"].split(", ")
|
||||
)["nc"]
|
||||
|
||||
assert first_nonce != second_nonce # ensures that the auth challenge was reset
|
||||
assert (
|
||||
first_nc == second_nc
|
||||
) # ensures the nonce count is reset when the authentication failed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth_header",
|
||||
[
|
||||
@ -539,7 +596,7 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
|
||||
auth_header: str,
|
||||
) -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = App(auth_header=auth_header, status_code=401)
|
||||
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
|
||||
@ -558,7 +615,7 @@ def test_sync_digest_auth_raises_protocol_error_on_malformed_header(
|
||||
auth_header: str,
|
||||
) -> None:
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = App(auth_header=auth_header, status_code=401)
|
||||
|
||||
with httpx.Client(transport=httpx.MockTransport(app)) as client:
|
||||
@ -629,7 +686,7 @@ class ConsumeBodyTransport(httpx.MockTransport):
|
||||
@pytest.mark.asyncio
|
||||
async def test_digest_auth_unavailable_streaming_body():
|
||||
url = "https://example.org/"
|
||||
auth = DigestAuth(username="tomchristie", password="password123")
|
||||
auth = DigestAuth(username="user", password="password123")
|
||||
app = DigestApp()
|
||||
|
||||
async def streaming_body():
|
||||
|
||||
@ -3,6 +3,8 @@ Unit tests for auth classes.
|
||||
|
||||
Integration tests also exist in tests/client/test_auth.py
|
||||
"""
|
||||
from urllib.request import parse_keqv_list
|
||||
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
@ -61,3 +63,41 @@ def test_digest_auth_with_401():
|
||||
response = httpx.Response(content=b"Hello, world!", status_code=200)
|
||||
with pytest.raises(StopIteration):
|
||||
flow.send(response)
|
||||
|
||||
|
||||
def test_digest_auth_with_401_nonce_counting():
|
||||
auth = httpx.DigestAuth(username="user", password="pass")
|
||||
request = httpx.Request("GET", "https://www.example.com")
|
||||
|
||||
# The initial request should not include an auth header.
|
||||
flow = auth.sync_auth_flow(request)
|
||||
request = next(flow)
|
||||
assert "Authorization" not in request.headers
|
||||
|
||||
# If a 401 response is returned, then a digest auth request is made.
|
||||
headers = {
|
||||
"WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
|
||||
}
|
||||
response = httpx.Response(
|
||||
content=b"Auth required", status_code=401, headers=headers
|
||||
)
|
||||
first_request = flow.send(response)
|
||||
assert first_request.headers["Authorization"].startswith("Digest")
|
||||
|
||||
# Each subsequent request contains the digest header by default...
|
||||
request = httpx.Request("GET", "https://www.example.com")
|
||||
flow = auth.sync_auth_flow(request)
|
||||
second_request = next(flow)
|
||||
assert second_request.headers["Authorization"].startswith("Digest")
|
||||
|
||||
# ... and the client nonce count (nc) is increased
|
||||
first_nc = parse_keqv_list(first_request.headers["Authorization"].split(", "))["nc"]
|
||||
second_nc = parse_keqv_list(second_request.headers["Authorization"].split(", "))[
|
||||
"nc"
|
||||
]
|
||||
assert int(first_nc, 16) + 1 == int(second_nc, 16)
|
||||
|
||||
# No other requests are made.
|
||||
response = httpx.Response(content=b"Hello, world!", status_code=200)
|
||||
with pytest.raises(StopIteration):
|
||||
flow.send(response)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user