Files
nomad/client/vaultclient/vaultclient_testing.go
Tim Gross 18fdda6242 vault: fix namespace reset for clients with unset namespace (#23491)
The Vault "logical" API doesn't allow configuring the namespace on a per-request
basis. Instead, it's set on the client. Our `vaultclient` wrapper locks access
to the API client and sets the namespace (and token, if applicable) for each
request, and then resets the namespace and unlocks the API client.

The logic for resetting the namespace incorrectly assumed that if the Vault
configuration didn't set the namespace that it was canonicalized to the
non-empty string `"default"`. This results in the API client's namespace getting
"stuck" whenever a job uses a non-default namespace if the configuration value
is empty. Update the logic to always go back to the configuration, rather than
accepting the "previous" namespace from the caller.

This changeset also removes some long-dead code in the Vault client wrapper.

Fixes: https://github.com/hashicorp/nomad/issues/22230
Ref: https://hashicorp.atlassian.net/browse/NET-10207
2024-07-03 10:13:20 -04:00

223 lines
5.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package vaultclient
import (
"context"
"fmt"
"sync"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/structs"
)
// MockVaultClient is used for testing the vaultclient integration and is safe
// for concurrent access.
type MockVaultClient struct {
// legacyTokens stores the tokens per task derived using the legacy flow.
legacyTokens map[string]string
// jwtTokens stores the tokens derived using the JWT flow.
jwtTokens map[string]string
// stoppedTokens tracks the tokens that have stopped renewing
stoppedTokens []string
// renewTokens are the tokens that have been renewed and their error
// channels
renewTokens map[string]chan error
// renewTokenErrors is used to return an error when the RenewToken is called
// with the given token
renewTokenErrors map[string]error
// deriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
deriveTokenErrors map[string]map[string]error
// DeriveTokenFn allows the caller to control the DeriveToken function. If
// not set an error is returned if found in DeriveTokenErrors and otherwise
// a token is generated and returned
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
// deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT
// function.
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, error)
// renewable determines if the tokens returned should be marked as renewable
renewable bool
mu sync.Mutex
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient(_ string) (VaultClient, error) {
return &MockVaultClient{renewable: true}, nil
}
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.deriveTokenWithJWTFn != nil {
return vc.deriveTokenWithJWTFn(ctx, req)
}
if vc.jwtTokens == nil {
vc.jwtTokens = make(map[string]string)
}
token := uuid.Generate()
if req.Role != "" {
token = fmt.Sprintf("%s-%s", token, req.Role)
}
vc.jwtTokens[req.JWT] = token
return token, vc.renewable, nil
}
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.DeriveTokenFn != nil {
return vc.DeriveTokenFn(a, tasks)
}
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if tasks, ok := vc.deriveTokenErrors[a.ID]; ok {
if err, ok := tasks[task]; ok {
return nil, err
}
}
tokens[task] = uuid.Generate()
}
vc.legacyTokens = tokens
return tokens, nil
}
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.deriveTokenErrors == nil {
vc.deriveTokenErrors = make(map[string]map[string]error, 10)
}
if _, ok := vc.deriveTokenErrors[allocID]; !ok {
vc.deriveTokenErrors[allocID] = make(map[string]error, 10)
}
for _, task := range tasks {
vc.deriveTokenErrors[allocID][task] = err
}
}
func (vc *MockVaultClient) RenewToken(token string, interval int) (<-chan error, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if err, ok := vc.renewTokenErrors[token]; ok {
return nil, err
}
renewCh := make(chan error)
if vc.renewTokens == nil {
vc.renewTokens = make(map[string]chan error, 10)
}
vc.renewTokens[token] = renewCh
return renewCh, nil
}
func (vc *MockVaultClient) SetRenewTokenError(token string, err error) {
vc.mu.Lock()
defer vc.mu.Unlock()
if vc.renewTokenErrors == nil {
vc.renewTokenErrors = make(map[string]error, 10)
}
vc.renewTokenErrors[token] = err
}
func (vc *MockVaultClient) StopRenewToken(token string) error {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.stoppedTokens = append(vc.stoppedTokens, token)
return nil
}
func (vc *MockVaultClient) Start() {}
func (vc *MockVaultClient) Stop() {}
func (vc *MockVaultClient) SetRenewable(renewable bool) {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.renewable = renewable
}
// LegacyTokens returns the tokens generated using the legacy flow.
func (vc *MockVaultClient) LegacyTokens() map[string]string {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.legacyTokens
}
// JWTTotkens returns the tokens generated suing the JWT flow.
func (vc *MockVaultClient) JWTTokens() map[string]string {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.jwtTokens
}
// StoppedTokens tracks the tokens that have stopped renewing
func (vc *MockVaultClient) StoppedTokens() []string {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.stoppedTokens
}
// RenewTokens are the tokens that have been renewed and their error
// channels
func (vc *MockVaultClient) RenewTokens() map[string]chan error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokens
}
// RenewTokenErrCh returns the error channel for the given token renewal
// process.
func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokens[token]
}
// RenewTokenErrors is used to return an error when the RenewToken is called
// with the given token
func (vc *MockVaultClient) RenewTokenErrors() map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.renewTokenErrors
}
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
func (vc *MockVaultClient) DeriveTokenErrors() map[string]map[string]error {
vc.mu.Lock()
defer vc.mu.Unlock()
return vc.deriveTokenErrors
}
// SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT.
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, error)) {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.deriveTokenWithJWTFn = f
}