mirror of
https://github.com/kemko/nomad.git
synced 2026-01-03 08:55:43 +03:00
Some users with batch workloads or short-lived prestart tasks want to derive a Vaul token, use it, and then allow it to expire without requiring a constant refresh. Add the `vault.allow_token_expiration` field, which works only with the Workload Identity workflow and not the legacy workflow. When set to true, this disables the client's renewal loop in the `vault_hook`. When Vault revokes the token lease, the token will no longer be valid. The client will also now automatically detect if the Vault auth configuration does not allow renewals and will disable the renewal loop automatically. Note this should only be used when a secret is requested from Vault once at the start of a task or in a short-lived prestart task. Long-running tasks should never set `allow_token_expiration=true` if they obtain Vault secrets via `template` blocks, as the Vault token will expire and the template runner will continue to make failing requests to Vault until the `vault_retry` attempts are exhausted. Fixes: https://github.com/hashicorp/nomad/issues/8690
226 lines
5.9 KiB
Go
226 lines
5.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package vaultclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/hashicorp/nomad/helper/uuid"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
vaultapi "github.com/hashicorp/vault/api"
|
|
)
|
|
|
|
// MockVaultClient is used for testing the vaultclient integration and is safe
|
|
// for concurrent access.
|
|
type MockVaultClient struct {
|
|
// legacyTokens stores the tokens per task derived using the legacy flow.
|
|
legacyTokens map[string]string
|
|
|
|
// jwtTokens stores the tokens derived using the JWT flow.
|
|
jwtTokens map[string]string
|
|
|
|
// stoppedTokens tracks the tokens that have stopped renewing
|
|
stoppedTokens []string
|
|
|
|
// renewTokens are the tokens that have been renewed and their error
|
|
// channels
|
|
renewTokens map[string]chan error
|
|
|
|
// renewTokenErrors is used to return an error when the RenewToken is called
|
|
// with the given token
|
|
renewTokenErrors map[string]error
|
|
|
|
// deriveTokenErrors maps an allocation ID and tasks to an error when the
|
|
// token is derived
|
|
deriveTokenErrors map[string]map[string]error
|
|
|
|
// DeriveTokenFn allows the caller to control the DeriveToken function. If
|
|
// not set an error is returned if found in DeriveTokenErrors and otherwise
|
|
// a token is generated and returned
|
|
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
|
|
|
|
// deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT
|
|
// function.
|
|
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, error)
|
|
|
|
// renewable determines if the tokens returned should be marked as renewable
|
|
renewable bool
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewMockVaultClient returns a MockVaultClient for testing
|
|
func NewMockVaultClient(_ string) (VaultClient, error) {
|
|
return &MockVaultClient{renewable: true}, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
if vc.deriveTokenWithJWTFn != nil {
|
|
return vc.deriveTokenWithJWTFn(ctx, req)
|
|
}
|
|
|
|
if vc.jwtTokens == nil {
|
|
vc.jwtTokens = make(map[string]string)
|
|
}
|
|
|
|
token := uuid.Generate()
|
|
if req.Role != "" {
|
|
token = fmt.Sprintf("%s-%s", token, req.Role)
|
|
}
|
|
vc.jwtTokens[req.JWT] = token
|
|
return token, vc.renewable, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
if vc.DeriveTokenFn != nil {
|
|
return vc.DeriveTokenFn(a, tasks)
|
|
}
|
|
|
|
tokens := make(map[string]string, len(tasks))
|
|
for _, task := range tasks {
|
|
if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
|
|
if err, ok := tasks[task]; ok {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
tokens[task] = uuid.Generate()
|
|
}
|
|
|
|
vc.legacyTokens = tokens
|
|
return tokens, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
if vc.deriveTokenErrors == nil {
|
|
vc.deriveTokenErrors = make(map[string]map[string]error, 10)
|
|
}
|
|
|
|
if _, ok := vc.deriveTokenErrors[allocID]; !ok {
|
|
vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
vc.deriveTokenErrors[allocID][task] = err
|
|
}
|
|
}
|
|
|
|
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
if err, ok := vc.renewTokenErrors[token]; ok {
|
|
return nil, err
|
|
}
|
|
|
|
renewCh := make(chan error)
|
|
if vc.renewTokens == nil {
|
|
vc.renewTokens = make(map[string]chan error, 10)
|
|
}
|
|
vc.renewTokens[token] = renewCh
|
|
return renewCh, nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
if vc.renewTokenErrors == nil {
|
|
vc.renewTokenErrors = make(map[string]error, 10)
|
|
}
|
|
|
|
vc.renewTokenErrors[token] = err
|
|
}
|
|
|
|
func (vc *MockVaultClient) StopRenewToken(token string) error {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
|
|
vc.stoppedTokens = append(vc.stoppedTokens, token)
|
|
return nil
|
|
}
|
|
|
|
func (vc *MockVaultClient) Start() {}
|
|
|
|
func (vc *MockVaultClient) Stop() {}
|
|
|
|
func (vc *MockVaultClient) SetRenewable(renewable bool) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
vc.renewable = renewable
|
|
}
|
|
|
|
func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil }
|
|
|
|
// LegacyTokens returns the tokens generated using the legacy flow.
|
|
func (vc *MockVaultClient) LegacyTokens() map[string]string {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.legacyTokens
|
|
}
|
|
|
|
// JWTTotkens returns the tokens generated suing the JWT flow.
|
|
func (vc *MockVaultClient) JWTTokens() map[string]string {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.jwtTokens
|
|
}
|
|
|
|
// StoppedTokens tracks the tokens that have stopped renewing
|
|
func (vc *MockVaultClient) StoppedTokens() []string {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.stoppedTokens
|
|
}
|
|
|
|
// RenewTokens are the tokens that have been renewed and their error
|
|
// channels
|
|
func (vc *MockVaultClient) RenewTokens() map[string]chan error {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.renewTokens
|
|
}
|
|
|
|
// RenewTokenErrCh returns the error channel for the given token renewal
|
|
// process.
|
|
func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.renewTokens[token]
|
|
}
|
|
|
|
// RenewTokenErrors is used to return an error when the RenewToken is called
|
|
// with the given token
|
|
func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.renewTokenErrors
|
|
}
|
|
|
|
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
|
|
// token is derived
|
|
func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
return vc.deriveTokenErrors
|
|
}
|
|
|
|
// SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT.
|
|
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, error)) {
|
|
vc.mu.Lock()
|
|
defer vc.mu.Unlock()
|
|
vc.deriveTokenWithJWTFn = f
|
|
}
|