From 94fd83e92ec55a70a41a650cd4002f74ace01e7f Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:40:58 -0700 Subject: [PATCH] PYTHON-3814 add types to pyopenssl_context.py (#1341) --- pymongo/pyopenssl_context.py | 103 ++++++++++++++++++++++------------- tools/ocsptest.py | 2 +- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 140e6ba84..e2bbb8000 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -22,6 +22,7 @@ import sys as _sys import time as _time from errno import EINTR as _EINTR from ipaddress import ip_address as _ip_address +from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate from OpenSSL import SSL as _SSL @@ -39,6 +40,14 @@ from pymongo.socket_checker import SocketChecker as _SocketChecker from pymongo.socket_checker import _errno_from_exception from pymongo.write_concern import validate_boolean +if TYPE_CHECKING: + import socket + from ssl import VerifyMode + + from cryptography.x509 import Certificate + +_T = TypeVar("_T") + try: import certifi @@ -73,7 +82,7 @@ _REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()} # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. -def _is_ip_address(address): +def _is_ip_address(address: Any) -> bool: try: _ip_address(address) return True @@ -86,7 +95,7 @@ def _is_ip_address(address): BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) -def _ragged_eof(exc): +def _ragged_eof(exc: BaseException) -> bool: """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" return exc.args == (-1, "Unexpected EOF") @@ -95,12 +104,14 @@ def _ragged_eof(exc): # https://github.com/pyca/pyopenssl/issues/176 # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): - def __init__(self, ctx, sock, suppress_ragged_eofs): + def __init__( + self, ctx: _SSL.Context, sock: Optional[socket.socket], suppress_ragged_eofs: bool + ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) - def _call(self, call, *args, **kwargs): + def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: timeout = self.gettimeout() if timeout: start = _time.monotonic() @@ -127,10 +138,10 @@ class _sslConn(_SSL.Connection): raise _socket.timeout("timed out") continue - def do_handshake(self, *args, **kwargs): + def do_handshake(self, *args: Any, **kwargs: Any) -> None: return self._call(super().do_handshake, *args, **kwargs) - def recv(self, *args, **kwargs): + def recv(self, *args: Any, **kwargs: Any) -> bytes: try: return self._call(super().recv, *args, **kwargs) except _SSL.SysCallError as exc: @@ -139,7 +150,7 @@ class _sslConn(_SSL.Connection): return b"" raise - def recv_into(self, *args, **kwargs): + def recv_into(self, *args: Any, **kwargs: Any) -> int: try: return self._call(super().recv_into, *args, **kwargs) except _SSL.SysCallError as exc: @@ -148,7 +159,7 @@ class _sslConn(_SSL.Connection): return 0 raise - def sendall(self, buf, flags=0): + def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override] view = memoryview(buf) total_length = len(buf) total_sent = 0 @@ -172,9 +183,9 @@ class _sslConn(_SSL.Connection): class _CallbackData: """Data class which is passed to the OCSP callback.""" - def __init__(self): - self.trusted_ca_certs = None - self.check_ocsp_endpoint = None + def __init__(self) -> None: + self.trusted_ca_certs: Optional[List[Certificate]] = None + self.check_ocsp_endpoint: Optional[bool] = None self.ocsp_response_cache = _OCSPCache() @@ -185,7 +196,7 @@ class SSLContext: __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname") - def __init__(self, protocol): + def __init__(self, protocol: int): self._protocol = protocol self._ctx = _SSL.Context(self._protocol) self._callback_data = _CallbackData() @@ -198,58 +209,67 @@ class SSLContext: self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data) @property - def protocol(self): + def protocol(self) -> int: """The protocol version chosen when constructing the context. This attribute is read-only. """ return self._protocol - def __get_verify_mode(self): + def __get_verify_mode(self) -> VerifyMode: """Whether to try to verify other peers' certificates and how to behave if verification fails. This attribute must be one of ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. """ return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()] - def __set_verify_mode(self, value): + def __set_verify_mode(self, value: VerifyMode) -> None: """Setter for verify_mode.""" - def _cb(connobj, x509obj, errnum, errdepth, retcode): + def _cb( + connobj: _SSL.Connection, + x509obj: _crypto.X509, + errnum: int, + errdepth: int, + retcode: int, + ) -> bool: # It seems we don't need to do anything here. Twisted doesn't, # and OpenSSL's SSL_CTX_set_verify let's you pass NULL # for the callback option. It's weird that PyOpenSSL requires # this. - return retcode + # This is optional in pyopenssl >= 20 and can be removed once minimum + # supported version is bumped + # See: pyopenssl.org/en/latest/changelog.html#id47 + return bool(retcode) self._ctx.set_verify(_VERIFY_MAP[value], _cb) verify_mode = property(__get_verify_mode, __set_verify_mode) - def __get_check_hostname(self): + def __get_check_hostname(self) -> bool: return self._check_hostname - def __set_check_hostname(self, value): + def __set_check_hostname(self, value: Any) -> None: validate_boolean("check_hostname", value) self._check_hostname = value check_hostname = property(__get_check_hostname, __set_check_hostname) - def __get_check_ocsp_endpoint(self): + def __get_check_ocsp_endpoint(self) -> Optional[bool]: return self._callback_data.check_ocsp_endpoint - def __set_check_ocsp_endpoint(self, value): + def __set_check_ocsp_endpoint(self, value: bool) -> None: validate_boolean("check_ocsp", value) self._callback_data.check_ocsp_endpoint = value check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) - def __get_options(self): + def __get_options(self) -> None: # Calling set_options adds the option to the existing bitmask and # returns the new bitmask. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options return self._ctx.set_options(0) - def __set_options(self, value): + def __set_options(self, value: int) -> None: # Explcitly convert to int, since newer CPython versions # use enum.IntFlag for options. The values are the same # regardless of implementation. @@ -257,7 +277,12 @@ class SSLContext: options = property(__get_options, __set_options) - def load_cert_chain(self, certfile, keyfile=None, password=None): + def load_cert_chain( + self, + certfile: Union[str, bytes], + keyfile: Union[str, bytes, None] = None, + password: Optional[str] = None, + ) -> None: """Load a private key and the corresponding certificate. The certfile string must be the path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to @@ -270,10 +295,11 @@ class SSLContext: # Password callback MUST be set first or it will be ignored. if password: - def _pwcb(max_length, prompt_twice, user_data): + def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes: # XXX:We could check the password length against what OpenSSL # tells us is the max, but we can't raise an exception, so... # warn? + assert password is not None return password.encode("utf-8") self._ctx.set_passwd_cb(_pwcb) @@ -281,7 +307,9 @@ class SSLContext: self._ctx.use_privatekey_file(keyfile or certfile) self._ctx.check_privatekey() - def load_verify_locations(self, cafile=None, capath=None): + def load_verify_locations( + self, cafile: Optional[str] = None, capath: Optional[str] = None + ) -> None: """Load a set of "certification authority"(CA) certificates used to validate other peers' certificates when `~verify_mode` is other than ssl.CERT_NONE. @@ -289,9 +317,10 @@ class SSLContext: self._ctx.load_verify_locations(cafile, capath) # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20). if not hasattr(_SSL.Connection, "get_verified_chain"): + assert cafile is not None self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile) - def _load_certifi(self): + def _load_certifi(self) -> None: """Attempt to load CA certs from certifi.""" if _HAVE_CERTIFI: self.load_verify_locations(certifi.where()) @@ -303,7 +332,7 @@ class SSLContext: "the tlsCAFile option" ) - def _load_wincerts(self, store): + def _load_wincerts(self, store: str) -> None: """Attempt to load CA certs from Windows trust store.""" cert_store = self._ctx.get_cert_store() oid = _stdlibssl.Purpose.SERVER_AUTH.oid @@ -314,7 +343,7 @@ class SSLContext: _crypto.X509.from_cryptography(_load_der_x509_certificate(cert)) ) - def load_default_certs(self): + def load_default_certs(self) -> None: """A PyOpenSSL version of load_default_certs from CPython.""" # PyOpenSSL is incapable of loading CA certs from Windows, and mostly # incapable on macOS. @@ -330,7 +359,7 @@ class SSLContext: self._load_certifi() self._ctx.set_default_verify_paths() - def set_default_verify_paths(self): + def set_default_verify_paths(self) -> None: """Specify that the platform provided CA certificates are to be used for verification purposes. """ @@ -340,13 +369,13 @@ class SSLContext: def wrap_socket( self, - sock, - server_side=False, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - server_hostname=None, - session=None, - ): + sock: socket.socket, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: Optional[str] = None, + session: Optional[_SSL.Session] = None, + ) -> _sslConn: """Wrap an existing Python socket connection and return a TLS socket object. """ diff --git a/tools/ocsptest.py b/tools/ocsptest.py index 14df8a8fe..bba84252b 100644 --- a/tools/ocsptest.py +++ b/tools/ocsptest.py @@ -42,7 +42,7 @@ def check_ocsp(host, port, capath): s = socket.socket() s.connect((host, port)) try: - s = ctx.wrap_socket(s, server_hostname=host) + s = ctx.wrap_socket(s, server_hostname=host) # type: ignore[assignment] finally: s.close()