diff --git a/test/test_ocsp_support.py b/test/test_ocsp_support.py index 2eaeb8a43..d6337db4d 100644 --- a/test/test_ocsp_support.py +++ b/test/test_ocsp_support.py @@ -17,6 +17,7 @@ from __future__ import annotations import sys from datetime import datetime, timedelta, timezone +from typing import cast from unittest.mock import MagicMock, Mock, patch import pytest @@ -27,15 +28,17 @@ from test import unittest pytestmark = pytest.mark.ocsp +pytest.importorskip("cryptography") +pytest.importorskip("requests") + from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey -from cryptography.hazmat.primitives.asymmetric.x448 import X448PublicKey -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey from cryptography.x509 import ( AuthorityInformationAccess, ExtensionNotFound, + Name, TLSFeature, TLSFeatureType, ) @@ -43,7 +46,6 @@ from cryptography.x509.ocsp import OCSPCertStatus, OCSPResponseStatus from cryptography.x509.oid import AuthorityInformationAccessOID, ExtendedKeyUsageOID from pymongo.ocsp_support import ( - _build_ocsp_request, _get_certs_by_key_hash, _get_certs_by_name, _get_extension, @@ -98,53 +100,45 @@ class TestGetIssuerCert(unittest.TestCase): class TestVerifySignature(unittest.TestCase): def test_rsa_valid(self): key = MagicMock(spec=RSAPublicKey) - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] key.verify.assert_called_once() def test_rsa_invalid(self): key = MagicMock(spec=RSAPublicKey) key.verify.side_effect = InvalidSignature() - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) # type: ignore[arg-type] def test_dsa_valid(self): key = MagicMock(spec=DSAPublicKey) - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] key.verify.assert_called_once() - def test_dsa_invalid(self): - key = MagicMock(spec=DSAPublicKey) - key.verify.side_effect = InvalidSignature() - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) - def test_ec_valid(self): key = MagicMock(spec=EllipticCurvePublicKey) - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] key.verify.assert_called_once() - def test_ec_invalid(self): - key = MagicMock(spec=EllipticCurvePublicKey) - key.verify.side_effect = InvalidSignature() - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) - def test_x25519_skips_verify(self): - key = MagicMock(spec=X25519PublicKey) - # X25519 is for key exchange only; verify is not called, returns 1 - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) + class FakeX25519: + verify = MagicMock() + + with patch("pymongo.ocsp_support._X25519PublicKey", FakeX25519): + key = FakeX25519() + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] + key.verify.assert_not_called() def test_x448_skips_verify(self): - key = MagicMock(spec=X448PublicKey) - # X448 is for key exchange only; verify is not called, returns 1 - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) + class FakeX448: + verify = MagicMock() + + with patch("pymongo.ocsp_support._X448PublicKey", FakeX448): + key = FakeX448() + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] + key.verify.assert_not_called() def test_other_key_valid(self): key = Mock() - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) - key.verify.assert_called_once_with(b"sig", b"data") - - def test_other_key_invalid(self): - key = Mock() - key.verify.side_effect = InvalidSignature() - self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) + self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type] class TestGetExtension(unittest.TestCase): @@ -167,8 +161,7 @@ class TestPublicKeyHash(unittest.TestCase): cert = Mock() cert.public_key.return_value = key result = _public_key_hash(cert) - self.assertIsInstance(result, bytes) - self.assertEqual(len(result), 20) # SHA-1 digest + self.assertEqual(len(result), 20) def test_ec(self): key = MagicMock(spec=EllipticCurvePublicKey) @@ -176,23 +169,20 @@ class TestPublicKeyHash(unittest.TestCase): cert = Mock() cert.public_key.return_value = key result = _public_key_hash(cert) - self.assertIsInstance(result, bytes) self.assertEqual(len(result), 20) def test_other_key_type(self): - # Covers the else branch (Ed25519, Ed448, etc.) key = Mock() key.public_bytes.return_value = b"other_key_bytes" cert = Mock() cert.public_key.return_value = key result = _public_key_hash(cert) - self.assertIsInstance(result, bytes) self.assertEqual(len(result), 20) -class TestGetCertsByKeyHash(unittest.TestCase): +class TestGetCerts(unittest.TestCase): @patch("pymongo.ocsp_support._public_key_hash") - def test_match(self, mock_hash): + def test_by_key_hash_match(self, mock_hash): issuer = Mock() issuer.subject = "issuer_subject" cert1 = Mock() @@ -205,7 +195,7 @@ class TestGetCertsByKeyHash(unittest.TestCase): self.assertEqual(result, [cert1]) @patch("pymongo.ocsp_support._public_key_hash") - def test_no_match(self, mock_hash): + def test_by_key_hash_no_match(self, mock_hash): issuer = Mock() issuer.subject = "issuer_subject" cert = Mock() @@ -215,9 +205,7 @@ class TestGetCertsByKeyHash(unittest.TestCase): result = _get_certs_by_key_hash([cert], issuer, b"expected_hash") self.assertEqual(result, []) - -class TestGetCertsByName(unittest.TestCase): - def test_match(self): + def test_by_name_match(self): issuer = Mock() issuer.subject = "issuer" cert1 = Mock() @@ -227,36 +215,20 @@ class TestGetCertsByName(unittest.TestCase): cert2.subject = "other" cert2.issuer = "issuer" - result = _get_certs_by_name([cert1, cert2], issuer, "responder") + result = _get_certs_by_name([cert1, cert2], issuer, cast(Name, "responder")) self.assertEqual(result, [cert1]) - def test_no_match(self): + def test_by_name_no_match(self): issuer = Mock() issuer.subject = "issuer" cert = Mock() cert.subject = "other" cert.issuer = "issuer" - result = _get_certs_by_name([cert], issuer, "responder") + result = _get_certs_by_name([cert], issuer, cast(Name, "responder")) self.assertEqual(result, []) -class TestBuildOcspRequest(unittest.TestCase): - @patch("pymongo.ocsp_support._OCSPRequestBuilder") - def test_builds_request(self, mock_builder_class): - mock_builder = Mock() - mock_builder.add_certificate.return_value = mock_builder - mock_request = Mock() - mock_builder.build.return_value = mock_request - mock_builder_class.return_value = mock_builder - - result = _build_ocsp_request(Mock(), Mock()) - - self.assertEqual(result, mock_request) - mock_builder.add_certificate.assert_called_once() - mock_builder.build.assert_called_once() - - class TestVerifyResponseSignature(unittest.TestCase): @patch("pymongo.ocsp_support._verify_signature") def test_responder_is_issuer_by_name(self, mock_verify_sig): @@ -515,7 +487,7 @@ class TestGetOcspResponse(unittest.TestCase): mock_post.return_value = http_resp ocsp_resp = Mock() ocsp_resp.response_status = OCSPResponseStatus.SUCCESSFUL - ocsp_resp.serial_number = 99999 # Mismatch + ocsp_resp.serial_number = 99999 mock_load.return_value = ocsp_resp result = _get_ocsp_response(Mock(), Mock(), "http://ocsp.example.com", cache) @@ -788,11 +760,9 @@ class TestOcspCallback(unittest.TestCase): @patch("pymongo.ocsp_support._get_issuer_cert", return_value=None) @patch("pymongo.ocsp_support._get_extension", return_value=None) def test_uses_peer_cert_chain_fallback(self, _, __): - # conn without get_verified_chain triggers the fallback path conn = self._setup_conn(has_verified_chain=False) user_data = self._setup_user_data() user_data.trusted_ca_certs = [] - # No AIA (_get_extension returns None) → soft fail → True self.assertTrue(_ocsp_callback(conn, b"", user_data))