mirror of
https://github.com/kemko/nomad.git
synced 2026-01-03 17:05:43 +03:00
Merge pull request #8047 from hashicorp/f-snapshot-save
API for atomic snapshot backups
This commit is contained in:
81
api/ioutil.go
Normal file
81
api/ioutil.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errMismatchChecksum = fmt.Errorf("mismatch checksum")
|
||||
|
||||
// checksumValidatingReader is a wrapper reader that validates
|
||||
// the checksum of the underlying reader.
|
||||
type checksumValidatingReader struct {
|
||||
r io.ReadCloser
|
||||
|
||||
// algo is the hash algorithm (e.g. `sha-256`)
|
||||
algo string
|
||||
|
||||
// checksum is the base64 component of checksum
|
||||
checksum string
|
||||
|
||||
// hash is the hashing function used to compute the checksum
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
// newChecksumValidatingReader returns a checksum-validating wrapper reader, according
|
||||
// to a digest received in HTTP header
|
||||
//
|
||||
// The digest must be in the format "<algo>=<base64 of hash>" (e.g. "sha-256=gPelGB7...").
|
||||
//
|
||||
// When the reader is fully consumed (i.e. EOT is encountered), if the checksum don't match,
|
||||
// `Read` returns a checksum mismatch error.
|
||||
func newChecksumValidatingReader(r io.ReadCloser, digest string) (io.ReadCloser, error) {
|
||||
parts := strings.SplitN(digest, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid digest format")
|
||||
}
|
||||
|
||||
algo := parts[0]
|
||||
var hash hash.Hash
|
||||
switch algo {
|
||||
case "sha-256":
|
||||
hash = sha256.New()
|
||||
case "sha-512":
|
||||
hash = sha512.New()
|
||||
case "md5":
|
||||
hash = md5.New()
|
||||
}
|
||||
|
||||
return &checksumValidatingReader{
|
||||
r: r,
|
||||
algo: algo,
|
||||
checksum: parts[1],
|
||||
hash: hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *checksumValidatingReader) Read(b []byte) (int, error) {
|
||||
n, err := r.r.Read(b)
|
||||
if n != 0 {
|
||||
r.hash.Write(b[:n])
|
||||
}
|
||||
|
||||
if err == io.EOF || err == io.ErrClosedPipe {
|
||||
found := base64.StdEncoding.EncodeToString(r.hash.Sum(nil))
|
||||
if found != r.checksum {
|
||||
return n, errMismatchChecksum
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *checksumValidatingReader) Close() error {
|
||||
return r.r.Close()
|
||||
}
|
||||
87
api/ioutil_test.go
Normal file
87
api/ioutil_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChecksumValidatingReader(t *testing.T) {
|
||||
data := make([]byte, 4096)
|
||||
_, err := rand.Read(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
algo string
|
||||
hash hash.Hash
|
||||
}{
|
||||
{"sha-256", sha256.New()},
|
||||
{"sha-512", sha512.New()},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run("valid: "+c.algo, func(t *testing.T) {
|
||||
_, err := c.hash.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
checksum := c.hash.Sum(nil)
|
||||
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
|
||||
|
||||
r := iotest.HalfReader(bytes.NewReader(data))
|
||||
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(ioutil.Discard, cr)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid: "+c.algo, func(t *testing.T) {
|
||||
_, err := c.hash.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
checksum := c.hash.Sum(nil)
|
||||
// mess up checksum
|
||||
checksum[0]++
|
||||
digest := c.algo + "=" + base64.StdEncoding.EncodeToString(checksum)
|
||||
|
||||
r := iotest.HalfReader(bytes.NewReader(data))
|
||||
cr, err := newChecksumValidatingReader(ioutil.NopCloser(r), digest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(ioutil.Discard, cr)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, errMismatchChecksum, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumValidatingReader_PropagatesError(t *testing.T) {
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
defer pw.Close()
|
||||
|
||||
expectedErr := fmt.Errorf("some error")
|
||||
|
||||
go func() {
|
||||
pw.Write([]byte("some input"))
|
||||
pw.CloseWithError(expectedErr)
|
||||
}()
|
||||
|
||||
cr, err := newChecksumValidatingReader(pr, "sha-256=aaaa")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.Copy(ioutil.Discard, cr)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, expectedErr, err)
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
@@ -194,6 +196,32 @@ func (op *Operator) SchedulerCASConfiguration(conf *SchedulerConfiguration, q *W
|
||||
return &out, wm, nil
|
||||
}
|
||||
|
||||
// Snapshot is used to capture a snapshot state of a running cluster.
|
||||
// The returned reader that must be consumed fully
|
||||
func (op *Operator) Snapshot(q *QueryOptions) (io.ReadCloser, error) {
|
||||
r, err := op.c.newRequest("GET", "/v1/operator/snapshot")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.setQueryOptions(q)
|
||||
_, resp, err := requireOK(op.c.doRequest(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digest := resp.Header.Get("Digest")
|
||||
|
||||
cr, err := newChecksumValidatingReader(resp.Body, digest)
|
||||
if err != nil {
|
||||
io.Copy(ioutil.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
type License struct {
|
||||
// The unique identifier of the license
|
||||
LicenseID string
|
||||
|
||||
@@ -318,6 +318,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) {
|
||||
s.mux.HandleFunc("/v1/operator/raft/", s.wrap(s.OperatorRequest))
|
||||
s.mux.HandleFunc("/v1/operator/autopilot/configuration", s.wrap(s.OperatorAutopilotConfiguration))
|
||||
s.mux.HandleFunc("/v1/operator/autopilot/health", s.wrap(s.OperatorServerHealth))
|
||||
s.mux.HandleFunc("/v1/operator/snapshot", s.wrap(s.SnapshotRequest))
|
||||
|
||||
s.mux.HandleFunc("/v1/system/gc", s.wrap(s.GarbageCollectRequest))
|
||||
s.mux.HandleFunc("/v1/system/reconcile/summaries", s.wrap(s.ReconcileJobSummaries))
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -9,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
"github.com/hashicorp/nomad/api"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/raft"
|
||||
@@ -283,3 +287,88 @@ func (s *HTTPServer) schedulerUpdateConfig(resp http.ResponseWriter, req *http.R
|
||||
setIndex(resp, reply.Index)
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func (s *HTTPServer) SnapshotRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
switch req.Method {
|
||||
case "GET":
|
||||
return s.snapshotSaveRequest(resp, req)
|
||||
default:
|
||||
return nil, CodedError(405, ErrInvalidMethod)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *HTTPServer) snapshotSaveRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||||
args := &structs.SnapshotSaveRequest{}
|
||||
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var handler structs.StreamingRpcHandler
|
||||
var handlerErr error
|
||||
|
||||
if server := s.agent.Server(); server != nil {
|
||||
handler, handlerErr = server.StreamingRpcHandler("Operator.SnapshotSave")
|
||||
} else if client := s.agent.Client(); client != nil {
|
||||
handler, handlerErr = client.RemoteStreamingRpcHandler("Operator.SnapshotSave")
|
||||
} else {
|
||||
handlerErr = fmt.Errorf("misconfigured connection")
|
||||
}
|
||||
|
||||
if handlerErr != nil {
|
||||
return nil, CodedError(500, handlerErr.Error())
|
||||
}
|
||||
|
||||
httpPipe, handlerPipe := net.Pipe()
|
||||
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
|
||||
encoder := codec.NewEncoder(httpPipe, structs.MsgpackHandle)
|
||||
|
||||
// Create a goroutine that closes the pipe if the connection closes.
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
httpPipe.Close()
|
||||
}()
|
||||
|
||||
errCh := make(chan HTTPCodedError, 1)
|
||||
go func() {
|
||||
defer cancel()
|
||||
|
||||
// Send the request
|
||||
if err := encoder.Encode(args); err != nil {
|
||||
errCh <- CodedError(500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var res structs.SnapshotSaveResponse
|
||||
if err := decoder.Decode(&res); err != nil {
|
||||
errCh <- CodedError(500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if res.ErrorMsg != "" {
|
||||
errCh <- CodedError(res.ErrorCode, res.ErrorMsg)
|
||||
return
|
||||
}
|
||||
|
||||
resp.Header().Add("Digest", res.SnapshotChecksum)
|
||||
|
||||
_, err := io.Copy(resp, httpPipe)
|
||||
if err != nil &&
|
||||
err != io.EOF &&
|
||||
!strings.Contains(err.Error(), "closed") &&
|
||||
!strings.Contains(err.Error(), "EOF") {
|
||||
errCh <- CodedError(500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
handler(handlerPipe)
|
||||
cancel()
|
||||
codedErr := <-errCh
|
||||
|
||||
return nil, codedErr
|
||||
}
|
||||
|
||||
@@ -2,9 +2,15 @@ package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -382,3 +388,39 @@ func TestOperator_SchedulerCASConfiguration(t *testing.T) {
|
||||
require.False(reply.SchedulerConfig.PreemptionConfig.BatchSchedulerEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOperator_SnapshotSaveRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
////// Nomad clusters topology - not specific to test
|
||||
dir, err := ioutil.TempDir("", "nomadtest-operator-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
httpTest(t, func(c *Config) {
|
||||
c.Server.BootstrapExpect = 1
|
||||
c.DevMode = false
|
||||
c.DataDir = path.Join(dir, "server")
|
||||
c.AdvertiseAddrs.HTTP = "127.0.0.1"
|
||||
c.AdvertiseAddrs.RPC = "127.0.0.1"
|
||||
c.AdvertiseAddrs.Serf = "127.0.0.1"
|
||||
}, func(s *TestAgent) {
|
||||
req, _ := http.NewRequest("GET", "/v1/operator/snapshot", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
_, err := s.Server.SnapshotRequest(resp, req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.Code)
|
||||
|
||||
digest := resp.Header().Get("Digest")
|
||||
require.NotEmpty(t, digest)
|
||||
require.Contains(t, digest, "sha-256=")
|
||||
|
||||
hash := sha256.New()
|
||||
_, err = io.Copy(hash, resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedChecksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
require.Equal(t, digest, expectedChecksum)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@@ -126,8 +126,15 @@ func (a *TestAgent) Start() *TestAgent {
|
||||
|
||||
i := 10
|
||||
|
||||
advertiseAddrs := *a.Config.AdvertiseAddrs
|
||||
RETRY:
|
||||
i--
|
||||
|
||||
// Clear out the advertise addresses such that through retries we
|
||||
// re-normalize the addresses correctly instead of using the values from the
|
||||
// last port selection that had a port conflict.
|
||||
newAddrs := advertiseAddrs
|
||||
a.Config.AdvertiseAddrs = &newAddrs
|
||||
a.pickRandomPorts(a.Config)
|
||||
if a.Config.NodeName == "" {
|
||||
a.Config.NodeName = fmt.Sprintf("Node %d", a.Config.Ports.RPC)
|
||||
@@ -312,15 +319,6 @@ func (a *TestAgent) pickRandomPorts(c *Config) {
|
||||
c.Ports.RPC = ports[1]
|
||||
c.Ports.Serf = ports[2]
|
||||
|
||||
// Clear out the advertise addresses such that through retries we
|
||||
// re-normalize the addresses correctly instead of using the values from the
|
||||
// last port selection that had a port conflict.
|
||||
if c.AdvertiseAddrs != nil {
|
||||
c.AdvertiseAddrs.HTTP = ""
|
||||
c.AdvertiseAddrs.RPC = ""
|
||||
c.AdvertiseAddrs.Serf = ""
|
||||
}
|
||||
|
||||
if err := c.normalizeAddrs(); err != nil {
|
||||
a.T.Fatalf("error normalizing config: %v", err)
|
||||
}
|
||||
|
||||
@@ -502,6 +502,22 @@ func Commands(metaPtr *Meta, agentUi cli.Ui) map[string]cli.CommandFactory {
|
||||
}, nil
|
||||
},
|
||||
|
||||
"operator snapshot": func() (cli.Command, error) {
|
||||
return &OperatorSnapshotCommand{
|
||||
Meta: meta,
|
||||
}, nil
|
||||
},
|
||||
"operator snapshot save": func() (cli.Command, error) {
|
||||
return &OperatorSnapshotSaveCommand{
|
||||
Meta: meta,
|
||||
}, nil
|
||||
},
|
||||
"operator snapshot inspect": func() (cli.Command, error) {
|
||||
return &OperatorSnapshotInspectCommand{
|
||||
Meta: meta,
|
||||
}, nil
|
||||
},
|
||||
|
||||
"plan": func() (cli.Command, error) {
|
||||
return &JobPlanCommand{
|
||||
Meta: meta,
|
||||
|
||||
50
command/operator_snapshot.go
Normal file
50
command/operator_snapshot.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/cli"
|
||||
)
|
||||
|
||||
type OperatorSnapshotCommand struct {
|
||||
Meta
|
||||
}
|
||||
|
||||
func (f *OperatorSnapshotCommand) Help() string {
|
||||
helpText := `
|
||||
Usage: nomad operator snapshot <subcommand> [options]
|
||||
|
||||
This command has subcommands for saving and inspecting the state
|
||||
of the Nomad servers for disaster recovery. These are atomic, point-in-time
|
||||
snapshots which include jobs, nodes, allocations, periodic jobs, and ACLs.
|
||||
|
||||
If ACLs are enabled, a management token must be supplied in order to perform
|
||||
snapshot operations.
|
||||
|
||||
Create a snapshot:
|
||||
|
||||
$ nomad operator snapshot save backup.snap
|
||||
|
||||
Inspect a snapshot:
|
||||
|
||||
$ nomad operator snapshot inspect backup.snap
|
||||
|
||||
Run a daemon process that locally saves a snapshot every hour (available only in
|
||||
Nomad Enterprise) :
|
||||
|
||||
$ nomad operator snapshot agent
|
||||
|
||||
Please see the individual subcommand help for detailed usage information.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
func (f *OperatorSnapshotCommand) Synopsis() string {
|
||||
return "Saves and inspects snapshots of Nomad server state"
|
||||
}
|
||||
|
||||
func (f *OperatorSnapshotCommand) Name() string { return "operator snapshot" }
|
||||
|
||||
func (f *OperatorSnapshotCommand) Run(args []string) int {
|
||||
return cli.RunResultHelp
|
||||
}
|
||||
74
command/operator_snapshot_inspect.go
Normal file
74
command/operator_snapshot_inspect.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/nomad/helper/snapshot"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
type OperatorSnapshotInspectCommand struct {
|
||||
Meta
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) Help() string {
|
||||
helpText := `
|
||||
Usage: nomad operator snapshot inspect [options] FILE
|
||||
|
||||
Displays information about a snapshot file on disk.
|
||||
|
||||
To inspect the file "backup.snap":
|
||||
$ nomad operator snapshot inspect backup.snap
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) AutocompleteFlags() complete.Flags {
|
||||
return complete.Flags{}
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) AutocompleteArgs() complete.Predictor {
|
||||
return complete.PredictNothing
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) Synopsis() string {
|
||||
return "Displays information about a Nomad snapshot file"
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) Name() string { return "operator snapshot inspect" }
|
||||
|
||||
func (c *OperatorSnapshotInspectCommand) Run(args []string) int {
|
||||
// Check that we either got no filename or exactly one.
|
||||
if len(args) != 1 {
|
||||
c.Ui.Error("This command takes one argument: <filename>")
|
||||
c.Ui.Error(commandErrorText(c))
|
||||
return 1
|
||||
}
|
||||
|
||||
path := args[0]
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error opening snapshot file: %s", err))
|
||||
return 1
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
meta, err := snapshot.Verify(f)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error verifying snapshot: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
output := []string{
|
||||
fmt.Sprintf("ID|%s", meta.ID),
|
||||
fmt.Sprintf("Size|%d", meta.Size),
|
||||
fmt.Sprintf("Index|%d", meta.Index),
|
||||
fmt.Sprintf("Term|%d", meta.Term),
|
||||
fmt.Sprintf("Version|%d", meta.Version),
|
||||
}
|
||||
|
||||
c.Ui.Output(formatList(output))
|
||||
return 0
|
||||
}
|
||||
99
command/operator_snapshot_inspect_test.go
Normal file
99
command/operator_snapshot_inspect_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/nomad/command/agent"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOperatorSnapshotInspect_Works(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
snapPath := generateSnapshotFile(t)
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
|
||||
|
||||
code := cmd.Run([]string{snapPath})
|
||||
require.Zero(t, code)
|
||||
|
||||
output := ui.OutputWriter.String()
|
||||
for _, key := range []string{
|
||||
"ID",
|
||||
"Size",
|
||||
"Index",
|
||||
"Term",
|
||||
"Version",
|
||||
} {
|
||||
require.Contains(t, output, key)
|
||||
}
|
||||
|
||||
}
|
||||
func TestOperatorSnapshotInspect_HandlesFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir, err := ioutil.TempDir("", "nomad-clitests-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = ioutil.WriteFile(
|
||||
filepath.Join(tmpDir, "invalid.snap"),
|
||||
[]byte("invalid data"),
|
||||
0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
ui := new(cli.MockUi)
|
||||
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
|
||||
|
||||
code := cmd.Run([]string{filepath.Join(tmpDir, "foo")})
|
||||
require.NotZero(t, code)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "no such file")
|
||||
})
|
||||
|
||||
t.Run("invalid file", func(t *testing.T) {
|
||||
ui := new(cli.MockUi)
|
||||
cmd := &OperatorSnapshotInspectCommand{Meta: Meta{Ui: ui}}
|
||||
|
||||
code := cmd.Run([]string{filepath.Join(tmpDir, "invalid.snap")})
|
||||
require.NotZero(t, code)
|
||||
require.Contains(t, ui.ErrorWriter.String(), "Error verifying snapshot")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func generateSnapshotFile(t *testing.T) string {
|
||||
|
||||
tmpDir, err := ioutil.TempDir("", "nomad-tempdir")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
srv, _, url := testServer(t, false, func(c *agent.Config) {
|
||||
c.DevMode = false
|
||||
c.DataDir = filepath.Join(tmpDir, "server")
|
||||
|
||||
c.AdvertiseAddrs.HTTP = "127.0.0.1"
|
||||
c.AdvertiseAddrs.RPC = "127.0.0.1"
|
||||
c.AdvertiseAddrs.Serf = "127.0.0.1"
|
||||
})
|
||||
|
||||
defer srv.Shutdown()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
cmd := &OperatorSnapshotSaveCommand{Meta: Meta{Ui: ui}}
|
||||
|
||||
dest := filepath.Join(tmpDir, "backup.snap")
|
||||
code := cmd.Run([]string{
|
||||
"--address=" + url,
|
||||
dest,
|
||||
})
|
||||
require.Zero(t, code)
|
||||
|
||||
return dest
|
||||
}
|
||||
142
command/operator_snapshot_save.go
Normal file
142
command/operator_snapshot_save.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/nomad/api"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
type OperatorSnapshotSaveCommand struct {
|
||||
Meta
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) Help() string {
|
||||
helpText := `
|
||||
Usage: nomad operator snapshot save [options] <filename>
|
||||
|
||||
Retrieves an atomic, point-in-time snapshot of the state of the Nomad servers
|
||||
which includes jobs, nodes, allocations, periodic jobs, and ACLs.
|
||||
|
||||
If ACLs are enabled, a management token must be supplied in order to perform
|
||||
snapshot operations.
|
||||
|
||||
To create a snapshot from the leader server and save it to "backup.snap":
|
||||
|
||||
$ nomad snapshot save backup.snap
|
||||
|
||||
To create a potentially stale snapshot from any available server (useful if no
|
||||
leader is available):
|
||||
|
||||
General Options:
|
||||
|
||||
` + generalOptionsUsage() + `
|
||||
|
||||
Snapshot Save Options:
|
||||
|
||||
-stale=[true|false]
|
||||
The -stale argument defaults to "false" which means the leader provides the
|
||||
result. If the cluster is in an outage state without a leader, you may need
|
||||
to set -stale to "true" to get the configuration from a non-leader server.
|
||||
`
|
||||
return strings.TrimSpace(helpText)
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) AutocompleteFlags() complete.Flags {
|
||||
return mergeAutocompleteFlags(c.Meta.AutocompleteFlags(FlagSetClient),
|
||||
complete.Flags{
|
||||
"-stale": complete.PredictAnything,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) AutocompleteArgs() complete.Predictor {
|
||||
return complete.PredictNothing
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) Synopsis() string {
|
||||
return "Saves snapshot of Nomad server state"
|
||||
}
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) Name() string { return "operator snapshot save" }
|
||||
|
||||
func (c *OperatorSnapshotSaveCommand) Run(args []string) int {
|
||||
var stale bool
|
||||
|
||||
flags := c.Meta.FlagSet(c.Name(), FlagSetClient)
|
||||
flags.Usage = func() { c.Ui.Output(c.Help()) }
|
||||
|
||||
flags.BoolVar(&stale, "stale", false, "")
|
||||
if err := flags.Parse(args); err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Failed to parse args: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Check for misuse
|
||||
// Check that we either got no filename or exactly one.
|
||||
args = flags.Args()
|
||||
if len(args) > 1 {
|
||||
c.Ui.Error("This command takes either no arguments or one: <filename>")
|
||||
c.Ui.Error(commandErrorText(c))
|
||||
return 1
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
filename := fmt.Sprintf("nomad-state-%04d%02d%0d-%d.snap", now.Year(), now.Month(), now.Day(), now.Unix())
|
||||
|
||||
if len(args) == 1 {
|
||||
filename = args[0]
|
||||
}
|
||||
|
||||
if _, err := os.Lstat(filename); err == nil {
|
||||
c.Ui.Error(fmt.Sprintf("Destination file already exists: %q", filename))
|
||||
c.Ui.Error(commandErrorText(c))
|
||||
return 1
|
||||
} else if !os.IsNotExist(err) {
|
||||
c.Ui.Error(fmt.Sprintf("Unexpected failure checking %q: %v", filename, err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Set up a client.
|
||||
client, err := c.Meta.Client()
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Error initializing client: %s", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
tmpFile, err := os.Create(filename + ".tmp")
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Failed to create file: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
// Fetch the current configuration.
|
||||
q := &api.QueryOptions{
|
||||
AllowStale: stale,
|
||||
}
|
||||
snapIn, err := client.Operator().Snapshot(q)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Failed to get snapshot file: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
defer snapIn.Close()
|
||||
|
||||
_, err = io.Copy(tmpFile, snapIn)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Filed to download snapshot file: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFile.Name(), filename)
|
||||
if err != nil {
|
||||
c.Ui.Error(fmt.Sprintf("Filed to finalize snapshot file: %v", err))
|
||||
return 1
|
||||
}
|
||||
|
||||
c.Ui.Output(fmt.Sprintf("State file written to %v", filename))
|
||||
return 0
|
||||
}
|
||||
51
command/operator_snapshot_save_test.go
Normal file
51
command/operator_snapshot_save_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/nomad/command/agent"
|
||||
"github.com/hashicorp/nomad/helper/snapshot"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOperatorSnapshotSave_Works(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir, err := ioutil.TempDir("", "nomad-tempdir")
|
||||
require.NoError(t, err)
|
||||
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
srv, _, url := testServer(t, false, func(c *agent.Config) {
|
||||
c.DevMode = false
|
||||
c.DataDir = filepath.Join(tmpDir, "server")
|
||||
|
||||
c.AdvertiseAddrs.HTTP = "127.0.0.1"
|
||||
c.AdvertiseAddrs.RPC = "127.0.0.1"
|
||||
c.AdvertiseAddrs.Serf = "127.0.0.1"
|
||||
})
|
||||
|
||||
defer srv.Shutdown()
|
||||
|
||||
ui := new(cli.MockUi)
|
||||
cmd := &OperatorSnapshotSaveCommand{Meta: Meta{Ui: ui}}
|
||||
|
||||
dest := filepath.Join(tmpDir, "backup.snap")
|
||||
code := cmd.Run([]string{
|
||||
"--address=" + url,
|
||||
dest,
|
||||
})
|
||||
require.Zero(t, code)
|
||||
require.Contains(t, ui.OutputWriter.String(), "State file written to "+dest)
|
||||
|
||||
f, err := os.Open(dest)
|
||||
require.NoError(t, err)
|
||||
|
||||
meta, err := snapshot.Verify(f)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, meta.Index)
|
||||
}
|
||||
234
helper/snapshot/archive.go
Normal file
234
helper/snapshot/archive.go
Normal file
@@ -0,0 +1,234 @@
|
||||
// The archive utilities manage the internal format of a snapshot, which is a
|
||||
// tar file with the following contents:
|
||||
//
|
||||
// meta.json - JSON-encoded snapshot metadata from Raft
|
||||
// state.bin - Encoded snapshot data from Raft
|
||||
// SHA256SUMS - SHA-256 sums of the above two files
|
||||
//
|
||||
// The integrity information is automatically created and checked, and a failure
|
||||
// there just looks like an error to the caller.
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/raft"
|
||||
)
|
||||
|
||||
// hashList manages a list of filenames and their hashes.
|
||||
type hashList struct {
|
||||
hashes map[string]hash.Hash
|
||||
}
|
||||
|
||||
// newHashList returns a new hashList.
|
||||
func newHashList() *hashList {
|
||||
return &hashList{
|
||||
hashes: make(map[string]hash.Hash),
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new hash for the given file.
|
||||
func (hl *hashList) Add(file string) hash.Hash {
|
||||
if existing, ok := hl.hashes[file]; ok {
|
||||
return existing
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
hl.hashes[file] = h
|
||||
return h
|
||||
}
|
||||
|
||||
// Encode takes the current sum of all the hashes and saves the hash list as a
|
||||
// SHA256SUMS-style text file.
|
||||
func (hl *hashList) Encode(w io.Writer) error {
|
||||
for file, h := range hl.hashes {
|
||||
if _, err := fmt.Fprintf(w, "%x %s\n", h.Sum([]byte{}), file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecodeAndVerify reads a SHA256SUMS-style text file and checks the results
|
||||
// against the current sums for all the hashes.
|
||||
func (hl *hashList) DecodeAndVerify(r io.Reader) error {
|
||||
// Read the file and make sure everything in there has a matching hash.
|
||||
seen := make(map[string]struct{})
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
sha := make([]byte, sha256.Size)
|
||||
var file string
|
||||
if _, err := fmt.Sscanf(s.Text(), "%x %s", &sha, &file); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, ok := hl.hashes[file]
|
||||
if !ok {
|
||||
return fmt.Errorf("list missing hash for %q", file)
|
||||
}
|
||||
if !bytes.Equal(sha, h.Sum([]byte{})) {
|
||||
return fmt.Errorf("hash check failed for %q", file)
|
||||
}
|
||||
seen[file] = struct{}{}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure everything we had a hash for was seen.
|
||||
for file := range hl.hashes {
|
||||
if _, ok := seen[file]; !ok {
|
||||
return fmt.Errorf("file missing for %q", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// write takes a writer and creates an archive with the snapshot metadata,
|
||||
// the snapshot itself, and adds some integrity checking information.
|
||||
func write(out io.Writer, metadata *raft.SnapshotMeta, snap io.Reader) error {
|
||||
// Start a new tarball.
|
||||
now := time.Now()
|
||||
archive := tar.NewWriter(out)
|
||||
|
||||
// Create a hash list that we will use to write a SHA256SUMS file into
|
||||
// the archive.
|
||||
hl := newHashList()
|
||||
|
||||
// Encode the snapshot metadata, which we need to feed back during a
|
||||
// restore.
|
||||
metaHash := hl.Add("meta.json")
|
||||
var metaBuffer bytes.Buffer
|
||||
enc := json.NewEncoder(&metaBuffer)
|
||||
if err := enc.Encode(metadata); err != nil {
|
||||
return fmt.Errorf("failed to encode snapshot metadata: %v", err)
|
||||
}
|
||||
if err := archive.WriteHeader(&tar.Header{
|
||||
Name: "meta.json",
|
||||
Mode: 0600,
|
||||
Size: int64(metaBuffer.Len()),
|
||||
ModTime: now,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot metadata header: %v", err)
|
||||
}
|
||||
if _, err := io.Copy(archive, io.TeeReader(&metaBuffer, metaHash)); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot metadata: %v", err)
|
||||
}
|
||||
|
||||
// Copy the snapshot data given the size from the metadata.
|
||||
snapHash := hl.Add("state.bin")
|
||||
if err := archive.WriteHeader(&tar.Header{
|
||||
Name: "state.bin",
|
||||
Mode: 0600,
|
||||
Size: metadata.Size,
|
||||
ModTime: now,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot data header: %v", err)
|
||||
}
|
||||
if _, err := io.CopyN(archive, io.TeeReader(snap, snapHash), metadata.Size); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot metadata: %v", err)
|
||||
}
|
||||
|
||||
// Create a SHA256SUMS file that we can use to verify on restore.
|
||||
var shaBuffer bytes.Buffer
|
||||
if err := hl.Encode(&shaBuffer); err != nil {
|
||||
return fmt.Errorf("failed to encode snapshot hashes: %v", err)
|
||||
}
|
||||
if err := archive.WriteHeader(&tar.Header{
|
||||
Name: "SHA256SUMS",
|
||||
Mode: 0600,
|
||||
Size: int64(shaBuffer.Len()),
|
||||
ModTime: now,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot hashes header: %v", err)
|
||||
}
|
||||
if _, err := io.Copy(archive, &shaBuffer); err != nil {
|
||||
return fmt.Errorf("failed to write snapshot metadata: %v", err)
|
||||
}
|
||||
|
||||
// Finalize the archive.
|
||||
if err := archive.Close(); err != nil {
|
||||
return fmt.Errorf("failed to finalize snapshot: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// read takes a reader and extracts the snapshot metadata and the snapshot
|
||||
// itself, and also checks the integrity of the data. You must arrange to call
|
||||
// Close() on the returned object or else you will leak a temporary file.
|
||||
func read(in io.Reader, metadata *raft.SnapshotMeta, snap io.Writer) error {
|
||||
// Start a new tar reader.
|
||||
archive := tar.NewReader(in)
|
||||
|
||||
// Create a hash list that we will use to compare with the SHA256SUMS
|
||||
// file in the archive.
|
||||
hl := newHashList()
|
||||
|
||||
// Populate the hashes for all the files we expect to see. The check at
|
||||
// the end will make sure these are all present in the SHA256SUMS file
|
||||
// and that the hashes match.
|
||||
metaHash := hl.Add("meta.json")
|
||||
snapHash := hl.Add("state.bin")
|
||||
|
||||
// Look through the archive for the pieces we care about.
|
||||
var shaBuffer bytes.Buffer
|
||||
for {
|
||||
hdr, err := archive.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading snapshot: %v", err)
|
||||
}
|
||||
|
||||
switch hdr.Name {
|
||||
case "meta.json":
|
||||
// Previously we used json.Decode to decode the archive stream. There are
|
||||
// edgecases in which it doesn't read all the bytes from the stream, even
|
||||
// though the json object is still being parsed properly. Since we
|
||||
// simultaneously feeded everything to metaHash, our hash ended up being
|
||||
// different than what we calculated when creating the snapshot. Which in
|
||||
// turn made the snapshot verification fail. By explicitly reading the
|
||||
// whole thing first we ensure that we calculate the correct hash
|
||||
// independent of how json.Decode works internally.
|
||||
buf, err := ioutil.ReadAll(io.TeeReader(archive, metaHash))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read snapshot metadata: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(buf, &metadata); err != nil {
|
||||
return fmt.Errorf("failed to decode snapshot metadata: %v", err)
|
||||
}
|
||||
|
||||
case "state.bin":
|
||||
if _, err := io.Copy(io.MultiWriter(snap, snapHash), archive); err != nil {
|
||||
return fmt.Errorf("failed to read or write snapshot data: %v", err)
|
||||
}
|
||||
|
||||
case "SHA256SUMS":
|
||||
if _, err := io.Copy(&shaBuffer, archive); err != nil {
|
||||
return fmt.Errorf("failed to read snapshot hashes: %v", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected file %q in snapshot", hdr.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all the hashes.
|
||||
if err := hl.DecodeAndVerify(&shaBuffer); err != nil {
|
||||
return fmt.Errorf("failed checking integrity of snapshot: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
153
helper/snapshot/archive_test.go
Normal file
153
helper/snapshot/archive_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/raft"
|
||||
)
|
||||
|
||||
func TestArchive(t *testing.T) {
|
||||
// Create some fake snapshot data.
|
||||
metadata := raft.SnapshotMeta{
|
||||
Index: 2005,
|
||||
Term: 2011,
|
||||
Configuration: raft.Configuration{
|
||||
Servers: []raft.Server{
|
||||
raft.Server{
|
||||
Suffrage: raft.Voter,
|
||||
ID: raft.ServerID("hello"),
|
||||
Address: raft.ServerAddress("127.0.0.1:8300"),
|
||||
},
|
||||
},
|
||||
},
|
||||
Size: 1024,
|
||||
}
|
||||
var snap bytes.Buffer
|
||||
var expected bytes.Buffer
|
||||
both := io.MultiWriter(&snap, &expected)
|
||||
if _, err := io.Copy(both, io.LimitReader(rand.Reader, 1024)); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Write out the snapshot.
|
||||
var archive bytes.Buffer
|
||||
if err := write(&archive, &metadata, &snap); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Read the snapshot back.
|
||||
var newMeta raft.SnapshotMeta
|
||||
var newSnap bytes.Buffer
|
||||
if err := read(&archive, &newMeta, &newSnap); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Check the contents.
|
||||
if !reflect.DeepEqual(newMeta, metadata) {
|
||||
t.Fatalf("bad: %#v", newMeta)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, &newSnap); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(buf.Bytes(), expected.Bytes()) {
|
||||
t.Fatalf("snapshot contents didn't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchive_GoodData(t *testing.T) {
|
||||
paths := []string{
|
||||
"./testdata/snapshot/spaces-meta.tar",
|
||||
}
|
||||
for i, p := range paths {
|
||||
f, err := os.Open(p)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var metadata raft.SnapshotMeta
|
||||
err = read(f, &metadata, ioutil.Discard)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: should've read the snapshot, but didn't: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchive_BadData(t *testing.T) {
|
||||
cases := []struct {
|
||||
Name string
|
||||
Error string
|
||||
}{
|
||||
{"./testdata/snapshot/empty.tar", "failed checking integrity of snapshot"},
|
||||
{"./testdata/snapshot/extra.tar", "unexpected file \"nope\""},
|
||||
{"./testdata/snapshot/missing-meta.tar", "hash check failed for \"meta.json\""},
|
||||
{"./testdata/snapshot/missing-state.tar", "hash check failed for \"state.bin\""},
|
||||
{"./testdata/snapshot/missing-sha.tar", "file missing"},
|
||||
{"./testdata/snapshot/corrupt-meta.tar", "hash check failed for \"meta.json\""},
|
||||
{"./testdata/snapshot/corrupt-state.tar", "hash check failed for \"state.bin\""},
|
||||
{"./testdata/snapshot/corrupt-sha.tar", "list missing hash for \"nope\""},
|
||||
}
|
||||
for i, c := range cases {
|
||||
f, err := os.Open(c.Name)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var metadata raft.SnapshotMeta
|
||||
err = read(f, &metadata, ioutil.Discard)
|
||||
if err == nil || !strings.Contains(err.Error(), c.Error) {
|
||||
t.Fatalf("case %d (%s): %v", i, c.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchive_hashList(t *testing.T) {
|
||||
hl := newHashList()
|
||||
for i := 0; i < 16; i++ {
|
||||
h := hl.Add(fmt.Sprintf("file-%d", i))
|
||||
if _, err := io.CopyN(h, rand.Reader, 32); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Do a normal round trip.
|
||||
var buf bytes.Buffer
|
||||
if err := hl.Encode(&buf); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if err := hl.DecodeAndVerify(&buf); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Have a local hash that isn't in the file.
|
||||
buf.Reset()
|
||||
if err := hl.Encode(&buf); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
hl.Add("nope")
|
||||
err := hl.DecodeAndVerify(&buf)
|
||||
if err == nil || !strings.Contains(err.Error(), "file missing for \"nope\"") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Have a hash in the file that we haven't seen locally.
|
||||
buf.Reset()
|
||||
if err := hl.Encode(&buf); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
delete(hl.hashes, "nope")
|
||||
err = hl.DecodeAndVerify(&buf)
|
||||
if err == nil || !strings.Contains(err.Error(), "list missing hash for \"nope\"") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
245
helper/snapshot/snapshot.go
Normal file
245
helper/snapshot/snapshot.go
Normal file
@@ -0,0 +1,245 @@
|
||||
// snapshot manages the interactions between Nomad and Raft in order to take
|
||||
// and restore snapshots for disaster recovery. The internal format of a
|
||||
// snapshot is simply a tar file, as described in archive.go.
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/raft"
|
||||
)
|
||||
|
||||
// Snapshot is a structure that holds state about a temporary file that is used
|
||||
// to hold a snapshot. By using an intermediate file we avoid holding everything
|
||||
// in memory.
|
||||
type Snapshot struct {
|
||||
file *os.File
|
||||
index uint64
|
||||
checksum string
|
||||
}
|
||||
|
||||
// New takes a state snapshot of the given Raft instance into a temporary file
|
||||
// and returns an object that gives access to the file as an io.Reader. You must
|
||||
// arrange to call Close() on the returned object or else you will leak a
|
||||
// temporary file.
|
||||
func New(logger hclog.Logger, r *raft.Raft) (*Snapshot, error) {
|
||||
// Take the snapshot.
|
||||
future := r.Snapshot()
|
||||
if err := future.Error(); err != nil {
|
||||
return nil, fmt.Errorf("Raft error when taking snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Open up the snapshot.
|
||||
metadata, snap, err := future.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open snapshot: %v:", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := snap.Close(); err != nil {
|
||||
logger.Error("Failed to close Raft snapshot", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Make a scratch file to receive the contents so that we don't buffer
|
||||
// everything in memory. This gets deleted in Close() since we keep it
|
||||
// around for re-reading.
|
||||
archive, err := ioutil.TempFile("", "snapshot")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create snapshot file: %v", err)
|
||||
}
|
||||
|
||||
// If anything goes wrong after this point, we will attempt to clean up
|
||||
// the temp file. The happy path will disarm this.
|
||||
var keep bool
|
||||
defer func() {
|
||||
if keep {
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Remove(archive.Name()); err != nil {
|
||||
logger.Error("Failed to clean up temp snapshot", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
hash := sha256.New()
|
||||
out := io.MultiWriter(hash, archive)
|
||||
|
||||
// Wrap the file writer in a gzip compressor.
|
||||
compressor := gzip.NewWriter(out)
|
||||
|
||||
// Write the archive.
|
||||
if err := write(compressor, metadata, snap); err != nil {
|
||||
return nil, fmt.Errorf("failed to write snapshot file: %v", err)
|
||||
}
|
||||
|
||||
// Finish the compressed stream.
|
||||
if err := compressor.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to compress snapshot file: %v", err)
|
||||
}
|
||||
|
||||
// Sync the compressed file and rewind it so it's ready to be streamed
|
||||
// out by the caller.
|
||||
if err := archive.Sync(); err != nil {
|
||||
return nil, fmt.Errorf("failed to sync snapshot: %v", err)
|
||||
}
|
||||
if _, err := archive.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to rewind snapshot: %v", err)
|
||||
}
|
||||
|
||||
checksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
|
||||
keep = true
|
||||
return &Snapshot{archive, metadata.Index, checksum}, nil
|
||||
}
|
||||
|
||||
// Index returns the index of the snapshot. This is safe to call on a nil
|
||||
// snapshot, it will just return 0.
|
||||
func (s *Snapshot) Index() uint64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.index
|
||||
}
|
||||
|
||||
func (s *Snapshot) Checksum() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.checksum
|
||||
}
|
||||
|
||||
// Read passes through to the underlying snapshot file. This is safe to call on
|
||||
// a nil snapshot, it will just return an EOF.
|
||||
func (s *Snapshot) Read(p []byte) (n int, err error) {
|
||||
if s == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return s.file.Read(p)
|
||||
}
|
||||
|
||||
// Close closes the snapshot and removes any temporary storage associated with
|
||||
// it. You must arrange to call this whenever NewSnapshot() has been called
|
||||
// successfully. This is safe to call on a nil snapshot.
|
||||
func (s *Snapshot) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(s.file.Name())
|
||||
}
|
||||
|
||||
// Verify takes the snapshot from the reader and verifies its contents.
|
||||
func Verify(in io.Reader) (*raft.SnapshotMeta, error) {
|
||||
// Wrap the reader in a gzip decompressor.
|
||||
decomp, err := gzip.NewReader(in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decompress snapshot: %v", err)
|
||||
}
|
||||
defer decomp.Close()
|
||||
|
||||
// Read the archive, throwing away the snapshot data.
|
||||
var metadata raft.SnapshotMeta
|
||||
if err := read(decomp, &metadata, ioutil.Discard); err != nil {
|
||||
return nil, fmt.Errorf("failed to read snapshot file: %v", err)
|
||||
}
|
||||
|
||||
if err := concludeGzipRead(decomp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// concludeGzipRead should be invoked after you think you've consumed all of
|
||||
// the data from the gzip stream. It will error if the stream was corrupt.
|
||||
//
|
||||
// The docs for gzip.Reader say: "Clients should treat data returned by Read as
|
||||
// tentative until they receive the io.EOF marking the end of the data."
|
||||
func concludeGzipRead(decomp *gzip.Reader) error {
|
||||
extra, err := ioutil.ReadAll(decomp) // ReadAll consumes the EOF
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(extra) != 0 {
|
||||
return fmt.Errorf("%d unread uncompressed bytes remain", len(extra))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type readWrapper struct {
|
||||
in io.Reader
|
||||
c int
|
||||
}
|
||||
|
||||
func (r *readWrapper) Read(b []byte) (int, error) {
|
||||
n, err := r.in.Read(b)
|
||||
r.c += n
|
||||
if err != nil && err != io.EOF {
|
||||
return n, fmt.Errorf("failed to read after %v: %v", r.c, err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Restore takes the snapshot from the reader and attempts to apply it to the
|
||||
// given Raft instance.
|
||||
func Restore(logger hclog.Logger, in io.Reader, r *raft.Raft) error {
|
||||
// Wrap the reader in a gzip decompressor.
|
||||
decomp, err := gzip.NewReader(&readWrapper{in, 0})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decompress snapshot: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := decomp.Close(); err != nil {
|
||||
logger.Error("Failed to close snapshot decompressor", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Make a scratch file to receive the contents of the snapshot data so
|
||||
// we can avoid buffering in memory.
|
||||
snap, err := ioutil.TempFile("", "snapshot")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp snapshot file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := snap.Close(); err != nil {
|
||||
logger.Error("Failed to close temp snapshot", "error", err)
|
||||
}
|
||||
if err := os.Remove(snap.Name()); err != nil {
|
||||
logger.Error("Failed to clean up temp snapshot", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Read the archive.
|
||||
var metadata raft.SnapshotMeta
|
||||
if err := read(decomp, &metadata, snap); err != nil {
|
||||
return fmt.Errorf("failed to read snapshot file: %v", err)
|
||||
}
|
||||
|
||||
if err := concludeGzipRead(decomp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sync and rewind the file so it's ready to be read again.
|
||||
if err := snap.Sync(); err != nil {
|
||||
return fmt.Errorf("failed to sync temp snapshot: %v", err)
|
||||
}
|
||||
if _, err := snap.Seek(0, 0); err != nil {
|
||||
return fmt.Errorf("failed to rewind temp snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Feed the snapshot into Raft.
|
||||
if err := r.Restore(&metadata, snap, 0); err != nil {
|
||||
return fmt.Errorf("Raft error when restoring snapshot: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
349
helper/snapshot/snapshot_test.go
Normal file
349
helper/snapshot/snapshot_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/consul/sdk/testutil"
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/raft"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// MockFSM is a simple FSM for testing that simply stores its logs in a slice of
|
||||
// byte slices.
|
||||
type MockFSM struct {
|
||||
sync.Mutex
|
||||
logs [][]byte
|
||||
}
|
||||
|
||||
// MockSnapshot is a snapshot sink for testing that encodes the contents of a
|
||||
// MockFSM using msgpack.
|
||||
type MockSnapshot struct {
|
||||
logs [][]byte
|
||||
maxIndex int
|
||||
}
|
||||
|
||||
// See raft.FSM.
|
||||
func (m *MockFSM) Apply(log *raft.Log) interface{} {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.logs = append(m.logs, log.Data)
|
||||
return len(m.logs)
|
||||
}
|
||||
|
||||
// See raft.FSM.
|
||||
func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return &MockSnapshot{m.logs, len(m.logs)}, nil
|
||||
}
|
||||
|
||||
// See raft.FSM.
|
||||
func (m *MockFSM) Restore(in io.ReadCloser) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
defer in.Close()
|
||||
dec := codec.NewDecoder(in, structs.MsgpackHandle)
|
||||
|
||||
m.logs = nil
|
||||
return dec.Decode(&m.logs)
|
||||
}
|
||||
|
||||
// See raft.SnapshotSink.
|
||||
func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error {
|
||||
enc := codec.NewEncoder(sink, structs.MsgpackHandle)
|
||||
if err := enc.Encode(m.logs[:m.maxIndex]); err != nil {
|
||||
sink.Cancel()
|
||||
return err
|
||||
}
|
||||
sink.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// See raft.SnapshotSink.
|
||||
func (m *MockSnapshot) Release() {
|
||||
}
|
||||
|
||||
// makeRaft returns a Raft and its FSM, with snapshots based in the given dir.
|
||||
func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) {
|
||||
snaps, err := raft.NewFileSnapshotStore(dir, 5, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
fsm := &MockFSM{}
|
||||
store := raft.NewInmemStore()
|
||||
addr, trans := raft.NewInmemTransport("")
|
||||
|
||||
config := raft.DefaultConfig()
|
||||
config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr))
|
||||
|
||||
var members raft.Configuration
|
||||
members.Servers = append(members.Servers, raft.Server{
|
||||
Suffrage: raft.Voter,
|
||||
ID: config.LocalID,
|
||||
Address: addr,
|
||||
})
|
||||
|
||||
err = raft.BootstrapCluster(config, store, store, snaps, trans, members)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
timeout := time.After(10 * time.Second)
|
||||
for {
|
||||
if raft.Leader() != "" {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-raft.LeaderCh():
|
||||
case <-time.After(1 * time.Second):
|
||||
// Need to poll because we might have missed the first
|
||||
// go with the leader channel.
|
||||
case <-timeout:
|
||||
t.Fatalf("timed out waiting for leader")
|
||||
}
|
||||
}
|
||||
|
||||
return raft, fsm
|
||||
}
|
||||
|
||||
func TestSnapshot(t *testing.T) {
|
||||
dir := testutil.TempDir(t, "snapshot")
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Make a Raft and populate it with some data. We tee everything we
|
||||
// apply off to a buffer for checking post-snapshot.
|
||||
var expected []bytes.Buffer
|
||||
entries := 64 * 1024
|
||||
before, _ := makeRaft(t, filepath.Join(dir, "before"))
|
||||
defer before.Shutdown()
|
||||
for i := 0; i < entries; i++ {
|
||||
var log bytes.Buffer
|
||||
var copy bytes.Buffer
|
||||
both := io.MultiWriter(&log, ©)
|
||||
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
future := before.Apply(log.Bytes(), time.Second)
|
||||
if err := future.Error(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
expected = append(expected, copy)
|
||||
}
|
||||
|
||||
// Take a snapshot.
|
||||
logger := testutil.Logger(t)
|
||||
snap, err := New(logger, before)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer snap.Close()
|
||||
|
||||
// Verify the snapshot. We have to rewind it after for the restore.
|
||||
metadata, err := Verify(snap)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if _, err := snap.file.Seek(0, 0); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if int(metadata.Index) != entries+2 {
|
||||
t.Fatalf("bad: %d", metadata.Index)
|
||||
}
|
||||
if metadata.Term != 2 {
|
||||
t.Fatalf("bad: %d", metadata.Index)
|
||||
}
|
||||
if metadata.Version != raft.SnapshotVersionMax {
|
||||
t.Fatalf("bad: %d", metadata.Version)
|
||||
}
|
||||
|
||||
// Make a new, independent Raft.
|
||||
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
|
||||
defer after.Shutdown()
|
||||
|
||||
// Put some initial data in there that the snapshot should overwrite.
|
||||
for i := 0; i < 16; i++ {
|
||||
var log bytes.Buffer
|
||||
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
future := after.Apply(log.Bytes(), time.Second)
|
||||
if err := future.Error(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the snapshot.
|
||||
if err := Restore(logger, snap, after); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Compare the contents.
|
||||
fsm.Lock()
|
||||
defer fsm.Unlock()
|
||||
if len(fsm.logs) != len(expected) {
|
||||
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
|
||||
}
|
||||
for i := range fsm.logs {
|
||||
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
|
||||
t.Fatalf("bad: log %d doesn't match", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot_Nil(t *testing.T) {
|
||||
var snap *Snapshot
|
||||
|
||||
if idx := snap.Index(); idx != 0 {
|
||||
t.Fatalf("bad: %d", idx)
|
||||
}
|
||||
|
||||
n, err := snap.Read(make([]byte, 16))
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("bad: %d %v", n, err)
|
||||
}
|
||||
|
||||
if err := snap.Close(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot_BadVerify(t *testing.T) {
|
||||
buf := bytes.NewBuffer([]byte("nope"))
|
||||
_, err := Verify(buf)
|
||||
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot_TruncatedVerify(t *testing.T) {
|
||||
dir := testutil.TempDir(t, "snapshot")
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Make a Raft and populate it with some data. We tee everything we
|
||||
// apply off to a buffer for checking post-snapshot.
|
||||
var expected []bytes.Buffer
|
||||
entries := 64 * 1024
|
||||
before, _ := makeRaft(t, filepath.Join(dir, "before"))
|
||||
defer before.Shutdown()
|
||||
for i := 0; i < entries; i++ {
|
||||
var log bytes.Buffer
|
||||
var copy bytes.Buffer
|
||||
both := io.MultiWriter(&log, ©)
|
||||
|
||||
_, err := io.CopyN(both, rand.Reader, 256)
|
||||
require.NoError(t, err)
|
||||
|
||||
future := before.Apply(log.Bytes(), time.Second)
|
||||
require.NoError(t, future.Error())
|
||||
expected = append(expected, copy)
|
||||
}
|
||||
|
||||
// Take a snapshot.
|
||||
logger := testutil.Logger(t)
|
||||
snap, err := New(logger, before)
|
||||
require.NoError(t, err)
|
||||
defer snap.Close()
|
||||
|
||||
var data []byte
|
||||
{
|
||||
var buf bytes.Buffer
|
||||
_, err = io.Copy(&buf, snap)
|
||||
require.NoError(t, err)
|
||||
data = buf.Bytes()
|
||||
}
|
||||
|
||||
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
|
||||
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
|
||||
// Lop off part of the end.
|
||||
buf := bytes.NewReader(data[0 : len(data)-removeBytes])
|
||||
|
||||
_, err = Verify(buf)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshot_BadRestore(t *testing.T) {
|
||||
dir := testutil.TempDir(t, "snapshot")
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// Make a Raft and populate it with some data.
|
||||
before, _ := makeRaft(t, filepath.Join(dir, "before"))
|
||||
defer before.Shutdown()
|
||||
for i := 0; i < 16*1024; i++ {
|
||||
var log bytes.Buffer
|
||||
if _, err := io.CopyN(&log, rand.Reader, 256); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
future := before.Apply(log.Bytes(), time.Second)
|
||||
if err := future.Error(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Take a snapshot.
|
||||
logger := testutil.Logger(t)
|
||||
snap, err := New(logger, before)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Make a new, independent Raft.
|
||||
after, fsm := makeRaft(t, filepath.Join(dir, "after"))
|
||||
defer after.Shutdown()
|
||||
|
||||
// Put some initial data in there that should not be harmed by the
|
||||
// failed restore attempt.
|
||||
var expected []bytes.Buffer
|
||||
for i := 0; i < 16; i++ {
|
||||
var log bytes.Buffer
|
||||
var copy bytes.Buffer
|
||||
both := io.MultiWriter(&log, ©)
|
||||
if _, err := io.CopyN(both, rand.Reader, 256); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
future := after.Apply(log.Bytes(), time.Second)
|
||||
if err := future.Error(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
expected = append(expected, copy)
|
||||
}
|
||||
|
||||
// Attempt to restore a truncated version of the snapshot. This is
|
||||
// expected to fail.
|
||||
err = Restore(logger, io.LimitReader(snap, 512), after)
|
||||
if err == nil || !strings.Contains(err.Error(), "unexpected EOF") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Compare the contents to make sure the aborted restore didn't harm
|
||||
// anything.
|
||||
fsm.Lock()
|
||||
defer fsm.Unlock()
|
||||
if len(fsm.logs) != len(expected) {
|
||||
t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected))
|
||||
}
|
||||
for i := range fsm.logs {
|
||||
if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) {
|
||||
t.Fatalf("bad: log %d doesn't match", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
helper/snapshot/testdata/snapshot/corrupt-meta.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/corrupt-meta.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/corrupt-sha.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/corrupt-sha.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/corrupt-state.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/corrupt-state.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/empty.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/empty.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/extra.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/extra.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/missing-meta.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/missing-meta.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/missing-sha.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/missing-sha.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/missing-state.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/missing-state.tar
vendored
Normal file
Binary file not shown.
BIN
helper/snapshot/testdata/snapshot/spaces-meta.tar
vendored
Normal file
BIN
helper/snapshot/testdata/snapshot/spaces-meta.tar
vendored
Normal file
Binary file not shown.
@@ -2,11 +2,14 @@ package nomad
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
|
||||
"github.com/hashicorp/consul/agent/consul/autopilot"
|
||||
"github.com/hashicorp/nomad/helper/snapshot"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/raft"
|
||||
"github.com/hashicorp/serf/serf"
|
||||
@@ -18,6 +21,10 @@ type Operator struct {
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (op *Operator) register() {
|
||||
op.srv.streamingRpcs.Register("Operator.SnapshotSave", op.snapshotSave)
|
||||
}
|
||||
|
||||
// RaftGetConfiguration is used to retrieve the current Raft configuration.
|
||||
func (op *Operator) RaftGetConfiguration(args *structs.GenericRequest, reply *structs.RaftConfigurationResponse) error {
|
||||
if done, err := op.srv.forward("Operator.RaftGetConfiguration", args, args, reply); done {
|
||||
@@ -355,3 +362,111 @@ func (op *Operator) SchedulerGetConfiguration(args *structs.GenericRequest, repl
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (op *Operator) forwardStreamingRPC(region string, method string, args interface{}, in io.ReadWriteCloser) error {
|
||||
server, err := op.srv.findRegionServer(region)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return op.forwardStreamingRPCToServer(server, method, args, in)
|
||||
}
|
||||
|
||||
func (op *Operator) forwardStreamingRPCToServer(server *serverParts, method string, args interface{}, in io.ReadWriteCloser) error {
|
||||
srvConn, err := op.srv.streamingRpc(server, method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srvConn.Close()
|
||||
|
||||
outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle)
|
||||
if err := outEncoder.Encode(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structs.Bridge(in, srvConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (op *Operator) snapshotSave(conn io.ReadWriteCloser) {
|
||||
defer conn.Close()
|
||||
|
||||
var args structs.SnapshotSaveRequest
|
||||
var reply structs.SnapshotSaveResponse
|
||||
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
|
||||
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
|
||||
|
||||
handleFailure := func(code int, err error) {
|
||||
encoder.Encode(&structs.SnapshotSaveResponse{
|
||||
ErrorCode: code,
|
||||
ErrorMsg: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&args); err != nil {
|
||||
handleFailure(500, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward to appropriate region
|
||||
if args.Region != op.srv.Region() {
|
||||
err := op.forwardStreamingRPC(args.Region, "Operator.SnapshotSave", args, conn)
|
||||
if err != nil {
|
||||
handleFailure(500, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// forward to leader
|
||||
if !args.AllowStale {
|
||||
remoteServer, err := op.srv.getLeaderForRPC()
|
||||
if err != nil {
|
||||
handleFailure(500, err)
|
||||
return
|
||||
}
|
||||
if remoteServer != nil {
|
||||
err := op.forwardStreamingRPCToServer(remoteServer, "Operator.SnapshotSave", args, conn)
|
||||
if err != nil {
|
||||
handleFailure(500, err)
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Check agent permissions
|
||||
if aclObj, err := op.srv.ResolveToken(args.AuthToken); err != nil {
|
||||
code := 500
|
||||
if err == structs.ErrTokenNotFound {
|
||||
code = 400
|
||||
}
|
||||
handleFailure(code, err)
|
||||
return
|
||||
} else if aclObj != nil && !aclObj.IsManagement() {
|
||||
handleFailure(403, structs.ErrPermissionDenied)
|
||||
return
|
||||
}
|
||||
|
||||
op.srv.setQueryMeta(&reply.QueryMeta)
|
||||
|
||||
// Take the snapshot and capture the index.
|
||||
snap, err := snapshot.New(op.logger.Named("snapshot"), op.srv.raft)
|
||||
reply.SnapshotChecksum = snap.Checksum()
|
||||
reply.Index = snap.Index()
|
||||
if err != nil {
|
||||
handleFailure(500, err)
|
||||
return
|
||||
}
|
||||
defer snap.Close()
|
||||
|
||||
enc := codec.NewEncoder(conn, structs.MsgpackHandle)
|
||||
if err := enc.Encode(&reply); err != nil {
|
||||
handleFailure(500, fmt.Errorf("failed to encode response: %v", err))
|
||||
return
|
||||
}
|
||||
if snap != nil {
|
||||
if _, err := io.Copy(conn, snap); err != nil {
|
||||
handleFailure(500, fmt.Errorf("failed to stream snapshot: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
package nomad
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-msgpack/codec"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
|
||||
"github.com/hashicorp/nomad/acl"
|
||||
"github.com/hashicorp/nomad/helper/freeport"
|
||||
"github.com/hashicorp/nomad/helper/snapshot"
|
||||
"github.com/hashicorp/nomad/helper/uuid"
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/testutil"
|
||||
@@ -521,3 +531,186 @@ func TestOperator_SchedulerSetConfiguration_ACL(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestOperator_SnapshotSave(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
////// Nomad clusters topology - not specific to test
|
||||
dir, err := ioutil.TempDir("", "nomadtest-operator-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
server1, cleanupLS := TestServer(t, func(c *Config) {
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DataDir = path.Join(dir, "server1")
|
||||
})
|
||||
defer cleanupLS()
|
||||
|
||||
server2, cleanupRS := TestServer(t, func(c *Config) {
|
||||
c.BootstrapExpect = 2
|
||||
c.DevMode = false
|
||||
c.DataDir = path.Join(dir, "server2")
|
||||
})
|
||||
defer cleanupRS()
|
||||
|
||||
remoteRegionServer, cleanupRRS := TestServer(t, func(c *Config) {
|
||||
c.Region = "two"
|
||||
c.DevMode = false
|
||||
c.DataDir = path.Join(dir, "remote_region_server")
|
||||
})
|
||||
defer cleanupRRS()
|
||||
|
||||
TestJoin(t, server1, server2)
|
||||
TestJoin(t, server1, remoteRegionServer)
|
||||
testutil.WaitForLeader(t, server1.RPC)
|
||||
testutil.WaitForLeader(t, server2.RPC)
|
||||
testutil.WaitForLeader(t, remoteRegionServer.RPC)
|
||||
|
||||
leader, nonLeader := server1, server2
|
||||
if server2.IsLeader() {
|
||||
leader, nonLeader = server2, server1
|
||||
}
|
||||
|
||||
///////// Actually run query now
|
||||
cases := []struct {
|
||||
name string
|
||||
server *Server
|
||||
}{
|
||||
{"leader", leader},
|
||||
{"non_leader", nonLeader},
|
||||
{"remote_region", remoteRegionServer},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
handler, err := c.server.StreamingRpcHandler("Operator.SnapshotSave")
|
||||
require.NoError(t, err)
|
||||
|
||||
p1, p2 := net.Pipe()
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
|
||||
// start handler
|
||||
go handler(p2)
|
||||
|
||||
var req structs.SnapshotSaveRequest
|
||||
var resp structs.SnapshotSaveResponse
|
||||
|
||||
req.Region = "global"
|
||||
|
||||
// send request
|
||||
encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
|
||||
err = encoder.Encode(&req)
|
||||
require.NoError(t, err)
|
||||
|
||||
decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
|
||||
err = decoder.Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.ErrorMsg)
|
||||
|
||||
require.NotZero(t, resp.Index)
|
||||
require.NotEmpty(t, resp.SnapshotChecksum)
|
||||
require.Contains(t, resp.SnapshotChecksum, "sha-256=")
|
||||
|
||||
index := resp.Index
|
||||
|
||||
snap, err := ioutil.TempFile("", "nomadtests-snapshot-")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(snap.Name())
|
||||
|
||||
hash := sha256.New()
|
||||
_, err = io.Copy(io.MultiWriter(snap, hash), p1)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedChecksum := "sha-256=" + base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
||||
|
||||
require.Equal(t, expectedChecksum, resp.SnapshotChecksum)
|
||||
|
||||
_, err = snap.Seek(0, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
meta, err := snapshot.Verify(snap)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotZerof(t, meta.Term, "snapshot term")
|
||||
require.Equal(t, index, meta.Index)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperator_SnapshotSave_ACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
////// Nomad clusters topology - not specific to test
|
||||
dir, err := ioutil.TempDir("", "nomadtest-operator-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
s, root, cleanupLS := TestACLServer(t, func(c *Config) {
|
||||
c.BootstrapExpect = 1
|
||||
c.DevMode = false
|
||||
c.DataDir = path.Join(dir, "server1")
|
||||
})
|
||||
defer cleanupLS()
|
||||
|
||||
testutil.WaitForLeader(t, s.RPC)
|
||||
|
||||
deniedToken := mock.CreatePolicyAndToken(t, s.fsm.State(), 1001, "test-invalid", mock.NodePolicy(acl.PolicyWrite))
|
||||
|
||||
///////// Actually run query now
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
errCode int
|
||||
err error
|
||||
}{
|
||||
{"root", root.SecretID, 0, nil},
|
||||
{"no_permission_token", deniedToken.SecretID, 403, structs.ErrPermissionDenied},
|
||||
{"invalid token", uuid.Generate(), 400, structs.ErrTokenNotFound},
|
||||
{"unauthenticated", "", 403, structs.ErrPermissionDenied},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
handler, err := s.StreamingRpcHandler("Operator.SnapshotSave")
|
||||
require.NoError(t, err)
|
||||
|
||||
p1, p2 := net.Pipe()
|
||||
defer p1.Close()
|
||||
defer p2.Close()
|
||||
|
||||
// start handler
|
||||
go handler(p2)
|
||||
|
||||
var req structs.SnapshotSaveRequest
|
||||
var resp structs.SnapshotSaveResponse
|
||||
|
||||
req.Region = "global"
|
||||
req.AuthToken = c.token
|
||||
|
||||
// send request
|
||||
encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
|
||||
err = encoder.Encode(&req)
|
||||
require.NoError(t, err)
|
||||
|
||||
decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
|
||||
err = decoder.Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// streaming errors appear as a response rather than a returned error
|
||||
if c.err != nil {
|
||||
require.Equal(t, c.err.Error(), resp.ErrorMsg)
|
||||
require.Equal(t, c.errCode, resp.ErrorCode)
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
require.NotZero(t, resp.Index)
|
||||
require.NotEmpty(t, resp.SnapshotChecksum)
|
||||
require.Contains(t, resp.SnapshotChecksum, "sha-256=")
|
||||
|
||||
io.Copy(ioutil.Discard, p1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
57
nomad/rpc.go
57
nomad/rpc.go
@@ -507,8 +507,6 @@ func (r *rpcHandler) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCt
|
||||
// forward is used to forward to a remote region or to forward to the local leader
|
||||
// Returns a bool of if forwarding was performed, as well as any error
|
||||
func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) {
|
||||
var firstCheck time.Time
|
||||
|
||||
region := info.RequestRegion()
|
||||
if region == "" {
|
||||
return true, fmt.Errorf("missing region for target RPC")
|
||||
@@ -527,21 +525,41 @@ func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface
|
||||
return false, nil
|
||||
}
|
||||
|
||||
remoteServer, err := r.getLeaderForRPC()
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
// we are the leader
|
||||
if remoteServer == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// forward to leader
|
||||
info.SetForwarded()
|
||||
err = r.forwardLeader(remoteServer, method, args, reply)
|
||||
return true, err
|
||||
}
|
||||
|
||||
// getLeaderForRPC returns the server info of the currently known leader, or
|
||||
// nil if this server is the current leader. If the local server is the leader
|
||||
// it blocks until it is ready to handle consistent RPC invocations. If leader
|
||||
// is not known or consistency isn't guaranteed, an error is returned.
|
||||
func (r *rpcHandler) getLeaderForRPC() (*serverParts, error) {
|
||||
var firstCheck time.Time
|
||||
|
||||
CHECK_LEADER:
|
||||
// Find the leader
|
||||
isLeader, remoteServer := r.getLeader()
|
||||
|
||||
// Handle the case we are the leader
|
||||
if isLeader && r.Server.isReadyForConsistentReads() {
|
||||
return false, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Handle the case of a known leader
|
||||
if remoteServer != nil {
|
||||
// Mark that we are forwarding the RPC
|
||||
info.SetForwarded()
|
||||
err := r.forwardLeader(remoteServer, method, args, reply)
|
||||
return true, err
|
||||
return remoteServer, nil
|
||||
}
|
||||
|
||||
// Gate the request until there is a leader
|
||||
@@ -559,10 +577,11 @@ CHECK_LEADER:
|
||||
|
||||
// hold time exceeeded without being ready to respond
|
||||
if isLeader {
|
||||
return true, structs.ErrNotReadyForConsistentReads
|
||||
return nil, structs.ErrNotReadyForConsistentReads
|
||||
}
|
||||
|
||||
return true, structs.ErrNoLeader
|
||||
return nil, structs.ErrNoLeader
|
||||
|
||||
}
|
||||
|
||||
// getLeader returns if the current node is the leader, and if not
|
||||
@@ -607,21 +626,27 @@ func (r *rpcHandler) forwardServer(server *serverParts, method string, args inte
|
||||
return r.connPool.RPC(r.config.Region, server.Addr, server.MajorVersion, method, args, reply)
|
||||
}
|
||||
|
||||
// forwardRegion is used to forward an RPC call to a remote region, or fail if no servers
|
||||
func (r *rpcHandler) forwardRegion(region, method string, args interface{}, reply interface{}) error {
|
||||
// Bail if we can't find any servers
|
||||
func (r *rpcHandler) findRegionServer(region string) (*serverParts, error) {
|
||||
r.peerLock.RLock()
|
||||
defer r.peerLock.RUnlock()
|
||||
|
||||
servers := r.peers[region]
|
||||
if len(servers) == 0 {
|
||||
r.peerLock.RUnlock()
|
||||
r.logger.Warn("no path found to region", "region", region)
|
||||
return structs.ErrNoRegionPath
|
||||
return nil, structs.ErrNoRegionPath
|
||||
}
|
||||
|
||||
// Select a random addr
|
||||
offset := rand.Intn(len(servers))
|
||||
server := servers[offset]
|
||||
r.peerLock.RUnlock()
|
||||
return servers[offset], nil
|
||||
}
|
||||
|
||||
// forwardRegion is used to forward an RPC call to a remote region, or fail if no servers
|
||||
func (r *rpcHandler) forwardRegion(region, method string, args interface{}, reply interface{}) error {
|
||||
server, err := r.findRegionServer(region)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Forward to remote Nomad
|
||||
metrics.IncrCounter([]string{"nomad", "rpc", "cross-region", region}, 1)
|
||||
|
||||
@@ -1131,6 +1131,8 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
|
||||
s.staticEndpoints.CSIPlugin = &CSIPlugin{srv: s, logger: s.logger.Named("csi_plugin")}
|
||||
s.staticEndpoints.Deployment = &Deployment{srv: s, logger: s.logger.Named("deployment")}
|
||||
s.staticEndpoints.Operator = &Operator{srv: s, logger: s.logger.Named("operator")}
|
||||
s.staticEndpoints.Operator.register()
|
||||
|
||||
s.staticEndpoints.Periodic = &Periodic{srv: s, logger: s.logger.Named("periodic")}
|
||||
s.staticEndpoints.Plan = &Plan{srv: s, logger: s.logger.Named("plan")}
|
||||
s.staticEndpoints.Region = &Region{srv: s, logger: s.logger.Named("region")}
|
||||
|
||||
@@ -224,3 +224,25 @@ type SchedulerSetConfigRequest struct {
|
||||
// WriteRequest holds the ACL token to go along with this request.
|
||||
WriteRequest
|
||||
}
|
||||
|
||||
// SnapshotSaveRequest is used by the Operator endpoint to get a Raft snapshot
|
||||
type SnapshotSaveRequest struct {
|
||||
QueryOptions
|
||||
}
|
||||
|
||||
// SnapshotSaveResponse is the header for the streaming snapshot endpoint,
|
||||
// and followed by the snapshot file content.
|
||||
type SnapshotSaveResponse struct {
|
||||
|
||||
// SnapshotChecksum returns the checksum of snapshot file in the format
|
||||
// `<algo>=<base64>` (e.g. `sha-256=...`)
|
||||
SnapshotChecksum string
|
||||
|
||||
// ErrorCode is an http error code if an error is found, e.g. 403 for permission errors
|
||||
ErrorCode int `codec:",omitempty"`
|
||||
|
||||
// ErrorMsg is the error message if an error is found, e.g. "Permission Denied"
|
||||
ErrorMsg string `codec:",omitempty"`
|
||||
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user