Add support for sync-specific or async-specific auth flows (#1217)

* Add support for async auth flows

* Move body logic to Auth, add sync_auth_flow, add NoAuth

* Update tests

* Stick to next() / __anext__()

* Fix undefined name errors

* Add docs

* Add unit tests for auth classes

Co-authored-by: Tom Christie <tom@tomchristie.com>
This commit is contained in:
Florimond Manca 2020-09-09 15:37:20 +02:00 committed by GitHub
parent 15187e7c21
commit 016e4ee210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 238 additions and 14 deletions

View File

@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth):
...
```
If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`.
```python
import asyncio
import threading
import httpx
class MyCustomAuth(httpx.Auth):
def __init__(self):
self._sync_lock = threading.RLock()
self._async_lock = asyncio.Lock()
def sync_get_token(self):
with self._sync_lock:
...
def sync_auth_flow(self, request):
token = self.sync_get_token()
request.headers["Authorization"] = f"Token {token}"
yield request
async def async_get_token(self):
async with self._async_lock:
...
async def async_auth_flow(self, request):
token = await self.async_get_token()
request.headers["Authorization"] = f"Token {token}"
yield request
```
If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`.
```python
import httpx
import sync_only_library
class MyCustomAuth(httpx.Auth):
def sync_auth_flow(self, request):
token = sync_only_library.get_token(...)
request.headers["Authorization"] = f"Token {token}"
yield request
async def async_auth_flow(self, request):
raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient")
```
## SSL certificates
When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA).

View File

@ -17,6 +17,11 @@ class Auth:
To implement a custom authentication scheme, subclass `Auth` and override
the `.auth_flow()` method.
If the authentication scheme does I/O such as disk access or network calls, or uses
synchronization primitives such as locks, you should override `.sync_auth_flow()`
and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
implementations that will be used by `Client` and `AsyncClient` respectively.
"""
requires_request_body = False
@ -46,6 +51,56 @@ class Auth:
"""
yield request
def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
"""
Execute the authentication flow synchronously.
By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
request.read()
flow = self.auth_flow(request)
request = next(flow)
while True:
response = yield request
if self.requires_response_body:
response.read()
try:
request = flow.send(response)
except StopIteration:
break
async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.
By default, this defers to `.auth_flow()`. You should override this method
when the authentication scheme does I/O and/or uses concurrency primitives.
"""
if self.requires_request_body:
await request.aread()
flow = self.auth_flow(request)
request = next(flow)
while True:
response = yield request
if self.requires_response_body:
await response.aread()
try:
request = flow.send(response)
except StopIteration:
break
class FunctionAuth(Auth):
"""

View File

@ -785,15 +785,12 @@ class Client(BaseClient):
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
request.read()
auth_flow = auth.auth_flow(request)
auth_flow = auth.sync_auth_flow(request)
request = next(auth_flow)
while True:
response = self._send_single_request(request, timeout)
if auth.requires_response_body:
response.read()
try:
next_request = auth_flow.send(response)
except StopIteration:
@ -1409,18 +1406,15 @@ class AsyncClient(BaseClient):
auth: Auth,
timeout: Timeout,
) -> Response:
if auth.requires_request_body:
await request.aread()
auth_flow = auth.async_auth_flow(request)
request = await auth_flow.__anext__()
auth_flow = auth.auth_flow(request)
request = next(auth_flow)
while True:
response = await self._send_single_request(request, timeout)
if auth.requires_response_body:
await response.aread()
try:
next_request = auth_flow.send(response)
except StopIteration:
next_request = await auth_flow.asend(response)
except StopAsyncIteration:
return response
except BaseException as exc:
await response.aclose()

View File

@ -1,5 +1,12 @@
"""
Integration tests for authentication.
Unit tests for auth classes also exist in tests/test_auth.py
"""
import asyncio
import hashlib
import os
import threading
import typing
import httpcore
@ -183,6 +190,31 @@ class ResponseBodyAuth(Auth):
yield request
class SyncOrAsyncAuth(Auth):
"""
A mock authentication scheme that uses a different implementation for the
sync and async cases.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._async_lock = asyncio.Lock()
def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
with self._lock:
request.headers["Authorization"] = "sync-auth"
yield request
async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
async with self._async_lock:
request.headers["Authorization"] = "async-auth"
yield request
@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
@ -664,3 +696,34 @@ def test_sync_auth_reads_response_body() -> None:
assert response.status_code == 200
assert response.json() == {"auth": '{"auth": "xyz"}'}
@pytest.mark.asyncio
async def test_async_auth() -> None:
"""
Test that we can use an auth implementation specific to the async case, to
support cases that require performing I/O or using concurrency primitives (such
as checking a disk-based cache or fetching a token from a remote auth server).
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()
async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
response = await client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "async-auth"}
def test_sync_auth() -> None:
"""
Test that we can use an auth implementation specific to the sync case.
"""
url = "https://example.org/"
auth = SyncOrAsyncAuth()
with httpx.Client(transport=SyncMockTransport()) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert response.json() == {"auth": "sync-auth"}

63
tests/test_auth.py Normal file
View File

@ -0,0 +1,63 @@
"""
Unit tests for auth classes.
Integration tests also exist in tests/client/test_auth.py
"""
import pytest
import httpx
def test_basic_auth():
auth = httpx.BasicAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")
# The initial request should include a basic auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert request.headers["Authorization"].startswith("Basic")
# No other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)
def test_digest_auth_with_200():
auth = httpx.DigestAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")
# The initial request should not include an auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert "Authorization" not in request.headers
# If a 200 response is returned, then no other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)
def test_digest_auth_with_401():
auth = httpx.DigestAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")
# The initial request should not include an auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert "Authorization" not in request.headers
# If a 401 response is returned, then a digest auth request is made.
headers = {
"WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
}
response = httpx.Response(
content=b"Auth required", status_code=401, headers=headers
)
request = flow.send(response)
assert request.headers["Authorization"].startswith("Digest")
# No other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)