* Initial work towards auth

* Add auth support

* Add test for custom auth

* Support auth-in-URL

* Support auth-on-session
This commit is contained in:
Tom Christie 2019-05-15 16:43:35 +01:00 committed by GitHub
parent 2e9ab2e904
commit 5afa7dd5cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 197 additions and 1 deletions

38
httpcore/auth.py Normal file
View File

@ -0,0 +1,38 @@
import typing
from base64 import b64encode
from .models import Request
class AuthBase:
"""
Base class that all auth implementations derive from.
"""
def __call__(self, request: Request) -> Request:
raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover
class HTTPBasicAuth(AuthBase):
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> None:
self.username = username
self.password = password
def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = self.build_auth_header()
return request
def build_auth_header(self) -> str:
username, password = self.username, self.password
if isinstance(username, str):
username = username.encode("latin1")
if isinstance(password, str):
password = password.encode("latin1")
userpass = b":".join((username, password))
token = b64encode(userpass).decode().strip()
return f"Basic {token}"

View File

@ -2,6 +2,7 @@ import asyncio
import typing
from types import TracebackType
from .auth import HTTPBasicAuth
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
@ -16,6 +17,7 @@ from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .interfaces import ConcurrencyBackend, Dispatcher
from .models import (
URL,
AuthTypes,
Headers,
HeaderTypes,
QueryParamTypes,
@ -31,6 +33,7 @@ from .status_codes import codes
class AsyncClient:
def __init__(
self,
auth: AuthTypes = None,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@ -43,6 +46,7 @@ class AsyncClient:
ssl=ssl, timeout=timeout, pool_limits=pool_limits, backend=backend
)
self.auth = auth
self.max_redirects = max_redirects
self.dispatch = dispatch
@ -53,6 +57,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -63,6 +68,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -75,6 +81,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -85,6 +92,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -97,6 +105,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -107,6 +116,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -120,6 +130,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -131,6 +142,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -144,6 +156,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -155,6 +168,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -168,6 +182,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -179,6 +194,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -192,6 +208,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -203,6 +220,7 @@ class AsyncClient:
query_params=query_params,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -217,6 +235,7 @@ class AsyncClient:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -228,6 +247,7 @@ class AsyncClient:
response = await self.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -238,6 +258,37 @@ class AsyncClient:
request.prepare()
async def send(
self,
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
allow_redirects: bool = True,
) -> Response:
if auth is None:
auth = self.auth
url = request.url
if auth is None and (url.username or url.password):
auth = HTTPBasicAuth(username=url.username, password=url.password)
if auth is not None:
if isinstance(auth, tuple):
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
request = auth(request)
response = await self.send_handling_redirects(
request,
stream=stream,
ssl=ssl,
timeout=timeout,
allow_redirects=allow_redirects,
)
return response
async def send_handling_redirects(
self,
request: Request,
*,
@ -273,7 +324,7 @@ class AsyncClient:
async def send_next() -> Response:
nonlocal request, response, ssl, allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
response = await self.send(
response = await self.send_handling_redirects(
request,
stream=stream,
allow_redirects=allow_redirects,
@ -375,6 +426,7 @@ class AsyncClient:
class Client:
def __init__(
self,
auth: AuthTypes = None,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@ -383,6 +435,7 @@ class Client:
backend: ConcurrencyBackend = None,
) -> None:
self._client = AsyncClient(
auth=auth,
ssl=ssl,
timeout=timeout,
pool_limits=pool_limits,
@ -401,6 +454,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -412,6 +466,7 @@ class Client:
response = self.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -425,6 +480,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -434,6 +490,7 @@ class Client:
url,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -446,6 +503,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -455,6 +513,7 @@ class Client:
url,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -467,6 +526,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = False, #  Note: Differs to usual default.
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -476,6 +536,7 @@ class Client:
url,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -489,6 +550,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -499,6 +561,7 @@ class Client:
data=data,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -512,6 +575,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -522,6 +586,7 @@ class Client:
data=data,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -535,6 +600,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -545,6 +611,7 @@ class Client:
data=data,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -558,6 +625,7 @@ class Client:
query_params: QueryParamTypes = None,
headers: HeaderTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -568,6 +636,7 @@ class Client:
data=data,
headers=headers,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,
@ -581,6 +650,7 @@ class Client:
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
@ -589,6 +659,7 @@ class Client:
self._client.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
ssl=ssl,
timeout=timeout,

View File

@ -43,6 +43,11 @@ HeaderTypes = typing.Union[
typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
]
AuthTypes = typing.Union[
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
typing.Callable[["Request"], "Request"]
]
RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
@ -93,6 +98,16 @@ class URL:
def authority(self) -> str:
return self.components.authority or ""
@property
def username(self) -> str:
userinfo = self.components.userinfo or ""
return userinfo.partition(':')[0]
@property
def password(self) -> str:
userinfo = self.components.userinfo or ""
return userinfo.partition(':')[2]
@property
def host(self) -> str:
return self.components.host or ""

72
tests/client/test_auth.py Normal file
View File

@ -0,0 +1,72 @@
import json
from urllib.parse import parse_qs
import pytest
from httpcore import (
URL,
Client,
Dispatcher,
Request,
Response,
SSLConfig,
TimeoutConfig,
)
class MockDispatch(Dispatcher):
async def send(
self,
request: Request,
stream: bool = False,
ssl: SSLConfig = None,
timeout: TimeoutConfig = None,
) -> Response:
body = json.dumps({"auth": request.headers.get('Authorization')}).encode()
return Response(200, content=body, request=request)
def test_basic_auth():
url = "https://example.org/"
auth = ('tomchristie', 'password123')
with Client(dispatch=MockDispatch()) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
def test_basic_auth_in_url():
url = "https://tomchristie:password123@example.org/"
with Client(dispatch=MockDispatch()) as client:
response = client.get(url)
assert response.status_code == 200
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
def test_basic_auth_on_session():
url = "https://example.org/"
auth = ('tomchristie', 'password123')
with Client(dispatch=MockDispatch(), auth=auth) as client:
response = client.get(url)
assert response.status_code == 200
assert json.loads(response.text) == {'auth': 'Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM='}
def test_custom_auth():
url = "https://example.org/"
def auth(request):
request.headers['Authorization'] = 'Token 123'
return request
with Client(dispatch=MockDispatch()) as client:
response = client.get(url, auth=auth)
assert response.status_code == 200
assert json.loads(response.text) == {'auth': 'Token 123'}