diff --git a/client/client.go b/client/client.go index 8379fd4bf..45c51fdbe 100644 --- a/client/client.go +++ b/client/client.go @@ -23,9 +23,11 @@ import ( "github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/client/rpcproxy" "github.com/hashicorp/nomad/client/stats" + "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" + vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/hashstructure" ) @@ -147,6 +149,9 @@ type Client struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // client to interact with vault for token and secret renewals + vaultClient vaultclient.VaultClient } // NewClient is used to create a new client from the given configuration @@ -213,6 +218,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg return nil, fmt.Errorf("failed to create client Consul syncer: %v", err) } + // Setup the vault client for token and secret renewals + if err := c.setupVaultClient(); err != nil { + return nil, fmt.Errorf("failed to setup vault client: %v", err) + } + // Register and then start heartbeating to the servers. go c.registerAndHeartbeat() @@ -238,6 +248,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg // populated by periodically polling Consul, if available. go c.rpcProxy.Run() + // Start renewing tokens and secrets + if c.vaultClient != nil { + c.vaultClient.Start() + } + return c, nil } @@ -319,6 +334,11 @@ func (c *Client) Shutdown() error { return nil } + // Stop renewing tokens and secrets + if c.vaultClient != nil { + c.vaultClient.Stop() + } + // Destroy all the running allocations. if c.config.DevMode { c.allocLock.Lock() @@ -1275,6 +1295,116 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error { return nil } +// setupVaultClient creates an object to periodically renew tokens and secrets +// with vault. +func (c *Client) setupVaultClient() error { + if c.config.VaultConfig == nil { + return fmt.Errorf("nil vault config") + } + + if !c.config.VaultConfig.Enabled { + return nil + } + + var err error + if c.vaultClient, err = + vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { + return err + } + + if c.vaultClient == nil { + c.logger.Printf("[ERR] client: failed to create vault client") + return fmt.Errorf("failed to create vault client") + } + + return nil +} + +// deriveToken takes in an allocation and a set of tasks and derives vault +// tokens for each of the tasks, unwraps all of them using the supplied vault +// client and returns a map of unwrapped tokens, indexed by the task name. +func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { + if alloc == nil { + return nil, fmt.Errorf("nil allocation") + } + + if taskNames == nil || len(taskNames) == 0 { + return nil, fmt.Errorf("missing task names") + } + + group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + if group == nil { + return nil, fmt.Errorf("group name in allocation is not present in job") + } + + verifiedTasks := []string{} + found := false + // Check if the given task names actually exist in the allocation + for _, taskName := range taskNames { + found = false + for _, task := range group.Tasks { + if task.Name == taskName { + found = true + } + } + if !found { + c.logger.Printf("[ERR] task %q not found in the allocation", taskName) + return nil, fmt.Errorf("task %q not found in the allocaition", taskName) + } + verifiedTasks = append(verifiedTasks, taskName) + } + + // DeriveVaultToken of nomad server can take in a set of tasks and + // creates tokens for all the tasks. + req := &structs.DeriveVaultTokenRequest{ + NodeID: c.Node().ID, + SecretID: c.Node().SecretID, + AllocID: alloc.ID, + Tasks: verifiedTasks, + QueryOptions: structs.QueryOptions{ + Region: c.Region(), + AllowStale: true, + }, + } + + // Derive the tokens + var resp structs.DeriveVaultTokenResponse + if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil { + c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err) + return nil, fmt.Errorf("failed to derive vault tokens: %v", err) + } + if resp.Tasks == nil { + c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response") + return nil, fmt.Errorf("failed to derive vault tokens: invalid response") + } + + unwrappedTokens := make(map[string]string) + + // Retrieve the wrapped tokens from the response and unwrap it + for _, taskName := range verifiedTasks { + // Get the wrapped token + wrappedToken, ok := resp.Tasks[taskName] + if !ok { + c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName) + return nil, fmt.Errorf("wrapped token missing for task %q", taskName) + } + + // Unwrap the vault token + unwrapResp, err := vclient.Logical().Unwrap(wrappedToken) + if err != nil { + return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err) + } + if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { + return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName) + } + + // Append the unwrapped token to the return value + unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken + } + + return unwrappedTokens, nil +} + // setupConsulSyncer creates Client-mode consul.Syncer which periodically // executes callbacks on a fixed interval. // diff --git a/client/client_test.go b/client/client_test.go index 1bc984bc3..ba98dc054 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -85,6 +85,7 @@ func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) { func testClient(t *testing.T, cb func(c *config.Config)) *Client { conf := config.DefaultConfig() + conf.VaultConfig.Enabled = false conf.DevMode = true if cb != nil { cb(conf) diff --git a/client/config/config.go b/client/config/config.go index 6ed0670fa..65cb822ab 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -149,6 +149,7 @@ func (c *Config) Copy() *Config { // DefaultConfig returns the default configuration func DefaultConfig() *Config { return &Config{ + VaultConfig: config.DefaultVaultConfig(), ConsulConfig: config.DefaultConsulConfig(), LogOutput: os.Stderr, Region: "global", diff --git a/client/task_runner.go b/client/task_runner.go index 91a26bb03..208f2ab25 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -12,7 +12,6 @@ import ( "time" "github.com/armon/go-metrics" - "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go new file mode 100644 index 000000000..a4799dce4 --- /dev/null +++ b/client/vaultclient/vaultclient.go @@ -0,0 +1,763 @@ +package vaultclient + +import ( + "container/heap" + "fmt" + "log" + "math/rand" + "strings" + "sync" + "time" + + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" + vaultapi "github.com/hashicorp/vault/api" +) + +// TokenDeriverFunc takes in an allocation and a set of tasks and derives a +// wrapped token for all the tasks, from the nomad server. All the derived +// wrapped tokens will be unwrapped using the vault API client. +type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map[string]string, error) + +// The interface which nomad client uses to interact with vault and +// periodically renews the tokens and secrets. +type VaultClient interface { + // Start initiates the renewal loop of tokens and secrets + Start() + + // Stop terminates the renewal loop for tokens and secrets + Stop() + + // DeriveToken contacts the nomad server and fetches wrapped tokens for + // a set of tasks. The wrapped tokens will be unwrapped using vault and + // returned. + DeriveToken(*structs.Allocation, []string) (map[string]string, error) + + // GetConsulACL fetches the Consul ACL token required for the task + GetConsulACL(string, string) (*vaultapi.Secret, error) + + // RenewToken renews a token with the given increment and adds it to + // the min-heap for periodic renewal. + RenewToken(string, int) <-chan error + + // StopRenewToken removes the token from the min-heap, stopping its + // renewal. + StopRenewToken(string) error + + // RenewLease renews a vault secret's lease and adds the lease + // identifier to the min-heap for periodic renewal. + RenewLease(string, int) <-chan error + + // StopRenewLease removes a secret's lease ID from the min-heap, + // stopping its renewal. + StopRenewLease(string) error +} + +// Implementation of VaultClient interface to interact with vault and perform +// token and lease renewals periodically. +type vaultClient struct { + // tokenDeriver is a function pointer passed in by the client to derive + // tokens by making RPC calls to the nomad server. The wrapped tokens + // returned by the nomad server will be unwrapped by this function + // using the vault API client. + tokenDeriver TokenDeriverFunc + + // running indicates if the renewal loop is active or not + running bool + + // connEstablished marks whether the connection to vault was successful + // or not + connEstablished bool + + // tokenData is the data of the passed VaultClient token + token *tokenData + + // client is the API client to interact with vault + client *vaultapi.Client + + // updateCh is the channel to notify heap modifications to the renewal + // loop + updateCh chan struct{} + + // stopCh is the channel to trigger termination of renewal loop + stopCh chan struct{} + + // heap is the min-heap to keep track of both tokens and leases + heap *vaultClientHeap + + // config is the configuration to connect to vault + config *config.VaultConfig + + lock sync.RWMutex + logger *log.Logger +} + +// tokenData holds the relevant information about the Vault token passed to the +// client. +type tokenData struct { + CreationTTL int `mapstructure:"creation_ttl"` + TTL int `mapstructure:"ttl"` + Renewable bool `mapstructure:"renewable"` + Policies []string `mapstructure:"policies"` + Role string `mapstructure:"role"` + Root bool +} + +// vaultClientRenewalRequest is a request object for renewal of both tokens and +// secret's leases. +type vaultClientRenewalRequest struct { + // errCh is the channel into which any renewal error will be sent to + errCh chan error + + // id is an identifier which represents either a token or a lease + id string + + // increment is the duration for which the token or lease should be + // renewed for + increment int + + // isToken indicates whether the 'id' field is a token or not + isToken bool +} + +// Element representing an entry in the renewal heap +type vaultClientHeapEntry struct { + req *vaultClientRenewalRequest + next time.Time + index int +} + +// Wrapper around the actual heap to provide additional symantics on top of +// functions provided by the heap interface. In order to achieve that, an +// additional map is placed beside the actual heap. This map can be used to +// check if an entry is already present in the heap. +type vaultClientHeap struct { + heapMap map[string]*vaultClientHeapEntry + heap vaultDataHeapImp +} + +// Data type of the heap +type vaultDataHeapImp []*vaultClientHeapEntry + +// NewVaultClient returns a new vault client from the given config. +func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver TokenDeriverFunc) (*vaultClient, error) { + if config == nil { + return nil, fmt.Errorf("nil vault config") + } + + if !config.Enabled { + return nil, nil + } + + if config.TaskTokenTTL == "" { + return nil, fmt.Errorf("task_token_ttl not set") + } + + if logger == nil { + return nil, fmt.Errorf("nil logger") + } + + c := &vaultClient{ + config: config, + stopCh: make(chan struct{}), + // Update channel should be a buffered channel + updateCh: make(chan struct{}, 1), + heap: newVaultClientHeap(), + logger: logger, + } + + // Get the Vault API configuration + apiConf, err := config.ApiConfig() + if err != nil { + logger.Printf("[ERR] client.vault: failed to create vault API config: %v", err) + return nil, err + } + + // Create the Vault API client + client, err := vaultapi.NewClient(apiConf) + if err != nil { + logger.Printf("[ERR] client.vault: failed to create Vault client. Not retrying: %v", err) + return nil, err + } + + c.client = client + + return c, nil +} + +// newVaultClientHeap returns a new vault client heap with both the heap and a +// map which is a secondary index for heap elements, both initialized. +func newVaultClientHeap() *vaultClientHeap { + return &vaultClientHeap{ + heapMap: make(map[string]*vaultClientHeapEntry), + heap: make(vaultDataHeapImp, 0), + } +} + +// isTracked returns if a given identifier is already present in the heap and +// hence is being renewed. Lock should be held before calling this method. +func (c *vaultClient) isTracked(id string) bool { + if id == "" { + return false + } + + _, ok := c.heap.heapMap[id] + return ok +} + +// Starts the renewal loop of vault client +func (c *vaultClient) Start() { + if !c.config.Enabled || c.running { + return + } + + c.logger.Printf("[DEBUG] client.vault: establishing connection to vault") + go c.establishConnection() +} + +// ConnectionEstablished indicates whether VaultClient successfully established +// connection to vault or not +func (c *vaultClient) ConnectionEstablished() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.connEstablished +} + +// establishConnection is used to make first contact with Vault. This should be +// called in a go-routine since the connection is retried till the Vault Client +// is stopped or the connection is successfully made at which point the renew +// loop is started. +func (c *vaultClient) establishConnection() { + // Create the retry timer and set initial duration to zero so it fires + // immediately + retryTimer := time.NewTimer(0) + +OUTER: + for { + select { + case <-c.stopCh: + return + case <-retryTimer.C: + // Ensure the API is reachable + if _, err := c.client.Sys().InitStatus(); err != nil { + c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v", + c.config.ConnectionRetryIntv) + retryTimer.Reset(c.config.ConnectionRetryIntv) + continue OUTER + } + + break OUTER + } + } + + c.lock.Lock() + c.connEstablished = true + c.lock.Unlock() + + // Begin the renewal loop + go c.run() + c.logger.Printf("[DEBUG] client.vault: started") +} + +// Stops the renewal loop of vault client +func (c *vaultClient) Stop() { + if !c.config.Enabled || !c.running { + return + } + + c.lock.Lock() + defer c.lock.Unlock() + + c.running = false + close(c.stopCh) +} + +// DeriveToken takes in an allocation and a set of tasks and for each of the +// task, it derives a vault token from nomad server and unwraps it using vault. +// The return value is a map containing all the unwrapped tokens indexed by the +// task name. +func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + if !c.running { + return nil, fmt.Errorf("vault client is not running") + } + + return c.tokenDeriver(alloc, taskNames, c.client) +} + +// GetConsulACL creates a vault API client and reads from vault a consul ACL +// token used by the task. +func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) { + if token == "" { + return nil, fmt.Errorf("missing token") + } + if path == "" { + return nil, fmt.Errorf("missing consul ACL token vault path") + } + + if !c.ConnectionEstablished() { + return nil, fmt.Errorf("connection with vault is not yet established") + } + + c.lock.Lock() + defer c.lock.Unlock() + + // Use the token supplied to interact with vault + c.client.SetToken(token) + + // Reset the token before returning + defer c.client.SetToken("") + + // Read the consul ACL token and return the secret directly + return c.client.Logical().Read(path) +} + +// RenewToken renews the supplied token for a given duration (in seconds) and +// adds it to the min-heap so that it is renewed periodically by the renewal +// loop. Any error returned during renewal will be written to a buffered +// channel and the channel is returned instead of an actual error. This helps +// the caller be notified of a renewal failure asynchronously for appropriate +// actions to be taken. The caller of this function need not have to close the +// error channel. +func (c *vaultClient) RenewToken(token string, increment int) <-chan error { + // Create a buffered error channel + errCh := make(chan error, 1) + + if token == "" { + errCh <- fmt.Errorf("missing token") + close(errCh) + return errCh + } + if increment < 1 { + errCh <- fmt.Errorf("increment cannot be less than 1") + close(errCh) + return errCh + } + + // Create a renewal request and indicate that the identifier in the + // request is a token and not a lease + renewalReq := &vaultClientRenewalRequest{ + errCh: errCh, + id: token, + isToken: true, + increment: increment, + } + + // Perform the renewal of the token and send any error to the dedicated + // error channel. + if err := c.renew(renewalReq); err != nil { + c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) + } + + return errCh +} + +// RenewLease renews the supplied lease identifier for a supplied duration (in +// seconds) and adds it to the min-heap so that it gets renewed periodically by +// the renewal loop. Any error returned during renewal will be written to a +// buffered channel and the channel is returned instead of an actual error. +// This helps the caller be notified of a renewal failure asynchronously for +// appropriate actions to be taken. The caller of this function need not have +// to close the error channel. +func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { + c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId) + // Create a buffered error channel + errCh := make(chan error, 1) + + if leaseId == "" { + errCh <- fmt.Errorf("missing lease ID") + close(errCh) + return errCh + } + + if increment < 1 { + errCh <- fmt.Errorf("increment cannot be less than 1") + close(errCh) + return errCh + } + + // Create a renewal request using the supplied lease and duration + renewalReq := &vaultClientRenewalRequest{ + errCh: errCh, + id: leaseId, + increment: increment, + } + + // Renew the secret and send any error to the dedicated error channel + if err := c.renew(renewalReq); err != nil { + c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err) + } + + return errCh +} + +// renew is a common method to handle renewal of both tokens and secret leases. +// It invokes a token renewal or a secret's lease renewal. If renewal is +// successful, min-heap is updated based on the duration after which it needs +// renewal again. The next renewal time is randomly selected to avoid spikes in +// the number of APIs periodically. +func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.running { + return fmt.Errorf("vault client is not running") + } + + if req == nil { + return fmt.Errorf("nil renewal request") + } + if req.id == "" { + return fmt.Errorf("missing id in renewal request") + } + if req.increment < 1 { + return fmt.Errorf("increment cannot be less than 1") + } + + var renewalErr error + leaseDuration := req.increment + if req.isToken { + // Reset the token in the API client before returning + defer c.client.SetToken("") + + // Set the token in the API client to the one that needs + // renewal + c.client.SetToken(req.id) + + // Renew the token + renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) + if err != nil { + renewalErr = fmt.Errorf("failed to renew the vault token: %v", err) + } + if renewResp == nil || renewResp.Auth == nil { + renewalErr = fmt.Errorf("failed to renew the vault token") + } else { + // Don't set this if renewal fails + leaseDuration = renewResp.Auth.LeaseDuration + } + } else { + // Renew the secret + renewResp, err := c.client.Sys().Renew(req.id, req.increment) + if err != nil { + renewalErr = fmt.Errorf("failed to renew vault secret: %v", err) + } + if renewResp == nil { + renewalErr = fmt.Errorf("failed to renew vault secret") + } else { + // Don't set this if renewal fails + leaseDuration = renewResp.LeaseDuration + } + } + + duration := leaseDuration / 2 + switch { + case leaseDuration < 30: + // Don't bother about introducing randomness if the + // leaseDuration is too small. + default: + // Give a breathing space of 20 seconds + min := 10 + max := leaseDuration - min + rand.Seed(time.Now().Unix()) + duration = min + rand.Intn(max-min) + } + + // Determine the next renewal time + next := time.Now().Add(time.Duration(duration) * time.Second) + + fatal := false + if renewalErr != nil && + (strings.Contains(renewalErr.Error(), "lease not found or lease is not renewable") || + strings.Contains(renewalErr.Error(), "token not found")) { + fatal = true + } else if renewalErr != nil { + c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", req.increment, leaseDuration, duration) + c.logger.Printf("[ERR] client.vault: renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String()) + } + + if c.isTracked(req.id) { + if fatal { + // If encountered with an error where in a lease or a + // token is not valid at all with vault, and if that + // item is tracked by the renewal loop, stop renewing + // it by removing the corresponding heap entry. + if err := c.heap.Remove(req.id); err != nil { + return fmt.Errorf("failed to remove heap entry. err: %v", err) + } + delete(c.heap.heapMap, req.id) + + // Report the fatal error to the client + req.errCh <- renewalErr + close(req.errCh) + + return renewalErr + } + + // If the identifier is already tracked, this indicates a + // subsequest renewal. In this case, update the existing + // element in the heap with the new renewal time. + if err := c.heap.Update(req, next); err != nil { + return fmt.Errorf("failed to update heap entry. err: %v", err) + } + + // There is no need to signal an update to the renewal loop + // here because this case is hit from the renewal loop itself. + } else { + if fatal { + // If encountered with an error where in a lease or a + // token is not valid at all with vault, and if that + // item is not tracked by renewal loop, don't add it. + + // Report the fatal error to the client + req.errCh <- renewalErr + close(req.errCh) + + return renewalErr + } + + // If the identifier is not already tracked, this is a first + // renewal request. In this case, add an entry into the heap + // with the next renewal time. + if err := c.heap.Push(req, next); err != nil { + return fmt.Errorf("failed to push an entry to heap. err: %v", err) + } + + // Signal an update for the renewal loop to trigger a fresh + // computation for the next best candidate for renewal. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + } + + return nil +} + +// run is the renewal loop which performs the periodic renewals of both the +// tokens and the secret leases. +func (c *vaultClient) run() { + if !c.config.Enabled { + return + } + + c.lock.Lock() + c.running = true + c.lock.Unlock() + + var renewalCh <-chan time.Time + for c.config.Enabled && c.running { + // Fetches the candidate for next renewal + renewalReq, renewalTime := c.nextRenewal() + if renewalTime.IsZero() { + // If the heap is empty, don't do anything + renewalCh = nil + } else { + now := time.Now() + if renewalTime.After(now) { + // Compute the duration after which the item + // needs renewal and set the renewalCh to fire + // at that time. + renewalDuration := renewalTime.Sub(time.Now()) + renewalCh = time.After(renewalDuration) + } else { + // If the renewals of multiple items are too + // close to each other and by the time the + // entry is fetched from heap it might be past + // the current time (by a small margin). In + // which case, fire immediately. + renewalCh = time.After(0) + } + } + + select { + case <-renewalCh: + if err := c.renew(renewalReq); err != nil { + c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) + } + case <-c.updateCh: + continue + case <-c.stopCh: + c.logger.Printf("[DEBUG] client.vault: stopped") + return + } + } +} + +// StopRenewToken removes the item from the heap which represents the given +// token. +func (c *vaultClient) StopRenewToken(token string) error { + return c.stopRenew(token) +} + +// StopRenewLease removes the item from the heap which represents the given +// lease identifier. +func (c *vaultClient) StopRenewLease(leaseId string) error { + return c.stopRenew(leaseId) +} + +// stopRenew removes the given identifier from the heap and signals the renewal +// loop to compute the next best candidate for renewal. +func (c *vaultClient) stopRenew(id string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.isTracked(id) { + return nil + } + + // Remove the identifier from the heap + if err := c.heap.Remove(id); err != nil { + return fmt.Errorf("failed to remove heap entry: %v", err) + } + + // Delete the identifier from the map only after the it is removed from + // the heap. Heap's remove method relies on the heap map. + delete(c.heap.heapMap, id) + + // Signal an update to the renewal loop. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + + return nil +} + +// nextRenewal returns the root element of the min-heap, which represents the +// next element to be renewed and the time at which the renewal needs to be +// triggered. +func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.heap.Length() == 0 { + return nil, time.Time{} + } + + // Fetches the root element in the min-heap + nextEntry := c.heap.Peek() + if nextEntry == nil { + return nil, time.Time{} + } + + return nextEntry.req, nextEntry.next +} + +// Additional helper functions on top of interface methods + +// Length returns the number of elements in the heap +func (h *vaultClientHeap) Length() int { + return len(h.heap) +} + +// Returns the root node of the min-heap +func (h *vaultClientHeap) Peek() *vaultClientHeapEntry { + if len(h.heap) == 0 { + return nil + } + + return h.heap[0] +} + +// Push adds the secondary index and inserts an item into the heap +func (h *vaultClientHeap) Push(req *vaultClientRenewalRequest, next time.Time) error { + if req == nil { + return fmt.Errorf("nil request") + } + + if _, ok := h.heapMap[req.id]; ok { + return fmt.Errorf("entry %v already exists", req.id) + } + + heapEntry := &vaultClientHeapEntry{ + req: req, + next: next, + } + h.heapMap[req.id] = heapEntry + heap.Push(&h.heap, heapEntry) + return nil +} + +// Update will modify the existing item in the heap with the new data and the +// time, and fixes the heap. +func (h *vaultClientHeap) Update(req *vaultClientRenewalRequest, next time.Time) error { + if entry, ok := h.heapMap[req.id]; ok { + entry.req = req + entry.next = next + heap.Fix(&h.heap, entry.index) + return nil + } + + return fmt.Errorf("heap doesn't contain %v", req.id) +} + +// Remove will remove an identifier from the secondary index and deletes the +// corresponding node from the heap. +func (h *vaultClientHeap) Remove(id string) error { + if entry, ok := h.heapMap[id]; ok { + heap.Remove(&h.heap, entry.index) + delete(h.heapMap, id) + return nil + } + + return fmt.Errorf("heap doesn't contain entry for %v", id) +} + +// The heap interface requires the following methods to be implemented. +// * Push(x interface{}) // add x as element Len() +// * Pop() interface{} // remove and return element Len() - 1. +// * sort.Interface +// +// sort.Interface comprises of the following methods: +// * Len() int +// * Less(i, j int) bool +// * Swap(i, j int) + +// Part of sort.Interface +func (h vaultDataHeapImp) Len() int { return len(h) } + +// Part of sort.Interface +func (h vaultDataHeapImp) Less(i, j int) bool { + // Two zero times should return false. + // Otherwise, zero is "greater" than any other time. + // (To sort it at the end of the list.) + // Sort such that zero times are at the end of the list. + iZero, jZero := h[i].next.IsZero(), h[j].next.IsZero() + if iZero && jZero { + return false + } else if iZero { + return false + } else if jZero { + return true + } + + return h[i].next.Before(h[j].next) +} + +// Part of sort.Interface +func (h vaultDataHeapImp) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Push(x interface{}) { + n := len(*h) + entry := x.(*vaultClientHeapEntry) + entry.index = n + *h = append(*h, entry) +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Pop() interface{} { + old := *h + n := len(old) + entry := old[n-1] + entry.index = -1 // for safety + *h = old[0 : n-1] + return entry +} diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go new file mode 100644 index 000000000..3ff1b128b --- /dev/null +++ b/client/vaultclient/vaultclient_test.go @@ -0,0 +1,223 @@ +package vaultclient + +import ( + "log" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/testutil" + vaultapi "github.com/hashicorp/vault/api" +) + +func TestVaultClient_EstablishConnection(t *testing.T) { + v := testutil.NewTestVault(t) + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + c, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while and check that no connection has been established. + time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) + + if c.ConnectionEstablished() { + t.Fatalf("ConnectionEstablished() returned true before Vault server started") + } + + // Start Vault + v.Start() + defer v.Stop() + + testutil.WaitForResult(func() (bool, error) { + return c.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) +} + +func TestVaultClient_TokenRenewals(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + c, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while to ensure that the renewal loop is active + time.Sleep(3 * time.Second) + + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "2s", + DisplayName: "derived-for-task", + Renewable: new(bool), + } + *tcr.Renewable = true + + num := 5 + tokens := make([]string, num) + for i := 0; i < num; i++ { + c.client.SetToken(v.Config.Token) + + if err := c.client.SetAddress(v.Config.Addr); err != nil { + t.Fatal(err) + } + + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + t.Fatalf("failed to create vault token: %v", err) + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatal("failed to derive a wrapped vault token") + } + + tokens[i] = secret.Auth.ClientToken + + errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + go func(errCh <-chan error) { + var err error + for { + select { + case err = <-errCh: + t.Fatalf("error while renewing the token: %v", err) + } + } + }(errCh) + } + + if c.heap.Length() != num { + t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) + } + + time.Sleep(5 * time.Second) + + for i := 0; i < num; i++ { + if err := c.StopRenewToken(tokens[i]); err != nil { + t.Fatal(err) + } + } + + if c.heap.Length() != 0 { + t.Fatal("bad: heap length: expected: 0, actual: %d", c.heap.Length()) + } +} + +func TestVaultClient_Heap(t *testing.T) { + conf := config.DefaultConfig() + conf.VaultConfig.Enabled = true + conf.VaultConfig.Token = "testvaulttoken" + conf.VaultConfig.TaskTokenTTL = "10s" + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + c, err := NewVaultClient(conf.VaultConfig, logger, nil) + if err != nil { + t.Fatal(err) + } + if c == nil { + t.Fatal("failed to create vault client") + } + + now := time.Now() + + renewalReq1 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id1", + increment: 10, + } + if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { + t.Fatal(err) + } + if !c.isTracked("id1") { + t.Fatalf("id1 should have been tracked") + } + + renewalReq2 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id2", + increment: 10, + } + if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { + t.Fatal(err) + } + if !c.isTracked("id2") { + t.Fatalf("id2 should have been tracked") + } + + renewalReq3 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id3", + increment: 10, + } + if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { + t.Fatal(err) + } + if !c.isTracked("id3") { + t.Fatalf("id3 should have been tracked") + } + + // Reading elements should yield id2, id1 and id3 in order + req, _ := c.nextRenewal() + if req != renewalReq2 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req) + } + if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq1 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req) + } + if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq3 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req) + } + if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id1"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id2"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id3"); err != nil { + t.Fatal(err) + } + + if c.isTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.isTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.isTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + +}