Push SSL context loading into config.py
This commit is contained in:
parent
30a03f4faa
commit
25c652a438
@ -1,3 +1,4 @@
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
import certifi
|
||||
@ -32,6 +33,58 @@ class SSLConfig:
|
||||
class_name = self.__class__.__name__
|
||||
return f"{class_name}(cert={self.cert}, verify={self.verify})"
|
||||
|
||||
async def load_ssl_context(self) -> ssl.SSLContext:
|
||||
if not hasattr(self, "ssl_context"):
|
||||
if not self.verify:
|
||||
self.ssl_context = self.load_ssl_context_no_verify()
|
||||
else:
|
||||
# Run the SSL loading in a threadpool, since it makes disk accesses.
|
||||
loop = asyncio.get_event_loop()
|
||||
self.ssl_context = await loop.run_in_executor(
|
||||
None, self.load_ssl_context_verify
|
||||
)
|
||||
|
||||
return self.ssl_context
|
||||
|
||||
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_default_verify_paths()
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
if isinstance(self.verify, bool):
|
||||
ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
|
||||
elif os.path.exists(self.verify):
|
||||
ca_bundle_path = self.verify
|
||||
else:
|
||||
raise IOError(
|
||||
"Could not find a suitable TLS CA certificate bundle, "
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
|
||||
context = ssl.create_default_context()
|
||||
if os.path.isfile(ca_bundle_path):
|
||||
context.load_verify_locations(cafile=ca_bundle_path)
|
||||
elif os.path.isdir(ca_bundle_path):
|
||||
context.load_verify_locations(capath=ca_bundle_path)
|
||||
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
context.load_cert_chain(certfile=self.cert)
|
||||
else:
|
||||
context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class TimeoutConfig:
|
||||
"""
|
||||
|
||||
@ -99,7 +99,11 @@ class ConnectionPool:
|
||||
self.num_active_connections += 1
|
||||
|
||||
except (KeyError, IndexError):
|
||||
ssl_context = await self.get_ssl_context(url, ssl)
|
||||
if url.is_secure:
|
||||
ssl_context = await ssl.load_ssl_context()
|
||||
else:
|
||||
ssl_context = None
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._max_connections.acquire(), timeout.pool_timeout
|
||||
@ -134,63 +138,6 @@ class ConnectionPool:
|
||||
except KeyError:
|
||||
self._keepalive_connections[key] = [connection]
|
||||
|
||||
async def get_ssl_context(
|
||||
self, url: URL, config: SSLConfig
|
||||
) -> typing.Optional[ssl.SSLContext]:
|
||||
if not url.is_secure:
|
||||
return None
|
||||
|
||||
if not hasattr(self, "ssl_context"):
|
||||
if not config.verify:
|
||||
self.ssl_context = self.get_ssl_context_no_verify()
|
||||
else:
|
||||
# Run the SSL loading in a threadpool, since it makes disk accesses.
|
||||
loop = asyncio.get_event_loop()
|
||||
self.ssl_context = await loop.run_in_executor(
|
||||
None, self.get_ssl_context_verify
|
||||
)
|
||||
|
||||
return self.ssl_context
|
||||
|
||||
def get_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
context.options |= ssl.OP_NO_SSLv2
|
||||
context.options |= ssl.OP_NO_SSLv3
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_default_verify_paths()
|
||||
return context
|
||||
|
||||
def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
if isinstance(config.verify, bool):
|
||||
ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
|
||||
elif os.path.exists(config.verify):
|
||||
ca_bundle_path = config.verify
|
||||
else:
|
||||
raise IOError(
|
||||
"Could not find a suitable TLS CA certificate bundle, "
|
||||
"invalid path: {}".format(config.verify)
|
||||
)
|
||||
|
||||
context = ssl.create_default_context()
|
||||
if os.path.isfile(ca_bundle_path):
|
||||
context.load_verify_locations(cafile=ca_bundle_path)
|
||||
elif os.path.isdir(ca_bundle_path):
|
||||
context.load_verify_locations(capath=ca_bundle_path)
|
||||
|
||||
if config.cert is not None:
|
||||
if isinstance(config.cert, str):
|
||||
context.load_cert_chain(certfile=config.cert)
|
||||
else:
|
||||
context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1])
|
||||
|
||||
return context
|
||||
|
||||
async def close(self) -> None:
|
||||
self.is_closed = True
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user