diff --git a/pymongo/ocsp_support.py b/pymongo/ocsp_support.py index dd070748a..fa9bd1b7e 100644 --- a/pymongo/ocsp_support.py +++ b/pymongo/ocsp_support.py @@ -13,11 +13,13 @@ # permissions and limitations under the License. """Support for requesting and verifying OCSP responses.""" +from __future__ import annotations import logging as _logging import re as _re from datetime import datetime as _datetime from datetime import timezone +from typing import TYPE_CHECKING, Iterable, List, Optional, Type, Union, cast from cryptography.exceptions import InvalidSignature as _InvalidSignature from cryptography.hazmat.backends import default_backend as _default_backend @@ -51,6 +53,26 @@ from requests.exceptions import RequestException as _RequestException from pymongo import _csot +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed448, ed25519, rsa + from cryptography.hazmat.primitives.asymmetric.utils import Prehashed + from cryptography.hazmat.primitives.hashes import HashAlgorithm + from cryptography.x509 import Certificate, Name + from cryptography.x509.extensions import Extension, ExtensionTypeVar + from cryptography.x509.ocsp import OCSPRequest, OCSPResponse + from OpenSSL.SSL import Connection + + from pymongo.ocsp_cache import _OCSPCache + from pymongo.pyopenssl_context import _CallbackData + + CertificateIssuerPublicKeyTypes = Union[ + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + ] + # Note: the functions in this module generally return 1 or 0. The reason # is simple. The entry point, ocsp_callback, is registered as a callback # with OpenSSL through PyOpenSSL. The callback must return 1 (success) or @@ -63,7 +85,7 @@ _CERT_REGEX = _re.compile( ) -def _load_trusted_ca_certs(cafile): +def _load_trusted_ca_certs(cafile: str) -> List[Certificate]: """Parse the tlsCAFile into a list of certificates.""" with open(cafile, "rb") as f: data = f.read() @@ -76,7 +98,9 @@ def _load_trusted_ca_certs(cafile): return trusted_ca_certs -def _get_issuer_cert(cert, chain, trusted_ca_certs): +def _get_issuer_cert( + cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[List[Certificate]] +) -> Optional[Certificate]: issuer_name = cert.issuer for candidate in chain: if candidate.subject == issuer_name: @@ -93,16 +117,21 @@ def _get_issuer_cert(cert, chain, trusted_ca_certs): return None -def _verify_signature(key, signature, algorithm, data): +def _verify_signature( + key: CertificateIssuerPublicKeyTypes, + signature: bytes, + algorithm: Union[Prehashed, HashAlgorithm, None], + data: bytes, +) -> int: # See cryptography.x509.Certificate.public_key # for the public key types. try: if isinstance(key, _RSAPublicKey): - key.verify(signature, data, _PKCS1v15(), algorithm) + key.verify(signature, data, _PKCS1v15(), algorithm) # type: ignore[arg-type] elif isinstance(key, _DSAPublicKey): - key.verify(signature, data, algorithm) + key.verify(signature, data, algorithm) # type: ignore[arg-type] elif isinstance(key, _EllipticCurvePublicKey): - key.verify(signature, data, _ECDSA(algorithm)) + key.verify(signature, data, _ECDSA(algorithm)) # type: ignore[arg-type] else: key.verify(signature, data) except _InvalidSignature: @@ -110,14 +139,16 @@ def _verify_signature(key, signature, algorithm, data): return 1 -def _get_extension(cert, klass): +def _get_extension( + cert: Certificate, klass: Type[ExtensionTypeVar] +) -> Optional[Extension[ExtensionTypeVar]]: try: return cert.extensions.get_extension_for_class(klass) except _ExtensionNotFound: return None -def _public_key_hash(cert): +def _public_key_hash(cert: Certificate) -> bytes: public_key = cert.public_key() # https://tools.ietf.org/html/rfc2560#section-4.2.1 # "KeyHash ::= OCTET STRING -- SHA-1 hash of responder's public key @@ -134,7 +165,9 @@ def _public_key_hash(cert): return digest.finalize() -def _get_certs_by_key_hash(certificates, issuer, responder_key_hash): +def _get_certs_by_key_hash( + certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes] +) -> List[Certificate]: return [ cert for cert in certificates @@ -142,7 +175,9 @@ def _get_certs_by_key_hash(certificates, issuer, responder_key_hash): ] -def _get_certs_by_name(certificates, issuer, responder_name): +def _get_certs_by_name( + certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name] +) -> List[Certificate]: return [ cert for cert in certificates @@ -150,7 +185,7 @@ def _get_certs_by_name(certificates, issuer, responder_name): ] -def _verify_response_signature(issuer, response): +def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> int: # Response object will have a responder_name or responder_key_hash # not both. name = response.responder_name @@ -185,7 +220,7 @@ def _verify_response_signature(issuer, response): _LOGGER.debug("Delegate not authorized for OCSP signing") return 0 if not _verify_signature( - issuer.public_key(), + cast(CertificateIssuerPublicKeyTypes, issuer.public_key()), responder_cert.signature, responder_cert.signature_hash_algorithm, responder_cert.tbs_certificate_bytes, @@ -194,7 +229,7 @@ def _verify_response_signature(issuer, response): return 0 # RFC6960, Section 3.2, Number 2 ret = _verify_signature( - responder_cert.public_key(), + cast(CertificateIssuerPublicKeyTypes, responder_cert.public_key()), response.signature, response.signature_hash_algorithm, response.tbs_response_bytes, @@ -204,14 +239,14 @@ def _verify_response_signature(issuer, response): return ret -def _build_ocsp_request(cert, issuer): +def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest: # https://cryptography.io/en/latest/x509/ocsp/#creating-requests builder = _OCSPRequestBuilder() builder = builder.add_certificate(cert, issuer, _SHA1()) return builder.build() -def _verify_response(issuer, response): +def _verify_response(issuer: Certificate, response: OCSPResponse) -> int: _LOGGER.debug("Verifying response") # RFC6960, Section 3.2, Number 2, 3 and 4 happen here. res = _verify_response_signature(issuer, response) @@ -232,7 +267,9 @@ def _verify_response(issuer, response): return 1 -def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache): +def _get_ocsp_response( + cert: Certificate, issuer: Certificate, uri: Union[str, bytes], ocsp_response_cache: _OCSPCache +) -> Optional[OCSPResponse]: ocsp_request = _build_ocsp_request(cert, issuer) try: ocsp_response = ocsp_response_cache[ocsp_request] @@ -275,30 +312,32 @@ def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache): return ocsp_response -def _ocsp_callback(conn, ocsp_bytes, user_data): +def _ocsp_callback(conn: Connection, ocsp_bytes: bytes, user_data: Optional[_CallbackData]) -> bool: """Callback for use with OpenSSL.SSL.Context.set_ocsp_client_callback.""" - cert = conn.get_peer_certificate() - if cert is None: + # always pass in user_data but OpenSSL requires it be optional + assert user_data + pycert = conn.get_peer_certificate() + if pycert is None: _LOGGER.debug("No peer cert?") - return 0 - cert = cert.to_cryptography() + return False + cert = pycert.to_cryptography() # Use the verified chain when available (pyopenssl>=20.0). if hasattr(conn, "get_verified_chain"): - chain = conn.get_verified_chain() + pychain = conn.get_verified_chain() trusted_ca_certs = None else: - chain = conn.get_peer_cert_chain() + pychain = conn.get_peer_cert_chain() trusted_ca_certs = user_data.trusted_ca_certs - if not chain: + if not pychain: _LOGGER.debug("No peer cert chain?") - return 0 - chain = [cer.to_cryptography() for cer in chain] + return False + chain = [cer.to_cryptography() for cer in pychain] issuer = _get_issuer_cert(cert, chain, trusted_ca_certs) must_staple = False # https://tools.ietf.org/html/rfc7633#section-4.2.3.1 - ext = _get_extension(cert, _TLSFeature) - if ext is not None: - for feature in ext.value: + ext_tls = _get_extension(cert, _TLSFeature) + if ext_tls is not None: + for feature in ext_tls.value: if feature == _TLSFeatureType.status_request: _LOGGER.debug("Peer presented a must-staple cert") must_staple = True @@ -310,29 +349,29 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): _LOGGER.debug("Peer did not staple an OCSP response") if must_staple: _LOGGER.debug("Must-staple cert with no stapled response, hard fail.") - return 0 + return False if not user_data.check_ocsp_endpoint: _LOGGER.debug("OCSP endpoint checking is disabled, soft fail.") # No stapled OCSP response, checking responder URI disabled, soft fail. - return 1 + return True # https://tools.ietf.org/html/rfc6960#section-3.1 - ext = _get_extension(cert, _AuthorityInformationAccess) - if ext is None: + ext_aia = _get_extension(cert, _AuthorityInformationAccess) + if ext_aia is None: _LOGGER.debug("No authority access information, soft fail") # No stapled OCSP response, no responder URI, soft fail. - return 1 + return True uris = [ desc.access_location.value - for desc in ext.value + for desc in ext_aia.value if desc.access_method == _AuthorityInformationAccessOID.OCSP ] if not uris: _LOGGER.debug("No OCSP URI, soft fail") # No responder URI, soft fail. - return 1 + return True if issuer is None: _LOGGER.debug("No issuer cert?") - return 0 + return False _LOGGER.debug("Requesting OCSP data") # When requesting data from an OCSP endpoint we only fail on # successful, valid responses with a certificate status of REVOKED. @@ -346,28 +385,28 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): continue _LOGGER.debug("OCSP cert status: %r", response.certificate_status) if response.certificate_status == _OCSPCertStatus.GOOD: - return 1 + return True if response.certificate_status == _OCSPCertStatus.REVOKED: - return 0 + return False # Soft fail if we couldn't get a definitive status. _LOGGER.debug("No definitive OCSP cert status, soft fail") - return 1 + return True _LOGGER.debug("Peer stapled an OCSP response") if issuer is None: _LOGGER.debug("No issuer cert?") - return 0 + return False response = _load_der_ocsp_response(ocsp_bytes) _LOGGER.debug("OCSP response status: %r", response.response_status) # This happens in _request_ocsp when there is no stapled response so # we know if we can compare serial numbers for the request and response. if response.response_status != _OCSPResponseStatus.SUCCESSFUL: - return 0 + return False if not _verify_response(issuer, response): - return 0 + return False # Cache the verified, stapled response. ocsp_response_cache[_build_ocsp_request(cert, issuer)] = response _LOGGER.debug("OCSP cert status: %r", response.certificate_status) if response.certificate_status == _OCSPCertStatus.REVOKED: - return 0 - return 1 + return False + return True