Drop per-request cert, verify, and trust_env (#617)

* Drop per-request cert/verify/trust_env

* Remove cert/verify from the dispatcher API

* Apply lint

* Reintroduce cert/verify/trust_env on client methods, with errors
This commit is contained in:
Florimond Manca 2019-12-29 16:01:20 +01:00 committed by Tom Christie
parent f3b799912e
commit e9ebd1df98
7 changed files with 36 additions and 126 deletions

View File

@ -223,23 +223,26 @@ class Client:
trust_env: bool = None,
) -> Response:
if cert is not None:
warnings.warn(
raise RuntimeError(
"Passing a 'cert' argument when making a request on a client "
"is due to be deprecated. Instantiate a new client instead, "
"is not supported anymore. Instantiate a new client instead, "
"passing any 'cert' arguments to the client itself."
)
if verify is not None:
warnings.warn(
raise RuntimeError(
"Passing a 'verify' argument when making a request on a client "
"is due to be deprecated. Instantiate a new client instead, "
"is not supported anymore. Instantiate a new client instead, "
"passing any 'verify' arguments to the client itself."
)
if trust_env is not None:
warnings.warn(
raise RuntimeError(
"Passing a 'trust_env' argument when making a request on a client "
"is due to be deprecated. Instantiate a new client instead, "
"is not supported anymore. Instantiate a new client instead, "
"passing any 'trust_env' argument to the client itself."
)
if stream:
warnings.warn(
"The 'stream=True' argument is due to be deprecated. "
@ -261,10 +264,7 @@ class Client:
stream=stream,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
trust_env=trust_env,
)
return response
@ -388,25 +388,17 @@ class Client:
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
trust_env: bool = None,
) -> Response:
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')
timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)
auth = self.setup_auth(request, trust_env, auth)
auth = self.setup_auth(request, auth)
response = await self.send_handling_redirects(
request,
auth=auth,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
request, auth=auth, timeout=timeout, allow_redirects=allow_redirects,
)
if not stream:
@ -417,11 +409,8 @@ class Client:
return response
def setup_auth(
self, request: Request, trust_env: bool = None, auth: AuthTypes = None
) -> Auth:
def setup_auth(self, request: Request, auth: AuthTypes = None) -> Auth:
auth = self.auth if auth is None else auth
trust_env = self.trust_env if trust_env is None else trust_env
if auth is not None:
if isinstance(auth, tuple):
@ -436,7 +425,7 @@ class Client:
if username or password:
return BasicAuth(username=username, password=password)
if trust_env and "Authorization" not in request.headers:
if self.trust_env and "Authorization" not in request.headers:
credentials = self.netrc.get_credentials(request.url.authority)
if credentials is not None:
return BasicAuth(username=credentials[0], password=credentials[1])
@ -448,8 +437,6 @@ class Client:
request: Request,
auth: Auth,
timeout: Timeout,
verify: VerifyTypes = None,
cert: CertTypes = None,
allow_redirects: bool = True,
history: typing.List[Response] = None,
) -> Response:
@ -463,7 +450,7 @@ class Client:
raise RedirectLoop()
response = await self.send_handling_auth(
request, auth=auth, timeout=timeout, verify=verify, cert=cert
request, auth=auth, timeout=timeout,
)
response.history = list(history)
@ -479,8 +466,6 @@ class Client:
self.send_handling_redirects,
request=request,
auth=auth,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=False,
history=history,
@ -580,17 +565,12 @@ class Client:
return request.stream
async def send_handling_auth(
self,
request: Request,
auth: Auth,
timeout: Timeout,
verify: VerifyTypes = None,
cert: CertTypes = None,
self, request: Request, auth: Auth, timeout: Timeout,
) -> Response:
auth_flow = auth(request)
request = next(auth_flow)
while True:
response = await self.send_single_request(request, timeout, verify, cert)
response = await self.send_single_request(request, timeout)
try:
next_request = auth_flow.send(response)
except StopIteration:
@ -603,11 +583,7 @@ class Client:
await response.close()
async def send_single_request(
self,
request: Request,
timeout: Timeout,
verify: VerifyTypes = None,
cert: CertTypes = None,
self, request: Request, timeout: Timeout,
) -> Response:
"""
Sends a single request, without handling any redirections.
@ -617,9 +593,7 @@ class Client:
try:
with ElapsedTimer() as timer:
response = await dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
response = await dispatcher.send(request, timeout=timeout)
response.elapsed = timer.elapsed
response.request = request
except HTTPError as exc:

View File

@ -1,6 +1,6 @@
import typing
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..config import TimeoutTypes
from ..content_streams import ByteStream
from ..models import Request, Response
from .base import Dispatcher
@ -54,14 +54,7 @@ class ASGIDispatch(Dispatcher):
self.root_path = root_path
self.client = client
async def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
async def send(self, request: Request, timeout: TimeoutTypes = None) -> Response:
scope = {
"type": "http",
"asgi": {"version": "3.0"},

View File

@ -1,7 +1,7 @@
import typing
from types import TracebackType
from ..config import CertTypes, Timeout, VerifyTypes
from ..config import Timeout
from ..models import (
HeaderTypes,
QueryParamTypes,
@ -29,20 +29,12 @@ class Dispatcher:
data: RequestData = b"",
params: QueryParamTypes = None,
headers: HeaderTypes = None,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
return await self.send(request, verify=verify, cert=cert, timeout=timeout)
return await self.send(request, timeout=timeout)
async def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
async def send(self, request: Request, timeout: Timeout = None) -> Response:
raise NotImplementedError() # pragma: nocover
async def close(self) -> None:

View File

@ -40,31 +40,21 @@ class HTTPConnection(Dispatcher):
self.open_connection: typing.Optional[OpenConnection] = None
self.expires_at: typing.Optional[float] = None
async def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
if self.open_connection is None:
await self.connect(verify=verify, cert=cert, timeout=timeout)
await self.connect(timeout=timeout)
assert self.open_connection is not None
response = await self.open_connection.send(request, timeout=timeout)
return response
async def connect(
self, timeout: Timeout, verify: VerifyTypes = None, cert: CertTypes = None,
) -> None:
ssl = self.ssl.with_overrides(verify=verify, cert=cert)
async def connect(self, timeout: Timeout) -> None:
host = self.origin.host
port = self.origin.port
ssl_context = await self.get_ssl_context(ssl)
ssl_context = await self.get_ssl_context(self.ssl)
if self.release_func is None:
on_release = None
@ -92,12 +82,7 @@ class HTTPConnection(Dispatcher):
self.set_open_connection(http_version, socket=stream, on_release=on_release)
async def tunnel_start_tls(
self,
origin: Origin,
proxy_url: URL,
timeout: Timeout = None,
cert: CertTypes = None,
verify: VerifyTypes = True,
self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
) -> None:
"""
Upgrade this connection to use TLS, assuming it represents a TCP tunnel.
@ -122,8 +107,7 @@ class HTTPConnection(Dispatcher):
if origin.is_ssl:
# Pull the socket stream off the internal HTTP connection object,
# and run start_tls().
ssl_config = SSLConfig(cert=cert, verify=verify)
ssl_context = await self.get_ssl_context(ssl_config)
ssl_context = await self.get_ssl_context(self.ssl)
assert ssl_context is not None
logger.trace(f"tunnel_start_tls proxy_url={proxy_url!r} origin={origin!r}")

View File

@ -140,21 +140,13 @@ class ConnectionPool(Dispatcher):
self.max_connections.release()
await connection.close()
async def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
async def send(self, request: Request, timeout: Timeout = None) -> Response:
await self.check_keepalive_expiry()
connection = await self.acquire_connection(
origin=request.url.origin, timeout=timeout
)
try:
response = await connection.send(
request, verify=verify, cert=cert, timeout=timeout
)
response = await connection.send(request, timeout=timeout)
except BaseException as exc:
self.active_connections.remove(connection)
self.max_connections.release()

View File

@ -117,11 +117,7 @@ class HTTPProxy(ConnectionPool):
self.active_connections.add(connection)
await connection.tunnel_start_tls(
origin=origin,
proxy_url=self.proxy_url,
timeout=timeout,
cert=self.cert,
verify=self.verify,
origin=origin, proxy_url=self.proxy_url, timeout=timeout,
)
else:
self.active_connections.add(connection)
@ -183,14 +179,7 @@ class HTTPProxy(ConnectionPool):
self.proxy_mode == DEFAULT_MODE and not origin.is_ssl
) or self.proxy_mode == FORWARD_ONLY
async def send(
self,
request: Request,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: Timeout = None,
) -> Response:
async def send(self, request: Request, timeout: Timeout = None) -> Response:
if self.should_forward_origin(request.url.origin):
# Change the request to have the target URL
# as its full_path and switch the proxy URL
@ -201,9 +190,7 @@ class HTTPProxy(ConnectionPool):
for name, value in self.proxy_headers.items():
request.headers.setdefault(name, value)
return await super().send(
request=request, verify=verify, cert=cert, timeout=timeout
)
return await super().send(request=request, timeout=timeout)
def __repr__(self) -> str:
return (

View File

@ -31,24 +31,12 @@ async def test_premature_close(server):
@pytest.mark.usefixtures("async_environment")
async def test_https_get_with_ssl_defaults(https_server, ca_cert_pem_file):
async def test_https_get_with_ssl(https_server, ca_cert_pem_file):
"""
An HTTPS request, with default SSL configuration set on the client.
An HTTPS request, with SSL configuration set on the client.
"""
async with HTTPConnection(origin=https_server.url, verify=ca_cert_pem_file) as conn:
response = await conn.request("GET", https_server.url)
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
@pytest.mark.usefixtures("async_environment")
async def test_https_get_with_sll_overrides(https_server, ca_cert_pem_file):
"""
An HTTPS request, with SSL configuration set on the request.
"""
async with HTTPConnection(origin=https_server.url) as conn:
response = await conn.request("GET", https_server.url, verify=ca_cert_pem_file)
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"