Auth (#65)
* Initial work towards auth * Add auth support * Add test for custom auth * Support auth-in-URL * Support auth-on-session
This commit is contained in:
parent
2e9ab2e904
commit
5afa7dd5cc
38
httpcore/auth.py
Normal file
38
httpcore/auth.py
Normal file
@ -0,0 +1,38 @@
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from .models import Request
|
||||
|
||||
|
||||
class AuthBase:
|
||||
"""
|
||||
Base class that all auth implementations derive from.
|
||||
"""
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover
|
||||
|
||||
|
||||
class HTTPBasicAuth(AuthBase):
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
request.headers["Authorization"] = self.build_auth_header()
|
||||
return request
|
||||
|
||||
def build_auth_header(self) -> str:
|
||||
username, password = self.username, self.password
|
||||
|
||||
if isinstance(username, str):
|
||||
username = username.encode("latin1")
|
||||
|
||||
if isinstance(password, str):
|
||||
password = password.encode("latin1")
|
||||
|
||||
userpass = b":".join((username, password))
|
||||
token = b64encode(userpass).decode().strip()
|
||||
return f"Basic {token}"
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from .auth import HTTPBasicAuth
|
||||
from .config import (
|
||||
DEFAULT_MAX_REDIRECTS,
|
||||
DEFAULT_POOL_LIMITS,
|
||||
@ -16,6 +17,7 @@ from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from .interfaces import ConcurrencyBackend, Dispatcher
|
||||
from .models import (
|
||||
URL,
|
||||
AuthTypes,
|
||||
Headers,
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
@ -31,6 +33,7 @@ from .status_codes import codes
|
||||
class AsyncClient:
|
||||
def __init__(
|
||||
self,
|
||||
auth: AuthTypes = None,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
@ -43,6 +46,7 @@ class AsyncClient:
|
||||
ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
|
||||
)
|
||||
|
||||
self.auth = auth
|
||||
self.max_redirects = max_redirects
|
||||
self.dispatch = dispatch
|
||||
|
||||
@ -53,6 +57,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -63,6 +68,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -75,6 +81,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -85,6 +92,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -97,6 +105,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = False, # Note: Differs to usual default.
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -107,6 +116,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -120,6 +130,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -131,6 +142,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -144,6 +156,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -155,6 +168,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -168,6 +182,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -179,6 +194,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -192,6 +208,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -203,6 +220,7 @@ class AsyncClient:
|
||||
query_params=query_params,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -217,6 +235,7 @@ class AsyncClient:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -228,6 +247,7 @@ class AsyncClient:
|
||||
response = await self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -238,6 +258,37 @@ class AsyncClient:
|
||||
request.prepare()
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
allow_redirects: bool = True,
|
||||
) -> Response:
|
||||
if auth is None:
|
||||
auth = self.auth
|
||||
|
||||
url = request.url
|
||||
if auth is None and (url.username or url.password):
|
||||
auth = HTTPBasicAuth(username=url.username, password=url.password)
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
|
||||
request = auth(request)
|
||||
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
stream=stream,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
return response
|
||||
|
||||
async def send_handling_redirects(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
@ -273,7 +324,7 @@ class AsyncClient:
|
||||
async def send_next() -> Response:
|
||||
nonlocal request, response, ssl, allow_redirects, timeout, history
|
||||
request = self.build_redirect_request(request, response)
|
||||
response = await self.send(
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
stream=stream,
|
||||
allow_redirects=allow_redirects,
|
||||
@ -375,6 +426,7 @@ class AsyncClient:
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
auth: AuthTypes = None,
|
||||
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
|
||||
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
|
||||
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
|
||||
@ -383,6 +435,7 @@ class Client:
|
||||
backend: ConcurrencyBackend = None,
|
||||
) -> None:
|
||||
self._client = AsyncClient(
|
||||
auth=auth,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
pool_limits=pool_limits,
|
||||
@ -401,6 +454,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -412,6 +466,7 @@ class Client:
|
||||
response = self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -425,6 +480,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -434,6 +490,7 @@ class Client:
|
||||
url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -446,6 +503,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -455,6 +513,7 @@ class Client:
|
||||
url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -467,6 +526,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = False, # Note: Differs to usual default.
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -476,6 +536,7 @@ class Client:
|
||||
url,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -489,6 +550,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -499,6 +561,7 @@ class Client:
|
||||
data=data,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -512,6 +575,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -522,6 +586,7 @@ class Client:
|
||||
data=data,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -535,6 +600,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -545,6 +611,7 @@ class Client:
|
||||
data=data,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -558,6 +625,7 @@ class Client:
|
||||
query_params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -568,6 +636,7 @@ class Client:
|
||||
data=data,
|
||||
headers=headers,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
@ -581,6 +650,7 @@ class Client:
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
@ -589,6 +659,7 @@ class Client:
|
||||
self._client.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
ssl=ssl,
|
||||
timeout=timeout,
|
||||
|
||||
@ -43,6 +43,11 @@ HeaderTypes = typing.Union[
|
||||
typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
|
||||
]
|
||||
|
||||
AuthTypes = typing.Union[
|
||||
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
|
||||
typing.Callable[["Request"], "Request"]
|
||||
]
|
||||
|
||||
RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
|
||||
|
||||
ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
|
||||
@ -93,6 +98,16 @@ class URL:
|
||||
def authority(self) -> str:
|
||||
return self.components.authority or ""
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
userinfo = self.components.userinfo or ""
|
||||
return userinfo.partition(':')[0]
|
||||
|
||||
@property
|
||||
def password(self) -> str:
|
||||
userinfo = self.components.userinfo or ""
|
||||
return userinfo.partition(':')[2]
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
return self.components.host or ""
|
||||
|
||||
72
tests/client/test_auth.py
Normal file
72
tests/client/test_auth.py
Normal file
@ -0,0 +1,72 @@
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
from httpcore import (
|
||||
URL,
|
||||
Client,
|
||||
Dispatcher,
|
||||
Request,
|
||||
Response,
|
||||
SSLConfig,
|
||||
TimeoutConfig,
|
||||
)
|
||||
|
||||
|
||||
class MockDispatch(Dispatcher):
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
stream: bool = False,
|
||||
ssl: SSLConfig = None,
|
||||
timeout: TimeoutConfig = None,
|
||||
) -> Response:
|
||||
body = json.dumps({"auth": request.headers.get('Authorization')}).encode()
|
||||
return Response(200, content=body, request=request)
|
||||
|
||||
|
||||
def test_basic_auth():
|
||||
url = "https://example.org/"
|
||||
auth = ('tomchristie', 'password123')
|
||||
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
|
||||
|
||||
|
||||
def test_basic_auth_in_url():
|
||||
url = "https://tomchristie:password123@example.org/"
|
||||
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
|
||||
|
||||
|
||||
def test_basic_auth_on_session():
|
||||
url = "https://example.org/"
|
||||
auth = ('tomchristie', 'password123')
|
||||
|
||||
with Client(dispatch=MockDispatch(), auth=auth) as client:
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
|
||||
|
||||
|
||||
def test_custom_auth():
|
||||
url = "https://example.org/"
|
||||
|
||||
def auth(request):
|
||||
request.headers['Authorization'] = 'Token 123'
|
||||
return request
|
||||
|
||||
with Client(dispatch=MockDispatch()) as client:
|
||||
response = client.get(url, auth=auth)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json.loads(response.text) == {'auth': 'Token 123'}
|
||||
Loading…
Reference in New Issue
Block a user