Refactor middleware (#325)
* Split middleware into a subpackage * Refactor basic auth header building * Add encoding parameter to to_bytes()
This commit is contained in:
parent
0f34f3b60f
commit
71648ece09
@ -23,12 +23,10 @@ from .dispatch.connection_pool import ConnectionPool
|
||||
from .dispatch.threaded import ThreadedDispatcher
|
||||
from .dispatch.wsgi import WSGIDispatch
|
||||
from .exceptions import HTTPError, InvalidURL
|
||||
from .middleware import (
|
||||
BaseMiddleware,
|
||||
BasicAuthMiddleware,
|
||||
CustomAuthMiddleware,
|
||||
RedirectMiddleware,
|
||||
)
|
||||
from .middleware.base import BaseMiddleware
|
||||
from .middleware.basic_auth import BasicAuthMiddleware
|
||||
from .middleware.custom_auth import CustomAuthMiddleware
|
||||
from .middleware.redirect import RedirectMiddleware
|
||||
from .models import (
|
||||
URL,
|
||||
AsyncRequest,
|
||||
|
||||
0
httpx/middleware/__init__.py
Normal file
0
httpx/middleware/__init__.py
Normal file
10
httpx/middleware/base.py
Normal file
10
httpx/middleware/base.py
Normal file
@ -0,0 +1,10 @@
|
||||
import typing
|
||||
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
27
httpx/middleware/basic_auth.py
Normal file
27
httpx/middleware/basic_auth.py
Normal file
@ -0,0 +1,27 @@
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
from ..utils import to_bytes
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class BasicAuthMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
self.authorization_header = build_basic_auth_header(username, password)
|
||||
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
request.headers["Authorization"] = self.authorization_header
|
||||
return await get_response(request)
|
||||
|
||||
|
||||
def build_basic_auth_header(
|
||||
username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode().strip()
|
||||
return f"Basic {token}"
|
||||
15
httpx/middleware/custom_auth.py
Normal file
15
httpx/middleware/custom_auth.py
Normal file
@ -0,0 +1,15 @@
|
||||
import typing
|
||||
|
||||
from ..models import AsyncRequest, AsyncResponse
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class CustomAuthMiddleware(BaseMiddleware):
|
||||
def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
|
||||
self.auth = auth
|
||||
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
request = self.auth(request)
|
||||
return await get_response(request)
|
||||
@ -1,51 +1,11 @@
|
||||
import functools
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from .config import DEFAULT_MAX_REDIRECTS
|
||||
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
|
||||
from .status_codes import codes
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
class BasicAuthMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
if isinstance(username, str):
|
||||
username = username.encode("latin1")
|
||||
|
||||
if isinstance(password, str):
|
||||
password = password.encode("latin1")
|
||||
|
||||
userpass = b":".join((username, password))
|
||||
token = b64encode(userpass).decode().strip()
|
||||
|
||||
self.authorization_header = f"Basic {token}"
|
||||
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
request.headers["Authorization"] = self.authorization_header
|
||||
return await get_response(request)
|
||||
|
||||
|
||||
class CustomAuthMiddleware(BaseMiddleware):
|
||||
def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
|
||||
self.auth = auth
|
||||
|
||||
async def __call__(
|
||||
self, request: AsyncRequest, get_response: typing.Callable
|
||||
) -> AsyncResponse:
|
||||
request = self.auth(request)
|
||||
return await get_response(request)
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
|
||||
from ..status_codes import codes
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class RedirectMiddleware(BaseMiddleware):
|
||||
@ -169,3 +169,7 @@ def get_logger(name: str) -> logging.Logger:
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
|
||||
return value.encode(encoding) if isinstance(value, str) else value
|
||||
|
||||
Loading…
Reference in New Issue
Block a user