Add support for sync-specific or async-specific auth flows (#1217)
* Add support for async auth flows * Move body logic to Auth, add sync_auth_flow, add NoAuth * Update tests * Stick to next() / __anext__() * Fix undefined name errors * Add docs * Add unit tests for auth classes Co-authored-by: Tom Christie <tom@tomchristie.com>
This commit is contained in:
parent
15187e7c21
commit
016e4ee210
@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth):
|
||||
...
|
||||
```
|
||||
|
||||
If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import threading
|
||||
import httpx
|
||||
|
||||
|
||||
class MyCustomAuth(httpx.Auth):
|
||||
def __init__(self):
|
||||
self._sync_lock = threading.RLock()
|
||||
self._async_lock = asyncio.Lock()
|
||||
|
||||
def sync_get_token(self):
|
||||
with self._sync_lock:
|
||||
...
|
||||
|
||||
def sync_auth_flow(self, request):
|
||||
token = self.sync_get_token()
|
||||
request.headers["Authorization"] = f"Token {token}"
|
||||
yield request
|
||||
|
||||
async def async_get_token(self):
|
||||
async with self._async_lock:
|
||||
...
|
||||
|
||||
async def async_auth_flow(self, request):
|
||||
token = await self.async_get_token()
|
||||
request.headers["Authorization"] = f"Token {token}"
|
||||
yield request
|
||||
```
|
||||
|
||||
If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`.
|
||||
|
||||
```python
|
||||
import httpx
|
||||
import sync_only_library
|
||||
|
||||
|
||||
class MyCustomAuth(httpx.Auth):
|
||||
def sync_auth_flow(self, request):
|
||||
token = sync_only_library.get_token(...)
|
||||
request.headers["Authorization"] = f"Token {token}"
|
||||
yield request
|
||||
|
||||
async def async_auth_flow(self, request):
|
||||
raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient")
|
||||
```
|
||||
|
||||
## SSL certificates
|
||||
|
||||
When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).
|
||||
|
||||
@ -17,6 +17,11 @@ class Auth:
|
||||
|
||||
To implement a custom authentication scheme, subclass `Auth` and override
|
||||
the `.auth_flow()` method.
|
||||
|
||||
If the authentication scheme does I/O such as disk access or network calls, or uses
|
||||
synchronization primitives such as locks, you should override `.sync_auth_flow()`
|
||||
and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
|
||||
implementations that will be used by `Client` and `AsyncClient` respectively.
|
||||
"""
|
||||
|
||||
requires_request_body = False
|
||||
@ -46,6 +51,56 @@ class Auth:
|
||||
"""
|
||||
yield request
|
||||
|
||||
def sync_auth_flow(
|
||||
self, request: Request
|
||||
) -> typing.Generator[Request, Response, None]:
|
||||
"""
|
||||
Execute the authentication flow synchronously.
|
||||
|
||||
By default, this defers to `.auth_flow()`. You should override this method
|
||||
when the authentication scheme does I/O and/or uses concurrency primitives.
|
||||
"""
|
||||
if self.requires_request_body:
|
||||
request.read()
|
||||
|
||||
flow = self.auth_flow(request)
|
||||
request = next(flow)
|
||||
|
||||
while True:
|
||||
response = yield request
|
||||
if self.requires_response_body:
|
||||
response.read()
|
||||
|
||||
try:
|
||||
request = flow.send(response)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
async def async_auth_flow(
|
||||
self, request: Request
|
||||
) -> typing.AsyncGenerator[Request, Response]:
|
||||
"""
|
||||
Execute the authentication flow asynchronously.
|
||||
|
||||
By default, this defers to `.auth_flow()`. You should override this method
|
||||
when the authentication scheme does I/O and/or uses concurrency primitives.
|
||||
"""
|
||||
if self.requires_request_body:
|
||||
await request.aread()
|
||||
|
||||
flow = self.auth_flow(request)
|
||||
request = next(flow)
|
||||
|
||||
while True:
|
||||
response = yield request
|
||||
if self.requires_response_body:
|
||||
await response.aread()
|
||||
|
||||
try:
|
||||
request = flow.send(response)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
|
||||
class FunctionAuth(Auth):
|
||||
"""
|
||||
|
||||
@ -785,15 +785,12 @@ class Client(BaseClient):
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
) -> Response:
|
||||
if auth.requires_request_body:
|
||||
request.read()
|
||||
|
||||
auth_flow = auth.auth_flow(request)
|
||||
auth_flow = auth.sync_auth_flow(request)
|
||||
request = next(auth_flow)
|
||||
|
||||
while True:
|
||||
response = self._send_single_request(request, timeout)
|
||||
if auth.requires_response_body:
|
||||
response.read()
|
||||
|
||||
try:
|
||||
next_request = auth_flow.send(response)
|
||||
except StopIteration:
|
||||
@ -1409,18 +1406,15 @@ class AsyncClient(BaseClient):
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
) -> Response:
|
||||
if auth.requires_request_body:
|
||||
await request.aread()
|
||||
auth_flow = auth.async_auth_flow(request)
|
||||
request = await auth_flow.__anext__()
|
||||
|
||||
auth_flow = auth.auth_flow(request)
|
||||
request = next(auth_flow)
|
||||
while True:
|
||||
response = await self._send_single_request(request, timeout)
|
||||
if auth.requires_response_body:
|
||||
await response.aread()
|
||||
|
||||
try:
|
||||
next_request = auth_flow.send(response)
|
||||
except StopIteration:
|
||||
next_request = await auth_flow.asend(response)
|
||||
except StopAsyncIteration:
|
||||
return response
|
||||
except BaseException as exc:
|
||||
await response.aclose()
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
"""
|
||||
Integration tests for authentication.
|
||||
|
||||
Unit tests for auth classes also exist in tests/test_auth.py
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import threading
|
||||
import typing
|
||||
|
||||
import httpcore
|
||||
@ -183,6 +190,31 @@ class ResponseBodyAuth(Auth):
|
||||
yield request
|
||||
|
||||
|
||||
class SyncOrAsyncAuth(Auth):
|
||||
"""
|
||||
A mock authentication scheme that uses a different implementation for the
|
||||
sync and async cases.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._async_lock = asyncio.Lock()
|
||||
|
||||
def sync_auth_flow(
|
||||
self, request: Request
|
||||
) -> typing.Generator[Request, Response, None]:
|
||||
with self._lock:
|
||||
request.headers["Authorization"] = "sync-auth"
|
||||
yield request
|
||||
|
||||
async def async_auth_flow(
|
||||
self, request: Request
|
||||
) -> typing.AsyncGenerator[Request, Response]:
|
||||
async with self._async_lock:
|
||||
request.headers["Authorization"] = "async-auth"
|
||||
yield request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_auth() -> None:
|
||||
url = "https://example.org/"
|
||||
@ -664,3 +696,34 @@ def test_sync_auth_reads_response_body() -> None:
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": '{"auth": "xyz"}'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_auth() -> None:
|
||||
"""
|
||||
Test that we can use an auth implementation specific to the async case, to
|
||||
support cases that require performing I/O or using concurrency primitives (such
|
||||
as checking a disk-based cache or fetching a token from a remote auth server).
|
||||
"""
|
||||
url = "https://example.org/"
|
||||
auth = SyncOrAsyncAuth()
|
||||
|
||||
async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
|
||||
response = await client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "async-auth"}
|
||||
|
||||
|
||||
def test_sync_auth() -> None:
|
||||
"""
|
||||
Test that we can use an auth implementation specific to the sync case.
|
||||
"""
|
||||
url = "https://example.org/"
|
||||
auth = SyncOrAsyncAuth()
|
||||
|
||||
with httpx.Client(transport=SyncMockTransport()) as client:
|
||||
response = client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"auth": "sync-auth"}
|
||||
|
||||
63
tests/test_auth.py
Normal file
63
tests/test_auth.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Unit tests for auth classes.
|
||||
|
||||
Integration tests also exist in tests/client/test_auth.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def test_basic_auth():
|
||||
auth = httpx.BasicAuth(username="user", password="pass")
|
||||
request = httpx.Request("GET", "https://www.example.com")
|
||||
|
||||
# The initial request should include a basic auth header.
|
||||
flow = auth.sync_auth_flow(request)
|
||||
request = next(flow)
|
||||
assert request.headers["Authorization"].startswith("Basic")
|
||||
|
||||
# No other requests are made.
|
||||
response = httpx.Response(content=b"Hello, world!", status_code=200)
|
||||
with pytest.raises(StopIteration):
|
||||
flow.send(response)
|
||||
|
||||
|
||||
def test_digest_auth_with_200():
|
||||
auth = httpx.DigestAuth(username="user", password="pass")
|
||||
request = httpx.Request("GET", "https://www.example.com")
|
||||
|
||||
# The initial request should not include an auth header.
|
||||
flow = auth.sync_auth_flow(request)
|
||||
request = next(flow)
|
||||
assert "Authorization" not in request.headers
|
||||
|
||||
# If a 200 response is returned, then no other requests are made.
|
||||
response = httpx.Response(content=b"Hello, world!", status_code=200)
|
||||
with pytest.raises(StopIteration):
|
||||
flow.send(response)
|
||||
|
||||
|
||||
def test_digest_auth_with_401():
|
||||
auth = httpx.DigestAuth(username="user", password="pass")
|
||||
request = httpx.Request("GET", "https://www.example.com")
|
||||
|
||||
# The initial request should not include an auth header.
|
||||
flow = auth.sync_auth_flow(request)
|
||||
request = next(flow)
|
||||
assert "Authorization" not in request.headers
|
||||
|
||||
# If a 401 response is returned, then a digest auth request is made.
|
||||
headers = {
|
||||
"WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
|
||||
}
|
||||
response = httpx.Response(
|
||||
content=b"Auth required", status_code=401, headers=headers
|
||||
)
|
||||
request = flow.send(response)
|
||||
assert request.headers["Authorization"].startswith("Digest")
|
||||
|
||||
# No other requests are made.
|
||||
response = httpx.Response(content=b"Hello, world!", status_code=200)
|
||||
with pytest.raises(StopIteration):
|
||||
flow.send(response)
|
||||
Loading…
Reference in New Issue
Block a user