Thread through whether DeriveToken error is recoverable or not

This commit is contained in:
Alex Dadgar
2016-10-22 18:08:30 -07:00
parent 0e296f4811
commit 42f7bc8e81
13 changed files with 389 additions and 105 deletions

View File

@@ -1714,6 +1714,10 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
}
if resp.Error != nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", resp.Error)
return nil, resp.Error
}
if resp.Tasks == nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")

View File

@@ -629,7 +629,7 @@ func (d *DockerDriver) recoverablePullError(err error, image string) error {
if imageNotFoundMatcher.MatchString(err.Error()) {
recoverable = false
}
return dstructs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
return structs.NewRecoverableError(fmt.Errorf("Failed to pull `%s`: %s", image, err), recoverable)
}
func (d *DockerDriver) Periodic() (bool, time.Duration) {

View File

@@ -37,26 +37,6 @@ func (r *WaitResult) String() string {
r.ExitCode, r.Signal, r.Err)
}
// RecoverableError wraps an error and marks whether it is recoverable and could
// be retried or it is fatal.
type RecoverableError struct {
Err error
Recoverable bool
}
// NewRecoverableError is used to wrap an error and mark it as recoverable or
// not.
func NewRecoverableError(e error, recoverable bool) *RecoverableError {
return &RecoverableError{
Err: e,
Recoverable: recoverable,
}
}
func (r *RecoverableError) Error() string {
return r.Err.Error()
}
// CheckResult encapsulates the result of a check
type CheckResult struct {

View File

@@ -6,7 +6,7 @@ import (
"sync"
"time"
cstructs "github.com/hashicorp/nomad/client/driver/structs"
dstructs "github.com/hashicorp/nomad/client/driver/structs"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -34,7 +34,7 @@ func newRestartTracker(policy *structs.RestartPolicy, jobType string) *RestartTr
}
type RestartTracker struct {
waitRes *cstructs.WaitResult
waitRes *dstructs.WaitResult
startErr error
restartTriggered bool // Whether the task has been signalled to be restarted
count int // Current number of attempts.
@@ -63,7 +63,7 @@ func (r *RestartTracker) SetStartError(err error) *RestartTracker {
}
// SetWaitResult is used to mark the most recent wait result.
func (r *RestartTracker) SetWaitResult(res *cstructs.WaitResult) *RestartTracker {
func (r *RestartTracker) SetWaitResult(res *dstructs.WaitResult) *RestartTracker {
r.lock.Lock()
defer r.lock.Unlock()
r.waitRes = res
@@ -149,7 +149,7 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
// infinitely try to start a task.
func (r *RestartTracker) handleStartError() (string, time.Duration) {
// If the error is not recoverable, do not restart.
if rerr, ok := r.startErr.(*cstructs.RecoverableError); !(ok && rerr.Recoverable) {
if rerr, ok := r.startErr.(*structs.RecoverableError); !(ok && rerr.Recoverable) {
r.reason = ReasonUnrecoverableErrror
return structs.TaskNotRestarting, 0
}

View File

@@ -108,7 +108,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Fail(t *testing.T) {
t.Parallel()
p := testPolicy(true, structs.RestartPolicyModeFail)
rt := newRestartTracker(p, structs.JobTypeSystem)
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
for i := 0; i < p.Attempts; i++ {
state, when := rt.SetStartError(recErr).GetState()
if state != structs.TaskRestarting {
@@ -129,7 +129,7 @@ func TestClient_RestartTracker_StartError_Recoverable_Delay(t *testing.T) {
t.Parallel()
p := testPolicy(true, structs.RestartPolicyModeDelay)
rt := newRestartTracker(p, structs.JobTypeSystem)
recErr := cstructs.NewRecoverableError(fmt.Errorf("foo"), true)
recErr := structs.NewRecoverableError(fmt.Errorf("foo"), true)
for i := 0; i < p.Attempts; i++ {
state, when := rt.SetStartError(recErr).GetState()
if state != structs.TaskRestarting {

View File

@@ -509,10 +509,10 @@ OUTER:
// restoring the TaskRunner
if token == "" {
// Get a token
var ok bool
token, ok = r.deriveVaultToken()
if !ok {
// We are shutting down
var exit bool
token, exit = r.deriveVaultToken()
if exit {
// Exit the manager
return
}
@@ -589,12 +589,20 @@ OUTER:
// deriveVaultToken derives the Vault token using exponential backoffs. It
// returns the Vault token and whether the token is valid. If it is not valid we
// are shutting down
func (r *TaskRunner) deriveVaultToken() (string, bool) {
func (r *TaskRunner) deriveVaultToken() (token string, exit bool) {
attempts := 0
for {
tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name})
if err == nil {
return tokens[r.task.Name], true
return tokens[r.task.Name], false
}
// Check if we can't recover from the error
if rerr, ok := err.(*structs.RecoverableError); !ok || !rerr.Recoverable {
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v",
r.task.Name, r.alloc.ID, err)
r.Kill("vault", fmt.Sprintf("failed to derive token: %v", err))
return "", true
}
// Handle the retry case
@@ -602,14 +610,15 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) {
if backoff > vaultBackoffLimit {
backoff = vaultBackoffLimit
}
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff)
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v",
r.task.Name, r.alloc.ID, err, backoff)
attempts++
// Wait till retrying
select {
case <-r.waitCh:
return "", false
return "", true
case <-time.After(backoff):
}
}
@@ -706,7 +715,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
if err := getter.GetArtifact(r.getTaskEnv(), artifact, r.taskDir); err != nil {
r.setState(structs.TaskStatePending,
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err))
r.restartTracker.SetStartError(dstructs.NewRecoverableError(err, true))
r.restartTracker.SetStartError(structs.NewRecoverableError(err, true))
goto RESTART
}
}

View File

@@ -721,7 +721,7 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
}
count++
return nil, fmt.Errorf("Want a retry")
return nil, structs.NewRecoverableError(fmt.Errorf("Want a retry"), true)
}
tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
go tr.Run()
@@ -770,6 +770,49 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
}
}
func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
ChangeMode: structs.VaultChangeModeRestart,
}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
// Error the token derivation
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
vc.SetDeriveTokenError(alloc.ID, []string{task.Name}, fmt.Errorf("Non recoverable"))
go tr.Run()
// Wait for the task to start
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 2 {
return false, fmt.Errorf("Expect two events; got %v", l)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskKilling {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskKilling)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
}
func TestTaskRunner_Template_Block(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]

View File

@@ -15,6 +15,7 @@ import (
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/watch"
"github.com/hashicorp/raft"
vapi "github.com/hashicorp/vault/api"
)
@@ -940,22 +941,26 @@ func (b *batchFuture) Respond(index uint64, err error) {
func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
reply *structs.DeriveVaultTokenResponse) error {
if done, err := n.srv.forward("Node.DeriveVaultToken", args, args, reply); done {
return err
reply.Error = structs.NewRecoverableError(err, err == structs.ErrNoLeader)
return nil
}
defer metrics.MeasureSince([]string{"nomad", "client", "derive_vault_token"}, time.Now())
// Verify the arguments
if args.NodeID == "" {
return fmt.Errorf("missing node ID")
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node ID"), false)
}
if args.SecretID == "" {
return fmt.Errorf("missing node SecretID")
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing node SecretID"), false)
return nil
}
if args.AllocID == "" {
return fmt.Errorf("missing allocation ID")
reply.Error = structs.NewRecoverableError(fmt.Errorf("missing allocation ID"), false)
return nil
}
if len(args.Tasks) == 0 {
return fmt.Errorf("no tasks specified")
reply.Error = structs.NewRecoverableError(fmt.Errorf("no tasks specified"), false)
return nil
}
// Verify the following:
@@ -965,41 +970,51 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
// tokens
snap, err := n.srv.fsm.State().Snapshot()
if err != nil {
return err
reply.Error = structs.NewRecoverableError(err, false)
return nil
}
node, err := snap.NodeByID(args.NodeID)
if err != nil {
return err
reply.Error = structs.NewRecoverableError(err, false)
return nil
}
if node == nil {
return fmt.Errorf("Node %q does not exist", args.NodeID)
reply.Error = structs.NewRecoverableError(fmt.Errorf("Node %q does not exist", args.NodeID), false)
return nil
}
if node.SecretID != args.SecretID {
return fmt.Errorf("SecretID mismatch")
reply.Error = structs.NewRecoverableError(fmt.Errorf("SecretID mismatch"), false)
return nil
}
alloc, err := snap.AllocByID(args.AllocID)
if err != nil {
return err
reply.Error = structs.NewRecoverableError(err, false)
return nil
}
if alloc == nil {
return fmt.Errorf("Allocation %q does not exist", args.AllocID)
reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q does not exist", args.AllocID), false)
return nil
}
if alloc.NodeID != args.NodeID {
return fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID)
reply.Error = structs.NewRecoverableError(fmt.Errorf("Allocation %q not running on Node %q", args.AllocID, args.NodeID), false)
return nil
}
if alloc.TerminalStatus() {
return fmt.Errorf("Can't request Vault token for terminal allocation")
reply.Error = structs.NewRecoverableError(fmt.Errorf("Can't request Vault token for terminal allocation"), false)
return nil
}
// Check the policies
policies := alloc.Job.VaultPolicies()
if policies == nil {
return fmt.Errorf("Job doesn't require Vault policies")
reply.Error = structs.NewRecoverableError(fmt.Errorf("Job doesn't require Vault policies"), false)
return nil
}
tg, ok := policies[alloc.TaskGroup]
if !ok {
return fmt.Errorf("Task group does not require Vault policies")
reply.Error = structs.NewRecoverableError(fmt.Errorf("Task group does not require Vault policies"), false)
return nil
}
var unneeded []string
@@ -1011,8 +1026,10 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
}
if len(unneeded) != 0 {
return fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
e := fmt.Errorf("Requested Vault tokens for tasks without defined Vault policies: %s",
strings.Join(unneeded, ", "))
reply.Error = structs.NewRecoverableError(e, false)
return nil
}
// At this point the request is valid and we should contact Vault for
@@ -1043,7 +1060,13 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
secret, err := n.srv.vault.CreateToken(ctx, alloc, task)
if err != nil {
return fmt.Errorf("failed to create token for task %q: %v", task, err)
wrapped := fmt.Errorf("failed to create token for task %q: %v", task, err)
if rerr, ok := err.(*structs.RecoverableError); ok && rerr.Recoverable {
// If the error is recoverable, propogate it
return structs.NewRecoverableError(wrapped, true)
}
return wrapped
}
results[task] = secret
@@ -1068,9 +1091,9 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
}()
// Wait for everything to complete or for an error
err = g.Wait()
createErr := g.Wait()
// Commit to Raft before returning any of the tokens
// Retrieve the results
accessors := make([]*structs.VaultAccessor, 0, len(results))
tokens := make(map[string]string, len(results))
for task, secret := range results {
@@ -1092,20 +1115,36 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
}
// If there was an error revoke the created tokens
if err != nil {
var mErr multierror.Error
mErr.Errors = append(mErr.Errors, err)
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil {
mErr.Errors = append(mErr.Errors, err)
if createErr != nil {
if revokeErr := n.srv.vault.RevokeTokens(context.Background(), accessors, false); revokeErr != nil {
n.srv.logger.Printf("[ERR] nomad.node: Vault token revocation failed: %v", revokeErr)
}
return mErr.ErrorOrNil()
if rerr, ok := createErr.(*structs.RecoverableError); ok {
reply.Error = rerr
} else {
reply.Error = structs.NewRecoverableError(createErr, false)
}
return nil
}
// Commit to Raft before returning any of the tokens
req := structs.VaultAccessorsRequest{Accessors: accessors}
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
return err
// Determine if we can recover from the error
retry := false
switch err {
case raft.ErrNotLeader, raft.ErrLeadershipLost, raft.ErrRaftShutdown, raft.ErrEnqueueTimeout:
retry = true
default:
}
reply.Error = structs.NewRecoverableError(err, retry)
return nil
}
reply.Index = index

View File

@@ -1822,18 +1822,23 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
}
var resp structs.DeriveVaultTokenResponse
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "SecretID mismatch") {
t.Fatalf("Expected SecretID mismatch: %v", err)
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "SecretID mismatch") {
t.Fatalf("Expected SecretID mismatch: %v", resp.Error)
}
// Put the correct SecretID
req.SecretID = node.SecretID
// Now we should get an error about the allocation not running on the node
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "not running on Node") {
t.Fatalf("Expected not running on node error: %v", err)
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "not running on Node") {
t.Fatalf("Expected not running on node error: %v", resp.Error)
}
// Update to be running on the node
@@ -1843,9 +1848,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
}
// Now we should get an error about the job not needing any Vault secrets
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "does not require") {
t.Fatalf("Expected no policies error: %v", err)
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "does not require") {
t.Fatalf("Expected no policies error: %v", resp.Error)
}
// Update to be terminal
@@ -1855,9 +1862,11 @@ func TestClientEndpoint_DeriveVaultToken_Bad(t *testing.T) {
}
// Now we should get an error about the job not needing any Vault secrets
err = msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err == nil || !strings.Contains(err.Error(), "terminal") {
t.Fatalf("Expected terminal allocation error: %v", err)
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error == nil || !strings.Contains(resp.Error.Error(), "terminal") {
t.Fatalf("Expected terminal allocation error: %v", resp.Error)
}
}
@@ -1920,6 +1929,9 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
if err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp); err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error != nil {
t.Fatalf("bad: %v", resp.Error)
}
// Check the state store and ensure that we created a VaultAccessor
va, err := state.VaultAccessor(accessor)
@@ -1947,3 +1959,59 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) {
t.Fatalf("Got %#v; want %#v", va, expected)
}
}
func TestClientEndpoint_DeriveVaultToken_VaultError(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
state := s1.fsm.State()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Enable vault and allow authenticated
tr := true
s1.config.VaultConfig.Enabled = &tr
s1.config.VaultConfig.AllowUnauthenticated = &tr
// Replace the Vault Client on the server
tvc := &TestVaultClient{}
s1.vault = tvc
// Create the node
node := mock.Node()
if err := state.UpsertNode(2, node); err != nil {
t.Fatalf("err: %v", err)
}
// Create an alloc an allocation that has vault policies required
alloc := mock.Alloc()
alloc.NodeID = node.ID
task := alloc.Job.TaskGroups[0].Tasks[0]
tasks := []string{task.Name}
task.Vault = &structs.Vault{Policies: []string{"a", "b"}}
if err := state.UpsertAllocs(3, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
// Return an error when creating the token
tvc.SetCreateTokenError(alloc.ID, task.Name,
structs.NewRecoverableError(fmt.Errorf("recover"), true))
req := &structs.DeriveVaultTokenRequest{
NodeID: node.ID,
SecretID: node.SecretID,
AllocID: alloc.ID,
Tasks: tasks,
QueryOptions: structs.QueryOptions{
Region: "global",
},
}
var resp structs.DeriveVaultTokenResponse
err := msgpackrpc.CallWithCodec(codec, "Node.DeriveVaultToken", req, &resp)
if err != nil {
t.Fatalf("bad: %v", err)
}
if resp.Error == nil || !resp.Error.Recoverable {
t.Fatalf("bad: %+v", resp.Error)
}
}

View File

@@ -389,6 +389,11 @@ type VaultAccessor struct {
type DeriveVaultTokenResponse struct {
// Tasks is a mapping between the task name and the wrapped token
Tasks map[string]string
// Error stores any error that occured. Errors are stored here so we can
// communicate whether it is retriable
Error *RecoverableError
QueryMeta
}
@@ -3688,3 +3693,27 @@ type KeyringResponse struct {
type KeyringRequest struct {
Key string
}
// RecoverableError wraps an error and marks whether it is recoverable and could
// be retried or it is fatal.
type RecoverableError struct {
Err string
Recoverable bool
}
// NewRecoverableError is used to wrap an error and mark it as recoverable or
// not.
func NewRecoverableError(e error, recoverable bool) *RecoverableError {
if e == nil {
return nil
}
return &RecoverableError{
Err: e.Error(),
Recoverable: recoverable,
}
}
func (r *RecoverableError) Error() string {
return r.Err
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"log"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
@@ -45,6 +46,14 @@ const (
// vaultRevocationIntv is the interval at which Vault tokens that failed
// initial revocation are retried
vaultRevocationIntv = 5 * time.Minute
// Errors returned by Vault
// vaultErrInvalidRequest is returned if the request is invalid
vaultErrInvalidRequest = "invalid request"
// vaultErrPermissionDenied is returned if the client is not authorized
vaultErrPermissionDenied = "permission denied"
)
// VaultClient is the Servers interface for interfacing with Vault
@@ -104,8 +113,11 @@ type vaultClient struct {
config *config.VaultConfig
// connEstablished marks whether we have an established connection to Vault.
// It should be accessed using a helper and updated atomically
connEstablished int32
connEstablished bool
// connEstablishedErr marks an error that can occur when establishing a
// connection
connEstablishedErr error
// token is the raw token used by the client
token string
@@ -202,7 +214,7 @@ func (v *vaultClient) flush() {
v.client = nil
v.auth = nil
v.connEstablished = 0
v.connEstablished = false
v.token = ""
v.tokenData = nil
v.revoking = make(map[*structs.VaultAccessor]time.Time)
@@ -225,7 +237,7 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error {
if v.config.IsEnabled() {
// Stop accepting any new request
atomic.StoreInt32(&v.connEstablished, 0)
v.connEstablished = false
// Kill any background routine and create a new tomb
v.tomb.Kill(nil)
@@ -310,8 +322,8 @@ OUTER:
case <-retryTimer.C:
// Ensure the API is reachable
if _, err := v.client.Sys().InitStatus(); err != nil {
v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v",
v.config.ConnectionRetryIntv)
v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v: %v",
v.config.ConnectionRetryIntv, err)
retryTimer.Reset(v.config.ConnectionRetryIntv)
continue OUTER
}
@@ -323,6 +335,10 @@ OUTER:
// Retrieve our token, validate it and parse the lease duration
if err := v.parseSelfToken(); err != nil {
v.logger.Printf("[ERR] vault: failed to lookup self token and not retrying: %v", err)
v.l.Lock()
v.connEstablished = false
v.connEstablishedErr = err
v.l.Unlock()
return
}
@@ -339,7 +355,9 @@ OUTER:
v.tomb.Go(wrapNilError(v.renewalLoop))
}
atomic.StoreInt32(&v.connEstablished, 1)
v.l.Lock()
v.connEstablished = true
v.l.Unlock()
}
// renewalLoop runs the renew loop. This should only be called if we are given a
@@ -407,7 +425,10 @@ func (v *vaultClient) renewalLoop() {
// We have failed to renew the token past its expiration. Stop
// renewing with Vault.
v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Shutting down Vault client")
atomic.StoreInt32(&v.connEstablished, 0)
v.l.Lock()
v.connEstablished = false
v.connEstablishedErr = err
v.l.Unlock()
return
} else if backoff > maxBackoff.Seconds() {
@@ -521,36 +542,42 @@ func (v *vaultClient) parseSelfToken() error {
}
// ConnectionEstablished returns whether a connection to Vault has been
// established.
func (v *vaultClient) ConnectionEstablished() bool {
return atomic.LoadInt32(&v.connEstablished) == 1
// established and any error that potentially caused it to be false
func (v *vaultClient) ConnectionEstablished() (bool, error) {
v.l.Lock()
defer v.l.Unlock()
return v.connEstablished, v.connEstablishedErr
}
// Enabled returns whether the client is active
func (v *vaultClient) Enabled() bool {
v.l.Lock()
defer v.l.Unlock()
return v.config.IsEnabled()
}
//
// Active returns whether the client is active
func (v *vaultClient) Active() bool {
return atomic.LoadInt32(&v.active) == 1
}
// CreateToken takes the allocation and task and returns an appropriate Vault
// token. The call is rate limited and may be canceled with the passed policy
// token. The call is rate limited and may be canceled with the passed policy.
// When the error is recoverable, it will be of type RecoverableError
func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, task string) (*vapi.Secret, error) {
if !v.Enabled() {
return nil, fmt.Errorf("Vault integration disabled")
}
if !v.Active() {
return nil, fmt.Errorf("Vault client not active")
return nil, structs.NewRecoverableError(fmt.Errorf("Vault client not active"), true)
}
// Check if we have established a connection with Vault
if !v.ConnectionEstablished() {
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
if established, err := v.ConnectionEstablished(); !established && err == nil {
return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
} else if !established {
return nil, fmt.Errorf("Connection to Vault failed: %v", err)
}
// Retrieve the Vault block for the task
@@ -596,7 +623,19 @@ func (v *vaultClient) CreateToken(ctx context.Context, a *structs.Allocation, ta
secret, err = v.auth.CreateWithRole(req, v.tokenData.Role)
}
return secret, err
// Determine whether it is unrecoverable
if err != nil {
eStr := err.Error()
if strings.Contains(eStr, vaultErrInvalidRequest) ||
strings.Contains(eStr, vaultErrPermissionDenied) {
return secret, err
}
// The error is recoverable
return nil, structs.NewRecoverableError(err, true)
}
return secret, nil
}
// LookupToken takes a Vault token and does a lookup against Vault. The call is
@@ -611,8 +650,10 @@ func (v *vaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secr
}
// Check if we have established a connection with Vault
if !v.ConnectionEstablished() {
return nil, fmt.Errorf("Connection to Vault has not been established. Retry")
if established, err := v.ConnectionEstablished(); !established && err == nil {
return nil, structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
} else if !established {
return nil, fmt.Errorf("Connection to Vault failed: %v", err)
}
// Ensure we are under our rate limit
@@ -652,7 +693,7 @@ func (v *vaultClient) RevokeTokens(ctx context.Context, accessors []*structs.Vau
// Check if we have established a connection with Vault. If not just add it
// to the queue
if !v.ConnectionEstablished() {
if established, err := v.ConnectionEstablished(); !established && err == nil {
// Only bother tracking it for later revocation if the accessor was
// committed
if committed {
@@ -709,8 +750,10 @@ func (v *vaultClient) parallelRevoke(ctx context.Context, accessors []*structs.V
}
// Check if we have established a connection with Vault
if !v.ConnectionEstablished() {
return fmt.Errorf("Connection to Vault has not been established. Retry")
if established, err := v.ConnectionEstablished(); !established && err == nil {
return structs.NewRecoverableError(fmt.Errorf("Connection to Vault has not been established"), true)
} else if !established {
return fmt.Errorf("Connection to Vault failed: %v", err)
}
g, pCtx := errgroup.WithContext(ctx)
@@ -770,7 +813,7 @@ func (v *vaultClient) revokeDaemon() {
case <-v.tomb.Dying():
return
case now := <-ticker.C:
if !v.ConnectionEstablished() {
if established, _ := v.ConnectionEstablished(); !established {
continue
}

View File

@@ -3,6 +3,7 @@ package nomad
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"reflect"
@@ -67,7 +68,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
// Sleep a little while and check that no connection has been established.
time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond)
if client.ConnectionEstablished() {
if established, _ := client.ConnectionEstablished(); established {
t.Fatalf("ConnectionEstablished() returned true before Vault server started")
}
@@ -417,7 +418,7 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
// Set the configs token in a new test role
v.Config.Token = testVaultRoleAndToken(v, t, 5)
//testVaultRoleAndToken(v, t, 5)
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
@@ -458,6 +459,74 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
}
}
func TestVaultClient_CreateToken_Role_InvalidToken(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
// Set the configs token in a new test role
testVaultRoleAndToken(v, t, 5)
v.Config.Token = "foo-bar"
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
testutil.WaitForResult(func() (bool, error) {
established, err := client.ConnectionEstablished()
if established {
return false, fmt.Errorf("Shouldn't establish")
}
return err != nil, nil
}, func(err error) {
t.Fatalf("Connection not established")
})
// Create an allocation that requires a Vault policy
a := mock.Alloc()
task := a.Job.TaskGroups[0].Tasks[0]
task.Vault = &structs.Vault{Policies: []string{"default"}}
_, err = client.CreateToken(context.Background(), a, task.Name)
if err == nil || !strings.Contains(err.Error(), "Connection to Vault failed") {
t.Fatalf("CreateToken should have failed: %v", err)
}
}
func TestVaultClient_CreateToken_Prestart(t *testing.T) {
v := testutil.NewTestVault(t)
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
// Create an allocation that requires a Vault policy
a := mock.Alloc()
task := a.Job.TaskGroups[0].Tasks[0]
task.Vault = &structs.Vault{Policies: []string{"default"}}
_, err = client.CreateToken(context.Background(), a, task.Name)
if err == nil {
t.Fatalf("CreateToken should have failed: %v", err)
}
if rerr, ok := err.(*structs.RecoverableError); !ok {
t.Fatalf("Err should have been type recoverable error")
} else if ok && !rerr.Recoverable {
t.Fatalf("Err should have been recoverable")
}
}
func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) {
v := testutil.NewTestVault(t)
logger := log.New(os.Stderr, "", log.LstdFlags)
@@ -559,7 +628,7 @@ func TestVaultClient_RevokeTokens(t *testing.T) {
func waitForConnection(v *vaultClient, t *testing.T) {
testutil.WaitForResult(func() (bool, error) {
return v.ConnectionEstablished(), nil
return v.ConnectionEstablished()
}, func(err error) {
t.Fatalf("Connection not established")
})

8
vendor/vendor.json vendored
View File

@@ -270,14 +270,14 @@
{
"checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=",
"path": "github.com/docker/docker/pkg/ioutils",
"revision": "52debcd58ac91bf68503ce60561536911b74ff05",
"revisionTime": "2016-05-20T15:17:10Z"
"revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5",
"revisionTime": "2016-05-28T08:11:04Z"
},
{
"checksumSHA1": "tdhmIGUaoOMEDymMC23qTS7bt0g=",
"path": "github.com/docker/docker/pkg/ioutils",
"revision": "da39e9a4f920a15683dd0f23923c302d4db6eed5",
"revisionTime": "2016-05-28T08:11:04Z"
"revision": "52debcd58ac91bf68503ce60561536911b74ff05",
"revisionTime": "2016-05-20T15:17:10Z"
},
{
"checksumSHA1": "ndnAFCfsGC3upNQ6jAEwzxcurww=",