From 1bec2425b0b9a8b4352720a7cdb140f6739d7ca1 Mon Sep 17 00:00:00 2001 From: Mahmood Ali Date: Thu, 21 May 2020 18:52:19 -0400 Subject: [PATCH] Add api/ package function to save snapshot --- api/ioutil.go | 81 ++++++++++++++++++++++++++++++++++++++++++ api/ioutil_test.go | 87 ++++++++++++++++++++++++++++++++++++++++++++++ api/operator.go | 28 +++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 api/ioutil.go create mode 100644 api/ioutil_test.go diff --git a/api/ioutil.go b/api/ioutil.go new file mode 100644 index 000000000..4f585dba0 --- /dev/null +++ b/api/ioutil.go @@ -0,0 +1,81 @@ +package api + +import ( + "crypto/md5" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "fmt" + "hash" + "io" + "strings" +) + +var errMismatchChecksum = fmt.Errorf("mismatch checksum") + +// checksumValidatingReader is a wrapper reader that validates +// the checksum of the underlying reader. +type checksumValidatingReader struct { + r io.ReadCloser + + // algo is the hash algorithm (e.g. `sha-256`) + algo string + + // checksum is the base64 component of checksum + checksum string + + // hash is the hashing function used to compute the checksum + hash hash.Hash +} + +// newChecksumValidatingReader returns a checksum-validating wrapper reader, according +// to a digest received in HTTP header +// +// The digest must be in the format "=" (e.g. "sha-256=gPelGB7..."). +// +// When the reader is fully consumed (i.e. EOT is encountered), if the checksum don't match, +// `Read` returns a checksum mismatch error. +func newChecksumValidatingReader(r io.ReadCloser, digest string) (io.ReadCloser, error) { + parts := strings.SplitN(digest, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid digest format") + } + + algo := parts[0] + var hash hash.Hash + switch algo { + case "sha-256": + hash = sha256.New() + case "sha-512": + hash = sha512.New() + case "md5": + hash = md5.New() + } + + return &checksumValidatingReader{ + r: r, + algo: algo, + checksum: parts[1], + hash: hash, + }, nil +} + +func (r *checksumValidatingReader) Read(b []byte) (int, error) { + n, err := r.r.Read(b) + if n != 0 { + r.hash.Write(b[:n]) + } + + if err == io.EOF || err == io.ErrClosedPipe { + found := base64.StdEncoding.EncodeToString(r.hash.Sum(nil)) + if found != r.checksum { + return n, errMismatchChecksum + } + } + + return n, err +} + +func (r *checksumValidatingReader) Close() error { + return r.r.Close() +} diff --git a/api/ioutil_test.go b/api/ioutil_test.go new file mode 100644 index 000000000..1871f410c --- /dev/null +++ b/api/ioutil_test.go @@ -0,0 +1,87 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "fmt" + "hash" + "io" + "io/ioutil" + "math/rand" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/require" +) + +func TestChecksumValidatingReader(t *testing.T) { + data := make([]byte, 4096) + _, err := rand.Read(data) + require.NoError(t, err) + + cases := []struct { + algo string + hash hash.Hash + }{ + {"sha-256", sha256.New()}, + {"sha-512", sha512.New()}, + } + + for _, c := range cases { + t.Run("valid: "+c.algo, func(t *testing.T) { + _, err := c.hash.Write(data) + require.NoError(t, err) + + checksum := c.hash.Sum(nil) + digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum) + + r := iotest.HalfReader(bytes.NewReader(data)) + cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest) + require.NoError(t, err) + + _, err = io.Copy(ioutil.Discard, cr) + require.NoError(t, err) + }) + + t.Run("invalid: "+c.algo, func(t *testing.T) { + _, err := c.hash.Write(data) + require.NoError(t, err) + + checksum := c.hash.Sum(nil) + // mess up checksum + checksum[0]++ + digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum) + + r := iotest.HalfReader(bytes.NewReader(data)) + cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest) + require.NoError(t, err) + + _, err = io.Copy(ioutil.Discard, cr) + require.Error(t, err) + require.Equal(t, errMismatchChecksum, err) + }) + } +} + +func TestChecksumValidatingReader_PropagatesError(t *testing.T) { + + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + + expectedErr := fmt.Errorf("some error") + + go func() { + pw.Write([]byte("some input")) + pw.CloseWithError(expectedErr) + }() + + cr, err := newChecksumValidatingReader(pr, "sha-256=aaaa") + require.NoError(t, err) + + _, err = io.Copy(ioutil.Discard, cr) + require.Error(t, err) + require.Equal(t, expectedErr, err) +} diff --git a/api/operator.go b/api/operator.go index cb9060e87..febbd8bd1 100644 --- a/api/operator.go +++ b/api/operator.go @@ -1,6 +1,8 @@ package api import ( + "io" + "io/ioutil" "strconv" "time" ) @@ -194,6 +196,32 @@ func (op *Operator) SchedulerCASConfiguration(conf *SchedulerConfiguration, q *W return &out, wm, nil } +// Snapshot is used to capture a snapshot state of a running cluster. +// The returned reader that must be consumed fully +func (op *Operator) Snapshot(q *QueryOptions) (io.ReadCloser, error) { + r, err := op.c.newRequest("GET", "/v1/operator/snapshot") + if err != nil { + return nil, err + } + r.setQueryOptions(q) + _, resp, err := requireOK(op.c.doRequest(r)) + if err != nil { + return nil, err + } + + digest := resp.Header.Get("Digest") + + cr, err := newChecksumValidatingReader(resp.Body, digest) + if err != nil { + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + + return nil, err + } + + return cr, nil +} + type License struct { // The unique identifier of the license LicenseID string