Noah + Copilot feedback
This commit is contained in:
parent
b483a48e42
commit
87bc26d12b
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -27,15 +28,17 @@ from test import unittest
|
||||
|
||||
pytestmark = pytest.mark.ocsp
|
||||
|
||||
pytest.importorskip("cryptography")
|
||||
pytest.importorskip("requests")
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.x448 import X448PublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey
|
||||
from cryptography.x509 import (
|
||||
AuthorityInformationAccess,
|
||||
ExtensionNotFound,
|
||||
Name,
|
||||
TLSFeature,
|
||||
TLSFeatureType,
|
||||
)
|
||||
@ -43,7 +46,6 @@ from cryptography.x509.ocsp import OCSPCertStatus, OCSPResponseStatus
|
||||
from cryptography.x509.oid import AuthorityInformationAccessOID, ExtendedKeyUsageOID
|
||||
|
||||
from pymongo.ocsp_support import (
|
||||
_build_ocsp_request,
|
||||
_get_certs_by_key_hash,
|
||||
_get_certs_by_name,
|
||||
_get_extension,
|
||||
@ -98,53 +100,45 @@ class TestGetIssuerCert(unittest.TestCase):
|
||||
class TestVerifySignature(unittest.TestCase):
|
||||
def test_rsa_valid(self):
|
||||
key = MagicMock(spec=RSAPublicKey)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
key.verify.assert_called_once()
|
||||
|
||||
def test_rsa_invalid(self):
|
||||
key = MagicMock(spec=RSAPublicKey)
|
||||
key.verify.side_effect = InvalidSignature()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0) # type: ignore[arg-type]
|
||||
|
||||
def test_dsa_valid(self):
|
||||
key = MagicMock(spec=DSAPublicKey)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
key.verify.assert_called_once()
|
||||
|
||||
def test_dsa_invalid(self):
|
||||
key = MagicMock(spec=DSAPublicKey)
|
||||
key.verify.side_effect = InvalidSignature()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
|
||||
|
||||
def test_ec_valid(self):
|
||||
key = MagicMock(spec=EllipticCurvePublicKey)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
key.verify.assert_called_once()
|
||||
|
||||
def test_ec_invalid(self):
|
||||
key = MagicMock(spec=EllipticCurvePublicKey)
|
||||
key.verify.side_effect = InvalidSignature()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
|
||||
|
||||
def test_x25519_skips_verify(self):
|
||||
key = MagicMock(spec=X25519PublicKey)
|
||||
# X25519 is for key exchange only; verify is not called, returns 1
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
class FakeX25519:
|
||||
verify = MagicMock()
|
||||
|
||||
with patch("pymongo.ocsp_support._X25519PublicKey", FakeX25519):
|
||||
key = FakeX25519()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
key.verify.assert_not_called()
|
||||
|
||||
def test_x448_skips_verify(self):
|
||||
key = MagicMock(spec=X448PublicKey)
|
||||
# X448 is for key exchange only; verify is not called, returns 1
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
class FakeX448:
|
||||
verify = MagicMock()
|
||||
|
||||
with patch("pymongo.ocsp_support._X448PublicKey", FakeX448):
|
||||
key = FakeX448()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
key.verify.assert_not_called()
|
||||
|
||||
def test_other_key_valid(self):
|
||||
key = Mock()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1)
|
||||
key.verify.assert_called_once_with(b"sig", b"data")
|
||||
|
||||
def test_other_key_invalid(self):
|
||||
key = Mock()
|
||||
key.verify.side_effect = InvalidSignature()
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 0)
|
||||
self.assertEqual(_verify_signature(key, b"sig", Mock(), b"data"), 1) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestGetExtension(unittest.TestCase):
|
||||
@ -167,8 +161,7 @@ class TestPublicKeyHash(unittest.TestCase):
|
||||
cert = Mock()
|
||||
cert.public_key.return_value = key
|
||||
result = _public_key_hash(cert)
|
||||
self.assertIsInstance(result, bytes)
|
||||
self.assertEqual(len(result), 20) # SHA-1 digest
|
||||
self.assertEqual(len(result), 20)
|
||||
|
||||
def test_ec(self):
|
||||
key = MagicMock(spec=EllipticCurvePublicKey)
|
||||
@ -176,23 +169,20 @@ class TestPublicKeyHash(unittest.TestCase):
|
||||
cert = Mock()
|
||||
cert.public_key.return_value = key
|
||||
result = _public_key_hash(cert)
|
||||
self.assertIsInstance(result, bytes)
|
||||
self.assertEqual(len(result), 20)
|
||||
|
||||
def test_other_key_type(self):
|
||||
# Covers the else branch (Ed25519, Ed448, etc.)
|
||||
key = Mock()
|
||||
key.public_bytes.return_value = b"other_key_bytes"
|
||||
cert = Mock()
|
||||
cert.public_key.return_value = key
|
||||
result = _public_key_hash(cert)
|
||||
self.assertIsInstance(result, bytes)
|
||||
self.assertEqual(len(result), 20)
|
||||
|
||||
|
||||
class TestGetCertsByKeyHash(unittest.TestCase):
|
||||
class TestGetCerts(unittest.TestCase):
|
||||
@patch("pymongo.ocsp_support._public_key_hash")
|
||||
def test_match(self, mock_hash):
|
||||
def test_by_key_hash_match(self, mock_hash):
|
||||
issuer = Mock()
|
||||
issuer.subject = "issuer_subject"
|
||||
cert1 = Mock()
|
||||
@ -205,7 +195,7 @@ class TestGetCertsByKeyHash(unittest.TestCase):
|
||||
self.assertEqual(result, [cert1])
|
||||
|
||||
@patch("pymongo.ocsp_support._public_key_hash")
|
||||
def test_no_match(self, mock_hash):
|
||||
def test_by_key_hash_no_match(self, mock_hash):
|
||||
issuer = Mock()
|
||||
issuer.subject = "issuer_subject"
|
||||
cert = Mock()
|
||||
@ -215,9 +205,7 @@ class TestGetCertsByKeyHash(unittest.TestCase):
|
||||
result = _get_certs_by_key_hash([cert], issuer, b"expected_hash")
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
class TestGetCertsByName(unittest.TestCase):
|
||||
def test_match(self):
|
||||
def test_by_name_match(self):
|
||||
issuer = Mock()
|
||||
issuer.subject = "issuer"
|
||||
cert1 = Mock()
|
||||
@ -227,36 +215,20 @@ class TestGetCertsByName(unittest.TestCase):
|
||||
cert2.subject = "other"
|
||||
cert2.issuer = "issuer"
|
||||
|
||||
result = _get_certs_by_name([cert1, cert2], issuer, "responder")
|
||||
result = _get_certs_by_name([cert1, cert2], issuer, cast(Name, "responder"))
|
||||
self.assertEqual(result, [cert1])
|
||||
|
||||
def test_no_match(self):
|
||||
def test_by_name_no_match(self):
|
||||
issuer = Mock()
|
||||
issuer.subject = "issuer"
|
||||
cert = Mock()
|
||||
cert.subject = "other"
|
||||
cert.issuer = "issuer"
|
||||
|
||||
result = _get_certs_by_name([cert], issuer, "responder")
|
||||
result = _get_certs_by_name([cert], issuer, cast(Name, "responder"))
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
class TestBuildOcspRequest(unittest.TestCase):
|
||||
@patch("pymongo.ocsp_support._OCSPRequestBuilder")
|
||||
def test_builds_request(self, mock_builder_class):
|
||||
mock_builder = Mock()
|
||||
mock_builder.add_certificate.return_value = mock_builder
|
||||
mock_request = Mock()
|
||||
mock_builder.build.return_value = mock_request
|
||||
mock_builder_class.return_value = mock_builder
|
||||
|
||||
result = _build_ocsp_request(Mock(), Mock())
|
||||
|
||||
self.assertEqual(result, mock_request)
|
||||
mock_builder.add_certificate.assert_called_once()
|
||||
mock_builder.build.assert_called_once()
|
||||
|
||||
|
||||
class TestVerifyResponseSignature(unittest.TestCase):
|
||||
@patch("pymongo.ocsp_support._verify_signature")
|
||||
def test_responder_is_issuer_by_name(self, mock_verify_sig):
|
||||
@ -515,7 +487,7 @@ class TestGetOcspResponse(unittest.TestCase):
|
||||
mock_post.return_value = http_resp
|
||||
ocsp_resp = Mock()
|
||||
ocsp_resp.response_status = OCSPResponseStatus.SUCCESSFUL
|
||||
ocsp_resp.serial_number = 99999 # Mismatch
|
||||
ocsp_resp.serial_number = 99999
|
||||
mock_load.return_value = ocsp_resp
|
||||
|
||||
result = _get_ocsp_response(Mock(), Mock(), "http://ocsp.example.com", cache)
|
||||
@ -788,11 +760,9 @@ class TestOcspCallback(unittest.TestCase):
|
||||
@patch("pymongo.ocsp_support._get_issuer_cert", return_value=None)
|
||||
@patch("pymongo.ocsp_support._get_extension", return_value=None)
|
||||
def test_uses_peer_cert_chain_fallback(self, _, __):
|
||||
# conn without get_verified_chain triggers the fallback path
|
||||
conn = self._setup_conn(has_verified_chain=False)
|
||||
user_data = self._setup_user_data()
|
||||
user_data.trusted_ca_certs = []
|
||||
# No AIA (_get_extension returns None) → soft fail → True
|
||||
self.assertTrue(_ocsp_callback(conn, b"", user_data))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user