Drop per-request cert, verify, and trust_env (#617)
* Drop per-request cert/verify/trust_env * Remove cert/verify from the dispatcher API * Apply lint * Reintroduce cert/verify/trust_env on client methods, with errors
This commit is contained in:
parent
f3b799912e
commit
e9ebd1df98
@ -223,23 +223,26 @@ class Client:
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
if cert is not None:
|
||||
warnings.warn(
|
||||
raise RuntimeError(
|
||||
"Passing a 'cert' argument when making a request on a client "
|
||||
"is due to be deprecated. Instantiate a new client instead, "
|
||||
"is not supported anymore. Instantiate a new client instead, "
|
||||
"passing any 'cert' arguments to the client itself."
|
||||
)
|
||||
|
||||
if verify is not None:
|
||||
warnings.warn(
|
||||
raise RuntimeError(
|
||||
"Passing a 'verify' argument when making a request on a client "
|
||||
"is due to be deprecated. Instantiate a new client instead, "
|
||||
"is not supported anymore. Instantiate a new client instead, "
|
||||
"passing any 'verify' arguments to the client itself."
|
||||
)
|
||||
|
||||
if trust_env is not None:
|
||||
warnings.warn(
|
||||
raise RuntimeError(
|
||||
"Passing a 'trust_env' argument when making a request on a client "
|
||||
"is due to be deprecated. Instantiate a new client instead, "
|
||||
"is not supported anymore. Instantiate a new client instead, "
|
||||
"passing any 'trust_env' argument to the client itself."
|
||||
)
|
||||
|
||||
if stream:
|
||||
warnings.warn(
|
||||
"The 'stream=True' argument is due to be deprecated. "
|
||||
@ -261,10 +264,7 @@ class Client:
|
||||
stream=stream,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
trust_env=trust_env,
|
||||
)
|
||||
return response
|
||||
|
||||
@ -388,25 +388,17 @@ class Client:
|
||||
stream: bool = False,
|
||||
auth: AuthTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
|
||||
trust_env: bool = None,
|
||||
) -> Response:
|
||||
if request.url.scheme not in ("http", "https"):
|
||||
raise InvalidURL('URL scheme must be "http" or "https".')
|
||||
|
||||
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
|
||||
|
||||
auth = self.setup_auth(request, trust_env, auth)
|
||||
auth = self.setup_auth(request, auth)
|
||||
|
||||
response = await self.send_handling_redirects(
|
||||
request,
|
||||
auth=auth,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
if not stream:
|
||||
@ -417,11 +409,8 @@ class Client:
|
||||
|
||||
return response
|
||||
|
||||
def setup_auth(
|
||||
self, request: Request, trust_env: bool = None, auth: AuthTypes = None
|
||||
) -> Auth:
|
||||
def setup_auth(self, request: Request, auth: AuthTypes = None) -> Auth:
|
||||
auth = self.auth if auth is None else auth
|
||||
trust_env = self.trust_env if trust_env is None else trust_env
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, tuple):
|
||||
@ -436,7 +425,7 @@ class Client:
|
||||
if username or password:
|
||||
return BasicAuth(username=username, password=password)
|
||||
|
||||
if trust_env and "Authorization" not in request.headers:
|
||||
if self.trust_env and "Authorization" not in request.headers:
|
||||
credentials = self.netrc.get_credentials(request.url.authority)
|
||||
if credentials is not None:
|
||||
return BasicAuth(username=credentials[0], password=credentials[1])
|
||||
@ -448,8 +437,6 @@ class Client:
|
||||
request: Request,
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
allow_redirects: bool = True,
|
||||
history: typing.List[Response] = None,
|
||||
) -> Response:
|
||||
@ -463,7 +450,7 @@ class Client:
|
||||
raise RedirectLoop()
|
||||
|
||||
response = await self.send_handling_auth(
|
||||
request, auth=auth, timeout=timeout, verify=verify, cert=cert
|
||||
request, auth=auth, timeout=timeout,
|
||||
)
|
||||
response.history = list(history)
|
||||
|
||||
@ -479,8 +466,6 @@ class Client:
|
||||
self.send_handling_redirects,
|
||||
request=request,
|
||||
auth=auth,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
timeout=timeout,
|
||||
allow_redirects=False,
|
||||
history=history,
|
||||
@ -580,17 +565,12 @@ class Client:
|
||||
return request.stream
|
||||
|
||||
async def send_handling_auth(
|
||||
self,
|
||||
request: Request,
|
||||
auth: Auth,
|
||||
timeout: Timeout,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
self, request: Request, auth: Auth, timeout: Timeout,
|
||||
) -> Response:
|
||||
auth_flow = auth(request)
|
||||
request = next(auth_flow)
|
||||
while True:
|
||||
response = await self.send_single_request(request, timeout, verify, cert)
|
||||
response = await self.send_single_request(request, timeout)
|
||||
try:
|
||||
next_request = auth_flow.send(response)
|
||||
except StopIteration:
|
||||
@ -603,11 +583,7 @@ class Client:
|
||||
await response.close()
|
||||
|
||||
async def send_single_request(
|
||||
self,
|
||||
request: Request,
|
||||
timeout: Timeout,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
self, request: Request, timeout: Timeout,
|
||||
) -> Response:
|
||||
"""
|
||||
Sends a single request, without handling any redirections.
|
||||
@ -617,9 +593,7 @@ class Client:
|
||||
|
||||
try:
|
||||
with ElapsedTimer() as timer:
|
||||
response = await dispatcher.send(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response = await dispatcher.send(request, timeout=timeout)
|
||||
response.elapsed = timer.elapsed
|
||||
response.request = request
|
||||
except HTTPError as exc:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import typing
|
||||
|
||||
from ..config import CertTypes, TimeoutTypes, VerifyTypes
|
||||
from ..config import TimeoutTypes
|
||||
from ..content_streams import ByteStream
|
||||
from ..models import Request, Response
|
||||
from .base import Dispatcher
|
||||
@ -54,14 +54,7 @@ class ASGIDispatch(Dispatcher):
|
||||
self.root_path = root_path
|
||||
self.client = client
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: TimeoutTypes = None,
|
||||
) -> Response:
|
||||
|
||||
async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from ..config import CertTypes, Timeout, VerifyTypes
|
||||
from ..config import Timeout
|
||||
from ..models import (
|
||||
HeaderTypes,
|
||||
QueryParamTypes,
|
||||
@ -29,20 +29,12 @@ class Dispatcher:
|
||||
data: RequestData = b"",
|
||||
params: QueryParamTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
request = Request(method, url, data=data, params=params, headers=headers)
|
||||
return await self.send(request, verify=verify, cert=cert, timeout=timeout)
|
||||
return await self.send(request, timeout=timeout)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@ -40,31 +40,21 @@ class HTTPConnection(Dispatcher):
|
||||
self.open_connection: typing.Optional[OpenConnection] = None
|
||||
self.expires_at: typing.Optional[float] = None
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
timeout = Timeout() if timeout is None else timeout
|
||||
|
||||
if self.open_connection is None:
|
||||
await self.connect(verify=verify, cert=cert, timeout=timeout)
|
||||
await self.connect(timeout=timeout)
|
||||
|
||||
assert self.open_connection is not None
|
||||
response = await self.open_connection.send(request, timeout=timeout)
|
||||
|
||||
return response
|
||||
|
||||
async def connect(
|
||||
self, timeout: Timeout, verify: VerifyTypes = None, cert: CertTypes = None,
|
||||
) -> None:
|
||||
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
|
||||
|
||||
async def connect(self, timeout: Timeout) -> None:
|
||||
host = self.origin.host
|
||||
port = self.origin.port
|
||||
ssl_context = await self.get_ssl_context(ssl)
|
||||
ssl_context = await self.get_ssl_context(self.ssl)
|
||||
|
||||
if self.release_func is None:
|
||||
on_release = None
|
||||
@ -92,12 +82,7 @@ class HTTPConnection(Dispatcher):
|
||||
self.set_open_connection(http_version, socket=stream, on_release=on_release)
|
||||
|
||||
async def tunnel_start_tls(
|
||||
self,
|
||||
origin: Origin,
|
||||
proxy_url: URL,
|
||||
timeout: Timeout = None,
|
||||
cert: CertTypes = None,
|
||||
verify: VerifyTypes = True,
|
||||
self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
|
||||
@ -122,8 +107,7 @@ class HTTPConnection(Dispatcher):
|
||||
if origin.is_ssl:
|
||||
# Pull the socket stream off the internal HTTP connection object,
|
||||
# and run start_tls().
|
||||
ssl_config = SSLConfig(cert=cert, verify=verify)
|
||||
ssl_context = await self.get_ssl_context(ssl_config)
|
||||
ssl_context = await self.get_ssl_context(self.ssl)
|
||||
assert ssl_context is not None
|
||||
|
||||
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")
|
||||
|
||||
@ -140,21 +140,13 @@ class ConnectionPool(Dispatcher):
|
||||
self.max_connections.release()
|
||||
await connection.close()
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
await self.check_keepalive_expiry()
|
||||
connection = await self.acquire_connection(
|
||||
origin=request.url.origin, timeout=timeout
|
||||
)
|
||||
try:
|
||||
response = await connection.send(
|
||||
request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
response = await connection.send(request, timeout=timeout)
|
||||
except BaseException as exc:
|
||||
self.active_connections.remove(connection)
|
||||
self.max_connections.release()
|
||||
|
||||
@ -117,11 +117,7 @@ class HTTPProxy(ConnectionPool):
|
||||
self.active_connections.add(connection)
|
||||
|
||||
await connection.tunnel_start_tls(
|
||||
origin=origin,
|
||||
proxy_url=self.proxy_url,
|
||||
timeout=timeout,
|
||||
cert=self.cert,
|
||||
verify=self.verify,
|
||||
origin=origin, proxy_url=self.proxy_url, timeout=timeout,
|
||||
)
|
||||
else:
|
||||
self.active_connections.add(connection)
|
||||
@ -183,14 +179,7 @@ class HTTPProxy(ConnectionPool):
|
||||
self.proxy_mode == DEFAULT_MODE and not origin.is_ssl
|
||||
) or self.proxy_mode == FORWARD_ONLY
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: Request,
|
||||
verify: VerifyTypes = None,
|
||||
cert: CertTypes = None,
|
||||
timeout: Timeout = None,
|
||||
) -> Response:
|
||||
|
||||
async def send(self, request: Request, timeout: Timeout = None) -> Response:
|
||||
if self.should_forward_origin(request.url.origin):
|
||||
# Change the request to have the target URL
|
||||
# as its full_path and switch the proxy URL
|
||||
@ -201,9 +190,7 @@ class HTTPProxy(ConnectionPool):
|
||||
for name, value in self.proxy_headers.items():
|
||||
request.headers.setdefault(name, value)
|
||||
|
||||
return await super().send(
|
||||
request=request, verify=verify, cert=cert, timeout=timeout
|
||||
)
|
||||
return await super().send(request=request, timeout=timeout)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@ -31,24 +31,12 @@ async def test_premature_close(server):
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file):
|
||||
async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
|
||||
"""
|
||||
An HTTPS request, with default SSL configuration set on the client.
|
||||
An HTTPS request, with SSL configuration set on the client.
|
||||
"""
|
||||
async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
|
||||
response = await conn.request("GET", https_server.url)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("async_environment")
|
||||
async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file):
|
||||
"""
|
||||
An HTTPS request, with SSL configuration set on the request.
|
||||
"""
|
||||
async with HTTPConnection(origin=https_server.url) as conn:
|
||||
response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file)
|
||||
await response.read()
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"Hello, world!"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user