Add api/ package function to save snapshot

This commit is contained in:
Mahmood Ali
2020-05-21 18:52:19 -04:00
parent f4fcc1c02c
commit 1bec2425b0
3 changed files with 196 additions and 0 deletions

81
api/ioutil.go Normal file
View File

@@ -0,0 +1,81 @@
package api
import (
"crypto/md5"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"io"
"strings"
)
var errMismatchChecksum = fmt.Errorf("mismatch checksum")
// checksumValidatingReader is a wrapper reader that validates
// the checksum of the underlying reader.
type checksumValidatingReader struct {
r io.ReadCloser
// algo is the hash algorithm (e.g. `sha-256`)
algo string
// checksum is the base64 component of checksum
checksum string
// hash is the hashing function used to compute the checksum
hash hash.Hash
}
// newChecksumValidatingReader returns a checksum-validating wrapper reader, according
// to a digest received in HTTP header
//
// The digest must be in the format "<algo>=<base64 of hash>" (e.g. "sha-256=gPelGB7...").
//
// When the reader is fully consumed (i.e. EOT is encountered), if the checksum don't match,
// `Read` returns a checksum mismatch error.
func newChecksumValidatingReader(r io.ReadCloser, digest string) (io.ReadCloser, error) {
parts := strings.SplitN(digest, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid digest format")
}
algo := parts[0]
var hash hash.Hash
switch algo {
case "sha-256":
hash = sha256.New()
case "sha-512":
hash = sha512.New()
case "md5":
hash = md5.New()
}
return &checksumValidatingReader{
r: r,
algo: algo,
checksum: parts[1],
hash: hash,
}, nil
}
func (r *checksumValidatingReader) Read(b []byte) (int, error) {
n, err := r.r.Read(b)
if n != 0 {
r.hash.Write(b[:n])
}
if err == io.EOF || err == io.ErrClosedPipe {
found := base64.StdEncoding.EncodeToString(r.hash.Sum(nil))
if found != r.checksum {
return n, errMismatchChecksum
}
}
return n, err
}
func (r *checksumValidatingReader) Close() error {
return r.r.Close()
}

87
api/ioutil_test.go Normal file
View File

@@ -0,0 +1,87 @@
package api
import (
"bytes"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"io"
"io/ioutil"
"math/rand"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
)
func TestChecksumValidatingReader(t *testing.T) {
data := make([]byte, 4096)
_, err := rand.Read(data)
require.NoError(t, err)
cases := []struct {
algo string
hash hash.Hash
}{
{"sha-256", sha256.New()},
{"sha-512", sha512.New()},
}
for _, c := range cases {
t.Run("valid: "+c.algo, func(t *testing.T) {
_, err := c.hash.Write(data)
require.NoError(t, err)
checksum := c.hash.Sum(nil)
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
r := iotest.HalfReader(bytes.NewReader(data))
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.NoError(t, err)
})
t.Run("invalid: "+c.algo, func(t *testing.T) {
_, err := c.hash.Write(data)
require.NoError(t, err)
checksum := c.hash.Sum(nil)
// mess up checksum
checksum[0]++
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
r := iotest.HalfReader(bytes.NewReader(data))
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.Error(t, err)
require.Equal(t, errMismatchChecksum, err)
})
}
}
func TestChecksumValidatingReader_PropagatesError(t *testing.T) {
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
expectedErr := fmt.Errorf("some error")
go func() {
pw.Write([]byte("some input"))
pw.CloseWithError(expectedErr)
}()
cr, err := newChecksumValidatingReader(pr, "sha-256=aaaa")
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.Error(t, err)
require.Equal(t, expectedErr, err)
}

View File

@@ -1,6 +1,8 @@
package api
import (
"io"
"io/ioutil"
"strconv"
"time"
)
@@ -194,6 +196,32 @@ func (op *Operator) SchedulerCASConfiguration(conf *SchedulerConfiguration, q *W
return &out, wm, nil
}
// Snapshot is used to capture a snapshot state of a running cluster.
// The returned reader that must be consumed fully
func (op *Operator) Snapshot(q *QueryOptions) (io.ReadCloser, error) {
r, err := op.c.newRequest("GET", "/v1/operator/snapshot")
if err != nil {
return nil, err
}
r.setQueryOptions(q)
_, resp, err := requireOK(op.c.doRequest(r))
if err != nil {
return nil, err
}
digest := resp.Header.Get("Digest")
cr, err := newChecksumValidatingReader(resp.Body, digest)
if err != nil {
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
return nil, err
}
return cr, nil
}
type License struct {
// The unique identifier of the license
LicenseID string