// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package vaultclient import ( "container/heap" "context" "errors" "fmt" "math/rand" "strings" "sync" "time" hclog "github.com/hashicorp/go-hclog" metrics "github.com/hashicorp/go-metrics/compat" "github.com/hashicorp/nomad/helper/useragent" "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" ) // VaultClientFunc is the interface of a function that retreives the VaultClient // by cluster name. This function is injected into the allocrunner/taskrunner type VaultClientFunc func(string) (VaultClient, error) // JWTLoginRequest is used to derive a Vault ACL token using a JWT login // request. type JWTLoginRequest struct { // JWT is the signed JWT to be used for the login request. JWT string // Role is Vault ACL role to use for the login request. If empty, the // Nomad client's create_from_role value is used, or the Vault cluster // default role. Role string // Namespace is the Vault namespace to use for the login request. If empty, // the Nomad client's Vault configuration namespace will be used. Namespace string } // VaultClient is the interface which nomad client uses to interact with vault and // periodically renews the tokens and secrets. type VaultClient interface { // Start initiates the renewal loop of tokens and secrets Start() // Stop terminates the renewal loop for tokens and secrets Stop() // DeriveTokenWithJWT returns a Vault ACL token using the JWT login // endpoint, along with whether or not the token is renewable and its lease // duration. DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, int, error) // RenewToken renews a token with the given increment and adds it to // the min-heap for periodic renewal. RenewToken(string, int) (<-chan error, error) // StopRenewToken removes the token from the min-heap, stopping its // renewal. StopRenewToken(string) error } // Implementation of VaultClient interface to interact with vault and perform // token and lease renewals periodically. type vaultClient struct { // running indicates if the renewal loop is active or not running bool // client is the API client to interact with vault client *vaultapi.Client // updateCh is the channel to notify heap modifications to the renewal // loop updateCh chan struct{} // stopCh is the channel to trigger termination of renewal loop stopCh chan struct{} // heap is the min-heap to keep track of both tokens and leases heap *vaultClientHeap // config is the configuration to connect to vault config *config.VaultConfig lock sync.RWMutex logger hclog.Logger } // vaultClientRenewalRequest is a request object for renewal of both tokens and // secret's leases. type vaultClientRenewalRequest struct { // errCh is the channel into which any renewal error will be sent to errCh chan error // id is an identifier which represents either a token or a lease id string // increment is the duration for which the token or lease should be // renewed for increment int // isToken indicates whether the 'id' field is a token or not isToken bool } // Element representing an entry in the renewal heap type vaultClientHeapEntry struct { req *vaultClientRenewalRequest next time.Time index int } // Wrapper around the actual heap to provide additional semantics on top of // functions provided by the heap interface. In order to achieve that, an // additional map is placed beside the actual heap. This map can be used to // check if an entry is already present in the heap. type vaultClientHeap struct { heapMap map[string]*vaultClientHeapEntry heap vaultDataHeapImp } // Data type of the heap type vaultDataHeapImp []*vaultClientHeapEntry // NewVaultClient returns a new vault client from the given config. func NewVaultClient(config *config.VaultConfig, logger hclog.Logger) (*vaultClient, error) { if config == nil { return nil, fmt.Errorf("nil vault config") } logger = logger.Named("vault").With("name", config.Name) c := &vaultClient{ config: config, stopCh: make(chan struct{}), updateCh: make(chan struct{}, 1), // Update channel should be buffered. heap: newVaultClientHeap(), logger: logger, } if !config.IsEnabled() { return c, nil } // Get the Vault API configuration apiConf, err := config.ApiConfig() if err != nil { logger.Error("error creating vault API config", "error", err) return nil, err } // Create the Vault API client client, err := vaultapi.NewClient(apiConf) if err != nil { logger.Error("error creating vault client", "error", err) return nil, err } // Set our Nomad user agent useragent.SetHeaders(client) // SetHeaders above will replace all headers, make this call second if config.Namespace != "" { logger.Debug("configuring Vault namespace", "namespace", config.Namespace) client.SetNamespace(config.Namespace) } c.client = client return c, nil } // newVaultClientHeap returns a new vault client heap with both the heap and a // map which is a secondary index for heap elements, both initialized. func newVaultClientHeap() *vaultClientHeap { return &vaultClientHeap{ heapMap: make(map[string]*vaultClientHeapEntry), heap: make(vaultDataHeapImp, 0), } } // isTracked returns if a given identifier is already present in the heap and // hence is being renewed. Lock should be held before calling this method. func (c *vaultClient) isTracked(id string) bool { if id == "" { return false } _, ok := c.heap.heapMap[id] return ok } // isRunning returns true if the client is running. func (c *vaultClient) isRunning() bool { c.lock.RLock() defer c.lock.RUnlock() return c.running } // Start starts the renewal loop of vault client func (c *vaultClient) Start() { c.lock.Lock() defer c.lock.Unlock() if !c.config.IsEnabled() || c.running { return } c.running = true go c.run() } // Stop stops the renewal loop of vault client func (c *vaultClient) Stop() { c.lock.Lock() defer c.lock.Unlock() if !c.config.IsEnabled() || !c.running { return } c.running = false close(c.stopCh) } // unlockAndUnset is used to unset the vault token on the client, restore the // client's default configured namespace, and release the lock. Helper method // for deferring a call that does both. func (c *vaultClient) unlockAndUnset() { c.client.SetToken("") c.client.SetNamespace(c.config.Namespace) c.lock.Unlock() } // DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint. func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) { if !c.config.IsEnabled() { return "", false, 0, fmt.Errorf("vault client not enabled") } if !c.isRunning() { return "", false, 0, fmt.Errorf("vault client is not running") } c.lock.Lock() defer c.unlockAndUnset() // Make sure the login request is not passing any token and that we're using // the expected namespace to login c.client.SetToken("") if req.Namespace != "" { c.client.SetNamespace(req.Namespace) } jwtLoginPath := fmt.Sprintf("auth/%s/login", c.config.JWTAuthBackendPath) s, err := c.client.Logical().WriteWithContext(ctx, jwtLoginPath, map[string]any{ "role": req.Role, "jwt": req.JWT, }, ) if err != nil { return "", false, 0, fmt.Errorf("failed to login with JWT: %v", err) } if s == nil { return "", false, 0, errors.New("JWT login returned an empty secret") } if s.Auth == nil { return "", false, 0, errors.New("JWT login did not return a token") } for _, w := range s.Warnings { c.logger.Warn("JWT login warning", "warning", w) } return s.Auth.ClientToken, s.Auth.Renewable, s.Auth.LeaseDuration, nil } // RenewToken renews the supplied token for a given duration (in seconds) and // adds it to the min-heap so that it is renewed periodically by the renewal // loop. Any error returned during renewal will be written to a buffered // channel and the channel is returned instead of an actual error. This helps // the caller be notified of a renewal failure asynchronously for appropriate // actions to be taken. The caller of this function need not have to close the // error channel. func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, error) { if token == "" { err := fmt.Errorf("missing token") return nil, err } if increment < 1 { err := fmt.Errorf("increment cannot be less than 1") return nil, err } // Create a buffered error channel errCh := make(chan error, 1) // Create a renewal request and indicate that the identifier in the // request is a token and not a lease renewalReq := &vaultClientRenewalRequest{ errCh: errCh, id: token, isToken: true, increment: increment, } // Perform the renewal of the token and send any error to the dedicated // error channel. if err := c.renew(renewalReq); err != nil { c.logger.Error("error during renewal of token", "error", err) metrics.IncrCounter([]string{"client", "vault", "renew_token_failure"}, 1) return nil, err } return errCh, nil } // renew is a common method to handle renewal of both tokens and secret leases. // It invokes a token renewal or a secret's lease renewal. If renewal is // successful, min-heap is updated based on the duration after which it needs // renewal again. The next renewal time is randomly selected to avoid spikes in // the number of APIs periodically. func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { c.lock.Lock() defer c.lock.Unlock() if req == nil { return fmt.Errorf("nil renewal request") } if req.errCh == nil { return fmt.Errorf("renewal request error channel nil") } if !c.config.IsEnabled() { close(req.errCh) return fmt.Errorf("vault client not enabled") } if !c.running { close(req.errCh) return fmt.Errorf("vault client is not running") } if req.id == "" { close(req.errCh) return fmt.Errorf("missing id in renewal request") } if req.increment < 1 { close(req.errCh) return fmt.Errorf("increment cannot be less than 1") } var renewalErr error leaseDuration := req.increment if req.isToken { // Set the token in the API client to the one that needs renewal c.client.SetToken(req.id) // Renew the token renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) if err != nil { renewalErr = fmt.Errorf("failed to renew the vault token: %v", err) } else if renewResp == nil || renewResp.Auth == nil { renewalErr = fmt.Errorf("failed to renew the vault token") } else { // Don't set this if renewal fails leaseDuration = renewResp.Auth.LeaseDuration req.increment = leaseDuration } // Reset the token in the API client before returning c.client.SetToken("") } else { // Renew the secret renewResp, err := c.client.Sys().Renew(req.id, req.increment) if err != nil { renewalErr = fmt.Errorf("failed to renew vault secret: %v", err) } else if renewResp == nil { renewalErr = fmt.Errorf("failed to renew vault secret") } else { // Don't set this if renewal fails leaseDuration = renewResp.LeaseDuration } } // Determine the next renewal time renewalDuration := renewalTime(rand.Intn, leaseDuration) next := time.Now().Add(renewalDuration) fatal := false if renewalErr != nil { // These errors aren't wrapped by the Vault SDK, so we have to read the // error messages. Unfortunately we can't easily enumerate non-fatal // errors so we have a large set here. These can be found at in // vault/expiration.go. // Current as of vault commit 52ba156d47da170bf40471fe57d72522030bdc7e errMsg := renewalErr.Error() if strings.Contains(errMsg, "no namespace") || strings.Contains(errMsg, "cannot renew a token across namespaces") || strings.Contains(errMsg, "invalid lease ID") || strings.Contains(errMsg, "lease expired") || strings.Contains(errMsg, "lease is not renewable") || strings.Contains(errMsg, "lease not found") || strings.Contains(errMsg, "permission denied") || strings.Contains(errMsg, "token not found") { fatal = true } else { c.logger.Debug("renewal error details", "req.increment", req.increment, "lease_duration", leaseDuration, "renewal_duration", renewalDuration) c.logger.Error("error during renewal of lease or token failed due to a non-fatal error; retrying", "error", renewalErr, "period", next) } } if c.isTracked(req.id) { if fatal { // If encountered with an error where in a lease or a // token is not valid at all with vault, and if that // item is tracked by the renewal loop, stop renewing // it by removing the corresponding heap entry. if err := c.heap.Remove(req.id); err != nil { return fmt.Errorf("failed to remove heap entry: %v", err) } // Report the fatal error to the client req.errCh <- renewalErr close(req.errCh) return renewalErr } // If the identifier is already tracked, this indicates a // subsequest renewal. In this case, update the existing // element in the heap with the new renewal time. if err := c.heap.Update(req, next); err != nil { return fmt.Errorf("failed to update heap entry. err: %v", err) } // There is no need to signal an update to the renewal loop // here because this case is hit from the renewal loop itself. } else { if fatal { // If encountered with an error where in a lease or a // token is not valid at all with vault, and if that // item is not tracked by renewal loop, don't add it. // Report the fatal error to the client req.errCh <- renewalErr close(req.errCh) return renewalErr } // If the identifier is not already tracked, this is a first // renewal request. In this case, add an entry into the heap // with the next renewal time. if err := c.heap.Push(req, next); err != nil { return fmt.Errorf("failed to push an entry to heap. err: %v", err) } // Signal an update for the renewal loop to trigger a fresh // computation for the next best candidate for renewal. if c.running { select { case c.updateCh <- struct{}{}: default: } } } return nil } // run is the renewal loop which performs the periodic renewals of both the // tokens and the secret leases. func (c *vaultClient) run() { if !c.config.IsEnabled() { return } var renewalCh <-chan time.Time for c.config.IsEnabled() && c.isRunning() { // Fetches the candidate for next renewal renewalReq, renewalTime := c.nextRenewal() if renewalTime.IsZero() { // If the heap is empty, don't do anything renewalCh = nil } else { now := time.Now() if renewalTime.After(now) { // Compute the duration after which the item // needs renewal and set the renewalCh to fire // at that time. renewalDuration := time.Until(renewalTime) renewalCh = time.After(renewalDuration) } else { // If the renewals of multiple items are too // close to each other and by the time the // entry is fetched from heap it might be past // the current time (by a small margin). In // which case, fire immediately. renewalCh = time.After(0) } } select { case <-renewalCh: if err := c.renew(renewalReq); err != nil { c.logger.Error("error renewing token", "error", err) metrics.IncrCounter([]string{"client", "vault", "renew_token_error"}, 1) } case <-c.updateCh: continue case <-c.stopCh: c.logger.Debug("stopped") return } } } // StopRenewToken removes the item from the heap which represents the given // token. func (c *vaultClient) StopRenewToken(token string) error { return c.stopRenew(token) } // stopRenew removes the given identifier from the heap and signals the renewal // loop to compute the next best candidate for renewal. func (c *vaultClient) stopRenew(id string) error { c.lock.Lock() defer c.lock.Unlock() if !c.isTracked(id) { return nil } if err := c.heap.Remove(id); err != nil { return fmt.Errorf("failed to remove heap entry: %v", err) } // Signal an update to the renewal loop. if c.running { select { case c.updateCh <- struct{}{}: default: } } return nil } // nextRenewal returns the root element of the min-heap, which represents the // next element to be renewed and the time at which the renewal needs to be // triggered. func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { c.lock.RLock() defer c.lock.RUnlock() if c.heap.Length() == 0 { return nil, time.Time{} } // Fetches the root element in the min-heap nextEntry := c.heap.Peek() if nextEntry == nil { return nil, time.Time{} } return nextEntry.req, nextEntry.next } // Additional helper functions on top of interface methods // Length returns the number of elements in the heap func (h *vaultClientHeap) Length() int { return len(h.heap) } // Returns the root node of the min-heap func (h *vaultClientHeap) Peek() *vaultClientHeapEntry { if len(h.heap) == 0 { return nil } return h.heap[0] } // Push adds the secondary index and inserts an item into the heap func (h *vaultClientHeap) Push(req *vaultClientRenewalRequest, next time.Time) error { if req == nil { return fmt.Errorf("nil request") } if _, ok := h.heapMap[req.id]; ok { return fmt.Errorf("entry %v already exists", req.id) } heapEntry := &vaultClientHeapEntry{ req: req, next: next, } h.heapMap[req.id] = heapEntry heap.Push(&h.heap, heapEntry) return nil } // Update will modify the existing item in the heap with the new data and the // time, and fixes the heap. func (h *vaultClientHeap) Update(req *vaultClientRenewalRequest, next time.Time) error { if entry, ok := h.heapMap[req.id]; ok { entry.req = req entry.next = next heap.Fix(&h.heap, entry.index) return nil } return fmt.Errorf("heap doesn't contain %v", req.id) } // Remove will remove an identifier from the secondary index and deletes the // corresponding node from the heap. func (h *vaultClientHeap) Remove(id string) error { if entry, ok := h.heapMap[id]; ok { heap.Remove(&h.heap, entry.index) delete(h.heapMap, id) return nil } return fmt.Errorf("heap doesn't contain entry for %v", id) } // The heap interface requires the following methods to be implemented. // * Push(x interface{}) // add x as element Len() // * Pop() interface{} // remove and return element Len() - 1. // * sort.Interface // // sort.Interface comprises of the following methods: // * Len() int // * Less(i, j int) bool // * Swap(i, j int) // Part of sort.Interface func (h vaultDataHeapImp) Len() int { return len(h) } // Part of sort.Interface func (h vaultDataHeapImp) Less(i, j int) bool { // Two zero times should return false. // Otherwise, zero is "greater" than any other time. // (To sort it at the end of the list.) // Sort such that zero times are at the end of the list. iZero, jZero := h[i].next.IsZero(), h[j].next.IsZero() if iZero && jZero { return false } else if iZero { return false } else if jZero { return true } return h[i].next.Before(h[j].next) } // Part of sort.Interface func (h vaultDataHeapImp) Swap(i, j int) { h[i], h[j] = h[j], h[i] h[i].index = i h[j].index = j } // Part of heap.Interface func (h *vaultDataHeapImp) Push(x interface{}) { n := len(*h) entry := x.(*vaultClientHeapEntry) entry.index = n *h = append(*h, entry) } // Part of heap.Interface func (h *vaultDataHeapImp) Pop() interface{} { old := *h n := len(old) entry := old[n-1] entry.index = -1 // for safety *h = old[0 : n-1] return entry } // randIntn is the function in math/rand needed by renewalTime. A type is used // to ease deterministic testing. type randIntn func(int) int // renewalTime returns when a token should be renewed given its leaseDuration // and a randomizer to provide jitter. // // Leases < 1m will be not jitter. func renewalTime(dice randIntn, leaseDuration int) time.Duration { // Start trying to renew at half the lease duration to allow ample time // for latency and retries. renew := leaseDuration / 2 // Don't bother about introducing randomness if the // leaseDuration is too small. const cutoff = 30 if renew < cutoff { return time.Duration(renew) * time.Second } // jitter is the amount +/- to vary the renewal time const jitter = 10 min := renew - jitter renew = min + dice(jitter*2) return time.Duration(renew) * time.Second }