Client handles redirect + auth (#552)

* Drop sync client

* Drop unused imports

* Async only

* Update tests/test_decoders.py

Co-Authored-By: Florimond Manca <florimond.manca@gmail.com>

* Linting

* Update docs for async-only

* Import sorting

* Add async notes to docs

* Update README for 0.8 async switch

* Move auth away from middleware where possible

* Drop middleware sub-package

* Client.dispatcher -> Client.dispatch

* Docs tweak

* Linting

* Fix type checking issue

* Import ordering

* Fix up docstrings

* Minor docs fixes

* Linting

* Remove unused import
This commit is contained in:
Tom Christie 2019-11-27 12:10:10 +00:00 committed by GitHub
parent 206c5372a6
commit 00e150f6a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 400 additions and 417 deletions

View File

@ -24,13 +24,13 @@ HTTPX
<em>A next-generation HTTP client for Python.</em>
</div>
HTTPX is an asynchronous HTTP client, that supports HTTP/2 and HTTP/1.1.
HTTPX is an asynchronous client library that supports HTTP/1.1 and HTTP/2.
It can be used in high-performance async web frameworks, using either asyncio
or trio, and is able to support making large numbers of requests concurrently.
or trio, and is able to support making large numbers of concurrent requests.
!!! note
The 0.8 release switched HTTPX into focusing exclusively on the async
The 0.8 release switched HTTPX into focusing exclusively on providing an async
client. It is possible that we'll look at re-introducing a sync API at a
later date.
@ -38,11 +38,10 @@ or trio, and is able to support making large numbers of requests concurrently.
Let's get started...
!!! note
The standard Python REPL does not allow top-level async statements.
The standard Python REPL does not allow top-level async statements.
To run async examples directly you'll probably want to either use `ipython`,
or use Python 3.8 with `python -m asyncio`.
To run these async examples you'll probably want to either use `ipython`,
or use Python 3.8 with `python -m asyncio`.
```python
>>> import httpx

View File

@ -1,5 +1,6 @@
from .__version__ import __description__, __title__, __version__
from .api import delete, get, head, options, patch, post, put, request
from .auth import BasicAuth, DigestAuth
from .client import Client
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import (
@ -43,7 +44,6 @@ from .exceptions import (
TooManyRedirects,
WriteTimeout,
)
from .middleware.digest_auth import DigestAuth
from .models import (
URL,
AuthTypes,
@ -76,7 +76,9 @@ __all__ = [
"patch",
"put",
"request",
"BasicAuth",
"Client",
"DigestAuth",
"AsyncioBackend",
"USER_AGENT",
"CertTypes",

View File

@ -3,15 +3,34 @@ import os
import re
import time
import typing
from base64 import b64encode
from urllib.request import parse_http_list
from ..exceptions import ProtocolError
from ..models import Request, Response, StatusCode
from ..utils import to_bytes, to_str, unquote
from .base import BaseMiddleware
from .exceptions import ProtocolError
from .middleware import Middleware
from .models import Request, Response
from .utils import to_bytes, to_str, unquote
class DigestAuth(BaseMiddleware):
class BasicAuth:
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
self.auth_header = self.build_auth_header(username, password)
def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = self.auth_header
return request
def build_auth_header(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode().strip()
return f"Basic {token}"
class DigestAuth(Middleware):
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
@ -33,12 +52,10 @@ class DigestAuth(BaseMiddleware):
self, request: Request, get_response: typing.Callable
) -> Response:
response = await get_response(request)
if not (
StatusCode.is_client_error(response.status_code)
and "www-authenticate" in response.headers
):
if response.status_code != 401 or "www-authenticate" not in response.headers:
return response
await response.close()
header = response.headers["www-authenticate"]
try:
challenge = DigestAuthChallenge.from_header(header)

View File

@ -5,6 +5,7 @@ from types import TracebackType
import hstspreload
from .auth import BasicAuth
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import ConcurrencyBackend
from .config import (
@ -21,11 +22,14 @@ from .dispatch.asgi import ASGIDispatch
from .dispatch.base import Dispatcher
from .dispatch.connection_pool import ConnectionPool
from .dispatch.proxy_http import HTTPProxy
from .exceptions import HTTPError, InvalidURL
from .middleware.base import BaseMiddleware
from .middleware.basic_auth import BasicAuthMiddleware
from .middleware.custom_auth import CustomAuthMiddleware
from .middleware.redirect import RedirectMiddleware
from .exceptions import (
HTTPError,
InvalidURL,
RedirectBodyUnavailable,
RedirectLoop,
TooManyRedirects,
)
from .middleware import Middleware
from .models import (
URL,
AuthTypes,
@ -42,6 +46,7 @@ from .models import (
Response,
URLTypes,
)
from .status_codes import codes
from .utils import ElapsedTimer, get_environment_proxies, get_logger, get_netrc
logger = get_logger(__name__)
@ -125,8 +130,6 @@ class Client:
if app is not None:
dispatch = ASGIDispatch(app=app, backend=backend)
self.trust_env = True if trust_env is None else trust_env
if dispatch is None:
dispatch = ConnectionPool(
verify=verify,
@ -135,7 +138,7 @@ class Client:
http_versions=http_versions,
pool_limits=pool_limits,
backend=backend,
trust_env=self.trust_env,
trust_env=trust_env,
uds=uds,
)
@ -152,6 +155,7 @@ class Client:
self._headers = Headers(headers)
self._cookies = Cookies(cookies)
self.max_redirects = max_redirects
self.trust_env = trust_env
self.dispatch = dispatch
self.concurrency_backend = backend
@ -202,169 +206,46 @@ class Client:
def params(self, params: QueryParamTypes) -> None:
self._params = QueryParams(params)
def merge_url(self, url: URLTypes) -> URL:
url = self.base_url.join(relative_url=url)
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
url = url.copy_with(scheme="https")
return url
def merge_cookies(
self, cookies: CookieTypes = None
) -> typing.Optional[CookieTypes]:
if cookies or self.cookies:
merged_cookies = Cookies(self.cookies)
merged_cookies.update(cookies)
return merged_cookies
return cookies
def merge_headers(
self, headers: HeaderTypes = None
) -> typing.Optional[HeaderTypes]:
if headers or self.headers:
merged_headers = Headers(self.headers)
merged_headers.update(headers)
return merged_headers
return headers
def merge_queryparams(
self, params: QueryParamTypes = None
) -> typing.Optional[QueryParamTypes]:
if params or self.params:
merged_queryparams = QueryParams(self.params)
merged_queryparams.update(params)
return merged_queryparams
return params
async def _get_response(
async def request(
self,
request: Request,
method: str,
url: URLTypes,
*,
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> Response:
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
dispatch = self._dispatcher_for_request(request, self.proxies)
async def get_response(request: Request) -> Response:
try:
with ElapsedTimer() as timer:
response = await dispatch.send(
request, verify=verify, cert=cert, timeout=timeout
)
response.elapsed = timer.elapsed
response.request = request
except HTTPError as exc:
# Add the original request to any HTTPError unless
# there'a already a request attached in the case of
# a ProxyError.
if exc.request is None:
exc.request = request
raise
self.cookies.extract_cookies(response)
if not stream:
try:
await response.read()
finally:
await response.close()
status = f"{response.status_code} {response.reason_phrase}"
response_line = f"{response.http_version} {status}"
logger.debug(
f'HTTP Request: {request.method} {request.url} "{response_line}"'
)
return response
def wrap(
get_response: typing.Callable, middleware: BaseMiddleware
) -> typing.Callable:
return functools.partial(middleware, get_response=get_response)
get_response = wrap(
get_response,
RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
request = self.build_request(
method=method,
url=url,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
)
auth_middleware = self._get_auth_middleware(
request=request,
trust_env=self.trust_env if trust_env is None else trust_env,
auth=self.auth if auth is None else auth,
response = await self.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
trust_env=trust_env,
)
if auth_middleware is not None:
get_response = wrap(get_response, auth_middleware)
return await get_response(request)
def _get_auth_middleware(
self, request: Request, trust_env: bool, auth: AuthTypes = None
) -> typing.Optional[BaseMiddleware]:
if isinstance(auth, tuple):
return BasicAuthMiddleware(username=auth[0], password=auth[1])
elif isinstance(auth, BaseMiddleware):
return auth
elif callable(auth):
return CustomAuthMiddleware(auth=auth)
if auth is not None:
raise TypeError(
'When specified, "auth" must be a (username, password) tuple or '
"a callable with signature (Request) -> Request "
f"(got {auth!r})"
)
if request.url.username or request.url.password:
return BasicAuthMiddleware(
username=request.url.username, password=request.url.password
)
if trust_env:
netrc_info = self._get_netrc()
if netrc_info:
netrc_login = netrc_info.authenticators(request.url.authority)
if netrc_login:
username, _, password = netrc_login
assert password is not None
return BasicAuthMiddleware(username=username, password=password)
return None
@functools.lru_cache(1)
def _get_netrc(self) -> typing.Optional[netrc.netrc]:
return get_netrc()
def _dispatcher_for_request(
self, request: Request, proxies: typing.Dict[str, Dispatcher]
) -> Dispatcher:
"""Gets the Dispatcher instance that should be used for a given Request"""
if proxies:
url = request.url
is_default_port = (url.scheme == "http" and url.port == 80) or (
url.scheme == "https" and url.port == 443
)
hostname = f"{url.host}:{url.port}"
proxy_keys = (
f"{url.scheme}://{hostname}",
f"{url.scheme}://{url.host}" if is_default_port else None,
f"all://{hostname}",
f"all://{url.host}" if is_default_port else None,
url.scheme,
"all",
)
for proxy_key in proxy_keys:
if proxy_key and proxy_key in proxies:
dispatcher = proxies[proxy_key]
return dispatcher
return self.dispatch
return response
def build_request(
self,
@ -396,6 +277,321 @@ class Client:
cookies=cookies,
)
def merge_url(self, url: URLTypes) -> URL:
"""
Merge a URL argument together with any 'base_url' on the client,
to create the URL used for the outgoing request.
"""
url = self.base_url.join(relative_url=url)
if url.scheme == "http" and hstspreload.in_hsts_preload(url.host):
url = url.copy_with(scheme="https")
return url
def merge_cookies(
self, cookies: CookieTypes = None
) -> typing.Optional[CookieTypes]:
"""
Merge a cookies argument together with any cookies on the client,
to create the cookies used for the outgoing request.
"""
if cookies or self.cookies:
merged_cookies = Cookies(self.cookies)
merged_cookies.update(cookies)
return merged_cookies
return cookies
def merge_headers(
self, headers: HeaderTypes = None
) -> typing.Optional[HeaderTypes]:
"""
Merge a headers argument together with any headers on the client,
to create the headers used for the outgoing request.
"""
if headers or self.headers:
merged_headers = Headers(self.headers)
merged_headers.update(headers)
return merged_headers
return headers
def merge_queryparams(
self, params: QueryParamTypes = None
) -> typing.Optional[QueryParamTypes]:
"""
Merge a queryparams argument together with any queryparams on the client,
to create the queryparams used for the outgoing request.
"""
if params or self.params:
merged_queryparams = QueryParams(self.params)
merged_queryparams.update(params)
return merged_queryparams
return params
async def send(
self,
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> Response:
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
auth = self.auth if auth is None else auth
trust_env = self.trust_env if trust_env is None else trust_env
if not isinstance(auth, Middleware):
request = self.authenticate(request, trust_env, auth)
response = await self.send_handling_redirects(
request,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
)
else:
get_response = functools.partial(
self.send_handling_redirects,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
)
response = await auth(request, get_response)
if not stream:
try:
await response.read()
finally:
await response.close()
return response
def authenticate(
self, request: Request, trust_env: bool, auth: AuthTypes = None
) -> "Request":
if auth is not None:
if isinstance(auth, tuple):
auth = BasicAuth(username=auth[0], password=auth[1])
return auth(request)
username, password = request.url.username, request.url.password
if username or password:
auth = BasicAuth(username=username, password=password)
return auth(request)
if trust_env:
netrc_info = self._get_netrc()
if netrc_info is not None:
netrc_login = netrc_info.authenticators(request.url.authority)
netrc_username, _, netrc_password = netrc_login or ("", None, None)
if netrc_password is not None:
auth = BasicAuth(username=netrc_username, password=netrc_password)
return auth(request)
return request
async def send_handling_redirects(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
allow_redirects: bool = True,
history: typing.List[Response] = None,
) -> Response:
if history is None:
history = []
while True:
if len(history) > self.max_redirects:
raise TooManyRedirects()
if request.url in (response.url for response in history):
raise RedirectLoop()
response = await self.send_single_request(
request, verify=verify, cert=cert, timeout=timeout
)
response.history = list(history)
if not response.is_redirect:
return response
await response.close()
request = self.build_redirect_request(request, response)
history = history + [response]
if not allow_redirects:
response.call_next = functools.partial(
self.send_handling_redirects,
request=request,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=False,
history=history,
)
return response
def build_redirect_request(self, request: Request, response: Response) -> Request:
"""
Given a request and a redirect response, return a new request that
should be used to effect the redirect.
"""
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url, method)
content = self.redirect_content(request, method)
cookies = Cookies(self.cookies)
return Request(
method=method, url=url, headers=headers, data=content, cookies=cookies
)
def redirect_method(self, request: Request, response: Response) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"
return method
def redirect_url(self, request: Request, response: Response) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
url = URL(location, allow_relative=True)
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
if url.is_relative_url:
url = request.url.join(url)
# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)
return url
def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
"""
Return the headers that should be used for the redirect request.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
# Strip Authorization headers when responses are redirected away from
# the origin.
headers.pop("Authorization", None)
headers["Host"] = url.authority
if method != request.method and method == "GET":
# If we've switch to a 'GET' request, then strip any headers which
# are only relevant to the request body.
headers.pop("Content-Length", None)
headers.pop("Transfer-Encoding", None)
# We should use the client cookie store to determine any cookie header,
# rather than whatever was on the original outgoing request.
headers.pop("Cookie", None)
return headers
def redirect_content(self, request: Request, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable()
return request.content
async def send_single_request(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
"""
Sends a single request, without handling any redirections.
"""
dispatcher = self._dispatcher_for_request(request, self.proxies)
try:
with ElapsedTimer() as timer:
response = await dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
response.elapsed = timer.elapsed
response.request = request
except HTTPError as exc:
# Add the original request to any HTTPError unless
# there'a already a request attached in the case of
# a ProxyError.
if exc.request is None:
exc.request = request
raise
self.cookies.extract_cookies(response)
status = f"{response.status_code} {response.reason_phrase}"
response_line = f"{response.http_version} {status}"
logger.debug(f'HTTP Request: {request.method} {request.url} "{response_line}"')
return response
@functools.lru_cache(1)
def _get_netrc(self) -> typing.Optional[netrc.netrc]:
return get_netrc()
def _dispatcher_for_request(
self, request: Request, proxies: typing.Dict[str, Dispatcher]
) -> Dispatcher:
"""Gets the Dispatcher instance that should be used for a given Request"""
if proxies:
url = request.url
is_default_port = (url.scheme == "http" and url.port == 80) or (
url.scheme == "https" and url.port == 443
)
hostname = f"{url.host}:{url.port}"
proxy_keys = (
f"{url.scheme}://{hostname}",
f"{url.scheme}://{url.host}" if is_default_port else None,
f"all://{hostname}",
f"all://{url.host}" if is_default_port else None,
url.scheme,
"all",
)
for proxy_key in proxy_keys:
if proxy_key and proxy_key in proxies:
dispatcher = proxies[proxy_key]
return dispatcher
return self.dispatch
async def get(
self,
url: URLTypes,
@ -624,70 +820,6 @@ class Client:
trust_env=trust_env,
)
async def request(
self,
method: str,
url: URLTypes,
*,
data: RequestData = None,
files: RequestFiles = None,
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> Response:
request = self.build_request(
method=method,
url=url,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
)
response = await self.send(
request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
trust_env=trust_env,
)
return response
async def send(
self,
request: Request,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> Response:
return await self._get_response(
request=request,
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
trust_env=trust_env,
)
async def close(self) -> None:
await self.dispatch.close()

View File

@ -1,4 +1,5 @@
import enum
from base64 import b64encode
import h11
@ -14,7 +15,6 @@ from ..config import (
VerifyTypes,
)
from ..exceptions import ProxyError
from ..middleware.basic_auth import build_basic_auth_header
from ..models import URL, Headers, HeaderTypes, Origin, Request, Response, URLTypes
from ..utils import get_logger
from .connection import HTTPConnection
@ -69,13 +69,18 @@ class HTTPProxy(ConnectionPool):
if url.username or url.password:
self.proxy_headers.setdefault(
"Proxy-Authorization",
build_basic_auth_header(url.username, url.password),
self.build_auth_header(url.username, url.password),
)
# Remove userinfo from the URL authority, e.g.:
# 'username:password@proxy_host:proxy_port' -> 'proxy_host:proxy_port'
credentials, _, authority = url.authority.rpartition("@")
self.proxy_url = url.copy_with(authority=authority)
def build_auth_header(self, username: str, password: str) -> str:
userpass = (username.encode("utf-8"), password.encode("utf-8"))
token = b64encode(b":".join(userpass)).decode().strip()
return f"Basic {token}"
async def acquire_connection(self, origin: Origin) -> HTTPConnection:
if self.should_forward_origin(origin):
logger.trace(

View File

@ -1,9 +1,9 @@
import typing
from ..models import Request, Response
from .models import Request, Response
class BaseMiddleware:
class Middleware:
async def __call__(
self, request: Request, get_response: typing.Callable
) -> Response:

View File

@ -1,27 +0,0 @@
import typing
from base64 import b64encode
from ..models import Request, Response
from ..utils import to_bytes
from .base import BaseMiddleware
class BasicAuthMiddleware(BaseMiddleware):
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
self.authorization_header = build_basic_auth_header(username, password)
async def __call__(
self, request: Request, get_response: typing.Callable
) -> Response:
request.headers["Authorization"] = self.authorization_header
return await get_response(request)
def build_basic_auth_header(
username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode().strip()
return f"Basic {token}"

View File

@ -1,15 +0,0 @@
import typing
from ..models import Request, Response
from .base import BaseMiddleware
class CustomAuthMiddleware(BaseMiddleware):
def __init__(self, auth: typing.Callable[[Request], Request]):
self.auth = auth
async def __call__(
self, request: Request, get_response: typing.Callable
) -> Response:
request = self.auth(request)
return await get_response(request)

View File

@ -1,130 +0,0 @@
import functools
import typing
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..models import URL, Cookies, Headers, Request, Response
from ..status_codes import codes
from .base import BaseMiddleware
class RedirectMiddleware(BaseMiddleware):
def __init__(
self,
allow_redirects: bool = True,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
cookies: typing.Optional[Cookies] = None,
):
self.allow_redirects = allow_redirects
self.max_redirects = max_redirects
self.cookies = cookies
self.history: typing.List[Response] = []
async def __call__(
self, request: Request, get_response: typing.Callable
) -> Response:
if len(self.history) > self.max_redirects:
raise TooManyRedirects()
if request.url in (response.url for response in self.history):
raise RedirectLoop()
response = await get_response(request)
response.history = list(self.history)
if not response.is_redirect:
return response
self.history.append(response)
next_request = self.build_redirect_request(request, response)
if self.allow_redirects:
return await self(next_request, get_response)
response.call_next = functools.partial(self, next_request, get_response)
return response
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url, method) # TODO: merge headers?
content = self.redirect_content(request, method)
cookies = Cookies(self.cookies)
return Request(
method=method, url=url, headers=headers, data=content, cookies=cookies
)
def redirect_method(self, request: Request, response: Response) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"
# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"
return method
def redirect_url(self, request: Request, response: Response) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]
url = URL(location, allow_relative=True)
# Facilitate relative 'Location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
if url.is_relative_url:
url = request.url.join(url)
# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)
return url
def redirect_headers(self, request: Request, url: URL, method: str) -> Headers:
"""
Return the headers that should be used for the redirect request.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
# Strip Authorization headers when responses are redirected away from
# the origin.
headers.pop("Authorization", None)
headers["Host"] = url.authority
if method != request.method and method == "GET":
# If we've switch to a 'GET' request, then strip any headers which
# are only relevant to the request body.
headers.pop("Content-Length", None)
headers.pop("Transfer-Encoding", None)
# We should use the client cookie store to determine any cookie header,
# rather than whatever was on the original outgoing request.
headers.pop("Cookie", None)
return headers
def redirect_content(self, request: Request, method: str) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable()
return request.content