From 8d69cc0fe85e03b2eb1ac754377e96b4973fc1fb Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 19 Aug 2024 18:38:38 -0400 Subject: [PATCH] fixes #11453 -- include localKeyID when serializaing a key with a cert --- src/rust/cryptography-x509/src/pkcs12.rs | 4 ++ src/rust/src/pkcs12.rs | 60 ++++++++++++++++-------- src/rust/src/x509/certificate.rs | 6 +-- tests/hazmat/primitives/test_pkcs12.py | 24 ++++++++++ 4 files changed, 72 insertions(+), 22 deletions(-) diff --git a/src/rust/cryptography-x509/src/pkcs12.rs b/src/rust/cryptography-x509/src/pkcs12.rs index fdcbc91ef802..f8f518a4b615 100644 --- a/src/rust/cryptography-x509/src/pkcs12.rs +++ b/src/rust/cryptography-x509/src/pkcs12.rs @@ -11,6 +11,7 @@ pub const SHROUDED_KEY_BAG_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 12, 10, 1, 2); pub const X509_CERTIFICATE_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 22, 1); pub const FRIENDLY_NAME_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 20); +pub const LOCAL_KEY_ID_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 21); #[derive(asn1::Asn1Write)] pub struct Pfx<'a> { @@ -46,6 +47,9 @@ pub struct Attribute<'a> { pub enum AttributeSet<'a> { #[defined_by(FRIENDLY_NAME_OID)] FriendlyName(asn1::SetOfWriter<'a, Utf8StoredBMPString<'a>, [Utf8StoredBMPString<'a>; 1]>), + + #[defined_by(LOCAL_KEY_ID_OID)] + LocalKeyId(asn1::SetOfWriter<'a, &'a [u8], [&'a [u8]; 1]>), } #[derive(asn1::Asn1DefinedByWrite)] diff --git a/src/rust/src/pkcs12.rs b/src/rust/src/pkcs12.rs index 45f8855bacf3..c8d334ecfa29 100644 --- a/src/rust/src/pkcs12.rs +++ b/src/rust/src/pkcs12.rs @@ -338,38 +338,51 @@ fn pkcs12_kdf( Ok(result) } -fn friendly_name_attributes( - friendly_name: Option<&[u8]>, +fn pkcs12_attributes<'a>( + friendly_name: Option<&'a [u8]>, + local_key_id: Option<&'a [u8]>, ) -> CryptographyResult< Option< asn1::SetOfWriter< - '_, - cryptography_x509::pkcs12::Attribute<'_>, - Vec>, + 'a, + cryptography_x509::pkcs12::Attribute<'a>, + Vec>, >, >, > { + let mut attrs = vec![]; if let Some(name) = friendly_name { let name_str = std::str::from_utf8(name).map_err(|_| { pyo3::exceptions::PyValueError::new_err("friendly_name must be valid UTF-8") })?; - Ok(Some(asn1::SetOfWriter::new(vec![ - cryptography_x509::pkcs12::Attribute { - _attr_id: asn1::DefinedByMarker::marker(), - attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName( - asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]), - ), - }, - ]))) - } else { + attrs.push(cryptography_x509::pkcs12::Attribute { + _attr_id: asn1::DefinedByMarker::marker(), + attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName( + asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]), + ), + }); + } + if let Some(key_id) = local_key_id { + attrs.push(cryptography_x509::pkcs12::Attribute { + _attr_id: asn1::DefinedByMarker::marker(), + attr_values: cryptography_x509::pkcs12::AttributeSet::LocalKeyId( + asn1::SetOfWriter::new([key_id]), + ), + }); + } + + if attrs.is_empty() { Ok(None) + } else { + Ok(Some(asn1::SetOfWriter::new(attrs))) } } fn cert_to_bag<'a>( cert: &'a Certificate, friendly_name: Option<&'a [u8]>, + local_key_id: Option<&'a [u8]>, ) -> CryptographyResult> { Ok(cryptography_x509::pkcs12::SafeBag { _bag_id: asn1::DefinedByMarker::marker(), @@ -381,7 +394,7 @@ fn cert_to_bag<'a>( )), }, )), - attributes: friendly_name_attributes(friendly_name)?, + attributes: pkcs12_attributes(friendly_name, local_key_id)?, }) } @@ -499,6 +512,7 @@ fn serialize_key_and_certificates<'p>( key_ciphertext, ); let mut ca_certs = vec![]; + let mut key_id = None; if cert.is_some() || cas.is_some() { let mut cert_bags = vec![]; @@ -515,9 +529,14 @@ fn serialize_key_and_certificates<'p>( ), )); } + key_id = Some(cert.fingerprint(py, &types::SHA1.get(py)?.call0()?)?); } - cert_bags.push(cert_to_bag(cert, name)?); + cert_bags.push(cert_to_bag( + cert, + name, + key_id.as_ref().map(|v| v.as_bytes()), + )?); } if let Some(cas) = cas { @@ -527,10 +546,13 @@ fn serialize_key_and_certificates<'p>( for cert in &ca_certs { let bag = match cert { - CertificateOrPKCS12Certificate::Certificate(c) => cert_to_bag(c.get(), None)?, + CertificateOrPKCS12Certificate::Certificate(c) => { + cert_to_bag(c.get(), None, None)? + } CertificateOrPKCS12Certificate::PKCS12Certificate(c) => cert_to_bag( c.get().certificate.get(), c.get().friendly_name.as_ref().map(|v| v.as_bytes(py)), + None, )?, }; cert_bags.push(bag); @@ -627,7 +649,7 @@ fn serialize_key_and_certificates<'p>( }, ), ), - attributes: friendly_name_attributes(name)?, + attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?, } } else { let pkcs8_tlv = asn1::parse_single(&pkcs8_bytes)?; @@ -637,7 +659,7 @@ fn serialize_key_and_certificates<'p>( bag_value: asn1::Explicit::new(cryptography_x509::pkcs12::BagValue::KeyBag( pkcs8_tlv, )), - attributes: friendly_name_attributes(name)?, + attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?, } }; diff --git a/src/rust/src/x509/certificate.rs b/src/rust/src/x509/certificate.rs index 075c258074ef..454f63ad5119 100644 --- a/src/rust/src/x509/certificate.rs +++ b/src/rust/src/x509/certificate.rs @@ -84,16 +84,16 @@ impl Certificate { ) } - fn fingerprint<'p>( + pub(crate) fn fingerprint<'p>( &self, py: pyo3::Python<'p>, algorithm: &pyo3::Bound<'p, pyo3::PyAny>, - ) -> CryptographyResult> { + ) -> CryptographyResult> { let serialized = asn1::write_single(&self.raw.borrow_dependent())?; let mut h = hashes::Hash::new(py, algorithm, None)?; h.update_bytes(&serialized)?; - Ok(h.finalize(py)?.into_any()) + h.finalize(py) } fn public_bytes<'p>( diff --git a/tests/hazmat/primitives/test_pkcs12.py b/tests/hazmat/primitives/test_pkcs12.py index d0645d9e9941..99bb122c1f1e 100644 --- a/tests/hazmat/primitives/test_pkcs12.py +++ b/tests/hazmat/primitives/test_pkcs12.py @@ -697,6 +697,30 @@ def test_set_mac_key_certificate_mismatch(self, backend): b"name", key, cacert, [], encryption ) + @pytest.mark.parametrize( + "encryption_algorithm", + [ + serialization.NoEncryption(), + serialization.BestAvailableEncryption(b"password"), + ], + ) + def test_generate_localkeyid(self, backend, encryption_algorithm): + cert, key = _load_ca(backend) + + p12 = serialize_key_and_certificates( + None, key, cert, None, encryption_algorithm + ) + # Dirty, but does the trick. Should be there: + # * 2x if unencrypted (once for the key and once for the cert) + # * 1x if encrypted (the cert one is encrypted, but the key one is + # plaintext) + count = ( + 2 + if isinstance(encryption_algorithm, serialization.NoEncryption) + else 1 + ) + assert p12.count(cert.fingerprint(hashes.SHA1())) == count + @pytest.mark.skip_fips( reason="PKCS12 unsupported in FIPS mode. So much bad crypto in it."