Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #11453 -- include localKeyID when serializaing a key with a cert #186

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/rust/cryptography-x509/src/pkcs12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub const SHROUDED_KEY_BAG_OID: asn1::ObjectIdentifier =
asn1::oid!(1, 2, 840, 113549, 1, 12, 10, 1, 2);
pub const X509_CERTIFICATE_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 22, 1);
pub const FRIENDLY_NAME_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 20);
pub const LOCAL_KEY_ID_OID: asn1::ObjectIdentifier = asn1::oid!(1, 2, 840, 113549, 1, 9, 21);

#[derive(asn1::Asn1Write)]
pub struct Pfx<'a> {
Expand Down Expand Up @@ -46,6 +47,9 @@ pub struct Attribute<'a> {
pub enum AttributeSet<'a> {
#[defined_by(FRIENDLY_NAME_OID)]
FriendlyName(asn1::SetOfWriter<'a, Utf8StoredBMPString<'a>, [Utf8StoredBMPString<'a>; 1]>),

#[defined_by(LOCAL_KEY_ID_OID)]
LocalKeyId(asn1::SetOfWriter<'a, &'a [u8], [&'a [u8]; 1]>),
}

#[derive(asn1::Asn1DefinedByWrite)]
Expand Down
60 changes: 41 additions & 19 deletions src/rust/src/pkcs12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,38 +338,51 @@ fn pkcs12_kdf(
Ok(result)
}

fn friendly_name_attributes(
friendly_name: Option<&[u8]>,
fn pkcs12_attributes<'a>(
friendly_name: Option<&'a [u8]>,
local_key_id: Option<&'a [u8]>,
) -> CryptographyResult<
Option<
asn1::SetOfWriter<
'_,
cryptography_x509::pkcs12::Attribute<'_>,
Vec<cryptography_x509::pkcs12::Attribute<'_>>,
'a,
cryptography_x509::pkcs12::Attribute<'a>,
Vec<cryptography_x509::pkcs12::Attribute<'a>>,
>,
>,
> {
let mut attrs = vec![];
if let Some(name) = friendly_name {
let name_str = std::str::from_utf8(name).map_err(|_| {
pyo3::exceptions::PyValueError::new_err("friendly_name must be valid UTF-8")
})?;

Ok(Some(asn1::SetOfWriter::new(vec![
cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName(
asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]),
),
},
])))
} else {
attrs.push(cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::FriendlyName(
asn1::SetOfWriter::new([Utf8StoredBMPString::new(name_str)]),
),
});
}
if let Some(key_id) = local_key_id {
attrs.push(cryptography_x509::pkcs12::Attribute {
_attr_id: asn1::DefinedByMarker::marker(),
attr_values: cryptography_x509::pkcs12::AttributeSet::LocalKeyId(
asn1::SetOfWriter::new([key_id]),
),
});
}

if attrs.is_empty() {
Ok(None)
} else {
Ok(Some(asn1::SetOfWriter::new(attrs)))
}
}

fn cert_to_bag<'a>(
cert: &'a Certificate,
friendly_name: Option<&'a [u8]>,
local_key_id: Option<&'a [u8]>,
) -> CryptographyResult<cryptography_x509::pkcs12::SafeBag<'a>> {
Ok(cryptography_x509::pkcs12::SafeBag {
_bag_id: asn1::DefinedByMarker::marker(),
Expand All @@ -381,7 +394,7 @@ fn cert_to_bag<'a>(
)),
},
)),
attributes: friendly_name_attributes(friendly_name)?,
attributes: pkcs12_attributes(friendly_name, local_key_id)?,
})
}

Expand Down Expand Up @@ -499,6 +512,7 @@ fn serialize_key_and_certificates<'p>(
key_ciphertext,
);
let mut ca_certs = vec![];
let mut key_id = None;
if cert.is_some() || cas.is_some() {
let mut cert_bags = vec![];

Expand All @@ -515,9 +529,14 @@ fn serialize_key_and_certificates<'p>(
),
));
}
key_id = Some(cert.fingerprint(py, &types::SHA1.get(py)?.call0()?)?);
}

cert_bags.push(cert_to_bag(cert, name)?);
cert_bags.push(cert_to_bag(
cert,
name,
key_id.as_ref().map(|v| v.as_bytes()),
)?);
}

if let Some(cas) = cas {
Expand All @@ -527,10 +546,13 @@ fn serialize_key_and_certificates<'p>(

for cert in &ca_certs {
let bag = match cert {
CertificateOrPKCS12Certificate::Certificate(c) => cert_to_bag(c.get(), None)?,
CertificateOrPKCS12Certificate::Certificate(c) => {
cert_to_bag(c.get(), None, None)?
}
CertificateOrPKCS12Certificate::PKCS12Certificate(c) => cert_to_bag(
c.get().certificate.get(),
c.get().friendly_name.as_ref().map(|v| v.as_bytes(py)),
None,
)?,
};
cert_bags.push(bag);
Expand Down Expand Up @@ -627,7 +649,7 @@ fn serialize_key_and_certificates<'p>(
},
),
),
attributes: friendly_name_attributes(name)?,
attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?,
}
} else {
let pkcs8_tlv = asn1::parse_single(&pkcs8_bytes)?;
Expand All @@ -637,7 +659,7 @@ fn serialize_key_and_certificates<'p>(
bag_value: asn1::Explicit::new(cryptography_x509::pkcs12::BagValue::KeyBag(
pkcs8_tlv,
)),
attributes: friendly_name_attributes(name)?,
attributes: pkcs12_attributes(name, key_id.as_ref().map(|v| v.as_bytes()))?,
}
};

Expand Down
6 changes: 3 additions & 3 deletions src/rust/src/x509/certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ impl Certificate {
)
}

fn fingerprint<'p>(
pub(crate) fn fingerprint<'p>(
&self,
py: pyo3::Python<'p>,
algorithm: &pyo3::Bound<'p, pyo3::PyAny>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::PyAny>> {
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
let serialized = asn1::write_single(&self.raw.borrow_dependent())?;

let mut h = hashes::Hash::new(py, algorithm, None)?;
h.update_bytes(&serialized)?;
Ok(h.finalize(py)?.into_any())
h.finalize(py)
}

fn public_bytes<'p>(
Expand Down
24 changes: 24 additions & 0 deletions tests/hazmat/primitives/test_pkcs12.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,30 @@ def test_set_mac_key_certificate_mismatch(self, backend):
b"name", key, cacert, [], encryption
)

@pytest.mark.parametrize(
"encryption_algorithm",
[
serialization.NoEncryption(),
serialization.BestAvailableEncryption(b"password"),
],
)
def test_generate_localkeyid(self, backend, encryption_algorithm):
cert, key = _load_ca(backend)

p12 = serialize_key_and_certificates(
None, key, cert, None, encryption_algorithm
)
# Dirty, but does the trick. Should be there:
# * 2x if unencrypted (once for the key and once for the cert)
# * 1x if encrypted (the cert one is encrypted, but the key one is
# plaintext)
count = (
2
if isinstance(encryption_algorithm, serialization.NoEncryption)
else 1
)
assert p12.count(cert.fingerprint(hashes.SHA1())) == count


@pytest.mark.skip_fips(
reason="PKCS12 unsupported in FIPS mode. So much bad crypto in it."
Expand Down
Loading