tr/service_hook: prevent Update from running before Poststart has finished

This commit is contained in:
Nick Ethier
2020-04-02 12:17:36 -04:00
parent 4b070e8e18
commit 88438e8982
2 changed files with 35 additions and 0 deletions

View File

@@ -48,6 +48,10 @@ type serviceHook struct {
networks structs.Networks
taskEnv *taskenv.TaskEnv
// initialRegistrations tracks if Poststart has completed, initializing
// fields required in other lifecycle funcs
initialRegistration bool
// Since Update() may be called concurrently with any other hook all
// hook methods must be fully serialized
mu sync.Mutex
@@ -91,12 +95,17 @@ func (h *serviceHook) Poststart(ctx context.Context, req *interfaces.TaskPoststa
// Create task services struct with request's driver metadata
workloadServices := h.getWorkloadServices()
defer func() { h.initialRegistration = true }()
return h.consul.RegisterWorkload(workloadServices)
}
func (h *serviceHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error {
h.mu.Lock()
defer h.mu.Unlock()
if !h.initialRegistration {
// no op since initial registration has not finished
return nil
}
// Create old task services struct with request's driver metadata as it
// can't change due to Updates

View File

@@ -1,7 +1,14 @@
package taskrunner
import (
"context"
"testing"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/consul"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/stretchr/testify/require"
)
// Statically assert the stats hook implements the expected interfaces
@@ -9,3 +16,22 @@ var _ interfaces.TaskPoststartHook = (*serviceHook)(nil)
var _ interfaces.TaskExitedHook = (*serviceHook)(nil)
var _ interfaces.TaskPreKillHook = (*serviceHook)(nil)
var _ interfaces.TaskUpdateHook = (*serviceHook)(nil)
func TestUpdate_beforePoststart(t *testing.T) {
alloc := mock.Alloc()
logger := testlog.HCLogger(t)
c := consul.NewMockConsulServiceClient(t, logger)
hook := newServiceHook(serviceHookConfig{
alloc: alloc,
task: alloc.LookupTask("web"),
consul: c,
logger: logger,
})
require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{}))
require.Len(t, c.GetOps(), 0)
require.NoError(t, hook.Poststart(context.Background(), &interfaces.TaskPoststartRequest{}, &interfaces.TaskPoststartResponse{}))
require.Len(t, c.GetOps(), 1)
require.NoError(t, hook.Update(context.Background(), &interfaces.TaskUpdateRequest{Alloc: alloc}, &interfaces.TaskUpdateResponse{}))
require.Len(t, c.GetOps(), 2)
}