diff --git a/client/allocrunner/taskrunner/state/state.go b/client/allocrunner/taskrunner/state/state.go index 83481c1ac..738080ba8 100644 --- a/client/allocrunner/taskrunner/state/state.go +++ b/client/allocrunner/taskrunner/state/state.go @@ -63,7 +63,13 @@ type HookState struct { // Prestart is true if the hook has run Prestart successfully and does // not need to run again PrestartDone bool - Data map[string]string + + // Data allows hooks to persist arbitrary state. + Data map[string]string + + // Environment variables set by the hook that will continue to be set + // even if PrestartDone=true. + Env map[string]string } func (h *HookState) Copy() *HookState { diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index e5fc922c0..58ae5e6c6 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -130,6 +130,7 @@ func (tr *TaskRunner) prestart() error { } name := pre.Name() + // Build the request req := interfaces.TaskPrestartRequest{ Task: tr.Task(), @@ -148,6 +149,8 @@ func (tr *TaskRunner) prestart() error { if origHookState != nil { if origHookState.PrestartDone { tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) + // Always set env vars from hooks + tr.envBuilder.SetHookEnv(name, origHookState.Env) continue } @@ -175,6 +178,7 @@ func (tr *TaskRunner) prestart() error { hookState := &state.HookState{ Data: resp.HookData, PrestartDone: resp.Done, + Env: resp.Env, } // Store and persist local state if the hook state has changed @@ -190,9 +194,7 @@ func (tr *TaskRunner) prestart() error { } // Store the environment variables returned by the hook - if len(resp.Env) != 0 { - tr.envBuilder.SetGenericEnv(resp.Env) - } + tr.envBuilder.SetHookEnv(name, resp.Env) // Store the resources if len(resp.Devices) != 0 { diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 6b7a9977d..19a42aa71 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/config" consulapi "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" @@ -238,7 +239,7 @@ func TestTaskRunner_DevicePropogation(t *testing.T) { dm.ReserveF = func(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) { res := &device.ContainerReservation{ Envs: map[string]string{ - "123": "456", + "ABC": "123", }, Mounts: []*device.Mount{ { @@ -287,5 +288,64 @@ func TestTaskRunner_DevicePropogation(t *testing.T) { require.Equal(driverCfg.Devices[0].Permissions, "123") require.Len(driverCfg.Mounts, 1) require.Equal(driverCfg.Mounts[0].TaskPath, "foo") - require.Contains(driverCfg.Env, "123") + require.Contains(driverCfg.Env, "ABC") +} + +// mockEnvHook is a test hook that sets an env var and done=true. It fails if +// it's called more than once. +type mockEnvHook struct { + called int +} + +func (*mockEnvHook) Name() string { + return "mock_env_hook" +} + +func (h *mockEnvHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + h.called++ + + resp.Done = true + resp.Env = map[string]string{ + "mock_hook": "1", + } + + return nil +} + +// TestTaskRunner_Restore_HookEnv asserts that re-running prestart hooks with +// hook environments set restores the environment without re-running done +// hooks. +func TestTaskRunner_Restore_HookEnv(t *testing.T) { + t.Parallel() + require := require.New(t) + + alloc := mock.BatchAlloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + conf, cleanup := testTaskRunnerConfig(t, alloc, task.Name) + conf.StateDB = cstate.NewMemDB() // "persist" state between prestart calls + defer cleanup() + + tr, err := NewTaskRunner(conf) + require.NoError(err) + + // Override the default hooks to only run the mock hook + mockHook := &mockEnvHook{} + tr.runnerHooks = []interfaces.TaskHook{mockHook} + + // Manually run prestart hooks + require.NoError(tr.prestart()) + + // Assert env was called + require.Equal(1, mockHook.called) + + // Re-running prestart hooks should *not* call done mock hook + require.NoError(tr.prestart()) + + // Assert env was called + require.Equal(1, mockHook.called) + + // Assert the env is still set + env := tr.envBuilder.Build().All() + require.Contains(env, "mock_hook") + require.Equal("1", env["mock_hook"]) } diff --git a/client/allocrunner/taskrunner/validate_hook_test.go b/client/allocrunner/taskrunner/validate_hook_test.go index 823f6cdcb..c4850bd04 100644 --- a/client/allocrunner/taskrunner/validate_hook_test.go +++ b/client/allocrunner/taskrunner/validate_hook_test.go @@ -52,12 +52,12 @@ func TestTaskRunner_Validate_ServiceName(t *testing.T) { require.NoError(t, validateTask(task, builder.Build(), conf)) // Add an env var that should validate - builder.SetGenericEnv(map[string]string{"FOO": "bar"}) + builder.SetHookEnv("test", map[string]string{"FOO": "bar"}) task.Services[0].Name = "${FOO}" require.NoError(t, validateTask(task, builder.Build(), conf)) // Add an env var that should *not* validate - builder.SetGenericEnv(map[string]string{"BAD": "invalid/in/consul"}) + builder.SetHookEnv("test", map[string]string{"BAD": "invalid/in/consul"}) task.Services[0].Name = "${BAD}" require.Error(t, validateTask(task, builder.Build(), conf)) } diff --git a/client/driver/env/env.go b/client/driver/env/env.go index b9acc5d01..e29c3c8b3 100644 --- a/client/driver/env/env.go +++ b/client/driver/env/env.go @@ -297,23 +297,31 @@ type Builder struct { // and affect network env vars. networks []*structs.NetworkResource + // hookEnvs are env vars set by hooks and stored by hook name to + // support adding/removing vars from multiple hooks (eg HookA adds A:1, + // HookB adds A:2, HookA removes A, A should equal 2) + hookEnvs map[string]map[string]string + + // hookNames is a slice of hooks in hookEnvs to apply hookEnvs in the + // order the hooks are run. + hookNames []string + mu *sync.RWMutex } // NewBuilder creates a new task environment builder. func NewBuilder(node *structs.Node, alloc *structs.Allocation, task *structs.Task, region string) *Builder { - b := &Builder{ - region: region, - mu: &sync.RWMutex{}, - } + b := NewEmptyBuilder() + b.region = region return b.setTask(task).setAlloc(alloc).setNode(node) } // NewEmptyBuilder creates a new environment builder. func NewEmptyBuilder() *Builder { return &Builder{ - mu: &sync.RWMutex{}, - envvars: make(map[string]string), + mu: &sync.RWMutex{}, + hookEnvs: map[string]map[string]string{}, + envvars: make(map[string]string), } } @@ -406,7 +414,14 @@ func (b *Builder) Build() *TaskEnv { envMap[k] = hargs.ReplaceEnv(v, nodeAttrs, envMap) } - // Copy template env vars third as they override task env vars + // Copy hook env vars in the order the hooks were run + for _, h := range b.hookNames { + for k, v := range b.hookEnvs[h] { + envMap[k] = hargs.ReplaceEnv(v, nodeAttrs, envMap) + } + } + + // Copy template env vars as they override task env vars for k, v := range b.templateEnv { envMap[k] = v } @@ -428,12 +443,17 @@ func (b *Builder) UpdateTask(alloc *structs.Allocation, task *structs.Task) *Bui return b.setTask(task).setAlloc(alloc) } -func (b *Builder) SetGenericEnv(envs map[string]string) *Builder { +// SetHookEnv sets environment variables from a hook. Variables are +// Last-Write-Wins, so if a hook writes a variable that's also written by a +// later hook, the later hooks value always gets used. +func (b *Builder) SetHookEnv(hook string, envs map[string]string) *Builder { b.mu.Lock() defer b.mu.Unlock() - for k, v := range envs { - b.envvars[k] = v + + if _, exists := b.hookEnvs[hook]; !exists { + b.hookNames = append(b.hookNames, hook) } + b.hookEnvs[hook] = envs return b } diff --git a/client/driver/env/env_test.go b/client/driver/env/env_test.go index 9fb27a1c1..9d035d194 100644 --- a/client/driver/env/env_test.go +++ b/client/driver/env/env_test.go @@ -14,6 +14,7 @@ import ( cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -511,6 +512,42 @@ func TestEnvironment_Envvars(t *testing.T) { } } +// TestEnvironment_HookVars asserts hook env vars are LWW and deletes of later +// writes allow earlier hook's values to be visible. +func TestEnvironment_HookVars(t *testing.T) { + n := mock.Node() + a := mock.Alloc() + builder := NewBuilder(n, a, a.Job.TaskGroups[0].Tasks[0], "global") + + // Add vars from two hooks and assert the second one wins on + // conflicting keys. + builder.SetHookEnv("hookA", map[string]string{ + "foo": "bar", + "baz": "quux", + }) + builder.SetHookEnv("hookB", map[string]string{ + "foo": "123", + "hookB": "wins", + }) + + { + out := builder.Build().All() + assert.Equal(t, "123", out["foo"]) + assert.Equal(t, "quux", out["baz"]) + assert.Equal(t, "wins", out["hookB"]) + } + + // Asserting overwriting hook vars allows the first hooks original + // value to be used. + builder.SetHookEnv("hookB", nil) + { + out := builder.Build().All() + assert.Equal(t, "bar", out["foo"]) + assert.Equal(t, "quux", out["baz"]) + assert.NotContains(t, out, "hookB") + } +} + func TestEnvironment_Interpolate(t *testing.T) { n := mock.Node() n.Attributes["arch"] = "x86"