PYTHON-3767 add types to ocsp_support.py (#1262)
This commit is contained in:
parent
1e14e89d0e
commit
5397d74668
@ -13,11 +13,13 @@
|
||||
# permissions and limitations under the License.
|
||||
|
||||
"""Support for requesting and verifying OCSP responses."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging as _logging
|
||||
import re as _re
|
||||
from datetime import datetime as _datetime
|
||||
from datetime import timezone
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Type, Union, cast
|
||||
|
||||
from cryptography.exceptions import InvalidSignature as _InvalidSignature
|
||||
from cryptography.hazmat.backends import default_backend as _default_backend
|
||||
@ -51,6 +53,26 @@ from requests.exceptions import RequestException as _RequestException
|
||||
|
||||
from pymongo import _csot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed448, ed25519, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
|
||||
from cryptography.hazmat.primitives.hashes import HashAlgorithm
|
||||
from cryptography.x509 import Certificate, Name
|
||||
from cryptography.x509.extensions import Extension, ExtensionTypeVar
|
||||
from cryptography.x509.ocsp import OCSPRequest, OCSPResponse
|
||||
from OpenSSL.SSL import Connection
|
||||
|
||||
from pymongo.ocsp_cache import _OCSPCache
|
||||
from pymongo.pyopenssl_context import _CallbackData
|
||||
|
||||
CertificateIssuerPublicKeyTypes = Union[
|
||||
dsa.DSAPublicKey,
|
||||
rsa.RSAPublicKey,
|
||||
ec.EllipticCurvePublicKey,
|
||||
ed25519.Ed25519PublicKey,
|
||||
ed448.Ed448PublicKey,
|
||||
]
|
||||
|
||||
# Note: the functions in this module generally return 1 or 0. The reason
|
||||
# is simple. The entry point, ocsp_callback, is registered as a callback
|
||||
# with OpenSSL through PyOpenSSL. The callback must return 1 (success) or
|
||||
@ -63,7 +85,7 @@ _CERT_REGEX = _re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _load_trusted_ca_certs(cafile):
|
||||
def _load_trusted_ca_certs(cafile: str) -> List[Certificate]:
|
||||
"""Parse the tlsCAFile into a list of certificates."""
|
||||
with open(cafile, "rb") as f:
|
||||
data = f.read()
|
||||
@ -76,7 +98,9 @@ def _load_trusted_ca_certs(cafile):
|
||||
return trusted_ca_certs
|
||||
|
||||
|
||||
def _get_issuer_cert(cert, chain, trusted_ca_certs):
|
||||
def _get_issuer_cert(
|
||||
cert: Certificate, chain: Iterable[Certificate], trusted_ca_certs: Optional[List[Certificate]]
|
||||
) -> Optional[Certificate]:
|
||||
issuer_name = cert.issuer
|
||||
for candidate in chain:
|
||||
if candidate.subject == issuer_name:
|
||||
@ -93,16 +117,21 @@ def _get_issuer_cert(cert, chain, trusted_ca_certs):
|
||||
return None
|
||||
|
||||
|
||||
def _verify_signature(key, signature, algorithm, data):
|
||||
def _verify_signature(
|
||||
key: CertificateIssuerPublicKeyTypes,
|
||||
signature: bytes,
|
||||
algorithm: Union[Prehashed, HashAlgorithm, None],
|
||||
data: bytes,
|
||||
) -> int:
|
||||
# See cryptography.x509.Certificate.public_key
|
||||
# for the public key types.
|
||||
try:
|
||||
if isinstance(key, _RSAPublicKey):
|
||||
key.verify(signature, data, _PKCS1v15(), algorithm)
|
||||
key.verify(signature, data, _PKCS1v15(), algorithm) # type: ignore[arg-type]
|
||||
elif isinstance(key, _DSAPublicKey):
|
||||
key.verify(signature, data, algorithm)
|
||||
key.verify(signature, data, algorithm) # type: ignore[arg-type]
|
||||
elif isinstance(key, _EllipticCurvePublicKey):
|
||||
key.verify(signature, data, _ECDSA(algorithm))
|
||||
key.verify(signature, data, _ECDSA(algorithm)) # type: ignore[arg-type]
|
||||
else:
|
||||
key.verify(signature, data)
|
||||
except _InvalidSignature:
|
||||
@ -110,14 +139,16 @@ def _verify_signature(key, signature, algorithm, data):
|
||||
return 1
|
||||
|
||||
|
||||
def _get_extension(cert, klass):
|
||||
def _get_extension(
|
||||
cert: Certificate, klass: Type[ExtensionTypeVar]
|
||||
) -> Optional[Extension[ExtensionTypeVar]]:
|
||||
try:
|
||||
return cert.extensions.get_extension_for_class(klass)
|
||||
except _ExtensionNotFound:
|
||||
return None
|
||||
|
||||
|
||||
def _public_key_hash(cert):
|
||||
def _public_key_hash(cert: Certificate) -> bytes:
|
||||
public_key = cert.public_key()
|
||||
# https://tools.ietf.org/html/rfc2560#section-4.2.1
|
||||
# "KeyHash ::= OCTET STRING -- SHA-1 hash of responder's public key
|
||||
@ -134,7 +165,9 @@ def _public_key_hash(cert):
|
||||
return digest.finalize()
|
||||
|
||||
|
||||
def _get_certs_by_key_hash(certificates, issuer, responder_key_hash):
|
||||
def _get_certs_by_key_hash(
|
||||
certificates: Iterable[Certificate], issuer: Certificate, responder_key_hash: Optional[bytes]
|
||||
) -> List[Certificate]:
|
||||
return [
|
||||
cert
|
||||
for cert in certificates
|
||||
@ -142,7 +175,9 @@ def _get_certs_by_key_hash(certificates, issuer, responder_key_hash):
|
||||
]
|
||||
|
||||
|
||||
def _get_certs_by_name(certificates, issuer, responder_name):
|
||||
def _get_certs_by_name(
|
||||
certificates: Iterable[Certificate], issuer: Certificate, responder_name: Optional[Name]
|
||||
) -> List[Certificate]:
|
||||
return [
|
||||
cert
|
||||
for cert in certificates
|
||||
@ -150,7 +185,7 @@ def _get_certs_by_name(certificates, issuer, responder_name):
|
||||
]
|
||||
|
||||
|
||||
def _verify_response_signature(issuer, response):
|
||||
def _verify_response_signature(issuer: Certificate, response: OCSPResponse) -> int:
|
||||
# Response object will have a responder_name or responder_key_hash
|
||||
# not both.
|
||||
name = response.responder_name
|
||||
@ -185,7 +220,7 @@ def _verify_response_signature(issuer, response):
|
||||
_LOGGER.debug("Delegate not authorized for OCSP signing")
|
||||
return 0
|
||||
if not _verify_signature(
|
||||
issuer.public_key(),
|
||||
cast(CertificateIssuerPublicKeyTypes, issuer.public_key()),
|
||||
responder_cert.signature,
|
||||
responder_cert.signature_hash_algorithm,
|
||||
responder_cert.tbs_certificate_bytes,
|
||||
@ -194,7 +229,7 @@ def _verify_response_signature(issuer, response):
|
||||
return 0
|
||||
# RFC6960, Section 3.2, Number 2
|
||||
ret = _verify_signature(
|
||||
responder_cert.public_key(),
|
||||
cast(CertificateIssuerPublicKeyTypes, responder_cert.public_key()),
|
||||
response.signature,
|
||||
response.signature_hash_algorithm,
|
||||
response.tbs_response_bytes,
|
||||
@ -204,14 +239,14 @@ def _verify_response_signature(issuer, response):
|
||||
return ret
|
||||
|
||||
|
||||
def _build_ocsp_request(cert, issuer):
|
||||
def _build_ocsp_request(cert: Certificate, issuer: Certificate) -> OCSPRequest:
|
||||
# https://cryptography.io/en/latest/x509/ocsp/#creating-requests
|
||||
builder = _OCSPRequestBuilder()
|
||||
builder = builder.add_certificate(cert, issuer, _SHA1())
|
||||
return builder.build()
|
||||
|
||||
|
||||
def _verify_response(issuer, response):
|
||||
def _verify_response(issuer: Certificate, response: OCSPResponse) -> int:
|
||||
_LOGGER.debug("Verifying response")
|
||||
# RFC6960, Section 3.2, Number 2, 3 and 4 happen here.
|
||||
res = _verify_response_signature(issuer, response)
|
||||
@ -232,7 +267,9 @@ def _verify_response(issuer, response):
|
||||
return 1
|
||||
|
||||
|
||||
def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache):
|
||||
def _get_ocsp_response(
|
||||
cert: Certificate, issuer: Certificate, uri: Union[str, bytes], ocsp_response_cache: _OCSPCache
|
||||
) -> Optional[OCSPResponse]:
|
||||
ocsp_request = _build_ocsp_request(cert, issuer)
|
||||
try:
|
||||
ocsp_response = ocsp_response_cache[ocsp_request]
|
||||
@ -275,30 +312,32 @@ def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache):
|
||||
return ocsp_response
|
||||
|
||||
|
||||
def _ocsp_callback(conn, ocsp_bytes, user_data):
|
||||
def _ocsp_callback(conn: Connection, ocsp_bytes: bytes, user_data: Optional[_CallbackData]) -> bool:
|
||||
"""Callback for use with OpenSSL.SSL.Context.set_ocsp_client_callback."""
|
||||
cert = conn.get_peer_certificate()
|
||||
if cert is None:
|
||||
# always pass in user_data but OpenSSL requires it be optional
|
||||
assert user_data
|
||||
pycert = conn.get_peer_certificate()
|
||||
if pycert is None:
|
||||
_LOGGER.debug("No peer cert?")
|
||||
return 0
|
||||
cert = cert.to_cryptography()
|
||||
return False
|
||||
cert = pycert.to_cryptography()
|
||||
# Use the verified chain when available (pyopenssl>=20.0).
|
||||
if hasattr(conn, "get_verified_chain"):
|
||||
chain = conn.get_verified_chain()
|
||||
pychain = conn.get_verified_chain()
|
||||
trusted_ca_certs = None
|
||||
else:
|
||||
chain = conn.get_peer_cert_chain()
|
||||
pychain = conn.get_peer_cert_chain()
|
||||
trusted_ca_certs = user_data.trusted_ca_certs
|
||||
if not chain:
|
||||
if not pychain:
|
||||
_LOGGER.debug("No peer cert chain?")
|
||||
return 0
|
||||
chain = [cer.to_cryptography() for cer in chain]
|
||||
return False
|
||||
chain = [cer.to_cryptography() for cer in pychain]
|
||||
issuer = _get_issuer_cert(cert, chain, trusted_ca_certs)
|
||||
must_staple = False
|
||||
# https://tools.ietf.org/html/rfc7633#section-4.2.3.1
|
||||
ext = _get_extension(cert, _TLSFeature)
|
||||
if ext is not None:
|
||||
for feature in ext.value:
|
||||
ext_tls = _get_extension(cert, _TLSFeature)
|
||||
if ext_tls is not None:
|
||||
for feature in ext_tls.value:
|
||||
if feature == _TLSFeatureType.status_request:
|
||||
_LOGGER.debug("Peer presented a must-staple cert")
|
||||
must_staple = True
|
||||
@ -310,29 +349,29 @@ def _ocsp_callback(conn, ocsp_bytes, user_data):
|
||||
_LOGGER.debug("Peer did not staple an OCSP response")
|
||||
if must_staple:
|
||||
_LOGGER.debug("Must-staple cert with no stapled response, hard fail.")
|
||||
return 0
|
||||
return False
|
||||
if not user_data.check_ocsp_endpoint:
|
||||
_LOGGER.debug("OCSP endpoint checking is disabled, soft fail.")
|
||||
# No stapled OCSP response, checking responder URI disabled, soft fail.
|
||||
return 1
|
||||
return True
|
||||
# https://tools.ietf.org/html/rfc6960#section-3.1
|
||||
ext = _get_extension(cert, _AuthorityInformationAccess)
|
||||
if ext is None:
|
||||
ext_aia = _get_extension(cert, _AuthorityInformationAccess)
|
||||
if ext_aia is None:
|
||||
_LOGGER.debug("No authority access information, soft fail")
|
||||
# No stapled OCSP response, no responder URI, soft fail.
|
||||
return 1
|
||||
return True
|
||||
uris = [
|
||||
desc.access_location.value
|
||||
for desc in ext.value
|
||||
for desc in ext_aia.value
|
||||
if desc.access_method == _AuthorityInformationAccessOID.OCSP
|
||||
]
|
||||
if not uris:
|
||||
_LOGGER.debug("No OCSP URI, soft fail")
|
||||
# No responder URI, soft fail.
|
||||
return 1
|
||||
return True
|
||||
if issuer is None:
|
||||
_LOGGER.debug("No issuer cert?")
|
||||
return 0
|
||||
return False
|
||||
_LOGGER.debug("Requesting OCSP data")
|
||||
# When requesting data from an OCSP endpoint we only fail on
|
||||
# successful, valid responses with a certificate status of REVOKED.
|
||||
@ -346,28 +385,28 @@ def _ocsp_callback(conn, ocsp_bytes, user_data):
|
||||
continue
|
||||
_LOGGER.debug("OCSP cert status: %r", response.certificate_status)
|
||||
if response.certificate_status == _OCSPCertStatus.GOOD:
|
||||
return 1
|
||||
return True
|
||||
if response.certificate_status == _OCSPCertStatus.REVOKED:
|
||||
return 0
|
||||
return False
|
||||
# Soft fail if we couldn't get a definitive status.
|
||||
_LOGGER.debug("No definitive OCSP cert status, soft fail")
|
||||
return 1
|
||||
return True
|
||||
|
||||
_LOGGER.debug("Peer stapled an OCSP response")
|
||||
if issuer is None:
|
||||
_LOGGER.debug("No issuer cert?")
|
||||
return 0
|
||||
return False
|
||||
response = _load_der_ocsp_response(ocsp_bytes)
|
||||
_LOGGER.debug("OCSP response status: %r", response.response_status)
|
||||
# This happens in _request_ocsp when there is no stapled response so
|
||||
# we know if we can compare serial numbers for the request and response.
|
||||
if response.response_status != _OCSPResponseStatus.SUCCESSFUL:
|
||||
return 0
|
||||
return False
|
||||
if not _verify_response(issuer, response):
|
||||
return 0
|
||||
return False
|
||||
# Cache the verified, stapled response.
|
||||
ocsp_response_cache[_build_ocsp_request(cert, issuer)] = response
|
||||
_LOGGER.debug("OCSP cert status: %r", response.certificate_status)
|
||||
if response.certificate_status == _OCSPCertStatus.REVOKED:
|
||||
return 0
|
||||
return 1
|
||||
return False
|
||||
return True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user