mirror of
https://github.com/kemko/nomad.git
synced 2026-01-07 10:55:42 +03:00
Thread through whether DeriveToken error is recoverable or not
This commit is contained in:
@@ -1714,6 +1714,10 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli
|
||||
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
|
||||
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", resp.Error)
|
||||
return nil, resp.Error
|
||||
}
|
||||
if resp.Tasks == nil {
|
||||
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
|
||||
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")
|
||||
|
||||
@@ -629,7 +629,7 @@ func (d *DockerDriver) recoverablePullError(err error, image string) error {
|
||||
if imageNotFoundMatcher.MatchString(err.Error()) {
|
||||
recoverable = false
|
||||
}
|
||||
return dstructs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
|
||||
return structs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
|
||||
}
|
||||
|
||||
func (d *DockerDriver) Periodic() (bool, time.Duration) {
|
||||
|
||||
@@ -37,26 +37,6 @@ func (r *WaitResult) String() string {
|
||||
r.ExitCode, r.Signal, r.Err)
|
||||
}
|
||||
|
||||
// RecoverableError wraps an error and marks whether it is recoverable and could
|
||||
// be retried or it is fatal.
|
||||
type RecoverableError struct {
|
||||
Err error
|
||||
Recoverable bool
|
||||
}
|
||||
|
||||
// NewRecoverableError is used to wrap an error and mark it as recoverable or
|
||||
// not.
|
||||
func NewRecoverableError(e error, recoverable bool) *RecoverableError {
|
||||
return &RecoverableError{
|
||||
Err: e,
|
||||
Recoverable: recoverable,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RecoverableError) Error() string {
|
||||
return r.Err.Error()
|
||||
}
|
||||
|
||||
// CheckResult encapsulates the result of a check
|
||||
type CheckResult struct {
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cstructs "github.com/hashicorp/nomad/client/driver/structs"
|
||||
dstructs "github.com/hashicorp/nomad/client/driver/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ func newRestartTracker(policy *structs.RestartPolicy, jobType string) *RestartTr
|
||||
}
|
||||
|
||||
type RestartTracker struct {
|
||||
waitRes *cstructs.WaitResult
|
||||
waitRes *dstructs.WaitResult
|
||||
startErr error
|
||||
restartTriggered bool // Whether the task has been signalled to be restarted
|
||||
count int // Current number of attempts.
|
||||
@@ -63,7 +63,7 @@ func (r *RestartTracker) SetStartError(err error) *RestartTracker {
|
||||
}
|
||||
|
||||
// SetWaitResult is used to mark the most recent wait result.
|
||||
func (r *RestartTracker) SetWaitResult(res *cstructs.WaitResult) *RestartTracker {
|
||||
func (r *RestartTracker) SetWaitResult(res *dstructs.WaitResult) *RestartTracker {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.waitRes = res
|
||||
@@ -149,7 +149,7 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
|
||||
// infinitely try to start a task.
|
||||
func (r *RestartTracker) handleStartError() (string, time.Duration) {
|
||||
// If the error is not recoverable, do not restart.
|
||||
if rerr, ok := r.startErr.(*cstructs.RecoverableError); !(ok && rerr.Recoverable) {
|
||||
if rerr, ok := r.startErr.(*structs.RecoverableError); !(ok && rerr.Recoverable) {
|
||||
r.reason = ReasonUnrecoverableErrror
|
||||
return structs.TaskNotRestarting, 0
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Fail(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := testPolicy(true, structs.RestartPolicyModeFail)
|
||||
rt := newRestartTracker(p, structs.JobTypeSystem)
|
||||
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
|
||||
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
|
||||
for i := 0; i < p.Attempts; i++ {
|
||||
state, when := rt.SetStartError(recErr).GetState()
|
||||
if state != structs.TaskRestarting {
|
||||
@@ -129,7 +129,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Delay(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := testPolicy(true, structs.RestartPolicyModeDelay)
|
||||
rt := newRestartTracker(p, structs.JobTypeSystem)
|
||||
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
|
||||
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
|
||||
for i := 0; i < p.Attempts; i++ {
|
||||
state, when := rt.SetStartError(recErr).GetState()
|
||||
if state != structs.TaskRestarting {
|
||||
|
||||
@@ -509,10 +509,10 @@ OUTER:
|
||||
// restoring the TaskRunner
|
||||
if token == "" {
|
||||
// Get a token
|
||||
var ok bool
|
||||
token, ok = r.deriveVaultToken()
|
||||
if !ok {
|
||||
// We are shutting down
|
||||
var exit bool
|
||||
token, exit = r.deriveVaultToken()
|
||||
if exit {
|
||||
// Exit the manager
|
||||
return
|
||||
}
|
||||
|
||||
@@ -589,12 +589,20 @@ OUTER:
|
||||
// deriveVaultToken derives the Vault token using exponential backoffs. It
|
||||
// returns the Vault token and whether the token is valid. If it is not valid we
|
||||
// are shutting down
|
||||
func (r *TaskRunner) deriveVaultToken() (string, bool) {
|
||||
func (r *TaskRunner) deriveVaultToken() (token string, exit bool) {
|
||||
attempts := 0
|
||||
for {
|
||||
tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name})
|
||||
if err == nil {
|
||||
return tokens[r.task.Name], true
|
||||
return tokens[r.task.Name], false
|
||||
}
|
||||
|
||||
// Check if we can't recover from the error
|
||||
if rerr, ok := err.(*structs.RecoverableError); !ok || !rerr.Recoverable {
|
||||
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v",
|
||||
r.task.Name, r.alloc.ID, err)
|
||||
r.Kill("vault", fmt.Sprintf("failed to derive token: %v", err))
|
||||
return "", true
|
||||
}
|
||||
|
||||
// Handle the retry case
|
||||
@@ -602,14 +610,15 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) {
|
||||
if backoff > vaultBackoffLimit {
|
||||
backoff = vaultBackoffLimit
|
||||
}
|
||||
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff)
|
||||
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v",
|
||||
r.task.Name, r.alloc.ID, err, backoff)
|
||||
|
||||
attempts++
|
||||
|
||||
// Wait till retrying
|
||||
select {
|
||||
case <-r.waitCh:
|
||||
return "", false
|
||||
return "", true
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
@@ -706,7 +715,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
|
||||
if err := getter.GetArtifact(r.getTaskEnv(), artifact, r.taskDir); err != nil {
|
||||
r.setState(structs.TaskStatePending,
|
||||
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err))
|
||||
r.restartTracker.SetStartError(dstructs.NewRecoverableError(err, true))
|
||||
r.restartTracker.SetStartError(structs.NewRecoverableError(err, true))
|
||||
goto RESTART
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,7 +721,7 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
|
||||
}
|
||||
|
||||
count++
|
||||
return nil, fmt.Errorf("Want a retry")
|
||||
return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
|
||||
}
|
||||
tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
|
||||
go tr.Run()
|
||||
@@ -770,6 +770,49 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
|
||||
alloc := mock.Alloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
task.Driver = "mock_driver"
|
||||
task.Config = map[string]interface{}{
|
||||
"exit_code": "0",
|
||||
"run_for": "10s",
|
||||
}
|
||||
task.Vault = &structs.Vault{
|
||||
Policies: []string{"default"},
|
||||
ChangeMode: structs.VaultChangeModeRestart,
|
||||
}
|
||||
|
||||
upd, tr := testTaskRunnerFromAlloc(false, alloc)
|
||||
tr.MarkReceived()
|
||||
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
|
||||
defer tr.ctx.AllocDir.Destroy()
|
||||
|
||||
// Error the token derivation
|
||||
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
|
||||
vc.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
|
||||
go tr.Run()
|
||||
|
||||
// Wait for the task to start
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
if l := len(upd.events); l != 2 {
|
||||
return false, fmt.Errorf("Expect two events; got %v", l)
|
||||
}
|
||||
|
||||
if upd.events[0].Type != structs.TaskReceived {
|
||||
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
|
||||
}
|
||||
|
||||
if upd.events[1].Type != structs.TaskKilling {
|
||||
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskKilling)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}, func(err error) {
|
||||
t.Fatalf("err: %v", err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskRunner_Template_Block(t *testing.T) {
|
||||
alloc := mock.Alloc()
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/hashicorp/nomad/nomad/state"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/watch"
|
||||
"github.com/hashicorp/raft"
|
||||
vapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
@@ -940,22 +941,26 @@ func (b *batchFuture) Respond(index uint64, err error) {
|
||||
func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
reply *structs.DeriveVaultTokenResponse) error {
|
||||
if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done {
|
||||
return err
|
||||
reply.Error = structs.NewRecoverableError(err, err == structs.ErrNoLeader)
|
||||
return nil
|
||||
}
|
||||
defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now())
|
||||
|
||||
// Verify the arguments
|
||||
if args.NodeID == "" {
|
||||
return fmt.Errorf("missing node ID")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node ID"), false)
|
||||
}
|
||||
if args.SecretID == "" {
|
||||
return fmt.Errorf("missing node SecretID")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node SecretID"), false)
|
||||
return nil
|
||||
}
|
||||
if args.AllocID == "" {
|
||||
return fmt.Errorf("missing allocation ID")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing allocation ID"), false)
|
||||
return nil
|
||||
}
|
||||
if len(args.Tasks) == 0 {
|
||||
return fmt.Errorf("no tasks specified")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("no tasks specified"), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Verify the following:
|
||||
@@ -965,41 +970,51 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
// tokens
|
||||
snap, err := n.srv.fsm.State().Snapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
reply.Error = structs.NewRecoverableError(err, false)
|
||||
return nil
|
||||
}
|
||||
node, err := snap.NodeByID(args.NodeID)
|
||||
if err != nil {
|
||||
return err
|
||||
reply.Error = structs.NewRecoverableError(err, false)
|
||||
return nil
|
||||
}
|
||||
if node == nil {
|
||||
return fmt.Errorf("Node %q does not exist", args.NodeID)
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Node %q does not exist", args.NodeID), false)
|
||||
return nil
|
||||
}
|
||||
if node.SecretID != args.SecretID {
|
||||
return fmt.Errorf("SecretID mismatch")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("SecretID mismatch"), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
alloc, err := snap.AllocByID(args.AllocID)
|
||||
if err != nil {
|
||||
return err
|
||||
reply.Error = structs.NewRecoverableError(err, false)
|
||||
return nil
|
||||
}
|
||||
if alloc == nil {
|
||||
return fmt.Errorf("Allocation %q does not exist", args.AllocID)
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q does not exist", args.AllocID), false)
|
||||
return nil
|
||||
}
|
||||
if alloc.NodeID != args.NodeID {
|
||||
return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID)
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID), false)
|
||||
return nil
|
||||
}
|
||||
if alloc.TerminalStatus() {
|
||||
return fmt.Errorf("Can't request Vault token for terminal allocation")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Can't request Vault token for terminal allocation"), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check the policies
|
||||
policies := alloc.Job.VaultPolicies()
|
||||
if policies == nil {
|
||||
return fmt.Errorf("Job doesn't require Vault policies")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Job doesn't require Vault policies"), false)
|
||||
return nil
|
||||
}
|
||||
tg, ok := policies[alloc.TaskGroup]
|
||||
if !ok {
|
||||
return fmt.Errorf("Task group does not require Vault policies")
|
||||
reply.Error = structs.NewRecoverableError(fmt.Errorf("Task group does not require Vault policies"), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
var unneeded []string
|
||||
@@ -1011,8 +1026,10 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
}
|
||||
|
||||
if len(unneeded) != 0 {
|
||||
return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
|
||||
e := fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
|
||||
strings.Join(unneeded, ", "))
|
||||
reply.Error = structs.NewRecoverableError(e, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point the request is valid and we should contact Vault for
|
||||
@@ -1043,7 +1060,13 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
|
||||
secret, err := n.srv.vault.CreateToken(ctx, alloc, task)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token for task %q: %v", task, err)
|
||||
wrapped := fmt.Errorf("failed to create token for task %q: %v", task, err)
|
||||
if rerr, ok := err.(*structs.RecoverableError); ok && rerr.Recoverable {
|
||||
// If the error is recoverable, propogate it
|
||||
return structs.NewRecoverableError(wrapped, true)
|
||||
}
|
||||
|
||||
return wrapped
|
||||
}
|
||||
|
||||
results[task] = secret
|
||||
@@ -1068,9 +1091,9 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
}()
|
||||
|
||||
// Wait for everything to complete or for an error
|
||||
err = g.Wait()
|
||||
createErr := g.Wait()
|
||||
|
||||
// Commit to Raft before returning any of the tokens
|
||||
// Retrieve the results
|
||||
accessors := make([]*structs.VaultAccessor, 0, len(results))
|
||||
tokens := make(map[string]string, len(results))
|
||||
for task, secret := range results {
|
||||
@@ -1092,20 +1115,36 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
|
||||
}
|
||||
|
||||
// If there was an error revoke the created tokens
|
||||
if err != nil {
|
||||
var mErr multierror.Error
|
||||
mErr.Errors = append(mErr.Errors, err)
|
||||
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil {
|
||||
mErr.Errors = append(mErr.Errors, err)
|
||||
if createErr != nil {
|
||||
if revokeErr := n.srv.vault.RevokeTokens(context.Background(), accessors, false); revokeErr != nil {
|
||||
n.srv.logger.Printf("[ERR] nomad.node: Vault token revocation failed: %v", revokeErr)
|
||||
}
|
||||
return mErr.ErrorOrNil()
|
||||
|
||||
if rerr, ok := createErr.(*structs.RecoverableError); ok {
|
||||
reply.Error = rerr
|
||||
} else {
|
||||
reply.Error = structs.NewRecoverableError(createErr, false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Commit to Raft before returning any of the tokens
|
||||
req := structs.VaultAccessorsRequest{Accessors: accessors}
|
||||
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
|
||||
if err != nil {
|
||||
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
|
||||
return err
|
||||
|
||||
// Determine if we can recover from the error
|
||||
retry := false
|
||||
switch err {
|
||||
case raft.ErrNotLeader, raft.ErrLeadershipLost, raft.ErrRaftShutdown, raft.ErrEnqueueTimeout:
|
||||
retry = true
|
||||
default:
|
||||
}
|
||||
|
||||
reply.Error = structs.NewRecoverableError(err, retry)
|
||||
return nil
|
||||
}
|
||||
|
||||
reply.Index = index
|
||||
|
||||
@@ -1822,18 +1822,23 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
|
||||
}
|
||||
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") {
|
||||
t.Fatalf("Expected SecretID mismatch: %v", err)
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
|
||||
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "SecretID mismatch") {
|
||||
t.Fatalf("Expected SecretID mismatch: %v", resp.Error)
|
||||
}
|
||||
|
||||
// Put the correct SecretID
|
||||
req.SecretID = node.SecretID
|
||||
|
||||
// Now we should get an error about the allocation not running on the node
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "not running on Node") {
|
||||
t.Fatalf("Expected not running on node error: %v", err)
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "not running on Node") {
|
||||
t.Fatalf("Expected not running on node error: %v", resp.Error)
|
||||
}
|
||||
|
||||
// Update to be running on the node
|
||||
@@ -1843,9 +1848,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
|
||||
}
|
||||
|
||||
// Now we should get an error about the job not needing any Vault secrets
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "does not require") {
|
||||
t.Fatalf("Expected no policies error: %v", err)
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "does not require") {
|
||||
t.Fatalf("Expected no policies error: %v", resp.Error)
|
||||
}
|
||||
|
||||
// Update to be terminal
|
||||
@@ -1855,9 +1862,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
|
||||
}
|
||||
|
||||
// Now we should get an error about the job not needing any Vault secrets
|
||||
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err == nil || !strings.Contains(err.Error(), "terminal") {
|
||||
t.Fatalf("Expected terminal allocation error: %v", err)
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "terminal") {
|
||||
t.Fatalf("Expected terminal allocation error: %v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1920,6 +1929,9 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
|
||||
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("bad: %v", resp.Error)
|
||||
}
|
||||
|
||||
// Check the state store and ensure that we created a VaultAccessor
|
||||
va, err := state.VaultAccessor(accessor)
|
||||
@@ -1947,3 +1959,59 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
|
||||
t.Fatalf("Got %#v; want %#v", va, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEndpoint_DeriveVaultToken_VaultError(t *testing.T) {
|
||||
s1 := testServer(t, nil)
|
||||
defer s1.Shutdown()
|
||||
state := s1.fsm.State()
|
||||
codec := rpcClient(t, s1)
|
||||
testutil.WaitForLeader(t, s1.RPC)
|
||||
|
||||
// Enable vault and allow authenticated
|
||||
tr := true
|
||||
s1.config.VaultConfig.Enabled = &tr
|
||||
s1.config.VaultConfig.AllowUnauthenticated = &tr
|
||||
|
||||
// Replace the Vault Client on the server
|
||||
tvc := &TestVaultClient{}
|
||||
s1.vault = tvc
|
||||
|
||||
// Create the node
|
||||
node := mock.Node()
|
||||
if err := state.UpsertNode(2, node); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Create an alloc an allocation that has vault policies required
|
||||
alloc := mock.Alloc()
|
||||
alloc.NodeID = node.ID
|
||||
task := alloc.Job.TaskGroups[0].Tasks[0]
|
||||
tasks := []string{task.Name}
|
||||
task.Vault = &structs.Vault{Policies: []string{"a", "b"}}
|
||||
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Return an error when creating the token
|
||||
tvc.SetCreateTokenError(alloc.ID, task.Name,
|
||||
structs.NewRecoverableError(fmt.Errorf("recover"), true))
|
||||
|
||||
req := &structs.DeriveVaultTokenRequest{
|
||||
NodeID: node.ID,
|
||||
SecretID: node.SecretID,
|
||||
AllocID: alloc.ID,
|
||||
Tasks: tasks,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: "global",
|
||||
},
|
||||
}
|
||||
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
|
||||
if err != nil {
|
||||
t.Fatalf("bad: %v", err)
|
||||
}
|
||||
if resp.Error == nil || !resp.Error.Recoverable {
|
||||
t.Fatalf("bad: %+v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,6 +389,11 @@ type VaultAccessor struct {
|
||||
type DeriveVaultTokenResponse struct {
|
||||
// Tasks is a mapping between the task name and the wrapped token
|
||||
Tasks map[string]string
|
||||
|
||||
// Error stores any error that occured. Errors are stored here so we can
|
||||
// communicate whether it is retriable
|
||||
Error *RecoverableError
|
||||
|
||||
QueryMeta
|
||||
}
|
||||
|
||||
@@ -3688,3 +3693,27 @@ type KeyringResponse struct {
|
||||
type KeyringRequest struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
// RecoverableError wraps an error and marks whether it is recoverable and could
|
||||
// be retried or it is fatal.
|
||||
type RecoverableError struct {
|
||||
Err string
|
||||
Recoverable bool
|
||||
}
|
||||
|
||||
// NewRecoverableError is used to wrap an error and mark it as recoverable or
|
||||
// not.
|
||||
func NewRecoverableError(e error, recoverable bool) *RecoverableError {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &RecoverableError{
|
||||
Err: e.Error(),
|
||||
Recoverable: recoverable,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RecoverableError) Error() string {
|
||||
return r.Err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -45,6 +46,14 @@ const (
|
||||
// vaultRevocationIntv is the interval at which Vault tokens that failed
|
||||
// initial revocation are retried
|
||||
vaultRevocationIntv = 5 * time.Minute
|
||||
|
||||
// Errors returned by Vault
|
||||
|
||||
// vaultErrInvalidRequest is returned if the request is invalid
|
||||
vaultErrInvalidRequest = "invalid request"
|
||||
|
||||
// vaultErrPermissionDenied is returned if the client is not authorized
|
||||
vaultErrPermissionDenied = "permission denied"
|
||||
)
|
||||
|
||||
// VaultClient is the Servers interface for interfacing with Vault
|
||||
@@ -104,8 +113,11 @@ type vaultClient struct {
|
||||
config *config.VaultConfig
|
||||
|
||||
// connEstablished marks whether we have an established connection to Vault.
|
||||
// It should be accessed using a helper and updated atomically
|
||||
connEstablished int32
|
||||
connEstablished bool
|
||||
|
||||
// connEstablishedErr marks an error that can occur when establishing a
|
||||
// connection
|
||||
connEstablishedErr error
|
||||
|
||||
// token is the raw token used by the client
|
||||
token string
|
||||
@@ -202,7 +214,7 @@ func (v *vaultClient) flush() {
|
||||
|
||||
v.client = nil
|
||||
v.auth = nil
|
||||
v.connEstablished = 0
|
||||
v.connEstablished = false
|
||||
v.token = ""
|
||||
v.tokenData = nil
|
||||
v.revoking = make(map[*structs.VaultAccessor]time.Time)
|
||||
@@ -225,7 +237,7 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error {
|
||||
|
||||
if v.config.IsEnabled() {
|
||||
// Stop accepting any new request
|
||||
atomic.StoreInt32(&v.connEstablished, 0)
|
||||
v.connEstablished = false
|
||||
|
||||
// Kill any background routine and create a new tomb
|
||||
v.tomb.Kill(nil)
|
||||
@@ -310,8 +322,8 @@ OUTER:
|
||||
case <-retryTimer.C:
|
||||
// Ensure the API is reachable
|
||||
if _, err := v.client.Sys().InitStatus(); err != nil {
|
||||
v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v",
|
||||
v.config.ConnectionRetryIntv)
|
||||
v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v: %v",
|
||||
v.config.ConnectionRetryIntv, err)
|
||||
retryTimer.Reset(v.config.ConnectionRetryIntv)
|
||||
continue OUTER
|
||||
}
|
||||
@@ -323,6 +335,10 @@ OUTER:
|
||||
// Retrieve our token, validate it and parse the lease duration
|
||||
if err := v.parseSelfToken(); err != nil {
|
||||
v.logger.Printf("[ERR] vault: failed to lookup self token and not retrying: %v", err)
|
||||
v.l.Lock()
|
||||
v.connEstablished = false
|
||||
v.connEstablishedErr = err
|
||||
v.l.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -339,7 +355,9 @@ OUTER:
|
||||
v.tomb.Go(wrapNilError(v.renewalLoop))
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&v.connEstablished, 1)
|
||||
v.l.Lock()
|
||||
v.connEstablished = true
|
||||
v.l.Unlock()
|
||||
}
|
||||
|
||||
// renewalLoop runs the renew loop. This should only be called if we are given a
|
||||
@@ -407,7 +425,10 @@ func (v *vaultClient) renewalLoop() {
|
||||
// We have failed to renew the token past its expiration. Stop
|
||||
// renewing with Vault.
|
||||
v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Shutting down Vault client")
|
||||
atomic.StoreInt32(&v.connEstablished, 0)
|
||||
v.l.Lock()
|
||||
v.connEstablished = false
|
||||
v.connEstablishedErr = err
|
||||
v.l.Unlock()
|
||||
return
|
||||
|
||||
} else if backoff > maxBackoff.Seconds() {
|
||||
@@ -521,36 +542,42 @@ func (v *vaultClient) parseSelfToken() error {
|
||||
}
|
||||
|
||||
// ConnectionEstablished returns whether a connection to Vault has been
|
||||
// established.
|
||||
func (v *vaultClient) ConnectionEstablished() bool {
|
||||
return atomic.LoadInt32(&v.connEstablished) == 1
|
||||
// established and any error that potentially caused it to be false
|
||||
func (v *vaultClient) ConnectionEstablished() (bool, error) {
|
||||
v.l.Lock()
|
||||
defer v.l.Unlock()
|
||||
return v.connEstablished, v.connEstablishedErr
|
||||
}
|
||||
|
||||
// Enabled returns whether the client is active
|
||||
func (v *vaultClient) Enabled() bool {
|
||||
v.l.Lock()
|
||||
defer v.l.Unlock()
|
||||
return v.config.IsEnabled()
|
||||
}
|
||||
|
||||
//
|
||||
// Active returns whether the client is active
|
||||
func (v *vaultClient) Active() bool {
|
||||
return atomic.LoadInt32(&v.active) == 1
|
||||
}
|
||||
|
||||
// CreateToken takes the allocation and task and returns an appropriate Vault
|
||||
// token. The call is rate limited and may be canceled with the passed policy
|
||||
// token. The call is rate limited and may be canceled with the passed policy.
|
||||
// When the error is recoverable, it will be of type RecoverableError
|
||||
func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
|
||||
if !v.Enabled() {
|
||||
return nil, fmt.Errorf("Vault integration disabled")
|
||||
}
|
||||
|
||||
if !v.Active() {
|
||||
return nil, fmt.Errorf("Vault client not active")
|
||||
return nil, structs.NewRecoverableError(fmt.Errorf("Vault client not active"), true)
|
||||
}
|
||||
|
||||
// Check if we have established a connection with Vault
|
||||
if !v.ConnectionEstablished() {
|
||||
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
|
||||
if established, err := v.ConnectionEstablished(); !established && err == nil {
|
||||
return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
|
||||
} else if !established {
|
||||
return nil, fmt.Errorf("Connection to Vault failed: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve the Vault block for the task
|
||||
@@ -596,7 +623,19 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta
|
||||
secret, err = v.auth.CreateWithRole(req, v.tokenData.Role)
|
||||
}
|
||||
|
||||
return secret, err
|
||||
// Determine whether it is unrecoverable
|
||||
if err != nil {
|
||||
eStr := err.Error()
|
||||
if strings.Contains(eStr, vaultErrInvalidRequest) ||
|
||||
strings.Contains(eStr, vaultErrPermissionDenied) {
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// The error is recoverable
|
||||
return nil, structs.NewRecoverableError(err, true)
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// LookupToken takes a Vault token and does a lookup against Vault. The call is
|
||||
@@ -611,8 +650,10 @@ func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secr
|
||||
}
|
||||
|
||||
// Check if we have established a connection with Vault
|
||||
if !v.ConnectionEstablished() {
|
||||
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
|
||||
if established, err := v.ConnectionEstablished(); !established && err == nil {
|
||||
return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
|
||||
} else if !established {
|
||||
return nil, fmt.Errorf("Connection to Vault failed: %v", err)
|
||||
}
|
||||
|
||||
// Ensure we are under our rate limit
|
||||
@@ -652,7 +693,7 @@ func (v *vaultClient) RevokeTokens(ctx context.Context, accessors []*structs.Vau
|
||||
|
||||
// Check if we have established a connection with Vault. If not just add it
|
||||
// to the queue
|
||||
if !v.ConnectionEstablished() {
|
||||
if established, err := v.ConnectionEstablished(); !established && err == nil {
|
||||
// Only bother tracking it for later revocation if the accessor was
|
||||
// committed
|
||||
if committed {
|
||||
@@ -709,8 +750,10 @@ func (v *vaultClient) parallelRevoke(ctx context.Context, accessors []*structs.V
|
||||
}
|
||||
|
||||
// Check if we have established a connection with Vault
|
||||
if !v.ConnectionEstablished() {
|
||||
return fmt.Errorf("Connection to Vault has not been established. Retry")
|
||||
if established, err := v.ConnectionEstablished(); !established && err == nil {
|
||||
return structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
|
||||
} else if !established {
|
||||
return fmt.Errorf("Connection to Vault failed: %v", err)
|
||||
}
|
||||
|
||||
g, pCtx := errgroup.WithContext(ctx)
|
||||
@@ -770,7 +813,7 @@ func (v *vaultClient) revokeDaemon() {
|
||||
case <-v.tomb.Dying():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
if !v.ConnectionEstablished() {
|
||||
if established, _ := v.ConnectionEstablished(); !established {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package nomad
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -67,7 +68,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
|
||||
// Sleep a little while and check that no connection has been established.
|
||||
time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond)
|
||||
|
||||
if client.ConnectionEstablished() {
|
||||
if established, _ := client.ConnectionEstablished(); established {
|
||||
t.Fatalf("ConnectionEstablished() returned true before Vault server started")
|
||||
}
|
||||
|
||||
@@ -417,7 +418,7 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
|
||||
|
||||
// Set the configs token in a new test role
|
||||
v.Config.Token = testVaultRoleAndToken(v, t, 5)
|
||||
//testVaultRoleAndToken(v, t, 5)
|
||||
|
||||
// Start the client
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger, nil)
|
||||
@@ -458,6 +459,74 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_CreateToken_Role_InvalidToken(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
// Set the configs token in a new test role
|
||||
testVaultRoleAndToken(v, t, 5)
|
||||
v.Config.Token = "foo-bar"
|
||||
|
||||
// Start the client
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
client.SetActive(true)
|
||||
defer client.Stop()
|
||||
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
established, err := client.ConnectionEstablished()
|
||||
if established {
|
||||
return false, fmt.Errorf("Shouldn't establish")
|
||||
}
|
||||
|
||||
return err != nil, nil
|
||||
}, func(err error) {
|
||||
t.Fatalf("Connection not established")
|
||||
})
|
||||
|
||||
// Create an allocation that requires a Vault policy
|
||||
a := mock.Alloc()
|
||||
task := a.Job.TaskGroups[0].Tasks[0]
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
_, err = client.CreateToken(context.Background(), a, task.Name)
|
||||
if err == nil || !strings.Contains(err.Error(), "Connection to Vault failed") {
|
||||
t.Fatalf("CreateToken should have failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_CreateToken_Prestart(t *testing.T) {
|
||||
v := testutil.NewTestVault(t)
|
||||
defer v.Stop()
|
||||
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
client, err := NewVaultClient(v.Config, logger, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
client.SetActive(true)
|
||||
defer client.Stop()
|
||||
|
||||
// Create an allocation that requires a Vault policy
|
||||
a := mock.Alloc()
|
||||
task := a.Job.TaskGroups[0].Tasks[0]
|
||||
task.Vault = &structs.Vault{Policies: []string{"default"}}
|
||||
|
||||
_, err = client.CreateToken(context.Background(), a, task.Name)
|
||||
if err == nil {
|
||||
t.Fatalf("CreateToken should have failed: %v", err)
|
||||
}
|
||||
|
||||
if rerr, ok := err.(*structs.RecoverableError); !ok {
|
||||
t.Fatalf("Err should have been type recoverable error")
|
||||
} else if ok && !rerr.Recoverable {
|
||||
t.Fatalf("Err should have been recoverable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) {
|
||||
v := testutil.NewTestVault(t)
|
||||
logger := log.New(os.Stderr, "", log.LstdFlags)
|
||||
@@ -559,7 +628,7 @@ func TestVaultClient_RevokeTokens(t *testing.T) {
|
||||
|
||||
func waitForConnection(v *vaultClient, t *testing.T) {
|
||||
testutil.WaitForResult(func() (bool, error) {
|
||||
return v.ConnectionEstablished(), nil
|
||||
return v.ConnectionEstablished()
|
||||
}, func(err error) {
|
||||
t.Fatalf("Connection not established")
|
||||
})
|
||||
|
||||
8
vendor/vendor.json
vendored
8
vendor/vendor.json
vendored
@@ -270,14 +270,14 @@
|
||||
{
|
||||
"checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=",
|
||||
"path": "github.com/docker/docker/pkg/ioutils",
|
||||
"revision": "52debcd58ac91bf68503ce60561536911b74ff05",
|
||||
"revisionTime": "2016-05-20T15:17:10Z"
|
||||
"revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5",
|
||||
"revisionTime": "2016-05-28T08:11:04Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=",
|
||||
"path": "github.com/docker/docker/pkg/ioutils",
|
||||
"revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5",
|
||||
"revisionTime": "2016-05-28T08:11:04Z"
|
||||
"revision": "52debcd58ac91bf68503ce60561536911b74ff05",
|
||||
"revisionTime": "2016-05-20T15:17:10Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "ndnAFCfsGC3upNQ6jAEwzxcurww=",
|
||||
|
||||
Reference in New Issue
Block a user