vault: set renew increment to lease duration (#26041)

When we renew Vault tokens, we use the lease duration to determine how often to
renew. But we also set an `increment` value which is never updated from the
initial 30s. For periodic tokens this is not a problem because the `increment`
field is ignored on renewal. But for non-periodic tokens this prevents the token
TTL from being properly incremented. This behavior has been in place since the
initial Vault client implementation in #1606 but before the switch to workload
identity most (all?) tokens being created were periodic tokens so this was never
detected.

Fix this bug by updating the request's `increment` field to the lease duration
on each renewal.

Also switch out a `time.After` call in backoff of the derive token caller with a
safe timer so that we don't have to spawn a new goroutine per loop, and have
tighter control over when that's GC'd.

Ref: https://github.com/hashicorp/nomad/pull/1606
Ref: https://github.com/hashicorp/nomad/issues/25812
This commit is contained in:
Tim Gross
2025-06-13 13:50:54 -04:00
committed by GitHub
parent fedd042e69
commit 26004c5407
8 changed files with 97 additions and 57 deletions

3
.changelog/26041.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
vault: Fixed a bug where non-periodic tokens would not have their TTL incremented to the lease duration
```

View File

@@ -35,8 +35,8 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
// Setup a test Vault client.
token := "1234"
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return token, true, nil
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)

View File

@@ -1462,9 +1462,9 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
<-waitCh
return token, true, nil
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
@@ -1571,8 +1571,8 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) {
// Setup a test Vault client
token := "1234"
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return token, true, nil
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)
@@ -1639,13 +1639,13 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
// Fail on the first attempt to derive a vault token
token := "1234"
count := 0
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
if count > 0 {
return token, true, nil
return token, true, 30, nil
}
count++
return "", false, structs.NewRecoverableError(fmt.Errorf("want a retry"), true)
return "", false, 0, structs.NewRecoverableError(fmt.Errorf("want a retry"), true)
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)
@@ -1741,8 +1741,8 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
must.NoError(t, err)
vc.(*vaultclient.MockVaultClient).SetDeriveTokenWithJWTFn(
func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return "", false, errors.New("unrecoverable")
func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "", false, 0, errors.New("unrecoverable")
},
)
@@ -2076,9 +2076,9 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
// Control when we get a Vault token
waitCh := make(chan struct{}, 1)
defer close(waitCh)
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
<-waitCh
return "1234", true, nil
return "1234", true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)

View File

@@ -238,6 +238,7 @@ func (h *vaultHook) run(token string) {
// updatedToken lets us store state between loops. If true, a new token
// has been retrieved and we need to apply the Vault change mode
var updatedToken bool
leaseDuration := 30
OUTER:
for {
@@ -255,7 +256,7 @@ OUTER:
if token == "" {
// Get a token
var exit bool
token, exit = h.deriveVaultToken()
token, leaseDuration, exit = h.deriveVaultToken()
if exit {
// Exit the manager
return
@@ -289,7 +290,10 @@ OUTER:
//
// If Vault is having availability issues or is overloaded, a large
// number of initial token renews can exacerbate the problem.
renewCh, err := h.client.RenewToken(token, 30)
if leaseDuration == 0 {
leaseDuration = 30
}
renewCh, err := h.client.RenewToken(token, leaseDuration)
// An error returned means the token is not being renewed
if err != nil {
@@ -358,13 +362,17 @@ OUTER:
// deriveVaultToken derives the Vault token using exponential backoffs. It
// returns the Vault token and whether the manager should exit.
func (h *vaultHook) deriveVaultToken() (string, bool) {
func (h *vaultHook) deriveVaultToken() (string, int, bool) {
var attempts uint64
var backoff time.Duration
timer, stopTimer := helper.NewSafeTimer(0)
defer stopTimer()
for {
token, err := h.deriveVaultTokenJWT()
token, lease, err := h.deriveVaultTokenJWT()
if err == nil {
return token, false
return token, lease, false
}
// Check if we can't recover from the error
@@ -374,11 +382,12 @@ func (h *vaultHook) deriveVaultToken() (string, bool) {
structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err)))
return "", true
return "", 0, true
}
// Handle the retry case
backoff = helper.Backoff(vaultBackoffBaseline, vaultBackoffLimit, attempts)
timer.Reset(backoff)
attempts++
h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff)
@@ -386,14 +395,14 @@ func (h *vaultHook) deriveVaultToken() (string, bool) {
// Wait till retrying
select {
case <-h.ctx.Done():
return "", true
case <-time.After(backoff):
return "", 0, true
case <-timer.C:
}
}
}
// deriveVaultTokenJWT returns a Vault ACL token using JWT auth login.
func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
func (h *vaultHook) deriveVaultTokenJWT() (string, int, error) {
// Retrieve signed identity.
signed, err := h.widmgr.Get(structs.WIHandle{
IdentityName: h.widName,
@@ -401,13 +410,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
WorkloadType: structs.WorkloadTypeTask,
})
if err != nil {
return "", structs.NewRecoverableError(
return "", 0, structs.NewRecoverableError(
fmt.Errorf("failed to retrieve signed workload identity: %w", err),
true,
)
}
if signed == nil {
return "", structs.NewRecoverableError(
return "", 0, structs.NewRecoverableError(
errors.New("no signed workload identity available"),
false,
)
@@ -419,13 +428,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
}
// Derive Vault token with signed identity.
token, renewable, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
token, renewable, leaseDuration, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
JWT: signed.JWT,
Role: role,
Namespace: h.vaultBlock.Namespace,
})
if err != nil {
return "", structs.WrapRecoverable(
return "", 0, structs.WrapRecoverable(
fmt.Sprintf("failed to derive Vault token for identity %s: %v", h.widName, err),
err,
)
@@ -437,7 +446,7 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
h.allowTokenExpiration = true
}
return token, nil
return token, leaseDuration, nil
}
// writeToken writes the given token to disk

View File

@@ -460,10 +460,10 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Set unrecoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
// Cancel the context to simulate the task being killed.
cancel()
return "", false, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
return "", false, 0, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
})
err := hook.Prestart(ctx, req, &resp)
@@ -509,16 +509,16 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Set recoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "", false, structs.NewRecoverableError(errors.New("recoverable test error"), true)
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "", false, 0, structs.NewRecoverableError(errors.New("recoverable test error"), true)
})
go func() {
// Wait a bit for the first error then fix token renewal.
time.Sleep(time.Second)
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
}()
@@ -555,8 +555,8 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Derive predictable token and fail renew request.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
mockVaultClient.SetRenewTokenError("secret", errors.New("test error"))

View File

@@ -50,8 +50,9 @@ type VaultClient interface {
Stop()
// DeriveTokenWithJWT returns a Vault ACL token using the JWT login
// endpoint, along with whether or not the token is renewable.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, error)
// endpoint, along with whether or not the token is renewable and its lease
// duration.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, int, error)
// RenewToken renews a token with the given increment and adds it to
// the min-heap for periodic renewal.
@@ -237,12 +238,12 @@ func (c *vaultClient) unlockAndUnset() {
}
// DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint.
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) {
if !c.config.IsEnabled() {
return "", false, fmt.Errorf("vault client not enabled")
return "", false, 0, fmt.Errorf("vault client not enabled")
}
if !c.isRunning() {
return "", false, fmt.Errorf("vault client is not running")
return "", false, 0, fmt.Errorf("vault client is not running")
}
c.lock.Lock()
@@ -263,20 +264,20 @@ func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginReques
},
)
if err != nil {
return "", false, fmt.Errorf("failed to login with JWT: %v", err)
return "", false, 0, fmt.Errorf("failed to login with JWT: %v", err)
}
if s == nil {
return "", false, errors.New("JWT login returned an empty secret")
return "", false, 0, errors.New("JWT login returned an empty secret")
}
if s.Auth == nil {
return "", false, errors.New("JWT login did not return a token")
return "", false, 0, errors.New("JWT login did not return a token")
}
for _, w := range s.Warnings {
c.logger.Warn("JWT login warning", "warning", w)
}
return s.Auth.ClientToken, s.Auth.Renewable, nil
return s.Auth.ClientToken, s.Auth.Renewable, s.Auth.LeaseDuration, nil
}
// RenewToken renews the supplied token for a given duration (in seconds) and
@@ -368,6 +369,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
} else {
// Don't set this if renewal fails
leaseDuration = renewResp.Auth.LeaseDuration
req.increment = leaseDuration
}
// Reset the token in the API client before returning

View File

@@ -9,6 +9,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
@@ -218,13 +219,14 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
// Derive Vault token using signed JWT.
jwtStr := signedWIDs[0].JWT
token, renewable, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, renewable, leaseDuration, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Namespace: "default",
})
must.NoError(t, err)
must.NotEq(t, "", token)
must.True(t, renewable)
must.Eq(t, 72*60*60, leaseDuration) // token_period from role
// Verify token has expected properties.
v.Client.SetToken(token)
@@ -259,7 +261,7 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any))
// Derive Vault token with non-existing role.
token, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, _, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Role: "test",
Namespace: "default",
@@ -448,8 +450,14 @@ func TestVaultClient_SetUserAgent(t *testing.T) {
func TestVaultClient_RenewalConcurrent(t *testing.T) {
ci.Parallel(t)
// collects renewal requests that the mock Vault API gets
requestCh := make(chan string, 10)
// Create test server to mock the Vault API.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := io.ReadAll(r.Body)
requestCh <- string(b)
resp := vaultapi.Secret{
RequestID: uuid.Generate(),
LeaseID: uuid.Generate(),
@@ -458,7 +466,7 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
Auth: &vaultapi.SecretAuth{
ClientToken: uuid.Generate(),
Accessor: uuid.Generate(),
LeaseDuration: 300,
LeaseDuration: 1, // force a fast renewal
},
}
@@ -482,9 +490,9 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
vc.Start()
// Renew token multiple times in parallel.
requests := 100
expectedRenewals := 100
resultCh := make(chan any)
for i := 0; i < requests; i++ {
for range expectedRenewals {
go func() {
_, err := vc.RenewToken("token", 30)
resultCh <- err
@@ -494,12 +502,28 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
// Collect results with timeout.
timer, stop := helper.NewSafeTimer(3 * time.Second)
defer stop()
for i := 0; i < requests; i++ {
sawInitial := 0
sawRenew := 0
for {
select {
case got := <-requestCh:
switch got {
case `{"increment":1}`:
sawRenew++
case `{"increment":30}`:
sawInitial++
default:
t.Fatalf("unexpected request body: %q", got)
}
if sawInitial == expectedRenewals && sawRenew >= expectedRenewals {
return
}
case got := <-resultCh:
must.Nil(t, got, must.Sprintf("token renewal error: %v", got))
case <-timer.C:
t.Fatal("timeout waiting for token renewal")
t.Fatalf("timeout waiting for expected token renewals (initial: %d renewed: %d)",
sawInitial, sawRenew)
}
}
}
@@ -524,7 +548,7 @@ func TestVaultClient_NamespaceReset(t *testing.T) {
must.NoError(t, err)
vc.Start()
_, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
_, _, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: "bogus",
Namespace: "bar",
})

View File

@@ -35,20 +35,22 @@ type MockVaultClient struct {
// deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT
// function.
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, error)
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, int, error)
// renewable determines if the tokens returned should be marked as renewable
renewable bool
duration int
mu sync.Mutex
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient(_ string) (VaultClient, error) {
return &MockVaultClient{renewable: true}, nil
return &MockVaultClient{renewable: true, duration: 30}, nil
}
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
@@ -65,7 +67,7 @@ func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginR
token = fmt.Sprintf("%s-%s", token, req.Role)
}
vc.jwtTokens[req.JWT] = token
return token, vc.renewable, nil
return token, vc.renewable, vc.duration, nil
}
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
@@ -161,7 +163,7 @@ func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error {
}
// SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT.
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, error)) {
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, int, error)) {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.deriveTokenWithJWTFn = f