From 1fca495a85c40afa339705bd44845834b0eb3a2c Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Wed, 15 Jan 2020 09:56:48 -0600 Subject: [PATCH] client: set context timeout around SI token derivation The derivation of an SI token needs to be safegaurded by a context timeout, otherwise an unresponsive Consul could cause the siHook to block forever on Prestart. --- client/allocrunner/taskrunner/sids_hook.go | 32 ++++++++++----- .../allocrunner/taskrunner/sids_hook_test.go | 40 +++++++++++++++++-- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/client/allocrunner/taskrunner/sids_hook.go b/client/allocrunner/taskrunner/sids_hook.go index 1db3713c0..77dd46ead 100644 --- a/client/allocrunner/taskrunner/sids_hook.go +++ b/client/allocrunner/taskrunner/sids_hook.go @@ -28,6 +28,11 @@ const ( // to retrieve a Consul SI token sidsBackoffLimit = 3 * time.Minute + // sidsDerivationTimeout limits the amount of time we may spend trying to + // derive a SI token. If the hook does not get a token within this amount of + // time, the result is a failure. + sidsDerivationTimeout = 5 * time.Minute + // sidsTokenFile is the name of the file holding the Consul SI token inside // the task's secret directory sidsTokenFile = "si_token" @@ -59,6 +64,11 @@ type sidsHook struct { // lifecycle is used to signal, restart, and kill a task lifecycle ti.TaskLifecycle + // derivationTimeout is the amount of time we may wait for Consul to successfully + // provide a SI token. Making this configurable for testing, otherwise + // default to sidsDerivationTimeout + derivationTimeout time.Duration + // logger is used to log logger hclog.Logger @@ -71,12 +81,13 @@ type sidsHook struct { func newSIDSHook(c sidsHookConfig) *sidsHook { return &sidsHook{ - alloc: c.alloc, - task: c.task, - sidsClient: c.sidsClient, - lifecycle: c.lifecycle, - logger: c.logger.Named(sidsHookName), - firstRun: true, + alloc: c.alloc, + task: c.task, + sidsClient: c.sidsClient, + lifecycle: c.lifecycle, + derivationTimeout: sidsDerivationTimeout, + logger: c.logger.Named(sidsHookName), + firstRun: true, } } @@ -163,18 +174,21 @@ func (h *sidsHook) recoverToken(dir string) (string, error) { // derive an SI token until a token is successfully created, or ctx is signaled // done. func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) { + ctx2, cancel := context.WithTimeout(ctx, h.derivationTimeout) + defer cancel() + tokenCh := make(chan string) // keep trying to get the token in the background - go h.tryDerive(ctx, tokenCh) + go h.tryDerive(ctx2, tokenCh) // wait until we get a token, or we get a signal to quit for { select { case token := <-tokenCh: return token, nil - case <-ctx.Done(): - return "", ctx.Err() + case <-ctx2.Done(): + return "", ctx2.Err() } } } diff --git a/client/allocrunner/taskrunner/sids_hook_test.go b/client/allocrunner/taskrunner/sids_hook_test.go index 415d1f980..5052fe9e2 100644 --- a/client/allocrunner/taskrunner/sids_hook_test.go +++ b/client/allocrunner/taskrunner/sids_hook_test.go @@ -37,8 +37,8 @@ func sidecar(task string) (string, structs.TaskKind) { func TestSIDSHook_recoverToken(t *testing.T) { t.Parallel() - r := require.New(t) + secrets := tmpDir(t) defer cleanupDir(t, secrets) @@ -63,8 +63,8 @@ func TestSIDSHook_recoverToken(t *testing.T) { func TestSIDSHook_recoverToken_empty(t *testing.T) { t.Parallel() - r := require.New(t) + secrets := tmpDir(t) defer cleanupDir(t, secrets) @@ -85,8 +85,8 @@ func TestSIDSHook_recoverToken_empty(t *testing.T) { func TestSIDSHook_deriveSIToken(t *testing.T) { t.Parallel() - r := require.New(t) + secrets := tmpDir(t) defer cleanupDir(t, secrets) @@ -108,6 +108,40 @@ func TestSIDSHook_deriveSIToken(t *testing.T) { r.True(helper.IsUUID(token), "token: %q", token) } +func TestSIDSHook_deriveSIToken_timeout(t *testing.T) { + t.Parallel() + r := require.New(t) + + secrets := tmpDir(t) + defer cleanupDir(t, secrets) + + taskName, taskKind := sidecar("task1") + + siClient := consul.NewMockServiceIdentitiesClient() + siClient.DeriveTokenFn = func(allocation *structs.Allocation, strings []string) (m map[string]string, err error) { + select { + // block forever, hopefully triggering a timeout in the caller + } + } + + h := newSIDSHook(sidsHookConfig{ + alloc: &structs.Allocation{ID: "a1"}, + task: &structs.Task{ + Name: taskName, + Kind: taskKind, + }, + logger: testlog.HCLogger(t), + sidsClient: siClient, + }) + + // set the timeout to a really small value for testing + h.derivationTimeout = time.Duration(1 * time.Millisecond) + + ctx := context.Background() + _, err := h.deriveSIToken(ctx) + r.EqualError(err, "context deadline exceeded") +} + func TestSIDSHook_computeBackoff(t *testing.T) { t.Parallel()