diff --git a/client/allocrunner/taskrunner/service_hook.go b/client/allocrunner/taskrunner/service_hook.go index 02b8d75b7..eaf5cdb89 100644 --- a/client/allocrunner/taskrunner/service_hook.go +++ b/client/allocrunner/taskrunner/service_hook.go @@ -48,6 +48,10 @@ type serviceHook struct { networks structs.Networks taskEnv *taskenv.TaskEnv + // initialRegistrations tracks if Poststart has completed, initializing + // fields required in other lifecycle funcs + initialRegistration bool + // Since Update() may be called concurrently with any other hook all // hook methods must be fully serialized mu sync.Mutex @@ -87,6 +91,7 @@ func (h *serviceHook) Poststart(ctx context.Context, req *interfaces.TaskPoststa h.driverExec = req.DriverExec h.driverNet = req.DriverNetwork h.taskEnv = req.TaskEnv + h.initialRegistration = true // Create task services struct with request's driver metadata workloadServices := h.getWorkloadServices() @@ -97,11 +102,27 @@ func (h *serviceHook) Poststart(ctx context.Context, req *interfaces.TaskPoststa func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error { h.mu.Lock() defer h.mu.Unlock() + if !h.initialRegistration { + // no op Consul since initial registration has not finished + // only update hook fields + return h.updateHookFields(req) + } // Create old task services struct with request's driver metadata as it // can't change due to Updates oldWorkloadServices := h.getWorkloadServices() + if err := h.updateHookFields(req); err != nil { + return err + } + + // Create new task services struct with those new values + newWorkloadServices := h.getWorkloadServices() + + return h.consul.UpdateWorkload(oldWorkloadServices, newWorkloadServices) +} + +func (h *serviceHook) updateHookFields(req *interfaces.TaskUpdateRequest) error { // Store new updated values out of request canary := false if req.Alloc.DeploymentStatus != nil { @@ -125,10 +146,7 @@ func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequ h.networks = networks h.canary = canary - // Create new task services struct with those new values - newWorkloadServices := h.getWorkloadServices() - - return h.consul.UpdateWorkload(oldWorkloadServices, newWorkloadServices) + return nil } func (h *serviceHook) PreKilling(ctx context.Context, req *interfaces.TaskPreKillRequest, resp *interfaces.TaskPreKillResponse) error { @@ -167,7 +185,7 @@ func (h *serviceHook) deregister() { // destroyed, so remove both variations of the service workloadServices.Canary = !workloadServices.Canary h.consul.RemoveWorkload(workloadServices) - + h.initialRegistration = false } func (h *serviceHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { diff --git a/client/allocrunner/taskrunner/service_hook_test.go b/client/allocrunner/taskrunner/service_hook_test.go index 4c246cb91..c9c753f7b 100644 --- a/client/allocrunner/taskrunner/service_hook_test.go +++ b/client/allocrunner/taskrunner/service_hook_test.go @@ -1,7 +1,14 @@ package taskrunner import ( + "context" + "testing" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/stretchr/testify/require" ) // Statically assert the stats hook implements the expected interfaces @@ -9,3 +16,39 @@ var _ interfaces.TaskPoststartHook = (*serviceHook)(nil) var _ interfaces.TaskExitedHook = (*serviceHook)(nil) var _ interfaces.TaskPreKillHook = (*serviceHook)(nil) var _ interfaces.TaskUpdateHook = (*serviceHook)(nil) + +func TestUpdate_beforePoststart(t *testing.T) { + alloc := mock.Alloc() + logger := testlog.HCLogger(t) + c := consul.NewMockConsulServiceClient(t, logger) + + hook := newServiceHook(serviceHookConfig{ + alloc: alloc, + task: alloc.LookupTask("web"), + consul: c, + logger: logger, + }) + require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{})) + require.Len(t, c.GetOps(), 0) + require.NoError(t, hook.Poststart(context.Background(), &interfaces.TaskPoststartRequest{}, &interfaces.TaskPoststartResponse{})) + require.Len(t, c.GetOps(), 1) + require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{})) + require.Len(t, c.GetOps(), 2) + + // When a task exits it could be restarted with new driver info + // so Update should again wait on Poststart. + + require.NoError(t, hook.Exited(context.Background(), &interfaces.TaskExitedRequest{}, &interfaces.TaskExitedResponse{})) + require.Len(t, c.GetOps(), 4) + require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{})) + require.Len(t, c.GetOps(), 4) + require.NoError(t, hook.Poststart(context.Background(), &interfaces.TaskPoststartRequest{}, &interfaces.TaskPoststartResponse{})) + require.Len(t, c.GetOps(), 5) + require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{})) + require.Len(t, c.GetOps(), 6) + require.NoError(t, hook.PreKilling(context.Background(), &interfaces.TaskPreKillRequest{}, &interfaces.TaskPreKillResponse{})) + require.Len(t, c.GetOps(), 8) + require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{})) + require.Len(t, c.GetOps(), 8) + +}