Make all the tests from test_headers and test_auth to be async

This commit is contained in:
Kar Petrosyan 2025-02-27 20:01:04 +04:00
parent a803813702
commit 6baba9c1ce

View File

@ -8,7 +8,6 @@ import hashlib
import netrc
import os
import sys
import threading
import typing
from urllib.request import parse_keqv_list
@ -134,23 +133,15 @@ class ResponseBodyAuth(httpx.Auth):
yield request
class SyncOrAsyncAuth(httpx.Auth):
class AsyncAuth(httpx.Auth):
"""
A mock authentication scheme that uses a different implementation for the
sync and async cases.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._async_lock = anyio.Lock()
def sync_auth_flow(
self, request: httpx.Request
) -> typing.Generator[httpx.Request, httpx.Response, None]:
with self._lock:
request.headers["Authorization"] = "sync-auth"
yield request
async def async_auth_flow(
self, request: httpx.Request
) -> typing.AsyncGenerator[httpx.Request, httpx.Response]:
@ -234,7 +225,8 @@ async def test_custom_auth() -> None:
assert response.json() == {"auth": "Token 123"}
def test_netrc_auth_credentials_exist() -> None:
@pytest.mark.anyio
async def test_netrc_auth_credentials_exist() -> None:
"""
When netrc auth is being used and a request is made to a host that is
in the netrc file, then the relevant credentials should be applied.
@ -244,8 +236,10 @@ def test_netrc_auth_credentials_exist() -> None:
app = App()
auth = httpx.NetRCAuth(netrc_file)
with httpx.Client(transport=httpx.MockTransport(app), auth=auth) as client:
response = client.get(url)
async with httpx.AsyncClient(
transport=httpx.MockTransport(app), auth=auth
) as client:
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {
@ -253,7 +247,8 @@ def test_netrc_auth_credentials_exist() -> None:
}
def test_netrc_auth_credentials_do_not_exist() -> None:
@pytest.mark.anyio
async def test_netrc_auth_credentials_do_not_exist() -> None:
"""
When netrc auth is being used and a request is made to a host that is
not in the netrc file, then no credentials should be applied.
@ -263,8 +258,10 @@ def test_netrc_auth_credentials_do_not_exist() -> None:
app = App()
auth = httpx.NetRCAuth(netrc_file)
with httpx.Client(transport=httpx.MockTransport(app), auth=auth) as client:
response = client.get(url)
async with httpx.AsyncClient(
transport=httpx.MockTransport(app), auth=auth
) as client:
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": None}
@ -274,7 +271,8 @@ def test_netrc_auth_credentials_do_not_exist() -> None:
sys.version_info >= (3, 11),
reason="netrc files without a password are valid from Python >= 3.11",
)
def test_netrc_auth_nopassword_parse_error() -> None: # pragma: no cover
@pytest.mark.anyio
async def test_netrc_auth_nopassword_parse_error() -> None: # pragma: no cover
"""
Python has different netrc parsing behaviours with different versions.
For Python < 3.11 a netrc file with no password is invalid. In this case
@ -300,7 +298,8 @@ async def test_auth_disable_per_request() -> None:
assert response.json() == {"auth": None}
def test_auth_hidden_url() -> None:
@pytest.mark.anyio
async def test_auth_hidden_url() -> None:
url = "http://example-username:example-password@example.org/"
expected = "URL('http://example-username:[secure]@example.org/')"
assert url == httpx.URL(url)
@ -367,18 +366,19 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->
assert len(response.history) == 0
def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
@pytest.mark.anyio
async def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
url = "https://example.org/"
auth = httpx.DigestAuth(username="user", password="password123")
auth_header = "Token ..."
app = App(auth_header=auth_header, status_code=401)
client = httpx.Client(transport=httpx.MockTransport(app))
response = client.get(url, auth=auth)
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
response = await client.get(url, auth=auth)
assert response.status_code == 401
assert response.json() == {"auth": None}
assert len(response.history) == 0
assert response.status_code == 401
assert response.json() == {"auth": None}
assert len(response.history) == 0
@pytest.mark.anyio
@ -601,7 +601,7 @@ async def test_digest_auth_resets_nonce_count_after_401() -> None:
],
)
@pytest.mark.anyio
async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
async def test_digest_auth_raises_protocol_error_on_malformed_header(
auth_header: str,
) -> None:
url = "https://example.org/"
@ -613,27 +613,8 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
await client.get(url, auth=auth)
@pytest.mark.parametrize(
"auth_header",
[
'Digest realm="httpx@example.org", qop="auth"', # missing fields
'Digest realm="httpx@example.org", qop="auth,au', # malformed fields list
],
)
def test_sync_digest_auth_raises_protocol_error_on_malformed_header(
auth_header: str,
) -> None:
url = "https://example.org/"
auth = httpx.DigestAuth(username="user", password="password123")
app = App(auth_header=auth_header, status_code=401)
with httpx.Client(transport=httpx.MockTransport(app)) as client:
with pytest.raises(httpx.ProtocolError):
client.get(url, auth=auth)
@pytest.mark.anyio
async def test_async_auth_history() -> None:
async def test_auth_history() -> None:
"""
Test that intermediate requests sent as part of an authentication flow
are recorded in the response history.
@ -659,36 +640,11 @@ async def test_async_auth_history() -> None:
assert len(resp1.history) == 0
def test_sync_auth_history() -> None:
"""
Test that intermediate requests sent as part of an authentication flow
are recorded in the response history.
"""
url = "https://example.org/"
auth = RepeatAuth(repeat=2)
app = App(auth_header="abc")
with httpx.Client(transport=httpx.MockTransport(app)) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "Repeat abc.abc"}
assert len(response.history) == 2
resp1, resp2 = response.history
assert resp1.json() == {"auth": "Repeat 0"}
assert resp2.json() == {"auth": "Repeat 1"}
assert len(resp2.history) == 1
assert resp2.history == [resp1]
assert len(resp1.history) == 0
class ConsumeBodyTransport(httpx.MockTransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
assert isinstance(request.stream, httpx.AsyncByteStream)
[_ async for _ in request.stream]
async for _ in request.stream:
pass
return self.handler(request) # type: ignore[return-value]
@ -707,7 +663,7 @@ async def test_digest_auth_unavailable_streaming_body():
@pytest.mark.anyio
async def test_async_auth_reads_response_body() -> None:
async def test_auth_reads_response_body() -> None:
"""
Test that we can read the response body in an auth flow if `requires_response_body`
is set.
@ -723,31 +679,15 @@ async def test_async_auth_reads_response_body() -> None:
assert response.json() == {"auth": '{"auth":"xyz"}'}
def test_sync_auth_reads_response_body() -> None:
"""
Test that we can read the response body in an auth flow if `requires_response_body`
is set.
"""
url = "https://example.org/"
auth = ResponseBodyAuth("xyz")
app = App()
with httpx.Client(transport=httpx.MockTransport(app)) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": '{"auth":"xyz"}'}
@pytest.mark.anyio
async def test_async_auth() -> None:
async def test_auth() -> None:
"""
Test that we can use an auth implementation specific to the async case, to
support cases that require performing I/O or using concurrency primitives (such
as checking a disk-based cache or fetching a token from a remote auth server).
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()
auth = AsyncAuth()
app = App()
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
@ -755,18 +695,3 @@ async def test_async_auth() -> None:
assert response.status_code == 200
assert response.json() == {"auth": "async-auth"}
def test_sync_auth() -> None:
"""
Test that we can use an auth implementation specific to the sync case.
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()
app = App()
with httpx.Client(transport=httpx.MockTransport(app)) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "sync-auth"}