diff --git a/client/task_runner.go b/client/task_runner.go index 4c068c570..78c609d61 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -13,8 +13,10 @@ import ( "time" "github.com/armon/go-metrics" + "github.com/golang/snappy" "github.com/hashicorp/consul-template/signals" "github.com/hashicorp/go-multierror" + "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" "github.com/hashicorp/nomad/client/getter" @@ -85,6 +87,9 @@ type TaskRunner struct { // downloaded artifactsDownloaded bool + // payloadRendered tracks whether the payload has been rendered to disk + payloadRendered bool + // vaultFuture is the means to wait for and get a Vault token vaultFuture *tokenFuture @@ -129,6 +134,7 @@ type taskRunnerState struct { Task *structs.Task HandleID string ArtifactDownloaded bool + PayloadRendered bool } // TaskStateUpdater is used to signal that tasks state has changed. @@ -231,6 +237,7 @@ func (r *TaskRunner) RestoreState() error { r.task = snap.Task } r.artifactsDownloaded = snap.ArtifactDownloaded + r.payloadRendered = snap.PayloadRendered if err := r.setTaskEnv(); err != nil { return fmt.Errorf("client: failed to create task environment for task %q in allocation %q: %v", @@ -293,6 +300,7 @@ func (r *TaskRunner) SaveState() error { Task: r.task, Version: r.config.Version, ArtifactDownloaded: r.artifactsDownloaded, + PayloadRendered: r.payloadRendered, } r.handleLock.Lock() if r.handle != nil { @@ -704,6 +712,31 @@ func (r *TaskRunner) prestart(resultCh chan bool) { return } + // If the job is a dispatch job and there is a payload write it to disk + requirePayload := len(r.alloc.Job.Payload) != 0 && + (r.task.DispatchInput != nil && r.task.DispatchInput.File != "") + if !r.payloadRendered && requirePayload { + renderTo := filepath.Join(r.taskDir, allocdir.TaskLocal, r.task.DispatchInput.File) + decoded, err := snappy.Decode(nil, r.alloc.Job.Payload) + if err != nil { + r.setState( + structs.TaskStateDead, + structs.NewTaskEvent(structs.TaskSetupFailure).SetSetupError(err).SetFailsTask()) + resultCh <- false + return + } + + if err := ioutil.WriteFile(renderTo, decoded, 0777); err != nil { + r.setState( + structs.TaskStateDead, + structs.NewTaskEvent(structs.TaskSetupFailure).SetSetupError(err).SetFailsTask()) + resultCh <- false + return + } + + r.payloadRendered = true + } + for { // Download the task's artifacts if !r.artifactsDownloaded && len(r.task.Artifacts) > 0 { diff --git a/client/task_runner_test.go b/client/task_runner_test.go index f34c9460e..abed11996 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "syscall" "testing" "time" + "github.com/golang/snappy" "github.com/hashicorp/nomad/client/allocdir" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" @@ -1244,3 +1246,66 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { t.Fatalf("err: %v", err) }) } + +// Test that the payload is written to disk +func TestTaskRunner_SimpleRun_Dispatch(t *testing.T) { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "1s", + } + fileName := "test" + task.DispatchInput = &structs.DispatchInputConfig{ + File: fileName, + } + alloc.Job.Constructor = &structs.ConstructorConfig{} + + // Add an encrypted payload + expected := []byte("hello world") + compressed := snappy.Encode(nil, expected) + alloc.Job.Payload = compressed + + upd, tr := testTaskRunnerFromAlloc(false, alloc) + tr.MarkReceived() + defer tr.Destroy(structs.NewTaskEvent(structs.TaskKilled)) + defer tr.ctx.AllocDir.Destroy() + go tr.Run() + + select { + case <-tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + if len(upd.events) != 3 { + t.Fatalf("should have 3 updates: %#v", upd.events) + } + + if upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", upd.state, structs.TaskStateDead) + } + + if upd.events[0].Type != structs.TaskReceived { + t.Fatalf("First Event was %v; want %v", upd.events[0].Type, structs.TaskReceived) + } + + if upd.events[1].Type != structs.TaskStarted { + t.Fatalf("Second Event was %v; want %v", upd.events[1].Type, structs.TaskStarted) + } + + if upd.events[2].Type != structs.TaskTerminated { + t.Fatalf("Third Event was %v; want %v", upd.events[2].Type, structs.TaskTerminated) + } + + // Check that the file was written to disk properly + payloadPath := filepath.Join(tr.taskDir, allocdir.TaskLocal, fileName) + data, err := ioutil.ReadFile(payloadPath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + if !reflect.DeepEqual(data, expected) { + t.Fatalf("Bad; got %v; want %v", string(data), string(expected)) + } +}