httpx/httpcore/pool.py
2019-04-06 13:18:39 +01:00

127 lines
3.9 KiB
Python

import asyncio
import os
import ssl
import typing
from types import TracebackType
from .config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
PoolLimits,
SSLConfig,
TimeoutConfig,
)
from .connections import Connection
from .datastructures import URL, Request, Response
class ConnectionPool:
def __init__(
self,
*,
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
limits: PoolLimits = DEFAULT_POOL_LIMITS,
):
self.ssl_config = ssl
self.timeout = timeout
self.limits = limits
self.is_closed = False
async def request(
self,
method: str,
url: str,
*,
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
stream: bool = False,
) -> Response:
parsed_url = URL(url)
request = Request(method, parsed_url, headers=headers, body=body)
ssl_context = await self.get_ssl_context(parsed_url)
connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
response = await connection.send(request, stream=stream)
return response
async def acquire_connection(
self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
) -> Connection:
connection = Connection(timeout=self.timeout)
await connection.open(url.hostname, url.port, ssl=ssl)
return connection
async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
if not url.is_secure:
return False
if not hasattr(self, "ssl_context"):
if not self.ssl_config.verify:
self.ssl_context = self.get_ssl_context_no_verify()
else:
# Run the SSL loading in a threadpool, since it makes disk accesses.
loop = asyncio.get_event_loop()
self.ssl_context = await loop.run_in_executor(
None, self.get_ssl_context_verify
)
return self.ssl_context
def get_ssl_context_no_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for unverified connections.
"""
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_COMPRESSION
context.set_default_verify_paths()
return context
def get_ssl_context_verify(self) -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
cert = self.ssl_config.cert
verify = self.ssl_config.verify
if isinstance(verify, bool):
ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
elif os.path.exists(verify):
ca_bundle_path = verify
else:
raise IOError(
"Could not find a suitable TLS CA certificate bundle, "
"invalid path: {}".format(verify)
)
context = ssl.create_default_context()
if os.path.isfile(ca_bundle_path):
context.load_verify_locations(cafile=ca_bundle_path)
elif os.path.isdir(ca_bundle_path):
context.load_verify_locations(capath=ca_bundle_path)
if cert is not None:
if isinstance(cert, str):
context.load_cert_chain(certfile=cert)
else:
context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
return context
async def close(self) -> None:
self.is_closed = True
async def __aenter__(self) -> "ConnectionPool":
return self
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.close()