Skip to content

Commit

Permalink
Merge pull request canonical#9513 from mvo5/snapshot-content-hash
Browse files Browse the repository at this point in the history
snapshotstate: detect duplicated snapshot imports
  • Loading branch information
stolowski authored Jan 14, 2021
2 parents 127cf1e + 476f180 commit f51367f
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 28 deletions.
41 changes: 41 additions & 0 deletions client/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ package client
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"sort"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -91,6 +93,21 @@ func (sh *Snapshot) IsValid() bool {
return !(sh == nil || sh.SetID == 0 || sh.Snap == "" || sh.Revision.Unset() || len(sh.SHA3_384) == 0 || sh.Time.IsZero())
}

// ContentHash returns a hash that can be used to identify the snapshot
// by its content, leaving out metadata like "time" or "set-id".
func (sh *Snapshot) ContentHash() ([]byte, error) {
sh2 := *sh
sh2.SetID = 0
sh2.Time = time.Time{}
sh2.Auto = false
h := sha256.New()
enc := json.NewEncoder(h)
if err := enc.Encode(&sh2); err != nil {
return nil, err
}
return h.Sum(nil), nil
}

// A SnapshotSet is a set of snapshots created by a single "snap save".
type SnapshotSet struct {
ID uint64 `json:"id"`
Expand Down Expand Up @@ -120,6 +137,30 @@ func (ss SnapshotSet) Size() int64 {
return sum
}

type bySnap []*Snapshot

func (ss bySnap) Len() int { return len(ss) }
func (ss bySnap) Swap(i, j int) { ss[i], ss[j] = ss[j], ss[i] }
func (ss bySnap) Less(i, j int) bool { return ss[i].Snap < ss[j].Snap }

// ContentHash returns a hash that can be used to identify the SnapshotSet by
// its content.
func (ss SnapshotSet) ContentHash() ([]byte, error) {
sortedSnapshots := make([]*Snapshot, len(ss.Snapshots))
copy(sortedSnapshots, ss.Snapshots)
sort.Sort(bySnap(sortedSnapshots))

h := sha256.New()
for _, sh := range sortedSnapshots {
ch, err := sh.ContentHash()
if err != nil {
return nil, err
}
h.Write(ch)
}
return h.Sum(nil), nil
}

// SnapshotSets lists the snapshot sets in the system that belong to the
// given set (if non-zero) and are for the given snaps (if non-empty).
func (client *Client) SnapshotSets(setID uint64, snapNames []string) ([]SnapshotSet, error) {
Expand Down
80 changes: 80 additions & 0 deletions client/snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package client_test

import (
"crypto/sha256"
"io/ioutil"
"net/http"
"net/url"
Expand Down Expand Up @@ -217,3 +218,82 @@ func (cs *clientSuite) TestClientSnapshotImport(c *check.C) {
c.Check(string(d), check.Equals, fakeSnapshotData)
}
}

func (cs *clientSuite) TestClientSnapshotContentHash(c *check.C) {
now := time.Now()
revno := snap.R(1)
sums := map[string]string{"user/foo.tgz": "some long hash"}

sh1 := &client.Snapshot{SetID: 1, Time: now, Snap: "asnap", Revision: revno, SHA3_384: sums}
// sh1, sh1_1 are the same except time
sh1_1 := &client.Snapshot{SetID: 1, Time: now.Add(10), Snap: "asnap", Revision: revno, SHA3_384: sums}
// sh1, sh2 are the same except setID
sh2 := &client.Snapshot{SetID: 2, Time: now, Snap: "asnap", Revision: revno, SHA3_384: sums}

h1, err := sh1.ContentHash()
c.Assert(err, check.IsNil)
// content hash uses sha256 internally
c.Check(h1, check.HasLen, sha256.Size)

// same except time means same hash
h1_1, err := sh1_1.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h1, check.DeepEquals, h1_1)

// same except set means same hash
h2, err := sh2.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h1, check.DeepEquals, h2)

// sh3 is actually different
sh3 := &client.Snapshot{SetID: 1, Time: now, Snap: "other-snap", Revision: revno, SHA3_384: sums}
h3, err := sh3.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h1, check.Not(check.DeepEquals), h3)

// identical to sh1 except for sha3_384 sums
sums4 := map[string]string{"user/foo.tgz": "some other hash"}
sh4 := &client.Snapshot{SetID: 1, Time: now, Snap: "asnap", Revision: revno, SHA3_384: sums4}
// same except sha3_384 means different hash
h4, err := sh4.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h4, check.Not(check.DeepEquals), h1)
}

func (cs *clientSuite) TestClientSnapshotSetContentHash(c *check.C) {
sums := map[string]string{"user/foo.tgz": "some long hash"}
ss1 := client.SnapshotSet{Snapshots: []*client.Snapshot{
{SetID: 1, Snap: "snap2", Size: 2, SHA3_384: sums},
{SetID: 1, Snap: "snap1", Size: 1, SHA3_384: sums},
{SetID: 1, Snap: "snap3", Size: 3, SHA3_384: sums},
}}
// ss2 is the same ss1 but in a different order with different setID
// (but that does not matter for the content hash)
ss2 := client.SnapshotSet{Snapshots: []*client.Snapshot{
{SetID: 2, Snap: "snap3", Size: 3, SHA3_384: sums},
{SetID: 2, Snap: "snap2", Size: 2, SHA3_384: sums},
{SetID: 2, Snap: "snap1", Size: 1, SHA3_384: sums},
}}

h1, err := ss1.ContentHash()
c.Assert(err, check.IsNil)
// content hash uses sha256 internally
c.Check(h1, check.HasLen, sha256.Size)

// h1 and h2 have the same hash
h2, err := ss2.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h2, check.DeepEquals, h1)

// ss3 is different because the size of snap3 is different
ss3 := client.SnapshotSet{Snapshots: []*client.Snapshot{
{SetID: 1, Snap: "snap2", Size: 2},
{SetID: 1, Snap: "snap3", Size: 666666666},
{SetID: 1, Snap: "snap1", Size: 1},
}}
// h1 and h3 are different
h3, err := ss3.ContentHash()
c.Assert(err, check.IsNil)
c.Check(h3, check.Not(check.DeepEquals), h1)

}
109 changes: 102 additions & 7 deletions overlord/snapshotstate/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,15 @@ func Import(ctx context.Context, id uint64, r io.Reader) (snapNames []string, er
defer tr.Cancel()

// Unpack and validate the streamed data
snapNames, err = unpackVerifySnapshotImport(r, id)
//
// XXX: this will leak snapshot IDs, i.e. we allocate a new
// snapshot ID before but then we error here because of e.g.
// duplicated import attempts
snapNames, err = unpackVerifySnapshotImport(ctx, r, id)
if err != nil {
if _, ok := err.(DuplicatedSnapshotImportError); ok {
return nil, err
}
return nil, fmt.Errorf("%s: %v", errPrefix, err)
}
if err := tr.Commit(); err != nil {
Expand All @@ -684,7 +691,43 @@ func writeOneSnapshotFile(targetPath string, tr io.Reader) error {
return nil
}

func unpackVerifySnapshotImport(r io.Reader, realSetID uint64) (snapNames []string, err error) {
type DuplicatedSnapshotImportError struct {
SetID uint64
}

func (e DuplicatedSnapshotImportError) Error() string {
return fmt.Sprintf("cannot import snapshot, already available as snapshot id %v", e.SetID)
}

func checkDuplicatedSnapshotSetWithContentHash(ctx context.Context, contentHash []byte) error {
snapshotSetMap := map[uint64]client.SnapshotSet{}

// XXX: deal with import in progress here

// get all current snapshotSets
err := Iter(ctx, func(reader *Reader) error {
ss := snapshotSetMap[reader.SetID]
ss.Snapshots = append(ss.Snapshots, &reader.Snapshot)
snapshotSetMap[reader.SetID] = ss
return nil
})
if err != nil {
return fmt.Errorf("cannot calculate snapshot set hashes: %v", err)
}

for setID, ss := range snapshotSetMap {
h, err := ss.ContentHash()
if err != nil {
return fmt.Errorf("cannot calculate content hash for %v: %v", setID, err)
}
if bytes.Equal(h, contentHash) {
return DuplicatedSnapshotImportError{setID}
}
}
return nil
}

func unpackVerifySnapshotImport(ctx context.Context, r io.Reader, realSetID uint64) (snapNames []string, err error) {
var exportFound bool

tr := tar.NewReader(r)
Expand All @@ -706,6 +749,21 @@ func unpackVerifySnapshotImport(r io.Reader, realSetID uint64) (snapNames []stri
return nil, errors.New("unexpected directory in import file")
}

if header.Name == "content.json" {
var ej contentJSON
dec := json.NewDecoder(tr)
if err := dec.Decode(&ej); err != nil {
return nil, err
}
// XXX: this is potentially slow as it needs
// to open all snapshots files and read a
// small amount of data from them
if err := checkDuplicatedSnapshotSetWithContentHash(ctx, ej.ContentHash); err != nil {
return nil, err
}
continue
}

if header.Name == "export.json" {
// XXX: read into memory and validate once we
// hashes in export.json
Expand Down Expand Up @@ -759,6 +817,9 @@ type SnapshotExport struct {
// open snapshot files
snapshotFiles []*os.File

// contentHash of the full snapshot
contentHash []byte

// remember setID mostly for nicer errors
setID uint64

Expand All @@ -770,6 +831,7 @@ type SnapshotExport struct {
// Close()ed after use to avoid leaking file descriptors.
func NewSnapshotExport(ctx context.Context, setID uint64) (se *SnapshotExport, err error) {
var snapshotFiles []*os.File
var snapshotSet client.SnapshotSet

defer func() {
// cleanup any open FDs if anything goes wrong
Expand All @@ -786,9 +848,13 @@ func NewSnapshotExport(ctx context.Context, setID uint64) (se *SnapshotExport, e
// files are getting opened.
err = Iter(ctx, func(reader *Reader) error {
if reader.SetID == setID {
// Duplicate the file descriptor of the reader we were handed as
// Iter() closes those as soon as this unnamed returns. We
// re-package the file descriptor into snapshotFiles below.
snapshotSet.Snapshots = append(snapshotSet.Snapshots, &reader.Snapshot)

// Duplicate the file descriptor of the reader
// we were handed as Iter() closes those as
// soon as this unnamed returns. We re-package
// the file descriptor into snapshotFiles
// below.
fd, err := syscall.Dup(int(reader.Fd()))
if err != nil {
return fmt.Errorf("cannot duplicate descriptor: %v", err)
Expand All @@ -808,7 +874,11 @@ func NewSnapshotExport(ctx context.Context, setID uint64) (se *SnapshotExport, e
return nil, fmt.Errorf("no snapshot data found for %v", setID)
}

se = &SnapshotExport{snapshotFiles: snapshotFiles, setID: setID}
h, err := snapshotSet.ContentHash()
if err != nil {
return nil, fmt.Errorf("cannot calculate content hash for snapshot export %v: %v", setID, err)
}
se = &SnapshotExport{snapshotFiles: snapshotFiles, setID: setID, contentHash: h}

// ensure we never leak FDs even if the user does not call close
runtime.SetFinalizer(se, (*SnapshotExport).Close)
Expand Down Expand Up @@ -847,11 +917,36 @@ func (se *SnapshotExport) Close() {
se.snapshotFiles = nil
}

type contentJSON struct {
ContentHash []byte `json:"content-hash"`
}

func (se *SnapshotExport) StreamTo(w io.Writer) error {
// write out a tar
var files []string
tw := tar.NewWriter(w)
defer tw.Close()

// export contentHash as content.json
h, err := json.Marshal(contentJSON{se.contentHash})
if err != nil {
return err
}
hdr := &tar.Header{
Typeflag: tar.TypeReg,
Name: "content.json",
Size: int64(len(h)),
Mode: 0640,
ModTime: timeNow(),
}
if err := tw.WriteHeader(hdr); err != nil {
return err
}
if _, err := tw.Write(h); err != nil {
return err
}

// write out the individual snapshots
for _, snapshotFile := range se.snapshotFiles {
stat, err := snapshotFile.Stat()
if err != nil {
Expand Down Expand Up @@ -889,7 +984,7 @@ func (se *SnapshotExport) StreamTo(w io.Writer) error {
if err != nil {
return fmt.Errorf("cannot marshal meta-data: %v", err)
}
hdr := &tar.Header{
hdr = &tar.Header{
Typeflag: tar.TypeReg,
Name: "export.json",
Size: int64(len(metaDataBuf)),
Expand Down
Loading

0 comments on commit f51367f

Please sign in to comment.