This commit is contained in:
Alex Dadgar
2016-10-18 11:22:16 -07:00
parent e34902ae8a
commit ae288a3ee6
4 changed files with 549 additions and 32 deletions

View File

@@ -101,6 +101,19 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
r.lock.Lock()
defer r.lock.Unlock()
// Clear out the existing state
defer func() {
r.startErr = nil
r.waitRes = nil
r.restartTriggered = false
}()
// Hot path if a restart was triggered
if r.restartTriggered {
r.reason = ""
return structs.TaskRestarting, 0
}
// Hot path if no attempts are expected
if r.policy.Attempts == 0 {
r.reason = ReasonNoRestartsAllowed
@@ -121,25 +134,13 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
r.startTime = now
}
var state string
var dur time.Duration
if r.startErr != nil {
state, dur = r.handleStartError()
return r.handleStartError()
} else if r.waitRes != nil {
state, dur = r.handleWaitResult()
} else if r.restartTriggered {
state, dur = structs.TaskRestarting, 0
r.reason = ""
} else {
state, dur = "", 0
return r.handleWaitResult()
}
// Clear out the existing state
r.startErr = nil
r.waitRes = nil
r.restartTriggered = false
return state, dur
return "", 0
}
// handleStartError returns the new state and potential wait duration for

View File

@@ -378,6 +378,7 @@ func (r *TaskRunner) Run() {
// NewTaskRunner
if r.task.Vault != nil {
// Start the go-routine to get a Vault token
r.vaultFuture.Clear()
go r.vaultManager(r.recoveredVaultToken)
}
@@ -578,12 +579,11 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) {
for {
tokens, err := r.vaultClient.DeriveToken(r.alloc, []string{r.task.Name})
if err != nil {
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v", r.task.Name, r.alloc.ID, err)
backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline
if backoff > vaultBackoffLimit {
backoff = vaultBackoffLimit
}
r.logger.Printf("[ERR] client: failed to derive Vault token for task %v on alloc %q: %v; retrying in %v", r.task.Name, r.alloc.ID, err, backoff)
attempts++
@@ -591,11 +591,11 @@ func (r *TaskRunner) deriveVaultToken() (string, bool) {
select {
case <-r.waitCh:
return "", false
case <-time.After(backoff * time.Second):
case <-time.After(backoff):
}
} else {
return tokens[r.task.Name], true
}
return tokens[r.task.Name], true
}
}
@@ -646,7 +646,9 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
if r.task.Vault != nil {
// Wait for the token
r.logger.Printf("[DEBUG] client: waiting for Vault token for task %v in alloc %q", r.task.Name, r.alloc.ID)
tokenCh := r.vaultFuture.Wait()
r.logger.Printf("[DEBUG] client: retrieved Vault token for task %v in alloc %q", r.task.Name, r.alloc.ID)
select {
case <-tokenCh:
@@ -1174,6 +1176,7 @@ func (r *TaskRunner) UnblockStart(source string) {
}
r.logger.Printf("[DEBUG] client: unblocking task %v for alloc %q: %v", r.task.Name, r.alloc.ID, source)
r.unblocked = true
close(r.unblockCh)
}

View File

@@ -2,6 +2,7 @@ package client
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
@@ -221,24 +222,65 @@ func TestTaskRunner_Update(t *testing.T) {
}
func TestTaskRunner_SaveRestoreState(t *testing.T) {
ctestutil.ExecCompatible(t)
upd, tr := testTaskRunner(false)
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "5s",
}
// Change command to ensure we run for a bit
tr.task.Config["command"] = "/bin/sleep"
tr.task.Config["args"] = []string{"10"}
// Give it a Vault token
task.Vault = &structs.Vault{Policies: []string{"default"}}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
go tr.Run()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
// Snapshot state
time.Sleep(2 * time.Second)
// Wait for the task to be running and then snapshot the state
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 2 {
return false, fmt.Errorf("Expect two events; got %v", l)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
if err := tr.SaveState(); err != nil {
t.Fatalf("err: %v", err)
}
// Read the token from the file system
secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name)
if err != nil {
t.Fatalf("failed to determine task %s secret dir: %v", err)
}
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
token := string(data)
if len(token) == 0 {
t.Fatalf("Token not written to disk")
}
// Create a new task runner
tr2 := NewTaskRunner(tr.logger, tr.config, upd.Update,
tr.ctx, tr.alloc, &structs.Task{Name: tr.task.Name}, tr.vaultClient)
tr2.restartTracker = noRestartsTracker()
if err := tr2.RestoreState(); err != nil {
t.Fatalf("err: %v", err)
}
@@ -246,11 +288,16 @@ func TestTaskRunner_SaveRestoreState(t *testing.T) {
defer tr2.Destroy(structs.NewTaskEvent(structs.TaskKilled))
// Destroy and wait
testutil.WaitForResult(func() (bool, error) {
return tr2.handle != nil, fmt.Errorf("RestoreState() didn't open handle")
}, func(err error) {
t.Fatalf("err: %v", err)
})
select {
case <-tr2.WaitCh():
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
t.Fatalf("timeout")
}
// Check that we recovered the token
if act := tr2.vaultFuture.Get(); act != token {
t.Fatalf("Vault token not properly recovered")
}
}
func TestTaskRunner_Download_List(t *testing.T) {
@@ -558,3 +605,463 @@ func TestTaskRunner_SignalFailure(t *testing.T) {
t.Fatalf("Didn't receive error")
}
}
func TestTaskRunner_BlockForVault(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "1s",
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
handler := func(*structs.Allocation, []string) (map[string]string, error) {
<-waitCh
return map[string]string{task.Name: token}, nil
}
tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
go tr.Run()
select {
case <-tr.WaitCh():
t.Fatalf("premature exit")
case <-time.After(1 * time.Second):
}
if len(upd.events) != 1 {
t.Fatalf("should have 1 updates: %#v", upd.events)
}
if upd.state != structs.TaskStatePending {
t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStatePending)
}
if upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
// Unblock
close(waitCh)
select {
case <-tr.WaitCh():
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
t.Fatalf("timeout")
}
if len(upd.events) != 3 {
t.Fatalf("should have 3 updates: %#v", upd.events)
}
if upd.state != structs.TaskStateDead {
t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead)
}
if upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
if upd.events[2].Type != structs.TaskTerminated {
t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated)
}
// Check that the token is on disk
secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name)
if err != nil {
t.Fatalf("failed to determine task %s secret dir: %v", err)
}
// Read the token from the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if act := string(data); act != token {
t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token)
}
}
func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "1s",
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
// Control when we get a Vault token
token := "1234"
count := 0
handler := func(*structs.Allocation, []string) (map[string]string, error) {
if count > 0 {
return map[string]string{task.Name: token}, nil
}
count++
return nil, fmt.Errorf("Want a retry")
}
tr.vaultClient.(*vaultclient.MockVaultClient).DeriveTokenFn = handler
go tr.Run()
select {
case <-tr.WaitCh():
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
t.Fatalf("timeout")
}
if len(upd.events) != 3 {
t.Fatalf("should have 3 updates: %#v", upd.events)
}
if upd.state != structs.TaskStateDead {
t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead)
}
if upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
if upd.events[2].Type != structs.TaskTerminated {
t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated)
}
// Check that the token is on disk
secretDir, err := tr.ctx.AllocDir.GetSecretDir(task.Name)
if err != nil {
t.Fatalf("failed to determine task %s secret dir: %v", err)
}
// Read the token from the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if act := string(data); act != token {
t.Fatalf("Token didn't get written to disk properly, got %q; want %q", act, token)
}
}
func TestTaskRunner_Template_Block(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "1s",
}
task.Templates = []*structs.Template{
{
EmbeddedTmpl: "{{key \"foo\"}}",
DestPath: "local/test",
ChangeMode: structs.TemplateChangeModeNoop,
},
}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
go tr.Run()
select {
case <-tr.WaitCh():
t.Fatalf("premature exit")
case <-time.After(1 * time.Second):
}
if len(upd.events) != 1 {
t.Fatalf("should have 1 updates: %#v", upd.events)
}
if upd.state != structs.TaskStatePending {
t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStatePending)
}
if upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
// Unblock
tr.UnblockStart("test")
select {
case <-tr.WaitCh():
case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second):
t.Fatalf("timeout")
}
if len(upd.events) != 3 {
t.Fatalf("should have 3 updates: %#v", upd.events)
}
if upd.state != structs.TaskStateDead {
t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead)
}
if upd.events[0].Type != structs.TaskReceived {
t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
if upd.events[2].Type != structs.TaskTerminated {
t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated)
}
}
func TestTaskRunner_Template_NewVaultToken(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "1s",
}
task.Templates = []*structs.Template{
{
EmbeddedTmpl: "{{key \"foo\"}}",
DestPath: "local/test",
ChangeMode: structs.TemplateChangeModeNoop,
},
}
task.Vault = &structs.Vault{Policies: []string{"default"}}
_, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
go tr.Run()
// Wait for a Vault token
var token string
testutil.WaitForResult(func() (bool, error) {
if token = tr.vaultFuture.Get(); token == "" {
return false, fmt.Errorf("No Vault token")
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
// Error the token renewal
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
renewalCh, ok := vc.RenewTokens[token]
if !ok {
t.Fatalf("no renewal channel")
}
originalManager := tr.templateManager
renewalCh <- fmt.Errorf("Test killing")
close(renewalCh)
// Wait for a new Vault token
var token2 string
testutil.WaitForResult(func() (bool, error) {
if token2 = tr.vaultFuture.Get(); token2 == "" || token2 == token {
return false, fmt.Errorf("No new Vault token")
}
if originalManager == tr.templateManager {
return false, fmt.Errorf("Template manager not updated")
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
}
func TestTaskRunner_VaultManager_Restart(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
ChangeMode: structs.VaultChangeModeRestart,
}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
go tr.Run()
// Wait for the task to start
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 2 {
return false, fmt.Errorf("Expect two events; got %v", l)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
// Error the token renewal
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
renewalCh, ok := vc.RenewTokens[tr.vaultFuture.Get()]
if !ok {
t.Fatalf("no renewal channel")
}
renewalCh <- fmt.Errorf("Test killing")
close(renewalCh)
// Ensure a restart
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 7 {
return false, fmt.Errorf("Expect seven events; got %#v", upd.events)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
if upd.events[2].Type != structs.TaskRestartSignal {
return false, fmt.Errorf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskRestartSignal)
}
if upd.events[3].Type != structs.TaskKilling {
return false, fmt.Errorf("Fourth Event was %v; want %v", upd.events[3].Type, structs.TaskKilling)
}
if upd.events[4].Type != structs.TaskKilled {
return false, fmt.Errorf("Fifth Event was %v; want %v", upd.events[4].Type, structs.TaskKilled)
}
if upd.events[5].Type != structs.TaskRestarting {
return false, fmt.Errorf("Sixth Event was %v; want %v", upd.events[5].Type, structs.TaskRestarting)
}
if upd.events[6].Type != structs.TaskStarted {
return false, fmt.Errorf("Seventh Event was %v; want %v", upd.events[6].Type, structs.TaskStarted)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
}
func TestTaskRunner_VaultManager_Signal(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
ChangeMode: structs.VaultChangeModeSignal,
ChangeSignal: "SIGUSR1",
}
upd, tr := testTaskRunnerFromAlloc(false, alloc)
tr.MarkReceived()
defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled))
defer tr.ctx.AllocDir.Destroy()
go tr.Run()
// Wait for the task to start
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 2 {
return false, fmt.Errorf("Expect two events; got %v", l)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
// Error the token renewal
vc := tr.vaultClient.(*vaultclient.MockVaultClient)
renewalCh, ok := vc.RenewTokens[tr.vaultFuture.Get()]
if !ok {
t.Fatalf("no renewal channel")
}
renewalCh <- fmt.Errorf("Test killing")
close(renewalCh)
// Ensure a restart
testutil.WaitForResult(func() (bool, error) {
if l := len(upd.events); l != 3 {
return false, fmt.Errorf("Expect three events; got %#v", upd.events)
}
if upd.events[0].Type != structs.TaskReceived {
return false, fmt.Errorf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived)
}
if upd.events[1].Type != structs.TaskStarted {
return false, fmt.Errorf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted)
}
if upd.events[2].Type != structs.TaskSignaling {
return false, fmt.Errorf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskSignaling)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
}

View File

@@ -21,12 +21,18 @@ type MockVaultClient struct {
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
DeriveTokenErrors map[string]map[string]error
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
if vc.DeriveTokenFn != nil {
return vc.DeriveTokenFn(a, tasks)
}
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {