Client handles redirect + auth (#552)
* Drop sync client * Drop unused imports * Async only * Update tests/test_decoders.py Co-Authored-By: Florimond Manca <florimond.manca@gmail.com> * Linting * Update docs for async-only * Import sorting * Add async notes to docs * Update README for 0.8 async switch * Move auth away from middleware where possible * Drop middleware sub-package * Client.dispatcher -> Client.dispatch * Docs tweak * Linting * Fix type checking issue * Import ordering * Fix up docstrings * Minor docs fixes * Linting * Remove unused import
This commit is contained in:
parent
206c5372a6
commit
00e150f6a5
@ -24,13 +24,13 @@ HTTPX
|
||||
<em>A next-generation HTTP client for Python.</em>
|
||||
</div>
|
||||
|
||||
HTTPX is an asynchronous HTTP client, that supports HTTP/2 and HTTP/1.1.
|
||||
HTTPX is an asynchronous client library that supports HTTP/1.1 and HTTP/2.
|
||||
|
||||
It can be used in high-performance async web frameworks, using either asyncio
|
||||
or trio, and is able to support making large numbers of requests concurrently.
|
||||
or trio, and is able to support making large numbers of concurrent requests.
|
||||
|
||||
!!! note
|
||||
The 0.8 release switched HTTPX into focusing exclusively on the async
|
||||
The 0.8 release switched HTTPX into focusing exclusively on providing an async
|
||||
client. It is possible that we'll look at re-introducing a sync API at a
|
||||
later date.
|
||||
|
||||
@ -38,11 +38,10 @@ or trio, and is able to support making large numbers of requests concurrently.
|
||||
|
||||
Let's get started...
|
||||
|
||||
!!! note
|
||||
The standard Python REPL does not allow top-level async statements.
|
||||
The standard Python REPL does not allow top-level async statements.
|
||||
|
||||
To run async examples directly you'll probably want to either use `ipython`,
|
||||
or use Python 3.8 with `python -m asyncio`.
|
||||
To run these async examples you'll probably want to either use `ipython`,
|
||||
or use Python 3.8 with `python -m asyncio`.
|
||||
|
||||
```python
|
||||
>>> import httpx
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from .__version__ import __description__, __title__, __version__
|
||||
from .api import delete, get, head, options, patch, post, put, request
|
||||
from .auth import BasicAuth, DigestAuth
|
||||
from .client import Client
|
||||
from .concurrency.asyncio import AsyncioBackend
|
||||
from .concurrency.base import (
|
||||
@ -43,7 +44,6 @@ from .exceptions import (
|
||||
TooManyRedirects,
|
||||
WriteTimeout,
|
||||
)
|
||||
from .middleware.digest_auth import DigestAuth
|
||||
from .models import (
|
||||
URL,
|
||||
AuthTypes,
|
||||
@ -76,7 +76,9 @@ __all__ = [
|
||||
"patch",
|
||||
"put",
|
||||
"request",
|
||||
"BasicAuth",
|
||||
"Client",
|
||||
"DigestAuth",
|
||||
"AsyncioBackend",
|
||||
"USER_AGENT",
|
||||
"CertTypes",
|
||||
|
||||
@ -3,15 +3,34 @@ import os
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
from urllib.request import parse_http_list
|
||||
|
||||
from ..exceptions import ProtocolError
|
||||
from ..models import Request, Response, StatusCode
|
||||
from ..utils import to_bytes, to_str, unquote
|
||||
from .base import BaseMiddleware
|
||||
from .exceptions import ProtocolError
|
||||
from .middleware import Middleware
|
||||
from .models import Request, Response
|
||||
from .utils import to_bytes, to_str, unquote
|
||||
|
||||
|
||||
class DigestAuth(BaseMiddleware):
|
||||
class BasicAuth:
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
self.auth_header = self.build_auth_header(username, password)
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
request.headers["Authorization"] = self.auth_header
|
||||
return request
|
||||
|
||||
def build_auth_header(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode().strip()
|
||||
return f"Basic {token}"
|
||||
|
||||
|
||||
class DigestAuth(Middleware):
|
||||
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
|
||||
"MD5": hashlib.md5,
|
||||
"MD5-SESS": hashlib.md5,
|
||||
@ -33,12 +52,10 @@ class DigestAuth(BaseMiddleware):
|
||||
self, request: Request, get_response: typing.Callable
|
||||
) -> Response:
|
||||
response = await get_response(request)
|
||||
if not (
|
||||
StatusCode.is_client_error(response.status_code)
|
||||
and "www-authenticate" in response.headers
|
||||
):
|
||||
if response.status_code != 401 or "www-authenticate" not in response.headers:
|
||||
return response
|
||||
|
||||
await response.close()
|
||||
header = response.headers["www-authenticate"]
|
||||
try:
|
||||
challenge = DigestAuthChallenge.from_header(header)
|
||||
580
httpx/client.py
580
httpx/client.py
@ -5,6 +5,7 @@ from types import TracebackType
|
||||
|
||||
import hstspreload
|
||||
|
||||
from .auth import BasicAuth
|
||||
from .concurrency.asyncio import AsyncioBackend
|
||||
from .concurrency.base import ConcurrencyBackend
|
||||
from .config import (
|
||||
@ -21,11 +22,14 @@ from .dispatch.asgi import ASGIDispatch
|
||||
from .dispatch.base import Dispatcher
|
||||
from .dispatch.connection_pool import ConnectionPool
|
||||
from .dispatch.proxy_http import HTTPProxy
|
||||
from .exceptions import HTTPError, InvalidURL
|
||||
from .middleware.base import BaseMiddleware
|
||||
from .middleware.basic_auth import BasicAuthMiddleware
|
||||
from .middleware.custom_auth import CustomAuthMiddleware
|
||||
from .middleware.redirect import RedirectMiddleware
|
||||
from .exceptions import (
|
||||
HTTPError,
|
||||
InvalidURL,
|
||||
RedirectBodyUnavailable,
|
||||
RedirectLoop,
|
||||
TooManyRedirects,
|
||||
)
|
||||
from .middleware import Middleware
|
||||
from .models import (
|
||||
URL,
|
||||
AuthTypes,
|
||||
@ -42,6 +46,7 @@ from .models import (
|
||||
Response,
|
||||
URLTypes,
|
||||
)
|
||||
from .status_codes import codes
|
||||
from .utils import ElapsedTimer, get_environment_proxies, get_logger, get_netrc
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -125,8 +130,6 @@ class Client:
|
||||
if app is not None:
|
||||
dispatch = ASGIDispatch(app=app, backend=backend)
|
||||
|
||||
self.trust_env = True if trust_env is None else trust_env
|
||||
|
||||
if dispatch is None:
|
||||
dispatch = ConnectionPool(
|
||||
verify=verify,
|
||||
@ -135,7 +138,7 @@ class Client:
|
||||
http_versions=http_versions,
|
||||
pool_limits=pool_limits,
|
||||
backend=backend,
|
||||
trust_env=self.trust_env,
|
||||
trust_env=trust_env,
|
||||
uds=uds,
|
||||
)
|
||||
|
||||
@ -152,6 +155,7 @@ class Client:
|
||||
self._headers = Headers(headers)
|
||||
self._cookies = Cookies(cookies)
|
||||
self.max_redirects = max_redirects
|
||||
self.trust_env = trust_env
|
||||
self.dispatch = dispatch
|
||||
self.concurrency_backend = backend
|
||||
|
||||
@ -202,169 +206,46 @@ class Client:
|
||||
def params(self, params: QueryParamTypes) -> None:
|
||||
self._params = QueryParams(params)
|
||||
|
||||
def merge_url(self, url: URLTypes) -> URL:
|
||||
url = self.base_url.join(relative_url=url)
|
||||
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
|
||||
url = url.copy_with(scheme="https")
|
||||
return url
|
||||
|
||||
def merge_cookies(
|
||||
self, cookies: CookieTypes = None
|
||||
) -> typing.Optional[CookieTypes]:
|
||||
if cookies or self.cookies:
|
||||
merged_cookies = Cookies(self.cookies)
|
||||
merged_cookies.update(cookies)
|
||||
return merged_cookies
|
||||
return cookies
|
||||
|
||||
def merge_headers(
|
||||
self, headers: HeaderTypes = None
|
||||
) -> typing.Optional[HeaderTypes]:
|
||||
if headers or self.headers:
|
||||
merged_headers = Headers(self.headers)
|
||||
merged_headers.update(headers)
|
||||
return merged_headers
|
||||
return headers
|
||||
|
||||
def merge_queryparams(
|
||||
self, params: QueryParamTypes = None
|
||||
) -> typing.Optional[QueryParamTypes]:
|
||||
if params or self.params:
|
||||
merged_queryparams = QueryParams(self.params)
|
||||
merged_queryparams.update(params)
|
||||
return merged_queryparams
|
||||
return params
|
||||
|
||||
async def _get_response(
|
||||
async def request(
|
||||
self,
|
||||
request: Request,
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = None,
|
||||
files: RequestFiles = None,
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
if request.url.scheme not in ("http", "https"):
|
||||
raise InvalidURL('URL scheme must be "http" or "https".')
|
||||
|
||||
dispatch = self._dispatcher_for_request(request, self.proxies)
|
||||
|
||||
async def get_response(request: Request) -> Response:
|
||||
try:
|
||||
with ElapsedTimer() as timer:
|
||||
response = await dispatch.send(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response.elapsed = timer.elapsed
|
||||
response.request = request
|
||||
except HTTPError as exc:
|
||||
# Add the original request to any HTTPError unless
|
||||
# there'a already a request attached in the case of
|
||||
# a ProxyError.
|
||||
if exc.request is None:
|
||||
exc.request = request
|
||||
raise
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
status = f"{response.status_code} {response.reason_phrase}"
|
||||
response_line = f"{response.http_version} {status}"
|
||||
logger.debug(
|
||||
f'HTTP Request: {request.method} {request.url} "{response_line}"'
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def wrap(
|
||||
get_response: typing.Callable, middleware: BaseMiddleware
|
||||
) -> typing.Callable:
|
||||
return functools.partial(middleware, get_response=get_response)
|
||||
|
||||
get_response = wrap(
|
||||
get_response,
|
||||
RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
|
||||
request = self.build_request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
auth_middleware = self._get_auth_middleware(
|
||||
request=request,
|
||||
trust_env=self.trust_env if trust_env is None else trust_env,
|
||||
auth=self.auth if auth is None else auth,
|
||||
response = await self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
|
||||
if auth_middleware is not None:
|
||||
get_response = wrap(get_response, auth_middleware)
|
||||
|
||||
return await get_response(request)
|
||||
|
||||
def _get_auth_middleware(
|
||||
self, request: Request, trust_env: bool, auth: AuthTypes = None
|
||||
) -> typing.Optional[BaseMiddleware]:
|
||||
if isinstance(auth, tuple):
|
||||
return BasicAuthMiddleware(username=auth[0], password=auth[1])
|
||||
elif isinstance(auth, BaseMiddleware):
|
||||
return auth
|
||||
elif callable(auth):
|
||||
return CustomAuthMiddleware(auth=auth)
|
||||
|
||||
if auth is not None:
|
||||
raise TypeError(
|
||||
'When specified, "auth" must be a (username, password) tuple or '
|
||||
"a callable with signature (Request) -> Request "
|
||||
f"(got {auth!r})"
|
||||
)
|
||||
|
||||
if request.url.username or request.url.password:
|
||||
return BasicAuthMiddleware(
|
||||
username=request.url.username, password=request.url.password
|
||||
)
|
||||
|
||||
if trust_env:
|
||||
netrc_info = self._get_netrc()
|
||||
if netrc_info:
|
||||
netrc_login = netrc_info.authenticators(request.url.authority)
|
||||
if netrc_login:
|
||||
username, _, password = netrc_login
|
||||
assert password is not None
|
||||
return BasicAuthMiddleware(username=username, password=password)
|
||||
|
||||
return None
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _get_netrc(self) -> typing.Optional[netrc.netrc]:
|
||||
return get_netrc()
|
||||
|
||||
def _dispatcher_for_request(
|
||||
self, request: Request, proxies: typing.Dict[str, Dispatcher]
|
||||
) -> Dispatcher:
|
||||
"""Gets the Dispatcher instance that should be used for a given Request"""
|
||||
if proxies:
|
||||
url = request.url
|
||||
is_default_port = (url.scheme == "http" and url.port == 80) or (
|
||||
url.scheme == "https" and url.port == 443
|
||||
)
|
||||
hostname = f"{url.host}:{url.port}"
|
||||
proxy_keys = (
|
||||
f"{url.scheme}://{hostname}",
|
||||
f"{url.scheme}://{url.host}" if is_default_port else None,
|
||||
f"all://{hostname}",
|
||||
f"all://{url.host}" if is_default_port else None,
|
||||
url.scheme,
|
||||
"all",
|
||||
)
|
||||
for proxy_key in proxy_keys:
|
||||
if proxy_key and proxy_key in proxies:
|
||||
dispatcher = proxies[proxy_key]
|
||||
return dispatcher
|
||||
|
||||
return self.dispatch
|
||||
return response
|
||||
|
||||
def build_request(
|
||||
self,
|
||||
@ -396,6 +277,321 @@ class Client:
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
def merge_url(self, url: URLTypes) -> URL:
|
||||
"""
|
||||
Merge a URL argument together with any 'base_url' on the client,
|
||||
to create the URL used for the outgoing request.
|
||||
"""
|
||||
url = self.base_url.join(relative_url=url)
|
||||
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
|
||||
url = url.copy_with(scheme="https")
|
||||
return url
|
||||
|
||||
def merge_cookies(
|
||||
self, cookies: CookieTypes = None
|
||||
) -> typing.Optional[CookieTypes]:
|
||||
"""
|
||||
Merge a cookies argument together with any cookies on the client,
|
||||
to create the cookies used for the outgoing request.
|
||||
"""
|
||||
if cookies or self.cookies:
|
||||
merged_cookies = Cookies(self.cookies)
|
||||
merged_cookies.update(cookies)
|
||||
return merged_cookies
|
||||
return cookies
|
||||
|
||||
def merge_headers(
|
||||
self, headers: HeaderTypes = None
|
||||
) -> typing.Optional[HeaderTypes]:
|
||||
"""
|
||||
Merge a headers argument together with any headers on the client,
|
||||
to create the headers used for the outgoing request.
|
||||
"""
|
||||
if headers or self.headers:
|
||||
merged_headers = Headers(self.headers)
|
||||
merged_headers.update(headers)
|
||||
return merged_headers
|
||||
return headers
|
||||
|
||||
def merge_queryparams(
|
||||
self, params: QueryParamTypes = None
|
||||
) -> typing.Optional[QueryParamTypes]:
|
||||
"""
|
||||
Merge a queryparams argument together with any queryparams on the client,
|
||||
to create the queryparams used for the outgoing request.
|
||||
"""
|
||||
if params or self.params:
|
||||
merged_queryparams = QueryParams(self.params)
|
||||
merged_queryparams.update(params)
|
||||
return merged_queryparams
|
||||
return params
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
if request.url.scheme not in ("http", "https"):
|
||||
raise InvalidURL('URL scheme must be "http" or "https".')
|
||||
|
||||
auth = self.auth if auth is None else auth
|
||||
trust_env = self.trust_env if trust_env is None else trust_env
|
||||
|
||||
if not isinstance(auth, Middleware):
|
||||
request = self.authenticate(request, trust_env, auth)
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
else:
|
||||
get_response = functools.partial(
|
||||
self.send_handling_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
response = await auth(request, get_response)
|
||||
|
||||
if not stream:
|
||||
try:
|
||||
await response.read()
|
||||
finally:
|
||||
await response.close()
|
||||
|
||||
return response
|
||||
|
||||
def authenticate(
|
||||
self, request: Request, trust_env: bool, auth: AuthTypes = None
|
||||
) -> "Request":
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
auth = BasicAuth(username=auth[0], password=auth[1])
|
||||
return auth(request)
|
||||
|
||||
username, password = request.url.username, request.url.password
|
||||
if username or password:
|
||||
auth = BasicAuth(username=username, password=password)
|
||||
return auth(request)
|
||||
|
||||
if trust_env:
|
||||
netrc_info = self._get_netrc()
|
||||
if netrc_info is not None:
|
||||
netrc_login = netrc_info.authenticators(request.url.authority)
|
||||
netrc_username, _, netrc_password = netrc_login or ("", None, None)
|
||||
if netrc_password is not None:
|
||||
auth = BasicAuth(username=netrc_username, password=netrc_password)
|
||||
return auth(request)
|
||||
|
||||
return request
|
||||
|
||||
async def send_handling_redirects(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[Response] = None,
|
||||
) -> Response:
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
while True:
|
||||
if len(history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
if request.url in (response.url for response in history):
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await self.send_single_request(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response.history = list(history)
|
||||
|
||||
if not response.is_redirect:
|
||||
return response
|
||||
|
||||
await response.close()
|
||||
request = self.build_redirect_request(request, response)
|
||||
history = history + [response]
|
||||
|
||||
if not allow_redirects:
|
||||
response.call_next = functools.partial(
|
||||
self.send_handling_redirects,
|
||||
request=request,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=False,
|
||||
history=history,
|
||||
)
|
||||
return response
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
"""
|
||||
Given a request and a redirect response, return a new request that
|
||||
should be used to effect the redirect.
|
||||
"""
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url, method)
|
||||
content = self.redirect_content(request, method)
|
||||
cookies = Cookies(self.cookies)
|
||||
return Request(
|
||||
method=method, url=url, headers=headers, data=content, cookies=cookies
|
||||
)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes.SEE_OTHER and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# Turn 302s into GETs.
|
||||
if response.status_code == codes.FOUND and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# If a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in 'requests' issue 1704.
|
||||
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
url = URL(location, allow_relative=True)
|
||||
|
||||
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
if url.is_relative_url:
|
||||
url = request.url.join(url)
|
||||
|
||||
# Attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
if request.url.fragment and not url.fragment:
|
||||
url = url.copy_with(fragment=request.url.fragment)
|
||||
|
||||
return url
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
|
||||
"""
|
||||
Return the headers that should be used for the redirect request.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
|
||||
if url.origin != request.url.origin:
|
||||
# Strip Authorization headers when responses are redirected away from
|
||||
# the origin.
|
||||
headers.pop("Authorization", None)
|
||||
headers["Host"] = url.authority
|
||||
|
||||
if method != request.method and method == "GET":
|
||||
# If we've switch to a 'GET' request, then strip any headers which
|
||||
# are only relevant to the request body.
|
||||
headers.pop("Content-Length", None)
|
||||
headers.pop("Transfer-Encoding", None)
|
||||
|
||||
# We should use the client cookie store to determine any cookie header,
|
||||
# rather than whatever was on the original outgoing request.
|
||||
headers.pop("Cookie", None)
|
||||
|
||||
return headers
|
||||
|
||||
def redirect_content(self, request: Request, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
raise RedirectBodyUnavailable()
|
||||
return request.content
|
||||
|
||||
async def send_single_request(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Sends a single request, without handling any redirections.
|
||||
"""
|
||||
|
||||
dispatcher = self._dispatcher_for_request(request, self.proxies)
|
||||
|
||||
try:
|
||||
with ElapsedTimer() as timer:
|
||||
response = await dispatcher.send(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response.elapsed = timer.elapsed
|
||||
response.request = request
|
||||
except HTTPError as exc:
|
||||
# Add the original request to any HTTPError unless
|
||||
# there'a already a request attached in the case of
|
||||
# a ProxyError.
|
||||
if exc.request is None:
|
||||
exc.request = request
|
||||
raise
|
||||
|
||||
self.cookies.extract_cookies(response)
|
||||
|
||||
status = f"{response.status_code} {response.reason_phrase}"
|
||||
response_line = f"{response.http_version} {status}"
|
||||
logger.debug(f'HTTP Request: {request.method} {request.url} "{response_line}"')
|
||||
|
||||
return response
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _get_netrc(self) -> typing.Optional[netrc.netrc]:
|
||||
return get_netrc()
|
||||
|
||||
def _dispatcher_for_request(
|
||||
self, request: Request, proxies: typing.Dict[str, Dispatcher]
|
||||
) -> Dispatcher:
|
||||
"""Gets the Dispatcher instance that should be used for a given Request"""
|
||||
if proxies:
|
||||
url = request.url
|
||||
is_default_port = (url.scheme == "http" and url.port == 80) or (
|
||||
url.scheme == "https" and url.port == 443
|
||||
)
|
||||
hostname = f"{url.host}:{url.port}"
|
||||
proxy_keys = (
|
||||
f"{url.scheme}://{hostname}",
|
||||
f"{url.scheme}://{url.host}" if is_default_port else None,
|
||||
f"all://{hostname}",
|
||||
f"all://{url.host}" if is_default_port else None,
|
||||
url.scheme,
|
||||
"all",
|
||||
)
|
||||
for proxy_key in proxy_keys:
|
||||
if proxy_key and proxy_key in proxies:
|
||||
dispatcher = proxies[proxy_key]
|
||||
return dispatcher
|
||||
|
||||
return self.dispatch
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: URLTypes,
|
||||
@ -624,70 +820,6 @@ class Client:
|
||||
trust_env=trust_env,
|
||||
)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: URLTypes,
|
||||
*,
|
||||
data: RequestData = None,
|
||||
files: RequestFiles = None,
|
||||
json: typing.Any = None,
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
request = self.build_request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=data,
|
||||
files=files,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
)
|
||||
response = await self.send(
|
||||
request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
return response
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
*,
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
return await self._get_response(
|
||||
request=request,
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.dispatch.close()
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
from base64 import b64encode
|
||||
|
||||
import h11
|
||||
|
||||
@ -14,7 +15,6 @@ from ..config import (
|
||||
VerifyTypes,
|
||||
)
|
||||
from ..exceptions import ProxyError
|
||||
from ..middleware.basic_auth import build_basic_auth_header
|
||||
from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
|
||||
from ..utils import get_logger
|
||||
from .connection import HTTPConnection
|
||||
@ -69,13 +69,18 @@ class HTTPProxy(ConnectionPool):
|
||||
if url.username or url.password:
|
||||
self.proxy_headers.setdefault(
|
||||
"Proxy-Authorization",
|
||||
build_basic_auth_header(url.username, url.password),
|
||||
self.build_auth_header(url.username, url.password),
|
||||
)
|
||||
# Remove userinfo from the URL authority, e.g.:
|
||||
# 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
|
||||
credentials, _, authority = url.authority.rpartition("@")
|
||||
self.proxy_url = url.copy_with(authority=authority)
|
||||
|
||||
def build_auth_header(self, username: str, password: str) -> str:
|
||||
userpass = (username.encode("utf-8"), password.encode("utf-8"))
|
||||
token = b64encode(b":".join(userpass)).decode().strip()
|
||||
return f"Basic {token}"
|
||||
|
||||
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
|
||||
if self.should_forward_origin(origin):
|
||||
logger.trace(
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import typing
|
||||
|
||||
from ..models import Request, Response
|
||||
from .models import Request, Response
|
||||
|
||||
|
||||
class BaseMiddleware:
|
||||
class Middleware:
|
||||
async def __call__(
|
||||
self, request: Request, get_response: typing.Callable
|
||||
) -> Response:
|
||||
@ -1,27 +0,0 @@
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from ..models import Request, Response
|
||||
from ..utils import to_bytes
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class BasicAuthMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
self.authorization_header = build_basic_auth_header(username, password)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, get_response: typing.Callable
|
||||
) -> Response:
|
||||
request.headers["Authorization"] = self.authorization_header
|
||||
return await get_response(request)
|
||||
|
||||
|
||||
def build_basic_auth_header(
|
||||
username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode().strip()
|
||||
return f"Basic {token}"
|
||||
@ -1,15 +0,0 @@
|
||||
import typing
|
||||
|
||||
from ..models import Request, Response
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class CustomAuthMiddleware(BaseMiddleware):
|
||||
def __init__(self, auth: typing.Callable[[Request], Request]):
|
||||
self.auth = auth
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, get_response: typing.Callable
|
||||
) -> Response:
|
||||
request = self.auth(request)
|
||||
return await get_response(request)
|
||||
@ -1,130 +0,0 @@
|
||||
import functools
|
||||
import typing
|
||||
|
||||
from ..config import DEFAULT_MAX_REDIRECTS
|
||||
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
|
||||
from ..models import URL, Cookies, Headers, Request, Response
|
||||
from ..status_codes import codes
|
||||
from .base import BaseMiddleware
|
||||
|
||||
|
||||
class RedirectMiddleware(BaseMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
allow_redirects: bool = True,
|
||||
max_redirects: int = DEFAULT_MAX_REDIRECTS,
|
||||
cookies: typing.Optional[Cookies] = None,
|
||||
):
|
||||
self.allow_redirects = allow_redirects
|
||||
self.max_redirects = max_redirects
|
||||
self.cookies = cookies
|
||||
self.history: typing.List[Response] = []
|
||||
|
||||
async def __call__(
|
||||
self, request: Request, get_response: typing.Callable
|
||||
) -> Response:
|
||||
if len(self.history) > self.max_redirects:
|
||||
raise TooManyRedirects()
|
||||
if request.url in (response.url for response in self.history):
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await get_response(request)
|
||||
response.history = list(self.history)
|
||||
|
||||
if not response.is_redirect:
|
||||
return response
|
||||
|
||||
self.history.append(response)
|
||||
next_request = self.build_redirect_request(request, response)
|
||||
|
||||
if self.allow_redirects:
|
||||
return await self(next_request, get_response)
|
||||
|
||||
response.call_next = functools.partial(self, next_request, get_response)
|
||||
return response
|
||||
|
||||
def build_redirect_request(self, request: Request, response: Response) -> Request:
|
||||
method = self.redirect_method(request, response)
|
||||
url = self.redirect_url(request, response)
|
||||
headers = self.redirect_headers(request, url, method) # TODO: merge headers?
|
||||
content = self.redirect_content(request, method)
|
||||
cookies = Cookies(self.cookies)
|
||||
return Request(
|
||||
method=method, url=url, headers=headers, data=content, cookies=cookies
|
||||
)
|
||||
|
||||
def redirect_method(self, request: Request, response: Response) -> str:
|
||||
"""
|
||||
When being redirected we may want to change the method of the request
|
||||
based on certain specs or browser behavior.
|
||||
"""
|
||||
method = request.method
|
||||
|
||||
# https://tools.ietf.org/html/rfc7231#section-6.4.4
|
||||
if response.status_code == codes.SEE_OTHER and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# Do what the browsers do, despite standards...
|
||||
# Turn 302s into GETs.
|
||||
if response.status_code == codes.FOUND and method != "HEAD":
|
||||
method = "GET"
|
||||
|
||||
# If a POST is responded to with a 301, turn it into a GET.
|
||||
# This bizarre behaviour is explained in 'requests' issue 1704.
|
||||
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
|
||||
method = "GET"
|
||||
|
||||
return method
|
||||
|
||||
def redirect_url(self, request: Request, response: Response) -> URL:
|
||||
"""
|
||||
Return the URL for the redirect to follow.
|
||||
"""
|
||||
location = response.headers["Location"]
|
||||
|
||||
url = URL(location, allow_relative=True)
|
||||
|
||||
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
|
||||
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
||||
if url.is_relative_url:
|
||||
url = request.url.join(url)
|
||||
|
||||
# Attach previous fragment if needed (RFC 7231 7.1.2)
|
||||
if request.url.fragment and not url.fragment:
|
||||
url = url.copy_with(fragment=request.url.fragment)
|
||||
|
||||
return url
|
||||
|
||||
def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
|
||||
"""
|
||||
Return the headers that should be used for the redirect request.
|
||||
"""
|
||||
headers = Headers(request.headers)
|
||||
|
||||
if url.origin != request.url.origin:
|
||||
# Strip Authorization headers when responses are redirected away from
|
||||
# the origin.
|
||||
headers.pop("Authorization", None)
|
||||
headers["Host"] = url.authority
|
||||
|
||||
if method != request.method and method == "GET":
|
||||
# If we've switch to a 'GET' request, then strip any headers which
|
||||
# are only relevant to the request body.
|
||||
headers.pop("Content-Length", None)
|
||||
headers.pop("Transfer-Encoding", None)
|
||||
|
||||
# We should use the client cookie store to determine any cookie header,
|
||||
# rather than whatever was on the original outgoing request.
|
||||
headers.pop("Cookie", None)
|
||||
|
||||
return headers
|
||||
|
||||
def redirect_content(self, request: Request, method: str) -> bytes:
|
||||
"""
|
||||
Return the body that should be used for the redirect request.
|
||||
"""
|
||||
if method != request.method and method == "GET":
|
||||
return b""
|
||||
if request.is_streaming:
|
||||
raise RedirectBodyUnavailable()
|
||||
return request.content
|
||||
Loading…
Reference in New Issue
Block a user