PYTHON-3767 add types to ocsp_support.py (#1262)

This commit is contained in:
Iris 2023-06-27 13:13:25 -07:00 committed by GitHub
parent 1e14e89d0e
commit 5397d74668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,11 +13,13 @@
# permissions and limitations under the License.
"""Support for requesting and verifying OCSP responses."""
from __future__ import annotations
import logging as _logging
import re as _re
from datetime import datetime as _datetime
from datetime import timezone
from typing import TYPE_CHECKING, Iterable, List, Optional, Type, Union, cast
from cryptography.exceptions import InvalidSignature as _InvalidSignature
from cryptography.hazmat.backends import default_backend as _default_backend
@ -51,6 +53,26 @@ from requests.exceptions import RequestException as _RequestException
from pymongo import _csot
if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed448, ed25519, rsa
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.hazmat.primitives.hashes import HashAlgorithm
from cryptography.x509 import Certificate, Name
from cryptography.x509.extensions import Extension, ExtensionTypeVar
from cryptography.x509.ocsp import OCSPRequest, OCSPResponse
from OpenSSL.SSL import Connection
from pymongo.ocsp_cache import _OCSPCache
from pymongo.pyopenssl_context import _CallbackData
CertificateIssuerPublicKeyTypes = Union[
dsa.DSAPublicKey,
rsa.RSAPublicKey,
ec.EllipticCurvePublicKey,
ed25519.Ed25519PublicKey,
ed448.Ed448PublicKey,
]
# Note: the functions in this module generally return 1 or 0. The reason
# is simple. The entry point, ocsp_callback, is registered as a callback
# with OpenSSL through PyOpenSSL. The callback must return 1 (success) or
@ -63,7 +85,7 @@ _CERT_REGEX = _re.compile(
)
def _load_trusted_ca_certs(cafile):
def _load_trusted_ca_certs(cafile: str) -> List[Certificate]:
"""Parse the tlsCAFile into a list of certificates."""
with open(cafile, "rb") as f:
data = f.read()
@ -76,7 +98,9 @@ def _load_trusted_ca_certs(cafile):
return trusted_ca_certs
def _get_issuer_cert(cert, chain, trusted_ca_certs):
def _get_issuer_cert(
cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[List[Certificate]]
) -> Optional[Certificate]:
issuer_name = cert.issuer
for candidate in chain:
if candidate.subject == issuer_name:
@ -93,16 +117,21 @@ def _get_issuer_cert(cert, chain, trusted_ca_certs):
return None
def _verify_signature(key, signature, algorithm, data):
def _verify_signature(
key: CertificateIssuerPublicKeyTypes,
signature: bytes,
algorithm: Union[Prehashed, HashAlgorithm, None],
data: bytes,
) -> int:
# See cryptography.x509.Certificate.public_key
# for the public key types.
try:
if isinstance(key, _RSAPublicKey):
key.verify(signature, data, _PKCS1v15(), algorithm)
key.verify(signature, data, _PKCS1v15(), algorithm) # type: ignore[arg-type]
elif isinstance(key, _DSAPublicKey):
key.verify(signature, data, algorithm)
key.verify(signature, data, algorithm) # type: ignore[arg-type]
elif isinstance(key, _EllipticCurvePublicKey):
key.verify(signature, data, _ECDSA(algorithm))
key.verify(signature, data, _ECDSA(algorithm)) # type: ignore[arg-type]
else:
key.verify(signature, data)
except _InvalidSignature:
@ -110,14 +139,16 @@ def _verify_signature(key, signature, algorithm, data):
return 1
def _get_extension(cert, klass):
def _get_extension(
cert: Certificate, klass: Type[ExtensionTypeVar]
) -> Optional[Extension[ExtensionTypeVar]]:
try:
return cert.extensions.get_extension_for_class(klass)
except _ExtensionNotFound:
return None
def _public_key_hash(cert):
def _public_key_hash(cert: Certificate) -> bytes:
public_key = cert.public_key()
# https://tools.ietf.org/html/rfc2560#section-4.2.1
# "KeyHash ::= OCTET STRING -- SHA-1 hash of responder's public key
@ -134,7 +165,9 @@ def _public_key_hash(cert):
return digest.finalize()
def _get_certs_by_key_hash(certificates, issuer, responder_key_hash):
def _get_certs_by_key_hash(
certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes]
) -> List[Certificate]:
return [
cert
for cert in certificates
@ -142,7 +175,9 @@ def _get_certs_by_key_hash(certificates, issuer, responder_key_hash):
]
def _get_certs_by_name(certificates, issuer, responder_name):
def _get_certs_by_name(
certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name]
) -> List[Certificate]:
return [
cert
for cert in certificates
@ -150,7 +185,7 @@ def _get_certs_by_name(certificates, issuer, responder_name):
]
def _verify_response_signature(issuer, response):
def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> int:
# Response object will have a responder_name or responder_key_hash
# not both.
name = response.responder_name
@ -185,7 +220,7 @@ def _verify_response_signature(issuer, response):
_LOGGER.debug("Delegate not authorized for OCSP signing")
return 0
if not _verify_signature(
issuer.public_key(),
cast(CertificateIssuerPublicKeyTypes, issuer.public_key()),
responder_cert.signature,
responder_cert.signature_hash_algorithm,
responder_cert.tbs_certificate_bytes,
@ -194,7 +229,7 @@ def _verify_response_signature(issuer, response):
return 0
# RFC6960, Section 3.2, Number 2
ret = _verify_signature(
responder_cert.public_key(),
cast(CertificateIssuerPublicKeyTypes, responder_cert.public_key()),
response.signature,
response.signature_hash_algorithm,
response.tbs_response_bytes,
@ -204,14 +239,14 @@ def _verify_response_signature(issuer, response):
return ret
def _build_ocsp_request(cert, issuer):
def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest:
# https://cryptography.io/en/latest/x509/ocsp/#creating-requests
builder = _OCSPRequestBuilder()
builder = builder.add_certificate(cert, issuer, _SHA1())
return builder.build()
def _verify_response(issuer, response):
def _verify_response(issuer: Certificate, response: OCSPResponse) -> int:
_LOGGER.debug("Verifying response")
# RFC6960, Section 3.2, Number 2, 3 and 4 happen here.
res = _verify_response_signature(issuer, response)
@ -232,7 +267,9 @@ def _verify_response(issuer, response):
return 1
def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache):
def _get_ocsp_response(
cert: Certificate, issuer: Certificate, uri: Union[str, bytes], ocsp_response_cache: _OCSPCache
) -> Optional[OCSPResponse]:
ocsp_request = _build_ocsp_request(cert, issuer)
try:
ocsp_response = ocsp_response_cache[ocsp_request]
@ -275,30 +312,32 @@ def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache):
return ocsp_response
def _ocsp_callback(conn, ocsp_bytes, user_data):
def _ocsp_callback(conn: Connection, ocsp_bytes: bytes, user_data: Optional[_CallbackData]) -> bool:
"""Callback for use with OpenSSL.SSL.Context.set_ocsp_client_callback."""
cert = conn.get_peer_certificate()
if cert is None:
# always pass in user_data but OpenSSL requires it be optional
assert user_data
pycert = conn.get_peer_certificate()
if pycert is None:
_LOGGER.debug("No peer cert?")
return 0
cert = cert.to_cryptography()
return False
cert = pycert.to_cryptography()
# Use the verified chain when available (pyopenssl>=20.0).
if hasattr(conn, "get_verified_chain"):
chain = conn.get_verified_chain()
pychain = conn.get_verified_chain()
trusted_ca_certs = None
else:
chain = conn.get_peer_cert_chain()
pychain = conn.get_peer_cert_chain()
trusted_ca_certs = user_data.trusted_ca_certs
if not chain:
if not pychain:
_LOGGER.debug("No peer cert chain?")
return 0
chain = [cer.to_cryptography() for cer in chain]
return False
chain = [cer.to_cryptography() for cer in pychain]
issuer = _get_issuer_cert(cert, chain, trusted_ca_certs)
must_staple = False
# https://tools.ietf.org/html/rfc7633#section-4.2.3.1
ext = _get_extension(cert, _TLSFeature)
if ext is not None:
for feature in ext.value:
ext_tls = _get_extension(cert, _TLSFeature)
if ext_tls is not None:
for feature in ext_tls.value:
if feature == _TLSFeatureType.status_request:
_LOGGER.debug("Peer presented a must-staple cert")
must_staple = True
@ -310,29 +349,29 @@ def _ocsp_callback(conn, ocsp_bytes, user_data):
_LOGGER.debug("Peer did not staple an OCSP response")
if must_staple:
_LOGGER.debug("Must-staple cert with no stapled response, hard fail.")
return 0
return False
if not user_data.check_ocsp_endpoint:
_LOGGER.debug("OCSP endpoint checking is disabled, soft fail.")
# No stapled OCSP response, checking responder URI disabled, soft fail.
return 1
return True
# https://tools.ietf.org/html/rfc6960#section-3.1
ext = _get_extension(cert, _AuthorityInformationAccess)
if ext is None:
ext_aia = _get_extension(cert, _AuthorityInformationAccess)
if ext_aia is None:
_LOGGER.debug("No authority access information, soft fail")
# No stapled OCSP response, no responder URI, soft fail.
return 1
return True
uris = [
desc.access_location.value
for desc in ext.value
for desc in ext_aia.value
if desc.access_method == _AuthorityInformationAccessOID.OCSP
]
if not uris:
_LOGGER.debug("No OCSP URI, soft fail")
# No responder URI, soft fail.
return 1
return True
if issuer is None:
_LOGGER.debug("No issuer cert?")
return 0
return False
_LOGGER.debug("Requesting OCSP data")
# When requesting data from an OCSP endpoint we only fail on
# successful, valid responses with a certificate status of REVOKED.
@ -346,28 +385,28 @@ def _ocsp_callback(conn, ocsp_bytes, user_data):
continue
_LOGGER.debug("OCSP cert status: %r", response.certificate_status)
if response.certificate_status == _OCSPCertStatus.GOOD:
return 1
return True
if response.certificate_status == _OCSPCertStatus.REVOKED:
return 0
return False
# Soft fail if we couldn't get a definitive status.
_LOGGER.debug("No definitive OCSP cert status, soft fail")
return 1
return True
_LOGGER.debug("Peer stapled an OCSP response")
if issuer is None:
_LOGGER.debug("No issuer cert?")
return 0
return False
response = _load_der_ocsp_response(ocsp_bytes)
_LOGGER.debug("OCSP response status: %r", response.response_status)
# This happens in _request_ocsp when there is no stapled response so
# we know if we can compare serial numbers for the request and response.
if response.response_status != _OCSPResponseStatus.SUCCESSFUL:
return 0
return False
if not _verify_response(issuer, response):
return 0
return False
# Cache the verified, stapled response.
ocsp_response_cache[_build_ocsp_request(cert, issuer)] = response
_LOGGER.debug("OCSP cert status: %r", response.certificate_status)
if response.certificate_status == _OCSPCertStatus.REVOKED:
return 0
return 1
return False
return True