diff --git a/pkg/verify/signature.go b/pkg/verify/signature.go index 054d887..0edadbe 100644 --- a/pkg/verify/signature.go +++ b/pkg/verify/signature.go @@ -23,6 +23,7 @@ import ( "fmt" "hash" "io" + "slices" in_toto "github.com/in-toto/attestation/go/v1" "github.com/secure-systems-lab/go-securesystemslib/dsse" @@ -146,56 +147,40 @@ func verifyEnvelopeWithArtifact(verifier signature.Verifier, envelope EnvelopeCo if err = limitSubjects(statement); err != nil { return err } - - var artifactDigestAlgorithm string - var artifactDigest []byte - - // Determine artifact digest algorithm by looking at the first subject's - // digests. This assumes that if a statement contains multiple subjects, - // they all use the same digest algorithm(s). + // Sanity check (no subjects) if len(statement.Subject) == 0 { return errors.New("no subjects found in statement") } - if len(statement.Subject[0].Digest) == 0 { - return errors.New("no digests found in statement") - } - // Select the strongest digest algorithm available. - for _, alg := range []string{"sha512", "sha384", "sha256"} { - if _, ok := statement.Subject[0].Digest[alg]; ok { - artifactDigestAlgorithm = alg - continue - } - } - if artifactDigestAlgorithm == "" { - return errors.New("could not verify artifact: unsupported digest algorithm") + // determine which hash functions to use + hashFuncs, err := determineHashFunctions(statement) + if err != nil { + return fmt.Errorf("could not verify artifact: unable to determine hash functions: %w", err) } // Compute digest of the artifact. - var hasher hash.Hash - switch artifactDigestAlgorithm { - case "sha512": - hasher = crypto.SHA512.New() - case "sha384": - hasher = crypto.SHA384.New() - case "sha256": - hasher = crypto.SHA256.New() - } + hasher := newMultihasher(hashFuncs) _, err = io.Copy(hasher, artifact) if err != nil { return fmt.Errorf("could not verify artifact: unable to calculate digest: %w", err) } - artifactDigest = hasher.Sum(nil) + artifactDigests := hasher.Sum(nil) // Look for artifact digest in statement for _, subject := range statement.Subject { for alg, digest := range subject.Digest { - hexdigest, err := hex.DecodeString(digest) + hf, err := algStringToHashFunc(alg) if err != nil { - return fmt.Errorf("could not verify artifact: unable to decode subject digest: %w", err) + continue } - if alg == artifactDigestAlgorithm && bytes.Equal(artifactDigest, hexdigest) { - return nil + if artifactDigest, ok := artifactDigests[hf]; ok { + hexdigest, err := hex.DecodeString(digest) + if err != nil { + continue + } + if bytes.Equal(artifactDigest, hexdigest) { + return nil + } } } } @@ -303,3 +288,88 @@ func (m *multihasher) Sum(b []byte) map[crypto.Hash][]byte { } return sums } + +func algStringToHashFunc(alg string) (crypto.Hash, error) { + switch alg { + case "sha256": + return crypto.SHA256, nil + case "sha384": + return crypto.SHA384, nil + case "sha512": + return crypto.SHA512, nil + default: + return 0, errors.New("unsupported digest algorithm") + } +} + +func determineHashFunctions(statement *in_toto.Statement) ([]crypto.Hash, error) { + if len(statement.Subject) == 0 { + return nil, errors.New("no subjects found in statement") + } + + algorithmCounts := supportedHashFuncsCount() + for _, subject := range statement.Subject { + for alg := range subject.Digest { + hf, err := algStringToHashFunc(alg) + if err != nil { + continue + } + algorithmCounts[hf]++ + } + } + anyCompatibleAlgorithms := false + var mostCommonHashFunc crypto.Hash + largestCount := 0 + seenHashFuncs := make([]crypto.Hash, 0) + for hf, count := range algorithmCounts { + if count > 0 { + anyCompatibleAlgorithms = true + if !slices.Contains(seenHashFuncs, hf) { + seenHashFuncs = append(seenHashFuncs, hf) + } + } + // if this algorithm is supported by all subjects, we can use it alone + if count == len(statement.Subject) { + return []crypto.Hash{hf}, nil + } + if count > largestCount { + largestCount = count + mostCommonHashFunc = hf + } + } + if !anyCompatibleAlgorithms { + return nil, errors.New("no supported digest algorithms found in statement") + } + + // If we didn't find a common algorithm, see if we cover more digests by using all seen algorithms + countWithAllAlgorithms := 0 + for _, subject := range statement.Subject { + for alg := range subject.Digest { + hf, err := algStringToHashFunc(alg) + if err != nil { + continue + } + if slices.Contains(seenHashFuncs, hf) { + countWithAllAlgorithms++ + break + } + } + } + // No need to calculate all digests if the most common one covers the same number of subjects + if countWithAllAlgorithms > largestCount { + return seenHashFuncs, nil + } + return []crypto.Hash{mostCommonHashFunc}, nil +} + +func supportedHashFuncsCount() map[crypto.Hash]int { + counts := make(map[crypto.Hash]int) + for _, hf := range supportedHashFuncs() { + counts[hf] = 0 + } + return counts +} + +func supportedHashFuncs() []crypto.Hash { + return []crypto.Hash{crypto.SHA512, crypto.SHA384, crypto.SHA256} +}