From 26004c54076ef4bb6e75711f5fd39287d44749d6 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Fri, 13 Jun 2025 13:50:54 -0400 Subject: [PATCH] vault: set renew increment to lease duration (#26041) When we renew Vault tokens, we use the lease duration to determine how often to renew. But we also set an `increment` value which is never updated from the initial 30s. For periodic tokens this is not a problem because the `increment` field is ignored on renewal. But for non-periodic tokens this prevents the token TTL from being properly incremented. This behavior has been in place since the initial Vault client implementation in #1606 but before the switch to workload identity most (all?) tokens being created were periodic tokens so this was never detected. Fix this bug by updating the request's `increment` field to the lease duration on each renewal. Also switch out a `time.After` call in backoff of the derive token caller with a safe timer so that we don't have to spawn a new goroutine per loop, and have tighter control over when that's GC'd. Ref: https://github.com/hashicorp/nomad/pull/1606 Ref: https://github.com/hashicorp/nomad/issues/25812 --- .changelog/26041.txt | 3 ++ .../taskrunner/task_runner_linux_test.go | 4 +- .../taskrunner/task_runner_test.go | 22 +++++----- client/allocrunner/taskrunner/vault_hook.go | 37 ++++++++++------- .../allocrunner/taskrunner/vault_hook_test.go | 16 ++++---- client/vaultclient/vaultclient.go | 20 +++++----- client/vaultclient/vaultclient_test.go | 40 +++++++++++++++---- client/vaultclient/vaultclient_testing.go | 12 +++--- 8 files changed, 97 insertions(+), 57 deletions(-) create mode 100644 .changelog/26041.txt diff --git a/.changelog/26041.txt b/.changelog/26041.txt new file mode 100644 index 000000000..7e5593595 --- /dev/null +++ b/.changelog/26041.txt @@ -0,0 +1,3 @@ +```release-note:bug +vault: Fixed a bug where non-periodic tokens would not have their TTL incremented to the lease duration +``` diff --git a/client/allocrunner/taskrunner/task_runner_linux_test.go b/client/allocrunner/taskrunner/task_runner_linux_test.go index 67452245e..e773393e6 100644 --- a/client/allocrunner/taskrunner/task_runner_linux_test.go +++ b/client/allocrunner/taskrunner/task_runner_linux_test.go @@ -35,8 +35,8 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) { // Setup a test Vault client. token := "1234" - handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { - return token, true, nil + handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { + return token, true, 30, nil } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) must.NoError(t, err) diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 89c1dd914..5f103107c 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -1462,9 +1462,9 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) { // Control when we get a Vault token token := "1234" waitCh := make(chan struct{}) - handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { + handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { <-waitCh - return token, true, nil + return token, true, 30, nil } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) @@ -1571,8 +1571,8 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) { // Setup a test Vault client token := "1234" - handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { - return token, true, nil + handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { + return token, true, 30, nil } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) must.NoError(t, err) @@ -1639,13 +1639,13 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) { // Fail on the first attempt to derive a vault token token := "1234" count := 0 - handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { + handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { if count > 0 { - return token, true, nil + return token, true, 30, nil } count++ - return "", false, structs.NewRecoverableError(fmt.Errorf("want a retry"), true) + return "", false, 0, structs.NewRecoverableError(fmt.Errorf("want a retry"), true) } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) must.NoError(t, err) @@ -1741,8 +1741,8 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) { must.NoError(t, err) vc.(*vaultclient.MockVaultClient).SetDeriveTokenWithJWTFn( - func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { - return "", false, errors.New("unrecoverable") + func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { + return "", false, 0, errors.New("unrecoverable") }, ) @@ -2076,9 +2076,9 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { // Control when we get a Vault token waitCh := make(chan struct{}, 1) defer close(waitCh) - handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) { + handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) { <-waitCh - return "1234", true, nil + return "1234", true, 30, nil } vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster) must.NoError(t, err) diff --git a/client/allocrunner/taskrunner/vault_hook.go b/client/allocrunner/taskrunner/vault_hook.go index 44764e12c..3a03f178f 100644 --- a/client/allocrunner/taskrunner/vault_hook.go +++ b/client/allocrunner/taskrunner/vault_hook.go @@ -238,6 +238,7 @@ func (h *vaultHook) run(token string) { // updatedToken lets us store state between loops. If true, a new token // has been retrieved and we need to apply the Vault change mode var updatedToken bool + leaseDuration := 30 OUTER: for { @@ -255,7 +256,7 @@ OUTER: if token == "" { // Get a token var exit bool - token, exit = h.deriveVaultToken() + token, leaseDuration, exit = h.deriveVaultToken() if exit { // Exit the manager return @@ -289,7 +290,10 @@ OUTER: // // If Vault is having availability issues or is overloaded, a large // number of initial token renews can exacerbate the problem. - renewCh, err := h.client.RenewToken(token, 30) + if leaseDuration == 0 { + leaseDuration = 30 + } + renewCh, err := h.client.RenewToken(token, leaseDuration) // An error returned means the token is not being renewed if err != nil { @@ -358,13 +362,17 @@ OUTER: // deriveVaultToken derives the Vault token using exponential backoffs. It // returns the Vault token and whether the manager should exit. -func (h *vaultHook) deriveVaultToken() (string, bool) { +func (h *vaultHook) deriveVaultToken() (string, int, bool) { var attempts uint64 var backoff time.Duration + + timer, stopTimer := helper.NewSafeTimer(0) + defer stopTimer() + for { - token, err := h.deriveVaultTokenJWT() + token, lease, err := h.deriveVaultTokenJWT() if err == nil { - return token, false + return token, lease, false } // Check if we can't recover from the error @@ -374,11 +382,12 @@ func (h *vaultHook) deriveVaultToken() (string, bool) { structs.NewTaskEvent(structs.TaskKilling). SetFailsTask(). SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err))) - return "", true + return "", 0, true } // Handle the retry case backoff = helper.Backoff(vaultBackoffBaseline, vaultBackoffLimit, attempts) + timer.Reset(backoff) attempts++ h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff) @@ -386,14 +395,14 @@ func (h *vaultHook) deriveVaultToken() (string, bool) { // Wait till retrying select { case <-h.ctx.Done(): - return "", true - case <-time.After(backoff): + return "", 0, true + case <-timer.C: } } } // deriveVaultTokenJWT returns a Vault ACL token using JWT auth login. -func (h *vaultHook) deriveVaultTokenJWT() (string, error) { +func (h *vaultHook) deriveVaultTokenJWT() (string, int, error) { // Retrieve signed identity. signed, err := h.widmgr.Get(structs.WIHandle{ IdentityName: h.widName, @@ -401,13 +410,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) { WorkloadType: structs.WorkloadTypeTask, }) if err != nil { - return "", structs.NewRecoverableError( + return "", 0, structs.NewRecoverableError( fmt.Errorf("failed to retrieve signed workload identity: %w", err), true, ) } if signed == nil { - return "", structs.NewRecoverableError( + return "", 0, structs.NewRecoverableError( errors.New("no signed workload identity available"), false, ) @@ -419,13 +428,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) { } // Derive Vault token with signed identity. - token, renewable, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{ + token, renewable, leaseDuration, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{ JWT: signed.JWT, Role: role, Namespace: h.vaultBlock.Namespace, }) if err != nil { - return "", structs.WrapRecoverable( + return "", 0, structs.WrapRecoverable( fmt.Sprintf("failed to derive Vault token for identity %s: %v", h.widName, err), err, ) @@ -437,7 +446,7 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) { h.allowTokenExpiration = true } - return token, nil + return token, leaseDuration, nil } // writeToken writes the given token to disk diff --git a/client/allocrunner/taskrunner/vault_hook_test.go b/client/allocrunner/taskrunner/vault_hook_test.go index 7e4be5e40..00d825a42 100644 --- a/client/allocrunner/taskrunner/vault_hook_test.go +++ b/client/allocrunner/taskrunner/vault_hook_test.go @@ -460,10 +460,10 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) { // Set unrecoverable error. mockVaultClient.SetDeriveTokenWithJWTFn( - func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) { + func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) { // Cancel the context to simulate the task being killed. cancel() - return "", false, structs.NewRecoverableError(errors.New("unrecoverable test error"), false) + return "", false, 0, structs.NewRecoverableError(errors.New("unrecoverable test error"), false) }) err := hook.Prestart(ctx, req, &resp) @@ -509,16 +509,16 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) { // Set recoverable error. mockVaultClient.SetDeriveTokenWithJWTFn( - func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) { - return "", false, structs.NewRecoverableError(errors.New("recoverable test error"), true) + func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) { + return "", false, 0, structs.NewRecoverableError(errors.New("recoverable test error"), true) }) go func() { // Wait a bit for the first error then fix token renewal. time.Sleep(time.Second) mockVaultClient.SetDeriveTokenWithJWTFn( - func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) { - return "secret", true, nil + func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) { + return "secret", true, 30, nil }) }() @@ -555,8 +555,8 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) { // Derive predictable token and fail renew request. mockVaultClient.SetDeriveTokenWithJWTFn( - func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) { - return "secret", true, nil + func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) { + return "secret", true, 30, nil }) mockVaultClient.SetRenewTokenError("secret", errors.New("test error")) diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 88a107bef..a1afe22de 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -50,8 +50,9 @@ type VaultClient interface { Stop() // DeriveTokenWithJWT returns a Vault ACL token using the JWT login - // endpoint, along with whether or not the token is renewable. - DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, error) + // endpoint, along with whether or not the token is renewable and its lease + // duration. + DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, int, error) // RenewToken renews a token with the given increment and adds it to // the min-heap for periodic renewal. @@ -237,12 +238,12 @@ func (c *vaultClient) unlockAndUnset() { } // DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint. -func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) { +func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) { if !c.config.IsEnabled() { - return "", false, fmt.Errorf("vault client not enabled") + return "", false, 0, fmt.Errorf("vault client not enabled") } if !c.isRunning() { - return "", false, fmt.Errorf("vault client is not running") + return "", false, 0, fmt.Errorf("vault client is not running") } c.lock.Lock() @@ -263,20 +264,20 @@ func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginReques }, ) if err != nil { - return "", false, fmt.Errorf("failed to login with JWT: %v", err) + return "", false, 0, fmt.Errorf("failed to login with JWT: %v", err) } if s == nil { - return "", false, errors.New("JWT login returned an empty secret") + return "", false, 0, errors.New("JWT login returned an empty secret") } if s.Auth == nil { - return "", false, errors.New("JWT login did not return a token") + return "", false, 0, errors.New("JWT login did not return a token") } for _, w := range s.Warnings { c.logger.Warn("JWT login warning", "warning", w) } - return s.Auth.ClientToken, s.Auth.Renewable, nil + return s.Auth.ClientToken, s.Auth.Renewable, s.Auth.LeaseDuration, nil } // RenewToken renews the supplied token for a given duration (in seconds) and @@ -368,6 +369,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { } else { // Don't set this if renewal fails leaseDuration = renewResp.Auth.LeaseDuration + req.increment = leaseDuration } // Reset the token in the API client before returning diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 2b222f608..1dcfe8ac9 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "testing" @@ -218,13 +219,14 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) { // Derive Vault token using signed JWT. jwtStr := signedWIDs[0].JWT - token, renewable, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ + token, renewable, leaseDuration, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ JWT: jwtStr, Namespace: "default", }) must.NoError(t, err) must.NotEq(t, "", token) must.True(t, renewable) + must.Eq(t, 72*60*60, leaseDuration) // token_period from role // Verify token has expected properties. v.Client.SetToken(token) @@ -259,7 +261,7 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) { must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any)) // Derive Vault token with non-existing role. - token, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ + token, _, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ JWT: jwtStr, Role: "test", Namespace: "default", @@ -448,8 +450,14 @@ func TestVaultClient_SetUserAgent(t *testing.T) { func TestVaultClient_RenewalConcurrent(t *testing.T) { ci.Parallel(t) + // collects renewal requests that the mock Vault API gets + requestCh := make(chan string, 10) + // Create test server to mock the Vault API. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + requestCh <- string(b) + resp := vaultapi.Secret{ RequestID: uuid.Generate(), LeaseID: uuid.Generate(), @@ -458,7 +466,7 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) { Auth: &vaultapi.SecretAuth{ ClientToken: uuid.Generate(), Accessor: uuid.Generate(), - LeaseDuration: 300, + LeaseDuration: 1, // force a fast renewal }, } @@ -482,9 +490,9 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) { vc.Start() // Renew token multiple times in parallel. - requests := 100 + expectedRenewals := 100 resultCh := make(chan any) - for i := 0; i < requests; i++ { + for range expectedRenewals { go func() { _, err := vc.RenewToken("token", 30) resultCh <- err @@ -494,12 +502,28 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) { // Collect results with timeout. timer, stop := helper.NewSafeTimer(3 * time.Second) defer stop() - for i := 0; i < requests; i++ { + + sawInitial := 0 + sawRenew := 0 + for { select { + case got := <-requestCh: + switch got { + case `{"increment":1}`: + sawRenew++ + case `{"increment":30}`: + sawInitial++ + default: + t.Fatalf("unexpected request body: %q", got) + } + if sawInitial == expectedRenewals && sawRenew >= expectedRenewals { + return + } case got := <-resultCh: must.Nil(t, got, must.Sprintf("token renewal error: %v", got)) case <-timer.C: - t.Fatal("timeout waiting for token renewal") + t.Fatalf("timeout waiting for expected token renewals (initial: %d renewed: %d)", + sawInitial, sawRenew) } } } @@ -524,7 +548,7 @@ func TestVaultClient_NamespaceReset(t *testing.T) { must.NoError(t, err) vc.Start() - _, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ + _, _, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{ JWT: "bogus", Namespace: "bar", }) diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index 2516ac40d..65d91805a 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -35,20 +35,22 @@ type MockVaultClient struct { // deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT // function. - deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, error) + deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, int, error) // renewable determines if the tokens returned should be marked as renewable renewable bool + duration int + mu sync.Mutex } // NewMockVaultClient returns a MockVaultClient for testing func NewMockVaultClient(_ string) (VaultClient, error) { - return &MockVaultClient{renewable: true}, nil + return &MockVaultClient{renewable: true, duration: 30}, nil } -func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) { +func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) { vc.mu.Lock() defer vc.mu.Unlock() @@ -65,7 +67,7 @@ func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginR token = fmt.Sprintf("%s-%s", token, req.Role) } vc.jwtTokens[req.JWT] = token - return token, vc.renewable, nil + return token, vc.renewable, vc.duration, nil } func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) { @@ -161,7 +163,7 @@ func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error { } // SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT. -func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, error)) { +func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, int, error)) { vc.mu.Lock() defer vc.mu.Unlock() vc.deriveTokenWithJWTFn = f