PYTHON-3814 add types to pyopenssl_context.py (#1341)
This commit is contained in:
parent
dc63c5d9b8
commit
94fd83e92e
@ -22,6 +22,7 @@ import sys as _sys
|
||||
import time as _time
|
||||
from errno import EINTR as _EINTR
|
||||
from ipaddress import ip_address as _ip_address
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
|
||||
from OpenSSL import SSL as _SSL
|
||||
@ -39,6 +40,14 @@ from pymongo.socket_checker import SocketChecker as _SocketChecker
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import socket
|
||||
from ssl import VerifyMode
|
||||
|
||||
from cryptography.x509 import Certificate
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
try:
|
||||
import certifi
|
||||
|
||||
@ -73,7 +82,7 @@ _REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}
|
||||
|
||||
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
|
||||
# not permitted for SNI hostname.
|
||||
def _is_ip_address(address):
|
||||
def _is_ip_address(address: Any) -> bool:
|
||||
try:
|
||||
_ip_address(address)
|
||||
return True
|
||||
@ -86,7 +95,7 @@ def _is_ip_address(address):
|
||||
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
|
||||
|
||||
|
||||
def _ragged_eof(exc):
|
||||
def _ragged_eof(exc: BaseException) -> bool:
|
||||
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
|
||||
return exc.args == (-1, "Unexpected EOF")
|
||||
|
||||
@ -95,12 +104,14 @@ def _ragged_eof(exc):
|
||||
# https://github.com/pyca/pyopenssl/issues/176
|
||||
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
|
||||
class _sslConn(_SSL.Connection):
|
||||
def __init__(self, ctx, sock, suppress_ragged_eofs):
|
||||
def __init__(
|
||||
self, ctx: _SSL.Context, sock: Optional[socket.socket], suppress_ragged_eofs: bool
|
||||
):
|
||||
self.socket_checker = _SocketChecker()
|
||||
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||
super().__init__(ctx, sock)
|
||||
|
||||
def _call(self, call, *args, **kwargs):
|
||||
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
|
||||
timeout = self.gettimeout()
|
||||
if timeout:
|
||||
start = _time.monotonic()
|
||||
@ -127,10 +138,10 @@ class _sslConn(_SSL.Connection):
|
||||
raise _socket.timeout("timed out")
|
||||
continue
|
||||
|
||||
def do_handshake(self, *args, **kwargs):
|
||||
def do_handshake(self, *args: Any, **kwargs: Any) -> None:
|
||||
return self._call(super().do_handshake, *args, **kwargs)
|
||||
|
||||
def recv(self, *args, **kwargs):
|
||||
def recv(self, *args: Any, **kwargs: Any) -> bytes:
|
||||
try:
|
||||
return self._call(super().recv, *args, **kwargs)
|
||||
except _SSL.SysCallError as exc:
|
||||
@ -139,7 +150,7 @@ class _sslConn(_SSL.Connection):
|
||||
return b""
|
||||
raise
|
||||
|
||||
def recv_into(self, *args, **kwargs):
|
||||
def recv_into(self, *args: Any, **kwargs: Any) -> int:
|
||||
try:
|
||||
return self._call(super().recv_into, *args, **kwargs)
|
||||
except _SSL.SysCallError as exc:
|
||||
@ -148,7 +159,7 @@ class _sslConn(_SSL.Connection):
|
||||
return 0
|
||||
raise
|
||||
|
||||
def sendall(self, buf, flags=0):
|
||||
def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override]
|
||||
view = memoryview(buf)
|
||||
total_length = len(buf)
|
||||
total_sent = 0
|
||||
@ -172,9 +183,9 @@ class _sslConn(_SSL.Connection):
|
||||
class _CallbackData:
|
||||
"""Data class which is passed to the OCSP callback."""
|
||||
|
||||
def __init__(self):
|
||||
self.trusted_ca_certs = None
|
||||
self.check_ocsp_endpoint = None
|
||||
def __init__(self) -> None:
|
||||
self.trusted_ca_certs: Optional[List[Certificate]] = None
|
||||
self.check_ocsp_endpoint: Optional[bool] = None
|
||||
self.ocsp_response_cache = _OCSPCache()
|
||||
|
||||
|
||||
@ -185,7 +196,7 @@ class SSLContext:
|
||||
|
||||
__slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")
|
||||
|
||||
def __init__(self, protocol):
|
||||
def __init__(self, protocol: int):
|
||||
self._protocol = protocol
|
||||
self._ctx = _SSL.Context(self._protocol)
|
||||
self._callback_data = _CallbackData()
|
||||
@ -198,58 +209,67 @@ class SSLContext:
|
||||
self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
def protocol(self) -> int:
|
||||
"""The protocol version chosen when constructing the context.
|
||||
This attribute is read-only.
|
||||
"""
|
||||
return self._protocol
|
||||
|
||||
def __get_verify_mode(self):
|
||||
def __get_verify_mode(self) -> VerifyMode:
|
||||
"""Whether to try to verify other peers' certificates and how to
|
||||
behave if verification fails. This attribute must be one of
|
||||
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
|
||||
"""
|
||||
return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]
|
||||
|
||||
def __set_verify_mode(self, value):
|
||||
def __set_verify_mode(self, value: VerifyMode) -> None:
|
||||
"""Setter for verify_mode."""
|
||||
|
||||
def _cb(connobj, x509obj, errnum, errdepth, retcode):
|
||||
def _cb(
|
||||
connobj: _SSL.Connection,
|
||||
x509obj: _crypto.X509,
|
||||
errnum: int,
|
||||
errdepth: int,
|
||||
retcode: int,
|
||||
) -> bool:
|
||||
# It seems we don't need to do anything here. Twisted doesn't,
|
||||
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL
|
||||
# for the callback option. It's weird that PyOpenSSL requires
|
||||
# this.
|
||||
return retcode
|
||||
# This is optional in pyopenssl >= 20 and can be removed once minimum
|
||||
# supported version is bumped
|
||||
# See: pyopenssl.org/en/latest/changelog.html#id47
|
||||
return bool(retcode)
|
||||
|
||||
self._ctx.set_verify(_VERIFY_MAP[value], _cb)
|
||||
|
||||
verify_mode = property(__get_verify_mode, __set_verify_mode)
|
||||
|
||||
def __get_check_hostname(self):
|
||||
def __get_check_hostname(self) -> bool:
|
||||
return self._check_hostname
|
||||
|
||||
def __set_check_hostname(self, value):
|
||||
def __set_check_hostname(self, value: Any) -> None:
|
||||
validate_boolean("check_hostname", value)
|
||||
self._check_hostname = value
|
||||
|
||||
check_hostname = property(__get_check_hostname, __set_check_hostname)
|
||||
|
||||
def __get_check_ocsp_endpoint(self):
|
||||
def __get_check_ocsp_endpoint(self) -> Optional[bool]:
|
||||
return self._callback_data.check_ocsp_endpoint
|
||||
|
||||
def __set_check_ocsp_endpoint(self, value):
|
||||
def __set_check_ocsp_endpoint(self, value: bool) -> None:
|
||||
validate_boolean("check_ocsp", value)
|
||||
self._callback_data.check_ocsp_endpoint = value
|
||||
|
||||
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)
|
||||
|
||||
def __get_options(self):
|
||||
def __get_options(self) -> None:
|
||||
# Calling set_options adds the option to the existing bitmask and
|
||||
# returns the new bitmask.
|
||||
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
|
||||
return self._ctx.set_options(0)
|
||||
|
||||
def __set_options(self, value):
|
||||
def __set_options(self, value: int) -> None:
|
||||
# Explcitly convert to int, since newer CPython versions
|
||||
# use enum.IntFlag for options. The values are the same
|
||||
# regardless of implementation.
|
||||
@ -257,7 +277,12 @@ class SSLContext:
|
||||
|
||||
options = property(__get_options, __set_options)
|
||||
|
||||
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile: Union[str, bytes],
|
||||
keyfile: Union[str, bytes, None] = None,
|
||||
password: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Load a private key and the corresponding certificate. The certfile
|
||||
string must be the path to a single file in PEM format containing the
|
||||
certificate as well as any number of CA certificates needed to
|
||||
@ -270,10 +295,11 @@ class SSLContext:
|
||||
# Password callback MUST be set first or it will be ignored.
|
||||
if password:
|
||||
|
||||
def _pwcb(max_length, prompt_twice, user_data):
|
||||
def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes:
|
||||
# XXX:We could check the password length against what OpenSSL
|
||||
# tells us is the max, but we can't raise an exception, so...
|
||||
# warn?
|
||||
assert password is not None
|
||||
return password.encode("utf-8")
|
||||
|
||||
self._ctx.set_passwd_cb(_pwcb)
|
||||
@ -281,7 +307,9 @@ class SSLContext:
|
||||
self._ctx.use_privatekey_file(keyfile or certfile)
|
||||
self._ctx.check_privatekey()
|
||||
|
||||
def load_verify_locations(self, cafile=None, capath=None):
|
||||
def load_verify_locations(
|
||||
self, cafile: Optional[str] = None, capath: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load a set of "certification authority"(CA) certificates used to
|
||||
validate other peers' certificates when `~verify_mode` is other than
|
||||
ssl.CERT_NONE.
|
||||
@ -289,9 +317,10 @@ class SSLContext:
|
||||
self._ctx.load_verify_locations(cafile, capath)
|
||||
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
|
||||
if not hasattr(_SSL.Connection, "get_verified_chain"):
|
||||
assert cafile is not None
|
||||
self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)
|
||||
|
||||
def _load_certifi(self):
|
||||
def _load_certifi(self) -> None:
|
||||
"""Attempt to load CA certs from certifi."""
|
||||
if _HAVE_CERTIFI:
|
||||
self.load_verify_locations(certifi.where())
|
||||
@ -303,7 +332,7 @@ class SSLContext:
|
||||
"the tlsCAFile option"
|
||||
)
|
||||
|
||||
def _load_wincerts(self, store):
|
||||
def _load_wincerts(self, store: str) -> None:
|
||||
"""Attempt to load CA certs from Windows trust store."""
|
||||
cert_store = self._ctx.get_cert_store()
|
||||
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
|
||||
@ -314,7 +343,7 @@ class SSLContext:
|
||||
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert))
|
||||
)
|
||||
|
||||
def load_default_certs(self):
|
||||
def load_default_certs(self) -> None:
|
||||
"""A PyOpenSSL version of load_default_certs from CPython."""
|
||||
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly
|
||||
# incapable on macOS.
|
||||
@ -330,7 +359,7 @@ class SSLContext:
|
||||
self._load_certifi()
|
||||
self._ctx.set_default_verify_paths()
|
||||
|
||||
def set_default_verify_paths(self):
|
||||
def set_default_verify_paths(self) -> None:
|
||||
"""Specify that the platform provided CA certificates are to be used
|
||||
for verification purposes.
|
||||
"""
|
||||
@ -340,13 +369,13 @@ class SSLContext:
|
||||
|
||||
def wrap_socket(
|
||||
self,
|
||||
sock,
|
||||
server_side=False,
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True,
|
||||
server_hostname=None,
|
||||
session=None,
|
||||
):
|
||||
sock: socket.socket,
|
||||
server_side: bool = False,
|
||||
do_handshake_on_connect: bool = True,
|
||||
suppress_ragged_eofs: bool = True,
|
||||
server_hostname: Optional[str] = None,
|
||||
session: Optional[_SSL.Session] = None,
|
||||
) -> _sslConn:
|
||||
"""Wrap an existing Python socket connection and return a TLS socket
|
||||
object.
|
||||
"""
|
||||
|
||||
@ -42,7 +42,7 @@ def check_ocsp(host, port, capath):
|
||||
s = socket.socket()
|
||||
s.connect((host, port))
|
||||
try:
|
||||
s = ctx.wrap_socket(s, server_hostname=host)
|
||||
s = ctx.wrap_socket(s, server_hostname=host) # type: ignore[assignment]
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user