Merge pull request #8047 from hashicorp/f-snapshot-save

API for atomic snapshot backups
This commit is contained in:
Mahmood Ali
2020-06-01 07:55:16 -04:00
committed by GitHub
31 changed files with 2121 additions and 25 deletions

81
api/ioutil.go Normal file
View File

@@ -0,0 +1,81 @@
package api
import (
"crypto/md5"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"io"
"strings"
)
var errMismatchChecksum = fmt.Errorf("mismatch checksum")
// checksumValidatingReader is a wrapper reader that validates
// the checksum of the underlying reader.
type checksumValidatingReader struct {
r io.ReadCloser
// algo is the hash algorithm (e.g. `sha-256`)
algo string
// checksum is the base64 component of checksum
checksum string
// hash is the hashing function used to compute the checksum
hash hash.Hash
}
// newChecksumValidatingReader returns a checksum-validating wrapper reader, according
// to a digest received in HTTP header
//
// The digest must be in the format "<algo>=<base64 of hash>" (e.g. "sha-256=gPelGB7...").
//
// When the reader is fully consumed (i.e. EOT is encountered), if the checksum don't match,
// `Read` returns a checksum mismatch error.
func newChecksumValidatingReader(r io.ReadCloser, digest string) (io.ReadCloser, error) {
parts := strings.SplitN(digest, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid digest format")
}
algo := parts[0]
var hash hash.Hash
switch algo {
case "sha-256":
hash = sha256.New()
case "sha-512":
hash = sha512.New()
case "md5":
hash = md5.New()
}
return &checksumValidatingReader{
r: r,
algo: algo,
checksum: parts[1],
hash: hash,
}, nil
}
func (r *checksumValidatingReader) Read(b []byte) (int, error) {
n, err := r.r.Read(b)
if n != 0 {
r.hash.Write(b[:n])
}
if err == io.EOF || err == io.ErrClosedPipe {
found := base64.StdEncoding.EncodeToString(r.hash.Sum(nil))
if found != r.checksum {
return n, errMismatchChecksum
}
}
return n, err
}
func (r *checksumValidatingReader) Close() error {
return r.r.Close()
}

87
api/ioutil_test.go Normal file
View File

@@ -0,0 +1,87 @@
package api
import (
"bytes"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"io"
"io/ioutil"
"math/rand"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
)
func TestChecksumValidatingReader(t *testing.T) {
data := make([]byte, 4096)
_, err := rand.Read(data)
require.NoError(t, err)
cases := []struct {
algo string
hash hash.Hash
}{
{"sha-256", sha256.New()},
{"sha-512", sha512.New()},
}
for _, c := range cases {
t.Run("valid: "+c.algo, func(t *testing.T) {
_, err := c.hash.Write(data)
require.NoError(t, err)
checksum := c.hash.Sum(nil)
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
r := iotest.HalfReader(bytes.NewReader(data))
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.NoError(t, err)
})
t.Run("invalid: "+c.algo, func(t *testing.T) {
_, err := c.hash.Write(data)
require.NoError(t, err)
checksum := c.hash.Sum(nil)
// mess up checksum
checksum[0]++
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
r := iotest.HalfReader(bytes.NewReader(data))
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.Error(t, err)
require.Equal(t, errMismatchChecksum, err)
})
}
}
func TestChecksumValidatingReader_PropagatesError(t *testing.T) {
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
expectedErr := fmt.Errorf("some error")
go func() {
pw.Write([]byte("some input"))
pw.CloseWithError(expectedErr)
}()
cr, err := newChecksumValidatingReader(pr, "sha-256=aaaa")
require.NoError(t, err)
_, err = io.Copy(ioutil.Discard, cr)
require.Error(t, err)
require.Equal(t, expectedErr, err)
}

View File

@@ -1,6 +1,8 @@
package api
import (
"io"
"io/ioutil"
"strconv"
"time"
)
@@ -194,6 +196,32 @@ func (op *Operator) SchedulerCASConfiguration(conf *SchedulerConfiguration, q *W
return &out, wm, nil
}
// Snapshot is used to capture a snapshot state of a running cluster.
// The returned reader that must be consumed fully
func (op *Operator) Snapshot(q *QueryOptions) (io.ReadCloser, error) {
r, err := op.c.newRequest("GET", "/v1/operator/snapshot")
if err != nil {
return nil, err
}
r.setQueryOptions(q)
_, resp, err := requireOK(op.c.doRequest(r))
if err != nil {
return nil, err
}
digest := resp.Header.Get("Digest")
cr, err := newChecksumValidatingReader(resp.Body, digest)
if err != nil {
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
return nil, err
}
return cr, nil
}
type License struct {
// The unique identifier of the license
LicenseID string

View File

@@ -318,6 +318,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
s.mux.HandleFunc("/v1/operator/raft/", s.wrap(s.OperatorRequest))
s.mux.HandleFunc("/v1/operator/autopilot/configuration", s.wrap(s.OperatorAutopilotConfiguration))
s.mux.HandleFunc("/v1/operator/autopilot/health", s.wrap(s.OperatorServerHealth))
s.mux.HandleFunc("/v1/operator/snapshot", s.wrap(s.SnapshotRequest))
s.mux.HandleFunc("/v1/system/gc", s.wrap(s.GarbageCollectRequest))
s.mux.HandleFunc("/v1/system/reconcile/summaries", s.wrap(s.ReconcileJobSummaries))

View File

@@ -1,6 +1,9 @@
package agent
import (
"context"
"io"
"net"
"net/http"
"strings"
@@ -9,6 +12,7 @@ import (
"time"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/api"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/raft"
@@ -283,3 +287,88 @@ func (s *HTTPServer) schedulerUpdateConfig(resp http.ResponseWriter, req *http.R
setIndex(resp, reply.Index)
return reply, nil
}
func (s *HTTPServer) SnapshotRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
switch req.Method {
case "GET":
return s.snapshotSaveRequest(resp, req)
default:
return nil, CodedError(405, ErrInvalidMethod)
}
}
func (s *HTTPServer) snapshotSaveRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
args := &structs.SnapshotSaveRequest{}
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
return nil, nil
}
var handler structs.StreamingRpcHandler
var handlerErr error
if server := s.agent.Server(); server != nil {
handler, handlerErr = server.StreamingRpcHandler("Operator.SnapshotSave")
} else if client := s.agent.Client(); client != nil {
handler, handlerErr = client.RemoteStreamingRpcHandler("Operator.SnapshotSave")
} else {
handlerErr = fmt.Errorf("misconfigured connection")
}
if handlerErr != nil {
return nil, CodedError(500, handlerErr.Error())
}
httpPipe, handlerPipe := net.Pipe()
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
encoder := codec.NewEncoder(httpPipe, structs.MsgpackHandle)
// Create a goroutine that closes the pipe if the connection closes.
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
go func() {
<-ctx.Done()
httpPipe.Close()
}()
errCh := make(chan HTTPCodedError, 1)
go func() {
defer cancel()
// Send the request
if err := encoder.Encode(args); err != nil {
errCh <- CodedError(500, err.Error())
return
}
var res structs.SnapshotSaveResponse
if err := decoder.Decode(&res); err != nil {
errCh <- CodedError(500, err.Error())
return
}
if res.ErrorMsg != "" {
errCh <- CodedError(res.ErrorCode, res.ErrorMsg)
return
}
resp.Header().Add("Digest", res.SnapshotChecksum)
_, err := io.Copy(resp, httpPipe)
if err != nil &&
err != io.EOF &&
!strings.Contains(err.Error(), "closed") &&
!strings.Contains(err.Error(), "EOF") {
errCh <- CodedError(500, err.Error())
return
}
errCh <- nil
}()
handler(handlerPipe)
cancel()
codedErr := <-errCh
return nil, codedErr
}

View File

@@ -2,9 +2,15 @@ package agent
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path"
"strings"
"testing"
"time"
@@ -382,3 +388,39 @@ func TestOperator_SchedulerCASConfiguration(t *testing.T) {
require.False(reply.SchedulerConfig.PreemptionConfig.BatchSchedulerEnabled)
})
}
func TestOperator_SnapshotSaveRequest(t *testing.T) {
t.Parallel()
////// Nomad clusters topology - not specific to test
dir, err := ioutil.TempDir("", "nomadtest-operator-")
require.NoError(t, err)
defer os.RemoveAll(dir)
httpTest(t, func(c *Config) {
c.Server.BootstrapExpect = 1
c.DevMode = false
c.DataDir = path.Join(dir, "server")
c.AdvertiseAddrs.HTTP = "127.0.0.1"
c.AdvertiseAddrs.RPC = "127.0.0.1"
c.AdvertiseAddrs.Serf = "127.0.0.1"
}, func(s *TestAgent) {
req, _ := http.NewRequest("GET", "/v1/operator/snapshot", nil)
resp := httptest.NewRecorder()
_, err := s.Server.SnapshotRequest(resp, req)
require.NoError(t, err)
require.Equal(t, 200, resp.Code)
digest := resp.Header().Get("Digest")
require.NotEmpty(t, digest)
require.Contains(t, digest, "sha-256=")
hash := sha256.New()
_, err = io.Copy(hash, resp.Body)
require.NoError(t, err)
expectedChecksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
require.Equal(t, digest, expectedChecksum)
})
}

View File

@@ -126,8 +126,15 @@ func (a *TestAgent) Start() *TestAgent {
i := 10
advertiseAddrs := *a.Config.AdvertiseAddrs
RETRY:
i--
// Clear out the advertise addresses such that through retries we
// re-normalize the addresses correctly instead of using the values from the
// last port selection that had a port conflict.
newAddrs := advertiseAddrs
a.Config.AdvertiseAddrs = &newAddrs
a.pickRandomPorts(a.Config)
if a.Config.NodeName == "" {
a.Config.NodeName = fmt.Sprintf("Node %d", a.Config.Ports.RPC)
@@ -312,15 +319,6 @@ func (a *TestAgent) pickRandomPorts(c *Config) {
c.Ports.RPC = ports[1]
c.Ports.Serf = ports[2]
// Clear out the advertise addresses such that through retries we
// re-normalize the addresses correctly instead of using the values from the
// last port selection that had a port conflict.
if c.AdvertiseAddrs != nil {
c.AdvertiseAddrs.HTTP = ""
c.AdvertiseAddrs.RPC = ""
c.AdvertiseAddrs.Serf = ""
}
if err := c.normalizeAddrs(); err != nil {
a.T.Fatalf("error normalizing config: %v", err)
}

View File

@@ -502,6 +502,22 @@ func Commands(metaPtr *Meta, agentUi cli.Ui) map[string]cli.CommandFactory {
}, nil
},
"operator snapshot": func() (cli.Command, error) {
return &OperatorSnapshotCommand{
Meta: meta,
}, nil
},
"operator snapshot save": func() (cli.Command, error) {
return &OperatorSnapshotSaveCommand{
Meta: meta,
}, nil
},
"operator snapshot inspect": func() (cli.Command, error) {
return &OperatorSnapshotInspectCommand{
Meta: meta,
}, nil
},
"plan": func() (cli.Command, error) {
return &JobPlanCommand{
Meta: meta,

View File

@@ -0,0 +1,50 @@
package command
import (
"strings"
"github.com/mitchellh/cli"
)
type OperatorSnapshotCommand struct {
Meta
}
func (f *OperatorSnapshotCommand) Help() string {
helpText := `
Usage: nomad operator snapshot <subcommand> [options]
This command has subcommands for saving and inspecting the state
of the Nomad servers for disaster recovery. These are atomic, point-in-time
snapshots which include jobs, nodes, allocations, periodic jobs, and ACLs.
If ACLs are enabled, a management token must be supplied in order to perform
snapshot operations.
Create a snapshot:
$ nomad operator snapshot save backup.snap
Inspect a snapshot:
$ nomad operator snapshot inspect backup.snap
Run a daemon process that locally saves a snapshot every hour (available only in
Nomad Enterprise) :
$ nomad operator snapshot agent
Please see the individual subcommand help for detailed usage information.
`
return strings.TrimSpace(helpText)
}
func (f *OperatorSnapshotCommand) Synopsis() string {
return "Saves and inspects snapshots of Nomad server state"
}
func (f *OperatorSnapshotCommand) Name() string { return "operator snapshot" }
func (f *OperatorSnapshotCommand) Run(args []string) int {
return cli.RunResultHelp
}

View File

@@ -0,0 +1,74 @@
package command
import (
"fmt"
"os"
"strings"
"github.com/hashicorp/nomad/helper/snapshot"
"github.com/posener/complete"
)
type OperatorSnapshotInspectCommand struct {
Meta
}
func (c *OperatorSnapshotInspectCommand) Help() string {
helpText := `
Usage: nomad operator snapshot inspect [options] FILE
Displays information about a snapshot file on disk.
To inspect the file "backup.snap":
$ nomad operator snapshot inspect backup.snap
`
return strings.TrimSpace(helpText)
}
func (c *OperatorSnapshotInspectCommand) AutocompleteFlags() complete.Flags {
return complete.Flags{}
}
func (c *OperatorSnapshotInspectCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
}
func (c *OperatorSnapshotInspectCommand) Synopsis() string {
return "Displays information about a Nomad snapshot file"
}
func (c *OperatorSnapshotInspectCommand) Name() string { return "operator snapshot inspect" }
func (c *OperatorSnapshotInspectCommand) Run(args []string) int {
// Check that we either got no filename or exactly one.
if len(args) != 1 {
c.Ui.Error("This command takes one argument: <filename>")
c.Ui.Error(commandErrorText(c))
return 1
}
path := args[0]
f, err := os.Open(path)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error opening snapshot file: %s", err))
return 1
}
defer f.Close()
meta, err := snapshot.Verify(f)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error verifying snapshot: %s", err))
return 1
}
output := []string{
fmt.Sprintf("ID|%s", meta.ID),
fmt.Sprintf("Size|%d", meta.Size),
fmt.Sprintf("Index|%d", meta.Index),
fmt.Sprintf("Term|%d", meta.Term),
fmt.Sprintf("Version|%d", meta.Version),
}
c.Ui.Output(formatList(output))
return 0
}

View File

@@ -0,0 +1,99 @@
package command
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/hashicorp/nomad/command/agent"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
func TestOperatorSnapshotInspect_Works(t *testing.T) {
t.Parallel()
snapPath := generateSnapshotFile(t)
ui := new(cli.MockUi)
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
code := cmd.Run([]string{snapPath})
require.Zero(t, code)
output := ui.OutputWriter.String()
for _, key := range []string{
"ID",
"Size",
"Index",
"Term",
"Version",
} {
require.Contains(t, output, key)
}
}
func TestOperatorSnapshotInspect_HandlesFailure(t *testing.T) {
t.Parallel()
tmpDir, err := ioutil.TempDir("", "nomad-clitests-")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
err = ioutil.WriteFile(
filepath.Join(tmpDir, "invalid.snap"),
[]byte("invalid data"),
0600)
require.NoError(t, err)
t.Run("not found", func(t *testing.T) {
ui := new(cli.MockUi)
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
code := cmd.Run([]string{filepath.Join(tmpDir, "foo")})
require.NotZero(t, code)
require.Contains(t, ui.ErrorWriter.String(), "no such file")
})
t.Run("invalid file", func(t *testing.T) {
ui := new(cli.MockUi)
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
code := cmd.Run([]string{filepath.Join(tmpDir, "invalid.snap")})
require.NotZero(t, code)
require.Contains(t, ui.ErrorWriter.String(), "Error verifying snapshot")
})
}
func generateSnapshotFile(t *testing.T) string {
tmpDir, err := ioutil.TempDir("", "nomad-tempdir")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(tmpDir) })
srv, _, url := testServer(t, false, func(c *agent.Config) {
c.DevMode = false
c.DataDir = filepath.Join(tmpDir, "server")
c.AdvertiseAddrs.HTTP = "127.0.0.1"
c.AdvertiseAddrs.RPC = "127.0.0.1"
c.AdvertiseAddrs.Serf = "127.0.0.1"
})
defer srv.Shutdown()
ui := new(cli.MockUi)
cmd := &OperatorSnapshotSaveCommand{Meta: Meta{Ui: ui}}
dest := filepath.Join(tmpDir, "backup.snap")
code := cmd.Run([]string{
"--address=" + url,
dest,
})
require.Zero(t, code)
return dest
}

View File

@@ -0,0 +1,142 @@
package command
import (
"fmt"
"io"
"os"
"strings"
"time"
"github.com/hashicorp/nomad/api"
"github.com/posener/complete"
)
type OperatorSnapshotSaveCommand struct {
Meta
}
func (c *OperatorSnapshotSaveCommand) Help() string {
helpText := `
Usage: nomad operator snapshot save [options] <filename>
Retrieves an atomic, point-in-time snapshot of the state of the Nomad servers
which includes jobs, nodes, allocations, periodic jobs, and ACLs.
If ACLs are enabled, a management token must be supplied in order to perform
snapshot operations.
To create a snapshot from the leader server and save it to "backup.snap":
$ nomad snapshot save backup.snap
To create a potentially stale snapshot from any available server (useful if no
leader is available):
General Options:
` + generalOptionsUsage() + `
Snapshot Save Options:
-stale=[true|false]
The -stale argument defaults to "false" which means the leader provides the
result. If the cluster is in an outage state without a leader, you may need
to set -stale to "true" to get the configuration from a non-leader server.
`
return strings.TrimSpace(helpText)
}
func (c *OperatorSnapshotSaveCommand) AutocompleteFlags() complete.Flags {
return mergeAutocompleteFlags(c.Meta.AutocompleteFlags(FlagSetClient),
complete.Flags{
"-stale": complete.PredictAnything,
})
}
func (c *OperatorSnapshotSaveCommand) AutocompleteArgs() complete.Predictor {
return complete.PredictNothing
}
func (c *OperatorSnapshotSaveCommand) Synopsis() string {
return "Saves snapshot of Nomad server state"
}
func (c *OperatorSnapshotSaveCommand) Name() string { return "operator snapshot save" }
func (c *OperatorSnapshotSaveCommand) Run(args []string) int {
var stale bool
flags := c.Meta.FlagSet(c.Name(), FlagSetClient)
flags.Usage = func() { c.Ui.Output(c.Help()) }
flags.BoolVar(&stale, "stale", false, "")
if err := flags.Parse(args); err != nil {
c.Ui.Error(fmt.Sprintf("Failed to parse args: %v", err))
return 1
}
// Check for misuse
// Check that we either got no filename or exactly one.
args = flags.Args()
if len(args) > 1 {
c.Ui.Error("This command takes either no arguments or one: <filename>")
c.Ui.Error(commandErrorText(c))
return 1
}
now := time.Now()
filename := fmt.Sprintf("nomad-state-%04d%02d%0d-%d.snap", now.Year(), now.Month(), now.Day(), now.Unix())
if len(args) == 1 {
filename = args[0]
}
if _, err := os.Lstat(filename); err == nil {
c.Ui.Error(fmt.Sprintf("Destination file already exists: %q", filename))
c.Ui.Error(commandErrorText(c))
return 1
} else if !os.IsNotExist(err) {
c.Ui.Error(fmt.Sprintf("Unexpected failure checking %q: %v", filename, err))
return 1
}
// Set up a client.
client, err := c.Meta.Client()
if err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err))
return 1
}
tmpFile, err := os.Create(filename + ".tmp")
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to create file: %v", err))
return 1
}
// Fetch the current configuration.
q := &api.QueryOptions{
AllowStale: stale,
}
snapIn, err := client.Operator().Snapshot(q)
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to get snapshot file: %v", err))
return 1
}
defer snapIn.Close()
_, err = io.Copy(tmpFile, snapIn)
if err != nil {
c.Ui.Error(fmt.Sprintf("Filed to download snapshot file: %v", err))
return 1
}
err = os.Rename(tmpFile.Name(), filename)
if err != nil {
c.Ui.Error(fmt.Sprintf("Filed to finalize snapshot file: %v", err))
return 1
}
c.Ui.Output(fmt.Sprintf("State file written to %v", filename))
return 0
}

View File

@@ -0,0 +1,51 @@
package command
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"github.com/hashicorp/nomad/command/agent"
"github.com/hashicorp/nomad/helper/snapshot"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
func TestOperatorSnapshotSave_Works(t *testing.T) {
t.Parallel()
tmpDir, err := ioutil.TempDir("", "nomad-tempdir")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
srv, _, url := testServer(t, false, func(c *agent.Config) {
c.DevMode = false
c.DataDir = filepath.Join(tmpDir, "server")
c.AdvertiseAddrs.HTTP = "127.0.0.1"
c.AdvertiseAddrs.RPC = "127.0.0.1"
c.AdvertiseAddrs.Serf = "127.0.0.1"
})
defer srv.Shutdown()
ui := new(cli.MockUi)
cmd := &OperatorSnapshotSaveCommand{Meta: Meta{Ui: ui}}
dest := filepath.Join(tmpDir, "backup.snap")
code := cmd.Run([]string{
"--address=" + url,
dest,
})
require.Zero(t, code)
require.Contains(t, ui.OutputWriter.String(), "State file written to "+dest)
f, err := os.Open(dest)
require.NoError(t, err)
meta, err := snapshot.Verify(f)
require.NoError(t, err)
require.NotZero(t, meta.Index)
}

234
helper/snapshot/archive.go Normal file
View File

@@ -0,0 +1,234 @@
// The archive utilities manage the internal format of a snapshot, which is a
// tar file with the following contents:
//
// meta.json - JSON-encoded snapshot metadata from Raft
// state.bin - Encoded snapshot data from Raft
// SHA256SUMS - SHA-256 sums of the above two files
//
// The integrity information is automatically created and checked, and a failure
// there just looks like an error to the caller.
package snapshot
import (
"archive/tar"
"bufio"
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"hash"
"io"
"io/ioutil"
"time"
"github.com/hashicorp/raft"
)
// hashList manages a list of filenames and their hashes.
type hashList struct {
hashes map[string]hash.Hash
}
// newHashList returns a new hashList.
func newHashList() *hashList {
return &hashList{
hashes: make(map[string]hash.Hash),
}
}
// Add creates a new hash for the given file.
func (hl *hashList) Add(file string) hash.Hash {
if existing, ok := hl.hashes[file]; ok {
return existing
}
h := sha256.New()
hl.hashes[file] = h
return h
}
// Encode takes the current sum of all the hashes and saves the hash list as a
// SHA256SUMS-style text file.
func (hl *hashList) Encode(w io.Writer) error {
for file, h := range hl.hashes {
if _, err := fmt.Fprintf(w, "%x %s\n", h.Sum([]byte{}), file); err != nil {
return err
}
}
return nil
}
// DecodeAndVerify reads a SHA256SUMS-style text file and checks the results
// against the current sums for all the hashes.
func (hl *hashList) DecodeAndVerify(r io.Reader) error {
// Read the file and make sure everything in there has a matching hash.
seen := make(map[string]struct{})
s := bufio.NewScanner(r)
for s.Scan() {
sha := make([]byte, sha256.Size)
var file string
if _, err := fmt.Sscanf(s.Text(), "%x %s", &sha, &file); err != nil {
return err
}
h, ok := hl.hashes[file]
if !ok {
return fmt.Errorf("list missing hash for %q", file)
}
if !bytes.Equal(sha, h.Sum([]byte{})) {
return fmt.Errorf("hash check failed for %q", file)
}
seen[file] = struct{}{}
}
if err := s.Err(); err != nil {
return err
}
// Make sure everything we had a hash for was seen.
for file := range hl.hashes {
if _, ok := seen[file]; !ok {
return fmt.Errorf("file missing for %q", file)
}
}
return nil
}
// write takes a writer and creates an archive with the snapshot metadata,
// the snapshot itself, and adds some integrity checking information.
func write(out io.Writer, metadata *raft.SnapshotMeta, snap io.Reader) error {
// Start a new tarball.
now := time.Now()
archive := tar.NewWriter(out)
// Create a hash list that we will use to write a SHA256SUMS file into
// the archive.
hl := newHashList()
// Encode the snapshot metadata, which we need to feed back during a
// restore.
metaHash := hl.Add("meta.json")
var metaBuffer bytes.Buffer
enc := json.NewEncoder(&metaBuffer)
if err := enc.Encode(metadata); err != nil {
return fmt.Errorf("failed to encode snapshot metadata: %v", err)
}
if err := archive.WriteHeader(&tar.Header{
Name: "meta.json",
Mode: 0600,
Size: int64(metaBuffer.Len()),
ModTime: now,
}); err != nil {
return fmt.Errorf("failed to write snapshot metadata header: %v", err)
}
if _, err := io.Copy(archive, io.TeeReader(&metaBuffer, metaHash)); err != nil {
return fmt.Errorf("failed to write snapshot metadata: %v", err)
}
// Copy the snapshot data given the size from the metadata.
snapHash := hl.Add("state.bin")
if err := archive.WriteHeader(&tar.Header{
Name: "state.bin",
Mode: 0600,
Size: metadata.Size,
ModTime: now,
}); err != nil {
return fmt.Errorf("failed to write snapshot data header: %v", err)
}
if _, err := io.CopyN(archive, io.TeeReader(snap, snapHash), metadata.Size); err != nil {
return fmt.Errorf("failed to write snapshot metadata: %v", err)
}
// Create a SHA256SUMS file that we can use to verify on restore.
var shaBuffer bytes.Buffer
if err := hl.Encode(&shaBuffer); err != nil {
return fmt.Errorf("failed to encode snapshot hashes: %v", err)
}
if err := archive.WriteHeader(&tar.Header{
Name: "SHA256SUMS",
Mode: 0600,
Size: int64(shaBuffer.Len()),
ModTime: now,
}); err != nil {
return fmt.Errorf("failed to write snapshot hashes header: %v", err)
}
if _, err := io.Copy(archive, &shaBuffer); err != nil {
return fmt.Errorf("failed to write snapshot metadata: %v", err)
}
// Finalize the archive.
if err := archive.Close(); err != nil {
return fmt.Errorf("failed to finalize snapshot: %v", err)
}
return nil
}
// read takes a reader and extracts the snapshot metadata and the snapshot
// itself, and also checks the integrity of the data. You must arrange to call
// Close() on the returned object or else you will leak a temporary file.
func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
// Start a new tar reader.
archive := tar.NewReader(in)
// Create a hash list that we will use to compare with the SHA256SUMS
// file in the archive.
hl := newHashList()
// Populate the hashes for all the files we expect to see. The check at
// the end will make sure these are all present in the SHA256SUMS file
// and that the hashes match.
metaHash := hl.Add("meta.json")
snapHash := hl.Add("state.bin")
// Look through the archive for the pieces we care about.
var shaBuffer bytes.Buffer
for {
hdr, err := archive.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed reading snapshot: %v", err)
}
switch hdr.Name {
case "meta.json":
// Previously we used json.Decode to decode the archive stream. There are
// edgecases in which it doesn't read all the bytes from the stream, even
// though the json object is still being parsed properly. Since we
// simultaneously feeded everything to metaHash, our hash ended up being
// different than what we calculated when creating the snapshot. Which in
// turn made the snapshot verification fail. By explicitly reading the
// whole thing first we ensure that we calculate the correct hash
// independent of how json.Decode works internally.
buf, err := ioutil.ReadAll(io.TeeReader(archive, metaHash))
if err != nil {
return fmt.Errorf("failed to read snapshot metadata: %v", err)
}
if err := json.Unmarshal(buf, &metadata); err != nil {
return fmt.Errorf("failed to decode snapshot metadata: %v", err)
}
case "state.bin":
if _, err := io.Copy(io.MultiWriter(snap, snapHash), archive); err != nil {
return fmt.Errorf("failed to read or write snapshot data: %v", err)
}
case "SHA256SUMS":
if _, err := io.Copy(&shaBuffer, archive); err != nil {
return fmt.Errorf("failed to read snapshot hashes: %v", err)
}
default:
return fmt.Errorf("unexpected file %q in snapshot", hdr.Name)
}
}
// Verify all the hashes.
if err := hl.DecodeAndVerify(&shaBuffer); err != nil {
return fmt.Errorf("failed checking integrity of snapshot: %v", err)
}
return nil
}

View File

@@ -0,0 +1,153 @@
package snapshot
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"io/ioutil"
"os"
"reflect"
"strings"
"testing"
"github.com/hashicorp/raft"
)
func TestArchive(t *testing.T) {
// Create some fake snapshot data.
metadata := raft.SnapshotMeta{
Index: 2005,
Term: 2011,
Configuration: raft.Configuration{
Servers: []raft.Server{
raft.Server{
Suffrage: raft.Voter,
ID: raft.ServerID("hello"),
Address: raft.ServerAddress("127.0.0.1:8300"),
},
},
},
Size: 1024,
}
var snap bytes.Buffer
var expected bytes.Buffer
both := io.MultiWriter(&snap, &expected)
if _, err := io.Copy(both, io.LimitReader(rand.Reader, 1024)); err != nil {
t.Fatalf("err: %v", err)
}
// Write out the snapshot.
var archive bytes.Buffer
if err := write(&archive, &metadata, &snap); err != nil {
t.Fatalf("err: %v", err)
}
// Read the snapshot back.
var newMeta raft.SnapshotMeta
var newSnap bytes.Buffer
if err := read(&archive, &newMeta, &newSnap); err != nil {
t.Fatalf("err: %v", err)
}
// Check the contents.
if !reflect.DeepEqual(newMeta, metadata) {
t.Fatalf("bad: %#v", newMeta)
}
var buf bytes.Buffer
if _, err := io.Copy(&buf, &newSnap); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf.Bytes(), expected.Bytes()) {
t.Fatalf("snapshot contents didn't match")
}
}
func TestArchive_GoodData(t *testing.T) {
paths := []string{
"./testdata/snapshot/spaces-meta.tar",
}
for i, p := range paths {
f, err := os.Open(p)
if err != nil {
t.Fatalf("err: %v", err)
}
defer f.Close()
var metadata raft.SnapshotMeta
err = read(f, &metadata, ioutil.Discard)
if err != nil {
t.Fatalf("case %d: should've read the snapshot, but didn't: %v", i, err)
}
}
}
func TestArchive_BadData(t *testing.T) {
cases := []struct {
Name string
Error string
}{
{"./testdata/snapshot/empty.tar", "failed checking integrity of snapshot"},
{"./testdata/snapshot/extra.tar", "unexpected file \"nope\""},
{"./testdata/snapshot/missing-meta.tar", "hash check failed for \"meta.json\""},
{"./testdata/snapshot/missing-state.tar", "hash check failed for \"state.bin\""},
{"./testdata/snapshot/missing-sha.tar", "file missing"},
{"./testdata/snapshot/corrupt-meta.tar", "hash check failed for \"meta.json\""},
{"./testdata/snapshot/corrupt-state.tar", "hash check failed for \"state.bin\""},
{"./testdata/snapshot/corrupt-sha.tar", "list missing hash for \"nope\""},
}
for i, c := range cases {
f, err := os.Open(c.Name)
if err != nil {
t.Fatalf("err: %v", err)
}
defer f.Close()
var metadata raft.SnapshotMeta
err = read(f, &metadata, ioutil.Discard)
if err == nil || !strings.Contains(err.Error(), c.Error) {
t.Fatalf("case %d (%s): %v", i, c.Name, err)
}
}
}
func TestArchive_hashList(t *testing.T) {
hl := newHashList()
for i := 0; i < 16; i++ {
h := hl.Add(fmt.Sprintf("file-%d", i))
if _, err := io.CopyN(h, rand.Reader, 32); err != nil {
t.Fatalf("err: %v", err)
}
}
// Do a normal round trip.
var buf bytes.Buffer
if err := hl.Encode(&buf); err != nil {
t.Fatalf("err: %v", err)
}
if err := hl.DecodeAndVerify(&buf); err != nil {
t.Fatalf("err: %v", err)
}
// Have a local hash that isn't in the file.
buf.Reset()
if err := hl.Encode(&buf); err != nil {
t.Fatalf("err: %v", err)
}
hl.Add("nope")
err := hl.DecodeAndVerify(&buf)
if err == nil || !strings.Contains(err.Error(), "file missing for \"nope\"") {
t.Fatalf("err: %v", err)
}
// Have a hash in the file that we haven't seen locally.
buf.Reset()
if err := hl.Encode(&buf); err != nil {
t.Fatalf("err: %v", err)
}
delete(hl.hashes, "nope")
err = hl.DecodeAndVerify(&buf)
if err == nil || !strings.Contains(err.Error(), "list missing hash for \"nope\"") {
t.Fatalf("err: %v", err)
}
}

245
helper/snapshot/snapshot.go Normal file
View File

@@ -0,0 +1,245 @@
// snapshot manages the interactions between Nomad and Raft in order to take
// and restore snapshots for disaster recovery. The internal format of a
// snapshot is simply a tar file, as described in archive.go.
package snapshot
import (
"compress/gzip"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/raft"
)
// Snapshot is a structure that holds state about a temporary file that is used
// to hold a snapshot. By using an intermediate file we avoid holding everything
// in memory.
type Snapshot struct {
file *os.File
index uint64
checksum string
}
// New takes a state snapshot of the given Raft instance into a temporary file
// and returns an object that gives access to the file as an io.Reader. You must
// arrange to call Close() on the returned object or else you will leak a
// temporary file.
func New(logger hclog.Logger, r *raft.Raft) (*Snapshot, error) {
// Take the snapshot.
future := r.Snapshot()
if err := future.Error(); err != nil {
return nil, fmt.Errorf("Raft error when taking snapshot: %v", err)
}
// Open up the snapshot.
metadata, snap, err := future.Open()
if err != nil {
return nil, fmt.Errorf("failed to open snapshot: %v:", err)
}
defer func() {
if err := snap.Close(); err != nil {
logger.Error("Failed to close Raft snapshot", "error", err)
}
}()
// Make a scratch file to receive the contents so that we don't buffer
// everything in memory. This gets deleted in Close() since we keep it
// around for re-reading.
archive, err := ioutil.TempFile("", "snapshot")
if err != nil {
return nil, fmt.Errorf("failed to create snapshot file: %v", err)
}
// If anything goes wrong after this point, we will attempt to clean up
// the temp file. The happy path will disarm this.
var keep bool
defer func() {
if keep {
return
}
if err := os.Remove(archive.Name()); err != nil {
logger.Error("Failed to clean up temp snapshot", "error", err)
}
}()
hash := sha256.New()
out := io.MultiWriter(hash, archive)
// Wrap the file writer in a gzip compressor.
compressor := gzip.NewWriter(out)
// Write the archive.
if err := write(compressor, metadata, snap); err != nil {
return nil, fmt.Errorf("failed to write snapshot file: %v", err)
}
// Finish the compressed stream.
if err := compressor.Close(); err != nil {
return nil, fmt.Errorf("failed to compress snapshot file: %v", err)
}
// Sync the compressed file and rewind it so it's ready to be streamed
// out by the caller.
if err := archive.Sync(); err != nil {
return nil, fmt.Errorf("failed to sync snapshot: %v", err)
}
if _, err := archive.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to rewind snapshot: %v", err)
}
checksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
keep = true
return &Snapshot{archive, metadata.Index, checksum}, nil
}
// Index returns the index of the snapshot. This is safe to call on a nil
// snapshot, it will just return 0.
func (s *Snapshot) Index() uint64 {
if s == nil {
return 0
}
return s.index
}
func (s *Snapshot) Checksum() string {
if s == nil {
return ""
}
return s.checksum
}
// Read passes through to the underlying snapshot file. This is safe to call on
// a nil snapshot, it will just return an EOF.
func (s *Snapshot) Read(p []byte) (n int, err error) {
if s == nil {
return 0, io.EOF
}
return s.file.Read(p)
}
// Close closes the snapshot and removes any temporary storage associated with
// it. You must arrange to call this whenever NewSnapshot() has been called
// successfully. This is safe to call on a nil snapshot.
func (s *Snapshot) Close() error {
if s == nil {
return nil
}
if err := s.file.Close(); err != nil {
return err
}
return os.Remove(s.file.Name())
}
// Verify takes the snapshot from the reader and verifies its contents.
func Verify(in io.Reader) (*raft.SnapshotMeta, error) {
// Wrap the reader in a gzip decompressor.
decomp, err := gzip.NewReader(in)
if err != nil {
return nil, fmt.Errorf("failed to decompress snapshot: %v", err)
}
defer decomp.Close()
// Read the archive, throwing away the snapshot data.
var metadata raft.SnapshotMeta
if err := read(decomp, &metadata, ioutil.Discard); err != nil {
return nil, fmt.Errorf("failed to read snapshot file: %v", err)
}
if err := concludeGzipRead(decomp); err != nil {
return nil, err
}
return &metadata, nil
}
// concludeGzipRead should be invoked after you think you've consumed all of
// the data from the gzip stream. It will error if the stream was corrupt.
//
// The docs for gzip.Reader say: "Clients should treat data returned by Read as
// tentative until they receive the io.EOF marking the end of the data."
func concludeGzipRead(decomp *gzip.Reader) error {
extra, err := ioutil.ReadAll(decomp) // ReadAll consumes the EOF
if err != nil {
return err
} else if len(extra) != 0 {
return fmt.Errorf("%d unread uncompressed bytes remain", len(extra))
}
return nil
}
type readWrapper struct {
in io.Reader
c int
}
func (r *readWrapper) Read(b []byte) (int, error) {
n, err := r.in.Read(b)
r.c += n
if err != nil && err != io.EOF {
return n, fmt.Errorf("failed to read after %v: %v", r.c, err)
}
return n, err
}
// Restore takes the snapshot from the reader and attempts to apply it to the
// given Raft instance.
func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
// Wrap the reader in a gzip decompressor.
decomp, err := gzip.NewReader(&readWrapper{in, 0})
if err != nil {
return fmt.Errorf("failed to decompress snapshot: %v", err)
}
defer func() {
if err := decomp.Close(); err != nil {
logger.Error("Failed to close snapshot decompressor", "error", err)
}
}()
// Make a scratch file to receive the contents of the snapshot data so
// we can avoid buffering in memory.
snap, err := ioutil.TempFile("", "snapshot")
if err != nil {
return fmt.Errorf("failed to create temp snapshot file: %v", err)
}
defer func() {
if err := snap.Close(); err != nil {
logger.Error("Failed to close temp snapshot", "error", err)
}
if err := os.Remove(snap.Name()); err != nil {
logger.Error("Failed to clean up temp snapshot", "error", err)
}
}()
// Read the archive.
var metadata raft.SnapshotMeta
if err := read(decomp, &metadata, snap); err != nil {
return fmt.Errorf("failed to read snapshot file: %v", err)
}
if err := concludeGzipRead(decomp); err != nil {
return err
}
// Sync and rewind the file so it's ready to be read again.
if err := snap.Sync(); err != nil {
return fmt.Errorf("failed to sync temp snapshot: %v", err)
}
if _, err := snap.Seek(0, 0); err != nil {
return fmt.Errorf("failed to rewind temp snapshot: %v", err)
}
// Feed the snapshot into Raft.
if err := r.Restore(&metadata, snap, 0); err != nil {
return fmt.Errorf("Raft error when restoring snapshot: %v", err)
}
return nil
}

View File

@@ -0,0 +1,349 @@
package snapshot
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/require"
)
// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
// byte slices.
type MockFSM struct {
sync.Mutex
logs [][]byte
}
// MockSnapshot is a snapshot sink for testing that encodes the contents of a
// MockFSM using msgpack.
type MockSnapshot struct {
logs [][]byte
maxIndex int
}
// See raft.FSM.
func (m *MockFSM) Apply(log *raft.Log) interface{} {
m.Lock()
defer m.Unlock()
m.logs = append(m.logs, log.Data)
return len(m.logs)
}
// See raft.FSM.
func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
m.Lock()
defer m.Unlock()
return &MockSnapshot{m.logs, len(m.logs)}, nil
}
// See raft.FSM.
func (m *MockFSM) Restore(in io.ReadCloser) error {
m.Lock()
defer m.Unlock()
defer in.Close()
dec := codec.NewDecoder(in, structs.MsgpackHandle)
m.logs = nil
return dec.Decode(&m.logs)
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
enc := codec.NewEncoder(sink, structs.MsgpackHandle)
if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
sink.Cancel()
return err
}
sink.Close()
return nil
}
// See raft.SnapshotSink.
func (m *MockSnapshot) Release() {
}
// makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
fsm := &MockFSM{}
store := raft.NewInmemStore()
addr, trans := raft.NewInmemTransport("")
config := raft.DefaultConfig()
config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
var members raft.Configuration
members.Servers = append(members.Servers, raft.Server{
Suffrage: raft.Voter,
ID: config.LocalID,
Address: addr,
})
err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
if err != nil {
t.Fatalf("err: %v", err)
}
raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
if err != nil {
t.Fatalf("err: %v", err)
}
timeout := time.After(10 * time.Second)
for {
if raft.Leader() != "" {
break
}
select {
case <-raft.LeaderCh():
case <-time.After(1 * time.Second):
// Need to poll because we might have missed the first
// go with the leader channel.
case <-timeout:
t.Fatalf("timed out waiting for leader")
}
}
return raft, fsm
}
func TestSnapshot(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
defer snap.Close()
// Verify the snapshot. We have to rewind it after for the restore.
metadata, err := Verify(snap)
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err := snap.file.Seek(0, 0); err != nil {
t.Fatalf("err: %v", err)
}
if int(metadata.Index) != entries+2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Term != 2 {
t.Fatalf("bad: %d", metadata.Index)
}
if metadata.Version != raft.SnapshotVersionMax {
t.Fatalf("bad: %d", metadata.Version)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that the snapshot should overwrite.
for i := 0; i < 16; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Restore the snapshot.
if err := Restore(logger, snap, after); err != nil {
t.Fatalf("err: %v", err)
}
// Compare the contents.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}
func TestSnapshot_Nil(t *testing.T) {
var snap *Snapshot
if idx := snap.Index(); idx != 0 {
t.Fatalf("bad: %d", idx)
}
n, err := snap.Read(make([]byte, 16))
if n != 0 || err != io.EOF {
t.Fatalf("bad: %d %v", n, err)
}
if err := snap.Close(); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_BadVerify(t *testing.T) {
buf := bytes.NewBuffer([]byte("nope"))
_, err := Verify(buf)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
}
func TestSnapshot_TruncatedVerify(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data. We tee everything we
// apply off to a buffer for checking post-snapshot.
var expected []bytes.Buffer
entries := 64 * 1024
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < entries; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
_, err := io.CopyN(both, rand.Reader, 256)
require.NoError(t, err)
future := before.Apply(log.Bytes(), time.Second)
require.NoError(t, future.Error())
expected = append(expected, copy)
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
require.NoError(t, err)
defer snap.Close()
var data []byte
{
var buf bytes.Buffer
_, err = io.Copy(&buf, snap)
require.NoError(t, err)
data = buf.Bytes()
}
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
buf := bytes.NewReader(data[0 : len(data)-removeBytes])
_, err = Verify(buf)
require.Error(t, err)
})
}
}
func TestSnapshot_BadRestore(t *testing.T) {
dir := testutil.TempDir(t, "snapshot")
defer os.RemoveAll(dir)
// Make a Raft and populate it with some data.
before, _ := makeRaft(t, filepath.Join(dir, "before"))
defer before.Shutdown()
for i := 0; i < 16*1024; i++ {
var log bytes.Buffer
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := before.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
}
// Take a snapshot.
logger := testutil.Logger(t)
snap, err := New(logger, before)
if err != nil {
t.Fatalf("err: %v", err)
}
// Make a new, independent Raft.
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
defer after.Shutdown()
// Put some initial data in there that should not be harmed by the
// failed restore attempt.
var expected []bytes.Buffer
for i := 0; i < 16; i++ {
var log bytes.Buffer
var copy bytes.Buffer
both := io.MultiWriter(&log, &copy)
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
t.Fatalf("err: %v", err)
}
future := after.Apply(log.Bytes(), time.Second)
if err := future.Error(); err != nil {
t.Fatalf("err: %v", err)
}
expected = append(expected, copy)
}
// Attempt to restore a truncated version of the snapshot. This is
// expected to fail.
err = Restore(logger, io.LimitReader(snap, 512), after)
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
t.Fatalf("err: %v", err)
}
// Compare the contents to make sure the aborted restore didn't harm
// anything.
fsm.Lock()
defer fsm.Unlock()
if len(fsm.logs) != len(expected) {
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
}
for i := range fsm.logs {
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
t.Fatalf("bad: log %d doesn't match", i)
}
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -2,11 +2,14 @@ package nomad
import (
"fmt"
"io"
"net"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/consul/agent/consul/autopilot"
"github.com/hashicorp/nomad/helper/snapshot"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/raft"
"github.com/hashicorp/serf/serf"
@@ -18,6 +21,10 @@ type Operator struct {
logger log.Logger
}
func (op *Operator) register() {
op.srv.streamingRpcs.Register("Operator.SnapshotSave", op.snapshotSave)
}
// RaftGetConfiguration is used to retrieve the current Raft configuration.
func (op *Operator) RaftGetConfiguration(args *structs.GenericRequest, reply *structs.RaftConfigurationResponse) error {
if done, err := op.srv.forward("Operator.RaftGetConfiguration", args, args, reply); done {
@@ -355,3 +362,111 @@ func (op *Operator) SchedulerGetConfiguration(args *structs.GenericRequest, repl
return nil
}
func (op *Operator) forwardStreamingRPC(region string, method string, args interface{}, in io.ReadWriteCloser) error {
server, err := op.srv.findRegionServer(region)
if err != nil {
return err
}
return op.forwardStreamingRPCToServer(server, method, args, in)
}
func (op *Operator) forwardStreamingRPCToServer(server *serverParts, method string, args interface{}, in io.ReadWriteCloser) error {
srvConn, err := op.srv.streamingRpc(server, method)
if err != nil {
return err
}
defer srvConn.Close()
outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
return err
}
structs.Bridge(in, srvConn)
return nil
}
func (op *Operator) snapshotSave(conn io.ReadWriteCloser) {
defer conn.Close()
var args structs.SnapshotSaveRequest
var reply structs.SnapshotSaveResponse
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
handleFailure := func(code int, err error) {
encoder.Encode(&structs.SnapshotSaveResponse{
ErrorCode: code,
ErrorMsg: err.Error(),
})
}
if err := decoder.Decode(&args); err != nil {
handleFailure(500, err)
return
}
// Forward to appropriate region
if args.Region != op.srv.Region() {
err := op.forwardStreamingRPC(args.Region, "Operator.SnapshotSave", args, conn)
if err != nil {
handleFailure(500, err)
}
return
}
// forward to leader
if !args.AllowStale {
remoteServer, err := op.srv.getLeaderForRPC()
if err != nil {
handleFailure(500, err)
return
}
if remoteServer != nil {
err := op.forwardStreamingRPCToServer(remoteServer, "Operator.SnapshotSave", args, conn)
if err != nil {
handleFailure(500, err)
}
return
}
}
// Check agent permissions
if aclObj, err := op.srv.ResolveToken(args.AuthToken); err != nil {
code := 500
if err == structs.ErrTokenNotFound {
code = 400
}
handleFailure(code, err)
return
} else if aclObj != nil && !aclObj.IsManagement() {
handleFailure(403, structs.ErrPermissionDenied)
return
}
op.srv.setQueryMeta(&reply.QueryMeta)
// Take the snapshot and capture the index.
snap, err := snapshot.New(op.logger.Named("snapshot"), op.srv.raft)
reply.SnapshotChecksum = snap.Checksum()
reply.Index = snap.Index()
if err != nil {
handleFailure(500, err)
return
}
defer snap.Close()
enc := codec.NewEncoder(conn, structs.MsgpackHandle)
if err := enc.Encode(&reply); err != nil {
handleFailure(500, fmt.Errorf("failed to encode response: %v", err))
return
}
if snap != nil {
if _, err := io.Copy(conn, snap); err != nil {
handleFailure(500, fmt.Errorf("failed to stream snapshot: %v", err))
}
}
}

View File

@@ -1,14 +1,24 @@
package nomad
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"path"
"reflect"
"strings"
"testing"
"github.com/hashicorp/go-msgpack/codec"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/helper/freeport"
"github.com/hashicorp/nomad/helper/snapshot"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
@@ -521,3 +531,186 @@ func TestOperator_SchedulerSetConfiguration_ACL(t *testing.T) {
}
}
func TestOperator_SnapshotSave(t *testing.T) {
t.Parallel()
////// Nomad clusters topology - not specific to test
dir, err := ioutil.TempDir("", "nomadtest-operator-")
require.NoError(t, err)
defer os.RemoveAll(dir)
server1, cleanupLS := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "server1")
})
defer cleanupLS()
server2, cleanupRS := TestServer(t, func(c *Config) {
c.BootstrapExpect = 2
c.DevMode = false
c.DataDir = path.Join(dir, "server2")
})
defer cleanupRS()
remoteRegionServer, cleanupRRS := TestServer(t, func(c *Config) {
c.Region = "two"
c.DevMode = false
c.DataDir = path.Join(dir, "remote_region_server")
})
defer cleanupRRS()
TestJoin(t, server1, server2)
TestJoin(t, server1, remoteRegionServer)
testutil.WaitForLeader(t, server1.RPC)
testutil.WaitForLeader(t, server2.RPC)
testutil.WaitForLeader(t, remoteRegionServer.RPC)
leader, nonLeader := server1, server2
if server2.IsLeader() {
leader, nonLeader = server2, server1
}
///////// Actually run query now
cases := []struct {
name string
server *Server
}{
{"leader", leader},
{"non_leader", nonLeader},
{"remote_region", remoteRegionServer},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
handler, err := c.server.StreamingRpcHandler("Operator.SnapshotSave")
require.NoError(t, err)
p1, p2 := net.Pipe()
defer p1.Close()
defer p2.Close()
// start handler
go handler(p2)
var req structs.SnapshotSaveRequest
var resp structs.SnapshotSaveResponse
req.Region = "global"
// send request
encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
err = encoder.Encode(&req)
require.NoError(t, err)
decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
err = decoder.Decode(&resp)
require.NoError(t, err)
require.Empty(t, resp.ErrorMsg)
require.NotZero(t, resp.Index)
require.NotEmpty(t, resp.SnapshotChecksum)
require.Contains(t, resp.SnapshotChecksum, "sha-256=")
index := resp.Index
snap, err := ioutil.TempFile("", "nomadtests-snapshot-")
require.NoError(t, err)
defer os.Remove(snap.Name())
hash := sha256.New()
_, err = io.Copy(io.MultiWriter(snap, hash), p1)
require.NoError(t, err)
expectedChecksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
require.Equal(t, expectedChecksum, resp.SnapshotChecksum)
_, err = snap.Seek(0, 0)
require.NoError(t, err)
meta, err := snapshot.Verify(snap)
require.NoError(t, err)
require.NotZerof(t, meta.Term, "snapshot term")
require.Equal(t, index, meta.Index)
})
}
}
func TestOperator_SnapshotSave_ACL(t *testing.T) {
t.Parallel()
////// Nomad clusters topology - not specific to test
dir, err := ioutil.TempDir("", "nomadtest-operator-")
require.NoError(t, err)
defer os.RemoveAll(dir)
s, root, cleanupLS := TestACLServer(t, func(c *Config) {
c.BootstrapExpect = 1
c.DevMode = false
c.DataDir = path.Join(dir, "server1")
})
defer cleanupLS()
testutil.WaitForLeader(t, s.RPC)
deniedToken := mock.CreatePolicyAndToken(t, s.fsm.State(), 1001, "test-invalid", mock.NodePolicy(acl.PolicyWrite))
///////// Actually run query now
cases := []struct {
name string
token string
errCode int
err error
}{
{"root", root.SecretID, 0, nil},
{"no_permission_token", deniedToken.SecretID, 403, structs.ErrPermissionDenied},
{"invalid token", uuid.Generate(), 400, structs.ErrTokenNotFound},
{"unauthenticated", "", 403, structs.ErrPermissionDenied},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
handler, err := s.StreamingRpcHandler("Operator.SnapshotSave")
require.NoError(t, err)
p1, p2 := net.Pipe()
defer p1.Close()
defer p2.Close()
// start handler
go handler(p2)
var req structs.SnapshotSaveRequest
var resp structs.SnapshotSaveResponse
req.Region = "global"
req.AuthToken = c.token
// send request
encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
err = encoder.Encode(&req)
require.NoError(t, err)
decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
err = decoder.Decode(&resp)
require.NoError(t, err)
// streaming errors appear as a response rather than a returned error
if c.err != nil {
require.Equal(t, c.err.Error(), resp.ErrorMsg)
require.Equal(t, c.errCode, resp.ErrorCode)
return
}
require.NotZero(t, resp.Index)
require.NotEmpty(t, resp.SnapshotChecksum)
require.Contains(t, resp.SnapshotChecksum, "sha-256=")
io.Copy(ioutil.Discard, p1)
})
}
}

View File

@@ -507,8 +507,6 @@ func (r *rpcHandler) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCt
// forward is used to forward to a remote region or to forward to the local leader
// Returns a bool of if forwarding was performed, as well as any error
func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) {
var firstCheck time.Time
region := info.RequestRegion()
if region == "" {
return true, fmt.Errorf("missing region for target RPC")
@@ -527,21 +525,41 @@ func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface
return false, nil
}
remoteServer, err := r.getLeaderForRPC()
if err != nil {
return true, err
}
// we are the leader
if remoteServer == nil {
return false, nil
}
// forward to leader
info.SetForwarded()
err = r.forwardLeader(remoteServer, method, args, reply)
return true, err
}
// getLeaderForRPC returns the server info of the currently known leader, or
// nil if this server is the current leader. If the local server is the leader
// it blocks until it is ready to handle consistent RPC invocations. If leader
// is not known or consistency isn't guaranteed, an error is returned.
func (r *rpcHandler) getLeaderForRPC() (*serverParts, error) {
var firstCheck time.Time
CHECK_LEADER:
// Find the leader
isLeader, remoteServer := r.getLeader()
// Handle the case we are the leader
if isLeader && r.Server.isReadyForConsistentReads() {
return false, nil
return nil, nil
}
// Handle the case of a known leader
if remoteServer != nil {
// Mark that we are forwarding the RPC
info.SetForwarded()
err := r.forwardLeader(remoteServer, method, args, reply)
return true, err
return remoteServer, nil
}
// Gate the request until there is a leader
@@ -559,10 +577,11 @@ CHECK_LEADER:
// hold time exceeeded without being ready to respond
if isLeader {
return true, structs.ErrNotReadyForConsistentReads
return nil, structs.ErrNotReadyForConsistentReads
}
return true, structs.ErrNoLeader
return nil, structs.ErrNoLeader
}
// getLeader returns if the current node is the leader, and if not
@@ -607,21 +626,27 @@ func (r *rpcHandler) forwardServer(server *serverParts, method string, args inte
return r.connPool.RPC(r.config.Region, server.Addr, server.MajorVersion, method, args, reply)
}
// forwardRegion is used to forward an RPC call to a remote region, or fail if no servers
func (r *rpcHandler) forwardRegion(region, method string, args interface{}, reply interface{}) error {
// Bail if we can't find any servers
func (r *rpcHandler) findRegionServer(region string) (*serverParts, error) {
r.peerLock.RLock()
defer r.peerLock.RUnlock()
servers := r.peers[region]
if len(servers) == 0 {
r.peerLock.RUnlock()
r.logger.Warn("no path found to region", "region", region)
return structs.ErrNoRegionPath
return nil, structs.ErrNoRegionPath
}
// Select a random addr
offset := rand.Intn(len(servers))
server := servers[offset]
r.peerLock.RUnlock()
return servers[offset], nil
}
// forwardRegion is used to forward an RPC call to a remote region, or fail if no servers
func (r *rpcHandler) forwardRegion(region, method string, args interface{}, reply interface{}) error {
server, err := r.findRegionServer(region)
if err != nil {
return err
}
// Forward to remote Nomad
metrics.IncrCounter([]string{"nomad", "rpc", "cross-region", region}, 1)

View File

@@ -1131,6 +1131,8 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
s.staticEndpoints.Operator.register()
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")}
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}

View File

@@ -224,3 +224,25 @@ type SchedulerSetConfigRequest struct {
// WriteRequest holds the ACL token to go along with this request.
WriteRequest
}
// SnapshotSaveRequest is used by the Operator endpoint to get a Raft snapshot
type SnapshotSaveRequest struct {
QueryOptions
}
// SnapshotSaveResponse is the header for the streaming snapshot endpoint,
// and followed by the snapshot file content.
type SnapshotSaveResponse struct {
// SnapshotChecksum returns the checksum of snapshot file in the format
// `<algo>=<base64>` (e.g. `sha-256=...`)
SnapshotChecksum string
// ErrorCode is an http error code if an error is found, e.g. 403 for permission errors
ErrorCode int `codec:",omitempty"`
// ErrorMsg is the error message if an error is found, e.g. "Permission Denied"
ErrorMsg string `codec:",omitempty"`
QueryMeta
}