From 6b83d070706203966fbc5cebec9c2805f052f078 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Wed, 17 Aug 2016 23:28:48 -0400 Subject: [PATCH 1/9] VaultClient for Nomad Client --- client/client.go | 35 ++ client/client_test.go | 1 + client/config/config.go | 1 + client/vaultclient/vaultclient.go | 789 +++++++++++++++++++++++++ client/vaultclient/vaultclient_test.go | 221 +++++++ 5 files changed, 1047 insertions(+) create mode 100644 client/vaultclient/vaultclient.go create mode 100644 client/vaultclient/vaultclient_test.go diff --git a/client/client.go b/client/client.go index 8379fd4bf..04897d0eb 100644 --- a/client/client.go +++ b/client/client.go @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/nomad/client/fingerprint" "github.com/hashicorp/nomad/client/rpcproxy" "github.com/hashicorp/nomad/client/stats" + "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" @@ -147,6 +148,9 @@ type Client struct { shutdown bool shutdownCh chan struct{} shutdownLock sync.Mutex + + // client to interact with vault for token and secret renewals + vaultClient vaultclient.VaultClient } // NewClient is used to create a new client from the given configuration @@ -213,6 +217,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg return nil, fmt.Errorf("failed to create client Consul syncer: %v", err) } + // Setup the vault client for token and secret renewals + if err := c.setupVaultClient(); err != nil { + return nil, fmt.Errorf("failed to setup vault client: %v", err) + } + // Register and then start heartbeating to the servers. go c.registerAndHeartbeat() @@ -238,6 +247,9 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg // populated by periodically polling Consul, if available. go c.rpcProxy.Run() + // Start renewing tokens and secrets + go c.vaultClient.Start() + return c, nil } @@ -319,6 +331,11 @@ func (c *Client) Shutdown() error { return nil } + // Stop renewing tokens and secrets + if c.vaultClient != nil { + c.vaultClient.Stop() + } + // Destroy all the running allocations. if c.config.DevMode { c.allocLock.Lock() @@ -1275,6 +1292,24 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error { return nil } +// setupVaultClient creates an object to periodically renew tokens and secrets +// with vault. +func (c *Client) setupVaultClient() error { + if c.config.VaultConfig == nil { + return fmt.Errorf("nil vault config") + } + if c.config.VaultConfig.Token == "" { + return fmt.Errorf("vault token not set") + } + + var err error + if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger); err != nil { + return err + } + + return nil +} + // setupConsulSyncer creates Client-mode consul.Syncer which periodically // executes callbacks on a fixed interval. // diff --git a/client/client_test.go b/client/client_test.go index 1bc984bc3..ba98dc054 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -85,6 +85,7 @@ func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) { func testClient(t *testing.T, cb func(c *config.Config)) *Client { conf := config.DefaultConfig() + conf.VaultConfig.Enabled = false conf.DevMode = true if cb != nil { cb(conf) diff --git a/client/config/config.go b/client/config/config.go index 6ed0670fa..65cb822ab 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -149,6 +149,7 @@ func (c *Config) Copy() *Config { // DefaultConfig returns the default configuration func DefaultConfig() *Config { return &Config{ + VaultConfig: config.DefaultVaultConfig(), ConsulConfig: config.DefaultConsulConfig(), LogOutput: os.Stderr, Region: "global", diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go new file mode 100644 index 000000000..32600cead --- /dev/null +++ b/client/vaultclient/vaultclient.go @@ -0,0 +1,789 @@ +package vaultclient + +import ( + "container/heap" + "fmt" + "log" + "sync" + "time" + + "github.com/hashicorp/nomad/nomad/structs/config" + vaultapi "github.com/hashicorp/vault/api" + "github.com/mitchellh/mapstructure" +) + +// The interface which nomad client uses to interact with vault and +// periodically renews the tokens and secrets. +type VaultClient interface { + // Starts the renewal loop of tokens and secrets + Start() + + // Stops the renewal loop for tokens and secrets + Stop() + + // Contacts the nomad server and fetches a wrapped token. The wrapped + // token will be unwrapped by contacting vault and returned. + DeriveToken() (string, error) + + // Fetch the Consul ACL token required for the task + GetConsulACL(string, string) (*vaultapi.Secret, error) + + // Renews a token with the given increment and adds it to the min-heap + // for periodic renewal. + RenewToken(string, int) <-chan error + + // Removes the token from the min-heap, stopping its renewal. + StopRenewToken(string) error + + // Renews a vault secret's lease and add the lease identifier to the + // min-heap for periodic renewal. + RenewLease(string, int) <-chan error + + // Removes a secret's lease id from the min-heap, stopping its renewal. + StopRenewLease(string) error +} + +// Implementation of VaultClient interface to interact with vault and perform +// token and lease renewals periodically. +type vaultClient struct { + // running indicates if the renewal loop is active or not + running bool + + // connEstablished marks whether the connection to vault was successful + // or not + connEstablished bool + + // tokenData is the data of the passed VaultClient token + token *tokenData + + // API client to interact with vault + client *vaultapi.Client + + // Channel to notify heap modifications to the renewal loop + updateCh chan struct{} + + // Channel to trigger termination of renewal loop + stopCh chan struct{} + + // Min-Heap to keep track of both tokens and leases + heap *vaultClientHeap + + // Configuration to connect to vault + config *config.VaultConfig + + lock sync.RWMutex + logger *log.Logger +} + +// tokenData holds the relevant information about the Vault token passed to the +// client. +type tokenData struct { + CreationTTL int `mapstructure:"creation_ttl"` + TTL int `mapstructure:"ttl"` + Renewable bool `mapstructure:"renewable"` + Policies []string `mapstructure:"policies"` + Role string `mapstructure:"role"` + Root bool +} + +// Request object for renewals. This can be used for both token renewals and +// secret's lease renewals. +type vaultClientRenewalRequest struct { + // Channel into which any renewal error will be sent down to + errCh chan error + + // This can either be a token or a lease identifier + id string + + // Duration for which the token or lease should be renewed for + increment int + + // Indicates whether the 'id' field is a token or not + isToken bool +} + +// Element representing an entry in the renewal heap +type vaultClientHeapEntry struct { + req *vaultClientRenewalRequest + next time.Time + index int +} + +// Wrapper around the actual heap to provide additional symantics on top of +// functions provided by the heap interface. In order to achieve that, an +// additional map is placed beside the actual heap. This map can be used to +// check if an entry is already present in the heap. +type vaultClientHeap struct { + heapMap map[string]*vaultClientHeapEntry + heap vaultDataHeapImp +} + +// Data type of the heap +type vaultDataHeapImp []*vaultClientHeapEntry + +// NewVaultClient returns a new vault client from the given config. +func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClient, error) { + if config == nil { + return nil, fmt.Errorf("nil vault config") + } + + // Creation of a vault client requires that the token is supplied via + // config. + if config.Token == "" { + return nil, fmt.Errorf("vault token not set") + } + + if config.TaskTokenTTL == "" { + return nil, fmt.Errorf("task_token_ttl not set") + } + + if logger == nil { + return nil, fmt.Errorf("nil logger") + } + + c := &vaultClient{ + config: config, + stopCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + heap: NewVaultClientHeap(), + logger: logger, + } + + if !c.config.Enabled { + return nil, nil + } + + // Get the Vault API configuration + apiConf, err := config.ApiConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault API config: %v", err) + } + + // Create the Vault API client + client, err := vaultapi.NewClient(apiConf) + if err != nil { + logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err) + return nil, err + } + + // Set the token and store the client + client.SetToken(c.config.Token) + + c.client = client + + return c, nil +} + +// NewVaultClientHeap returns a new vault client heap with both the heap and a +// map which is a secondary index for heap elements, both initialized. +func NewVaultClientHeap() *vaultClientHeap { + return &vaultClientHeap{ + heapMap: make(map[string]*vaultClientHeapEntry), + heap: make(vaultDataHeapImp, 0), + } +} + +// IsTracked returns if a given identifier is already present in the heap and +// hence is being renewed. +func (c *vaultClient) IsTracked(id string) bool { + if id == "" { + return false + } + + _, ok := c.heap.heapMap[id] + return ok +} + +// Starts the renewal loop of vault client +func (c *vaultClient) Start() { + if !c.config.Enabled || c.running { + return + } + + c.logger.Printf("[INFO] vaultclient: establishing connection to vault") + go c.establishConnection() +} + +// ConnectionEstablished indicates whether VaultClient successfully established +// connection to vault or not +func (c *vaultClient) ConnectionEstablished() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.connEstablished +} + +// establishConnection is used to make first contact with Vault. This should be +// called in a go-routine since the connection is retried til the Vault Client +// is stopped or the connection is successfully made at which point the renew +// loop is started. +func (c *vaultClient) establishConnection() { + // Create the retry timer and set initial duration to zero so it fires + // immediately + retryTimer := time.NewTimer(0) + +OUTER: + for { + select { + case <-c.stopCh: + return + case <-retryTimer.C: + // Ensure the API is reachable + if _, err := c.client.Sys().InitStatus(); err != nil { + c.logger.Printf("[WARN] vaultclient: failed to contact Vault API. Retrying in %v", + c.config.ConnectionRetryIntv) + retryTimer.Reset(c.config.ConnectionRetryIntv) + continue OUTER + } + + break OUTER + } + } + + c.lock.Lock() + c.connEstablished = true + c.lock.Unlock() + + // Retrieve our token, validate it and parse the lease duration + if err := c.parseSelfToken(); err != nil { + c.logger.Printf("[ERR] vaultclient: failed to lookup self token and not retrying: %v", err) + return + } + + // Begin the renewal loop + go c.run() + c.logger.Printf("[INFO] vaultclient: started") + + // If we are given a non-root token, start renewing it + if c.token.Renewable { + c.logger.Printf("[INFO] vaultclient: not renewing token as it is not renewable") + } else { + c.logger.Printf("[INFO] vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second) + + // Add the VaultClient's token to the renewal loop + errCh := c.RenewToken(c.config.Token, c.token.CreationTTL) + // Catch the renewal error of VaultClient's token. + go func(errCh <-chan error) { + var err error + for { + select { + case err = <-errCh: + c.logger.Printf("[ERR] vaultclient: error while renewing the vault client's token: %v", err) + } + } + }(errCh) + } +} + +// parseSelfToken looks up the Vault token in Vault and parses its data storing +// it in the client. If the token is not valid for Nomads purposes an error is +// returned. +func (c *vaultClient) parseSelfToken() error { + // Get the initial lease duration + auth := c.client.Auth().Token() + self, err := auth.LookupSelf() + if err != nil { + return fmt.Errorf("failed to lookup VaultClient's token: %v", err) + } + + // Read and parse the fields + var data tokenData + if err := mapstructure.WeakDecode(self.Data, &data); err != nil { + return fmt.Errorf("failed to parse Vault token's data block: %v", err) + } + + root := false + for _, p := range data.Policies { + if p == "root" { + root = true + break + } + } + + if !data.Renewable && !root { + return fmt.Errorf("vault token is not renewable or root") + } + + if data.CreationTTL == 0 && !root { + return fmt.Errorf("invalid lease duration of zero") + } + + if data.TTL == 0 && !root { + return fmt.Errorf("token TTL is zero") + } + + if !root && data.Role == "" { + return fmt.Errorf("token role name must be set when not using a root token") + } + + data.Root = root + c.token = &data + return nil +} + +// Stops the renewal loop of vault client +func (c *vaultClient) Stop() { + if !c.config.Enabled || !c.running { + return + } + + c.lock.Lock() + defer c.lock.Unlock() + + c.running = false + close(c.stopCh) +} + +// DeriveToken contacts the nomad server and fetches a wrapped token. Then it +// contacts vault to unwrap the token and returns the unwrapped token. +func (c *vaultClient) DeriveToken() (string, error) { + // TODO: Replace this code with an actual call to the nomad server. + // This is a sample code which directly fetches a wrapped token from + // vault and unwraps it for time being. + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "10s", + DisplayName: "derived-token", + Renewable: new(bool), + } + *tcr.Renewable = true + + // Set the TTL for the wrapped token + wrapLookupFunc := func(method, path string) string { + if method == "POST" && path == "auth/token/create" { + return "60s" + } + return "" + } + c.client.SetWrappingLookupFunc(wrapLookupFunc) + + // Create a wrapped token + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + return "", fmt.Errorf("failed to create vault token: %v", err) + } + if secret == nil || secret.WrapInfo == nil || secret.WrapInfo.Token == "" || + secret.WrapInfo.WrappedAccessor == "" { + return "", fmt.Errorf("failed to derive a wrapped vault token") + } + + wrappedToken := secret.WrapInfo.Token + + // Unwrap the vault token + unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) + if err != nil { + return "", fmt.Errorf("failed to unwrap the token: %v", err) + } + if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { + return "", fmt.Errorf("failed to unwrap the token") + } + + // Return the unwrapped token + return unwrapResp.Auth.ClientToken, nil +} + +// GetConsulACL creates a vault API client and reads from vault a consul ACL +// token used by the task. +func (c *vaultClient) GetConsulACL(token, vaultPath string) (*vaultapi.Secret, error) { + if token == "" { + return nil, fmt.Errorf("missing token") + } + if vaultPath == "" { + return nil, fmt.Errorf("missing vault path") + } + + // Use the token supplied to interact with vault + c.client.SetToken(token) + + // Read the consul ACL token and return the secret directly + return c.client.Logical().Read(vaultPath) +} + +// RenewToken renews the supplied token and adds it to the min-heap so that it +// is renewed periodically by the renewal loop. Any error returned during +// renewal will be written to a buffered channel and the channel is returned +// instead of an actual error. This helps the caller be notified of a renewal +// failure asynchronously for appropriate actions to be taken. +func (c *vaultClient) RenewToken(token string, increment int) <-chan error { + // Create a buffered error channel + errCh := make(chan error, 1) + + if token == "" { + errCh <- fmt.Errorf("missing token") + return errCh + } + + // Create a renewal request and indicate that the identifier in the + // request is a token and not a lease + renewalReq := &vaultClientRenewalRequest{ + errCh: errCh, + id: token, + isToken: true, + increment: increment, + } + + // Perform the renewal of the token and send any error to the dedicated + // error channel. + if err := c.renew(renewalReq); err != nil { + errCh <- err + } + + return errCh +} + +// RenewLease renews the supplied lease identifier for a supplied duration and +// adds it to the min-heap so that it gets renewed periodically by the renewal +// loop. Any error returned during renewal will be written to a buffered +// channel and the channel is returned instead of an actual error. This helps +// the caller be notified of a renewal failure asynchronously for appropriate +// actions to be taken. +func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error { + c.logger.Printf("[INFO] vaultclient: renewing lease %q", leaseId) + // Create a buffered error channel + errCh := make(chan error, 1) + + if leaseId == "" { + errCh <- fmt.Errorf("missing lease ID") + return errCh + } + + if leaseDuration == 0 { + errCh <- fmt.Errorf("missing lease duration") + return errCh + } + + // Create a renewal request using the supplied lease and duration + renewalReq := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: leaseId, + increment: leaseDuration, + } + + // Renew the secret and send any error to the dedicated error channel + if err := c.renew(renewalReq); err != nil { + errCh <- err + } + + return errCh +} + +// renew is a common method to handle renewal of both tokens and secret leases. +// It creates a vault API client and invokes either a token renewal request or +// a secret renewal request. If renewal is successful, min-heap is updated +// based on the duration after which it needs its renewal again. The duration +// is set to half the lease duration present in the renewal response. +func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { + c.logger.Printf("[INFO] vaultclient: ~~~~~~~Renewing %s~~~~~~~~", req.id) + c.lock.Lock() + defer c.lock.Unlock() + + if !c.running { + return fmt.Errorf("vault client is not running") + } + + if req == nil { + return fmt.Errorf("nil renewal request") + } + if req.id == "" { + return fmt.Errorf("missing id in renewal request") + } + if req.increment == 0 { + return fmt.Errorf("missing increment in renewal request") + } + + var duration time.Duration + if req.isToken { + // Reset the token in the API client to that of VaultClient + // before returning + defer c.client.SetToken(c.config.Token) + + // Set the token in the API client to the one that needs + // renewal + c.client.SetToken(req.id) + + // Renew the token + renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) + if err != nil { + return fmt.Errorf("failed to renew the vault token: %v", err) + } + if renewResp == nil || renewResp.Auth == nil { + return fmt.Errorf("failed to renew the vault token") + } + + // Set the next renewal time to half the lease duration + duration = time.Duration(renewResp.Auth.LeaseDuration) * time.Second / 2 + } else { + // Renew the secret + renewResp, err := c.client.Sys().Renew(req.id, req.increment) + if err != nil { + return fmt.Errorf("failed to renew vault secret: %v", err) + } + if renewResp == nil { + return fmt.Errorf("failed to renew vault secret") + } + + // Set the next renewal time to half the lease duration + duration = time.Duration(renewResp.LeaseDuration) * time.Second / 2 + } + + // Determine the next renewal time + next := time.Now().Add(duration) + + if c.IsTracked(req.id) { + // If the identifier is already tracked, this indicates a + // subsequest renewal. In this case, update the existing + // element in the heap with the new renewal time. + + // There is no need to signal an update to the renewal loop + // here because this case is hit from the renewal loop itself. + if err := c.heap.Update(req, next); err != nil { + return fmt.Errorf("failed to update heap entry. err: %v", err) + } + } else { + // If the identifier is not already tracked, this is a first + // renewal request. In this case, add an entry into the heap + // with the next renewal time. + if err := c.heap.Push(req, next); err != nil { + return fmt.Errorf("failed to push an entry to heap. err: %v", err) + } + + // Signal an update for the renewal loop to trigger a fresh + // computation for the next best candidate for renewal. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + } + + return nil +} + +// run is the renewal loop which performs the periodic renewals of both the +// tokens and the secret leases. +func (c *vaultClient) run() { + var renewalCh <-chan time.Time + + if !c.config.Enabled { + return + } + + c.lock.Lock() + c.running = true + c.lock.Unlock() + + for c.config.Enabled && c.running { + // Fetches the candidate for next renewal + renewalReq, renewalTime := c.nextRenewal() + if renewalTime.IsZero() { + // If the heap is empty, don't do anything + renewalCh = nil + } else { + now := time.Now() + if renewalTime.After(now) { + // Compute the duration after which the item + // needs renewal and set the renewalCh to fire + // at that time. + renewalDuration := renewalTime.Sub(time.Now()) + renewalCh = time.After(renewalDuration) + } else { + // If the renewals of multiple items are too + // close to each other and by the time the + // entry is fetched from heap it might be past + // the current time (by a small margin). In + // which case, fire immediately. + renewalCh = time.After(0) + } + } + + select { + case <-renewalCh: + if err := c.renew(renewalReq); err != nil { + renewalReq.errCh <- err + } + case <-c.updateCh: + continue + case <-c.stopCh: + c.logger.Printf("[INFO] vaultclient: stopped") + return + } + } +} + +// StopRenewToken removes the item from the heap which represents the given +// token. +func (c *vaultClient) StopRenewToken(token string) error { + return c.stopRenew(token) +} + +// StopRenewLease removes the item from the heap which represents the given +// lease identifier. +func (c *vaultClient) StopRenewLease(leaseId string) error { + return c.stopRenew(leaseId) +} + +// stopRenew removes the given identifier from the heap and signals the renewal +// loop to compute the next best candidate for renewal. +func (c *vaultClient) stopRenew(id string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.IsTracked(id) { + return nil + } + + // Remove the identifier from the heap + if err := c.heap.Remove(id); err != nil { + return fmt.Errorf("failed to remove heap entry: %v", err) + } + + // Delete the identifier from the map only after the it is removed from + // the heap. Heap's remove method relies on the heap map. + delete(c.heap.heapMap, id) + + // Signal an update to the renewal loop. + if c.running { + select { + case c.updateCh <- struct{}{}: + default: + } + } + + return nil +} + +// nextRenewal returns the root element of the min-heap, which represents the +// next element to be renewed and the time at which the renewal needs to be +// triggered. +func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.heap.Length() == 0 { + return nil, time.Time{} + } + + // Fetches the root element in the min-heap + nextEntry := c.heap.Peek() + if nextEntry == nil { + return nil, time.Time{} + } + + return nextEntry.req, nextEntry.next +} + +// Additional helper functions on top of interface methods + +// Length returns the number of elements in the heap +func (h *vaultClientHeap) Length() int { + return len(h.heap) +} + +// Returns the root node of the min-heap +func (h *vaultClientHeap) Peek() *vaultClientHeapEntry { + if len(h.heap) == 0 { + return nil + } + + return h.heap[0] +} + +// Push adds the secondary index and inserts an item into the heap +func (h *vaultClientHeap) Push(req *vaultClientRenewalRequest, next time.Time) error { + if req == nil { + return fmt.Errorf("nil request") + } + + if _, ok := h.heapMap[req.id]; ok { + return fmt.Errorf("entry %v already exists", req.id) + } + + heapEntry := &vaultClientHeapEntry{ + req: req, + next: next, + } + h.heapMap[req.id] = heapEntry + heap.Push(&h.heap, heapEntry) + return nil +} + +// Update will modify the existing item in the heap with the new data and the +// time, and fixes the heap. +func (h *vaultClientHeap) Update(req *vaultClientRenewalRequest, next time.Time) error { + if entry, ok := h.heapMap[req.id]; ok { + entry.req = req + entry.next = next + heap.Fix(&h.heap, entry.index) + return nil + } + + return fmt.Errorf("heap doesn't contain %v", req.id) +} + +// Remove will remove an identifier from the secondary index and deletes the +// corresponding node from the heap. +func (h *vaultClientHeap) Remove(id string) error { + if entry, ok := h.heapMap[id]; ok { + heap.Remove(&h.heap, entry.index) + delete(h.heapMap, id) + return nil + } + + return fmt.Errorf("heap doesn't contain entry for %v", id) +} + +// The heap interface requires the following methods to be implemented. +// * Push(x interface{}) // add x as element Len() +// * Pop() interface{} // remove and return element Len() - 1. +// * sort.Interface +// +// sort.Interface comprises of the following methods: +// * Len() int +// * Less(i, j int) bool +// * Swap(i, j int) + +// Part of sort.Interface +func (h vaultDataHeapImp) Len() int { return len(h) } + +// Part of sort.Interface +func (h vaultDataHeapImp) Less(i, j int) bool { + // Two zero times should return false. + // Otherwise, zero is "greater" than any other time. + // (To sort it at the end of the list.) + // Sort such that zero times are at the end of the list. + iZero, jZero := h[i].next.IsZero(), h[j].next.IsZero() + if iZero && jZero { + return false + } else if iZero { + return false + } else if jZero { + return true + } + + return h[i].next.Before(h[j].next) +} + +// Part of sort.Interface +func (h vaultDataHeapImp) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Push(x interface{}) { + n := len(*h) + entry := x.(*vaultClientHeapEntry) + entry.index = n + *h = append(*h, entry) +} + +// Part of heap.Interface +func (h *vaultDataHeapImp) Pop() interface{} { + old := *h + n := len(old) + entry := old[n-1] + entry.index = -1 // for safety + *h = old[0 : n-1] + return entry +} diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go new file mode 100644 index 000000000..d36c2e711 --- /dev/null +++ b/client/vaultclient/vaultclient_test.go @@ -0,0 +1,221 @@ +package vaultclient + +import ( + "log" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/testutil" + vaultapi "github.com/hashicorp/vault/api" +) + +func TestVaultClient_EstablishConnection(t *testing.T) { + v := testutil.NewTestVault(t) + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + + c, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while and check that no connection has been established. + time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) + + if c.ConnectionEstablished() { + t.Fatalf("ConnectionEstablished() returned true before Vault server started") + } + + // Start Vault + v.Start() + defer v.Stop() + + testutil.WaitForResult(func() (bool, error) { + return c.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) +} + +func TestVaultClient_TokenRenewals(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "10s" + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + c, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while to ensure that the renewal loop is active + time.Sleep(2 * time.Second) + + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "2s", + DisplayName: "derived-for-task", + Renewable: new(bool), + } + *tcr.Renewable = true + + num := 10 + tokens := make([]string, num) + for i := 0; i < num; i++ { + c.client.SetToken(v.Config.Token) + + if err := c.client.SetAddress(v.Config.Addr); err != nil { + t.Fatal(err) + } + + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + t.Fatalf("failed to create vault token: %v", err) + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatal("failed to derive a wrapped vault token") + } + + tokens[i] = secret.Auth.ClientToken + + errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) + go func(errCh <-chan error) { + var err error + for { + select { + case err = <-errCh: + t.Fatalf("error while renewing the token: %v", err) + } + } + }(errCh) + } + + if c.heap.Length() != num { + t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) + } + + time.Sleep(10 * time.Second) + + for i := 0; i < num; i++ { + if err := c.StopRenewToken(tokens[i]); err != nil { + t.Fatal(err) + } + } + + if c.heap.Length() != 0 { + t.Fatal("bad: heap length: expected: 0, actual: %d", c.heap.Length()) + } +} + +func TestVaultClient_Heap(t *testing.T) { + conf := config.DefaultConfig() + conf.VaultConfig.Token = "testvaulttoken" + conf.VaultConfig.TaskTokenTTL = "10s" + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + c, err := NewVaultClient(conf.VaultConfig, logger) + if err != nil { + t.Fatal(err) + } + + now := time.Now() + + renewalReq1 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id1", + increment: 10, + } + if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id1") { + t.Fatalf("id1 should have been tracked") + } + + renewalReq2 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id2", + increment: 10, + } + if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id2") { + t.Fatalf("id2 should have been tracked") + } + + renewalReq3 := &vaultClientRenewalRequest{ + errCh: make(chan error, 1), + id: "id3", + increment: 10, + } + if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { + t.Fatal(err) + } + if !c.IsTracked("id3") { + t.Fatalf("id3 should have been tracked") + } + + // Reading elements should yield id2, id1 and id3 in order + req, _ := c.nextRenewal() + if req != renewalReq2 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req) + } + if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq1 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req) + } + if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil { + t.Fatal(err) + } + + req, _ = c.nextRenewal() + if req != renewalReq3 { + t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req) + } + if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id1"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id2"); err != nil { + t.Fatal(err) + } + + if err := c.StopRenewToken("id3"); err != nil { + t.Fatal(err) + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + + if c.IsTracked("id1") { + t.Fatalf("id1 should not have been tracked") + } + +} From 04fab3bc8189336939b96e918a94fcf915574a8b Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Tue, 23 Aug 2016 17:10:00 -0400 Subject: [PATCH 2/9] Employ DeriveVaultToken API and flesh-up DeriveToken --- client/client.go | 18 +- client/task_runner.go | 1 - client/vaultclient/vaultclient.go | 350 +++++++++++++++++-------- client/vaultclient/vaultclient_test.go | 38 ++- 4 files changed, 292 insertions(+), 115 deletions(-) diff --git a/client/client.go b/client/client.go index 04897d0eb..785b608f3 100644 --- a/client/client.go +++ b/client/client.go @@ -248,7 +248,9 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg go c.rpcProxy.Run() // Start renewing tokens and secrets - go c.vaultClient.Start() + if c.vaultClient != nil { + go c.vaultClient.Start() + } return c, nil } @@ -1298,15 +1300,27 @@ func (c *Client) setupVaultClient() error { if c.config.VaultConfig == nil { return fmt.Errorf("nil vault config") } + + if !c.config.VaultConfig.Enabled { + return nil + } + if c.config.VaultConfig.Token == "" { return fmt.Errorf("vault token not set") } var err error - if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger); err != nil { + if c.vaultClient, err = vaultclient.NewVaultClient(c.Node(), c.Region(), + c.config.VaultConfig, c.logger, c.config.RPCHandler, c.connPool, + c.rpcProxy); err != nil { return err } + if c.vaultClient == nil { + c.logger.Printf("[ERR] client: failed to create vault client") + return fmt.Errorf("failed to create vault client") + } + return nil } diff --git a/client/task_runner.go b/client/task_runner.go index 0544a1eae..d57d6ae95 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -12,7 +12,6 @@ import ( "time" "github.com/armon/go-metrics" - "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/driver" diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 32600cead..34905dfa5 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -4,9 +4,15 @@ import ( "container/heap" "fmt" "log" + "math/rand" "sync" "time" + "github.com/hashicorp/go-multierror" + clientconfig "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/rpcproxy" + "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/mapstructure" @@ -21,9 +27,10 @@ type VaultClient interface { // Stops the renewal loop for tokens and secrets Stop() - // Contacts the nomad server and fetches a wrapped token. The wrapped - // token will be unwrapped by contacting vault and returned. - DeriveToken() (string, error) + // Contacts the nomad server and fetches wrapped tokens for a set of + // tasks. The wrapped tokens will be unwrapped using vault and + // returned. + DeriveToken(*structs.Allocation, []string) (map[string]string, error) // Fetch the Consul ACL token required for the task GetConsulACL(string, string) (*vaultapi.Secret, error) @@ -39,13 +46,19 @@ type VaultClient interface { // min-heap for periodic renewal. RenewLease(string, int) <-chan error - // Removes a secret's lease id from the min-heap, stopping its renewal. + // Removes a secret's lease ID from the min-heap, stopping its renewal. StopRenewLease(string) error } // Implementation of VaultClient interface to interact with vault and perform // token and lease renewals periodically. type vaultClient struct { + // Client's region + region string + + // The node in which this vault client is running in + node *structs.Node + // running indicates if the renewal loop is active or not running bool @@ -73,6 +86,10 @@ type vaultClient struct { lock sync.RWMutex logger *log.Logger + + rpcHandler clientconfig.RPCHandler + rpcProxy *rpcproxy.RPCProxy + connPool *nomad.ConnPool } // tokenData holds the relevant information about the Vault token passed to the @@ -122,13 +139,26 @@ type vaultClientHeap struct { type vaultDataHeapImp []*vaultClientHeapEntry // NewVaultClient returns a new vault client from the given config. -func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClient, error) { +func NewVaultClient(node *structs.Node, region string, config *config.VaultConfig, + logger *log.Logger, rpcHandler clientconfig.RPCHandler, connPool *nomad.ConnPool, + rpcProxy *rpcproxy.RPCProxy) (*vaultClient, error) { + if !config.Enabled { + return nil, nil + } + + if node == nil { + return nil, fmt.Errorf("nil node") + } + + if region == "" { + return nil, fmt.Errorf("missing region") + } + if config == nil { return nil, fmt.Errorf("nil vault config") } - // Creation of a vault client requires that the token is supplied via - // config. + // Creation of a vault client requires a token if config.Token == "" { return nil, fmt.Errorf("vault token not set") } @@ -141,34 +171,44 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClien return nil, fmt.Errorf("nil logger") } + if connPool == nil { + return nil, fmt.Errorf("nil connection pool") + } + + if rpcProxy == nil { + return nil, fmt.Errorf("nil rpc proxy") + } + c := &vaultClient{ - config: config, - stopCh: make(chan struct{}), + rpcHandler: rpcHandler, + connPool: connPool, + rpcProxy: rpcProxy, + region: region, + node: node, + config: config, + stopCh: make(chan struct{}), + // Update channel should be a buffered channel updateCh: make(chan struct{}, 1), heap: NewVaultClientHeap(), logger: logger, } - if !c.config.Enabled { - return nil, nil - } - // Get the Vault API configuration apiConf, err := config.ApiConfig() if err != nil { - return nil, fmt.Errorf("failed to create vault API config: %v", err) + logger.Printf("[ERR] client/vaultclient: failed to create vault API config: %v", err) + return nil, err } // Create the Vault API client client, err := vaultapi.NewClient(apiConf) if err != nil { - logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err) + logger.Printf("[ERR] client/vaultclient: failed to create Vault client. Not retrying: %v", err) return nil, err } // Set the token and store the client client.SetToken(c.config.Token) - c.client = client return c, nil @@ -184,7 +224,7 @@ func NewVaultClientHeap() *vaultClientHeap { } // IsTracked returns if a given identifier is already present in the heap and -// hence is being renewed. +// hence is being renewed. Lock should be held before calling this method. func (c *vaultClient) IsTracked(id string) bool { if id == "" { return false @@ -200,7 +240,7 @@ func (c *vaultClient) Start() { return } - c.logger.Printf("[INFO] vaultclient: establishing connection to vault") + c.logger.Printf("[DEBUG] client/vaultclient: establishing connection to vault") go c.establishConnection() } @@ -213,7 +253,7 @@ func (c *vaultClient) ConnectionEstablished() bool { } // establishConnection is used to make first contact with Vault. This should be -// called in a go-routine since the connection is retried til the Vault Client +// called in a go-routine since the connection is retried till the Vault Client // is stopped or the connection is successfully made at which point the renew // loop is started. func (c *vaultClient) establishConnection() { @@ -229,7 +269,7 @@ OUTER: case <-retryTimer.C: // Ensure the API is reachable if _, err := c.client.Sys().InitStatus(); err != nil { - c.logger.Printf("[WARN] vaultclient: failed to contact Vault API. Retrying in %v", + c.logger.Printf("[WARN] client/vaultclient: failed to contact Vault API. Retrying in %v", c.config.ConnectionRetryIntv) retryTimer.Reset(c.config.ConnectionRetryIntv) continue OUTER @@ -245,38 +285,42 @@ OUTER: // Retrieve our token, validate it and parse the lease duration if err := c.parseSelfToken(); err != nil { - c.logger.Printf("[ERR] vaultclient: failed to lookup self token and not retrying: %v", err) + c.logger.Printf("[ERR] client/vaultclient: failed to lookup self token and not retrying: %v", err) return } // Begin the renewal loop go c.run() - c.logger.Printf("[INFO] vaultclient: started") + c.logger.Printf("[DEBUG] client/vaultclient: started") - // If we are given a non-root token, start renewing it - if c.token.Renewable { - c.logger.Printf("[INFO] vaultclient: not renewing token as it is not renewable") + // If we are given a token that needs renewal, place it in the renewal + // loop. + + // Root tokens can also have a TTL + if c.token.Root && c.token.TTL == 0 { + c.logger.Printf("[DEBUG] client/vaultclient: not renewing token as it is a non-expiring root token") } else { - c.logger.Printf("[INFO] vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second) + c.logger.Printf("[DEBUG] client/vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second) - // Add the VaultClient's token to the renewal loop + // Renew the token and place it in renewal min-heap errCh := c.RenewToken(c.config.Token, c.token.CreationTTL) + // Catch the renewal error of VaultClient's token. go func(errCh <-chan error) { var err error for { select { case err = <-errCh: - c.logger.Printf("[ERR] vaultclient: error while renewing the vault client's token: %v", err) + c.logger.Printf("[ERR] client/vaultclient: error while renewing the vault client's token: %v", err) } } }(errCh) } } -// parseSelfToken looks up the Vault token in Vault and parses its data storing -// it in the client. If the token is not valid for Nomads purposes an error is -// returned. +// parseSelfToken looks up the VaultClient's token in vault and parses its data +// storing it in the client. If the token is not valid for Nomads purposes an +// error is returned. func (c *vaultClient) parseSelfToken() error { // Get the initial lease duration auth := c.client.Auth().Token() @@ -333,69 +377,127 @@ func (c *vaultClient) Stop() { close(c.stopCh) } -// DeriveToken contacts the nomad server and fetches a wrapped token. Then it -// contacts vault to unwrap the token and returns the unwrapped token. -func (c *vaultClient) DeriveToken() (string, error) { - // TODO: Replace this code with an actual call to the nomad server. - // This is a sample code which directly fetches a wrapped token from - // vault and unwraps it for time being. - tcr := &vaultapi.TokenCreateRequest{ - Policies: []string{"foo", "bar"}, - TTL: "10s", - DisplayName: "derived-token", - Renewable: new(bool), - } - *tcr.Renewable = true +// DeriveToken takes in an allocation and a set of tasks and for each of the +// task, it derives a vault token from nomad server and unwraps it using vault. +// The return value is a map containing all the unwrapped tokens indexed by the +// task name. +func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + var result *multierror.Error - // Set the TTL for the wrapped token - wrapLookupFunc := func(method, path string) string { - if method == "POST" && path == "auth/token/create" { - return "60s" + if !c.running { + result = multierror.Append(fmt.Errorf("vault client is not running")) + return nil, result + } + + if alloc == nil { + result = multierror.Append(fmt.Errorf("nil allocation")) + return nil, result + } + if taskNames == nil || len(taskNames) == 0 { + result = multierror.Append(fmt.Errorf("missing task names")) + return nil, result + } + + found := false + verifiedTasks := []string{} + // Check if the given task names actually exist in the allocation + for _, taskName := range taskNames { + found = false + for _, group := range alloc.Job.TaskGroups { + for _, task := range group.Tasks { + if task.Name == taskName { + found = true + } + } + } + if found { + verifiedTasks = append(verifiedTasks, taskName) + } else { + // Append the error for an invalid task name, but don't + // break out of the loop. Continue to process other + // tasks. + result = multierror.Append(result, fmt.Errorf("task %s not found in the allocation", taskName)) } - return "" - } - c.client.SetWrappingLookupFunc(wrapLookupFunc) - - // Create a wrapped token - secret, err := c.client.Auth().Token().Create(tcr) - if err != nil { - return "", fmt.Errorf("failed to create vault token: %v", err) - } - if secret == nil || secret.WrapInfo == nil || secret.WrapInfo.Token == "" || - secret.WrapInfo.WrappedAccessor == "" { - return "", fmt.Errorf("failed to derive a wrapped vault token") } - wrappedToken := secret.WrapInfo.Token - - // Unwrap the vault token - unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) - if err != nil { - return "", fmt.Errorf("failed to unwrap the token: %v", err) - } - if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { - return "", fmt.Errorf("failed to unwrap the token") + // DeriveVaultToken of nomad server can take in a set of tasks and + // creates tokens for all the tasks. + req := &structs.DeriveVaultTokenRequest{ + NodeID: c.node.ID, + SecretID: c.node.SecretID, + AllocID: alloc.ID, + Tasks: verifiedTasks, + QueryOptions: structs.QueryOptions{ + Region: c.region, + AllowStale: true, + }, } - // Return the unwrapped token - return unwrapResp.Auth.ClientToken, nil + // Derive the tokens + var resp structs.DeriveVaultTokenResponse + if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil { + c.logger.Printf("[ERR] client/vaultclient: failed to derive vault tokens: %v", err) + result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: %v", err)) + return nil, result + } + if resp.Tasks == nil { + c.logger.Printf("[ERR] client/vaultclient: failed to derive vault token: invalid response") + result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: invalid response")) + return nil, result + } + + unwrappedTokens := make(map[string]string) + + // Retrieve the wrapped tokens from the response and unwrap it using + // the VaultClient's token, which is cached at the API client. + for _, taskName := range verifiedTasks { + // Get the wrapped token + wrappedToken, ok := resp.Tasks[taskName] + if !ok { + c.logger.Printf("[ERR] client/vaultclient: wrapped token missing for task %q", taskName) + result = multierror.Append(result, fmt.Errorf("wrapped token missing for task %q", taskName)) + return nil, result + } + + // Unwrap the vault token + unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) + if err != nil { + result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err)) + return nil, result + } + if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { + result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q", taskName)) + return nil, result + } + + // Append the unwrapped token to the return value + unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken + } + + return unwrappedTokens, nil } // GetConsulACL creates a vault API client and reads from vault a consul ACL // token used by the task. -func (c *vaultClient) GetConsulACL(token, vaultPath string) (*vaultapi.Secret, error) { +func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) { if token == "" { return nil, fmt.Errorf("missing token") } - if vaultPath == "" { - return nil, fmt.Errorf("missing vault path") + if path == "" { + return nil, fmt.Errorf("missing consul ACL token vault path") } + c.lock.Lock() + defer c.lock.Unlock() + // Use the token supplied to interact with vault c.client.SetToken(token) + // Restore the token in client to VaultClient's token + defer c.client.SetToken(c.config.Token) + // Read the consul ACL token and return the secret directly - return c.client.Logical().Read(vaultPath) + return c.client.Logical().Read(path) } // RenewToken renews the supplied token and adds it to the min-heap so that it @@ -411,6 +513,10 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { errCh <- fmt.Errorf("missing token") return errCh } + if increment < 1 { + errCh <- fmt.Errorf("increment cannot be less than 1") + return errCh + } // Create a renewal request and indicate that the identifier in the // request is a token and not a lease @@ -436,8 +542,8 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // channel and the channel is returned instead of an actual error. This helps // the caller be notified of a renewal failure asynchronously for appropriate // actions to be taken. -func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error { - c.logger.Printf("[INFO] vaultclient: renewing lease %q", leaseId) +func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { + c.logger.Printf("[DEBUG] client/vaultclient: renewing lease %q", leaseId) // Create a buffered error channel errCh := make(chan error, 1) @@ -446,8 +552,8 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error return errCh } - if leaseDuration == 0 { - errCh <- fmt.Errorf("missing lease duration") + if increment < 1 { + errCh <- fmt.Errorf("increment cannot be less than 1") return errCh } @@ -455,7 +561,7 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error renewalReq := &vaultClientRenewalRequest{ errCh: make(chan error, 1), id: leaseId, - increment: leaseDuration, + increment: increment, } // Renew the secret and send any error to the dedicated error channel @@ -467,12 +573,11 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error } // renew is a common method to handle renewal of both tokens and secret leases. -// It creates a vault API client and invokes either a token renewal request or -// a secret renewal request. If renewal is successful, min-heap is updated -// based on the duration after which it needs its renewal again. The duration -// is set to half the lease duration present in the renewal response. +// It invokes a token renewal or a secret's lease renewal. If renewal is +// successful, min-heap is updated based on the duration after which it needs +// renewal again. The next renewal time is randomly selected to avoid spikes in +// the number of APIs periodically. func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { - c.logger.Printf("[INFO] vaultclient: ~~~~~~~Renewing %s~~~~~~~~", req.id) c.lock.Lock() defer c.lock.Unlock() @@ -486,11 +591,12 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { if req.id == "" { return fmt.Errorf("missing id in renewal request") } - if req.increment == 0 { - return fmt.Errorf("missing increment in renewal request") + if req.increment < 1 { + return fmt.Errorf("increment cannot be less than 1") } - var duration time.Duration + var renewalErr error + leaseDuration := req.increment if req.isToken { // Reset the token in the API client to that of VaultClient // before returning @@ -503,30 +609,45 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { // Renew the token renewResp, err := c.client.Auth().Token().RenewSelf(req.increment) if err != nil { - return fmt.Errorf("failed to renew the vault token: %v", err) + renewalErr = fmt.Errorf("failed to renew the vault token: %v", err) } if renewResp == nil || renewResp.Auth == nil { - return fmt.Errorf("failed to renew the vault token") + renewalErr = fmt.Errorf("failed to renew the vault token") + } else { + // Don't set this if renewal fails + leaseDuration = renewResp.Auth.LeaseDuration } - - // Set the next renewal time to half the lease duration - duration = time.Duration(renewResp.Auth.LeaseDuration) * time.Second / 2 } else { // Renew the secret renewResp, err := c.client.Sys().Renew(req.id, req.increment) if err != nil { - return fmt.Errorf("failed to renew vault secret: %v", err) + renewalErr = fmt.Errorf("failed to renew vault secret: %v", err) } if renewResp == nil { - return fmt.Errorf("failed to renew vault secret") + renewalErr = fmt.Errorf("failed to renew vault secret") + } else { + // Don't set this if renewal fails + leaseDuration = renewResp.LeaseDuration } - - // Set the next renewal time to half the lease duration - duration = time.Duration(renewResp.LeaseDuration) * time.Second / 2 } + duration := leaseDuration / 2 + switch { + case leaseDuration < 30: + // Don't bother about introducing randomness if the + // leaseDuration is too small. + default: + // Give a breathing space of 20 seconds + min := 10 + max := leaseDuration - min + rand.Seed(time.Now().Unix()) + duration = min + rand.Intn(max-min) + } + c.logger.Printf("[DEBUG] client/vaultclient: req.increment: %d, leaseDuration: %d, duration: %d", + req.increment, leaseDuration, duration) + // Determine the next renewal time - next := time.Now().Add(duration) + next := time.Now().Add(time.Duration(duration) * time.Second) if c.IsTracked(req.id) { // If the identifier is already tracked, this indicates a @@ -556,14 +677,15 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { } } - return nil + // Returning the renewal error here ensures that an entry is either + // added or updated in the min-heap. This is done to not starve other + // entries in heap. + return renewalErr } // run is the renewal loop which performs the periodic renewals of both the // tokens and the secret leases. func (c *vaultClient) run() { - var renewalCh <-chan time.Time - if !c.config.Enabled { return } @@ -572,6 +694,7 @@ func (c *vaultClient) run() { c.running = true c.lock.Unlock() + var renewalCh <-chan time.Time for c.config.Enabled && c.running { // Fetches the candidate for next renewal renewalReq, renewalTime := c.nextRenewal() @@ -604,7 +727,7 @@ func (c *vaultClient) run() { case <-c.updateCh: continue case <-c.stopCh: - c.logger.Printf("[INFO] vaultclient: stopped") + c.logger.Printf("[DEBUG] client/vaultclient: stopped") return } } @@ -672,6 +795,27 @@ func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { return nextEntry.req, nextEntry.next } +// RPC is used to forward an RPC call to a nomad server, or fail if no servers +func (c *vaultClient) RPC(method string, args interface{}, reply interface{}) error { + // Invoke the RPCHandler if it exists + if c.rpcHandler != nil { + return c.rpcHandler.RPC(method, args, reply) + } + + // Pick a server to request from + server := c.rpcProxy.FindServer() + if server == nil { + return fmt.Errorf("no known servers") + } + + // Make the RPC request + if err := c.connPool.RPC(c.region, server.Addr, structs.ApiMajorVersion, method, args, reply); err != nil { + c.rpcProxy.NotifyFailedServer(server) + return fmt.Errorf("RPC failed to server %s: %v", server.Addr, err) + } + return nil +} + // Additional helper functions on top of interface methods // Length returns the number of elements in the heap diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index d36c2e711..065bce76e 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -7,6 +7,9 @@ import ( "time" "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/client/rpcproxy" + "github.com/hashicorp/nomad/nomad" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" vaultapi "github.com/hashicorp/vault/api" ) @@ -17,8 +20,12 @@ func TestVaultClient_EstablishConnection(t *testing.T) { logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - - c, err := NewVaultClient(v.Config, logger) + node := &structs.Node{} + connPool := &nomad.ConnPool{} + rpcProxy := &rpcproxy.RPCProxy{} + var rpcHandler config.RPCHandler + c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, + connPool, rpcProxy) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -48,11 +55,15 @@ func TestVaultClient_TokenRenewals(t *testing.T) { v := testutil.NewTestVault(t).Start() defer v.Stop() + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - - logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - c, err := NewVaultClient(v.Config, logger) + node := &structs.Node{} + connPool := &nomad.ConnPool{} + rpcProxy := &rpcproxy.RPCProxy{} + var rpcHandler config.RPCHandler + c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, + connPool, rpcProxy) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -61,7 +72,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { defer c.Stop() // Sleep a little while to ensure that the renewal loop is active - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) tcr := &vaultapi.TokenCreateRequest{ Policies: []string{"foo", "bar"}, @@ -71,7 +82,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { } *tcr.Renewable = true - num := 10 + num := 5 tokens := make([]string, num) for i := 0; i < num; i++ { c.client.SetToken(v.Config.Token) @@ -107,7 +118,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) } - time.Sleep(10 * time.Second) + time.Sleep(5 * time.Second) for i := 0; i < num; i++ { if err := c.StopRenewToken(tokens[i]); err != nil { @@ -122,14 +133,23 @@ func TestVaultClient_TokenRenewals(t *testing.T) { func TestVaultClient_Heap(t *testing.T) { conf := config.DefaultConfig() + conf.VaultConfig.Enabled = true conf.VaultConfig.Token = "testvaulttoken" conf.VaultConfig.TaskTokenTTL = "10s" logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - c, err := NewVaultClient(conf.VaultConfig, logger) + node := &structs.Node{} + connPool := &nomad.ConnPool{} + rpcProxy := &rpcproxy.RPCProxy{} + var rpcHandler config.RPCHandler + c, err := NewVaultClient(node, "global", conf.VaultConfig, logger, rpcHandler, + connPool, rpcProxy) if err != nil { t.Fatal(err) } + if c == nil { + t.Fatal("failed to create vault client") + } now := time.Now() From 7f919c9d745061227e4cf9bd83736cefbbbbc212 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Mon, 29 Aug 2016 12:37:39 -0400 Subject: [PATCH 3/9] Address review feedback --- client/vaultclient/vaultclient.go | 222 ++++++++----------------- client/vaultclient/vaultclient_test.go | 12 +- 2 files changed, 71 insertions(+), 163 deletions(-) diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 34905dfa5..adc45fded 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -8,14 +8,12 @@ import ( "sync" "time" - "github.com/hashicorp/go-multierror" clientconfig "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/client/rpcproxy" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" - "github.com/mitchellh/mapstructure" ) // The interface which nomad client uses to interact with vault and @@ -158,11 +156,6 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi return nil, fmt.Errorf("nil vault config") } - // Creation of a vault client requires a token - if config.Token == "" { - return nil, fmt.Errorf("vault token not set") - } - if config.TaskTokenTTL == "" { return nil, fmt.Errorf("task_token_ttl not set") } @@ -189,43 +182,41 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi stopCh: make(chan struct{}), // Update channel should be a buffered channel updateCh: make(chan struct{}, 1), - heap: NewVaultClientHeap(), + heap: newVaultClientHeap(), logger: logger, } // Get the Vault API configuration apiConf, err := config.ApiConfig() if err != nil { - logger.Printf("[ERR] client/vaultclient: failed to create vault API config: %v", err) + logger.Printf("[ERR] client.vault: failed to create vault API config: %v", err) return nil, err } // Create the Vault API client client, err := vaultapi.NewClient(apiConf) if err != nil { - logger.Printf("[ERR] client/vaultclient: failed to create Vault client. Not retrying: %v", err) + logger.Printf("[ERR] client.vault: failed to create Vault client. Not retrying: %v", err) return nil, err } - // Set the token and store the client - client.SetToken(c.config.Token) c.client = client return c, nil } -// NewVaultClientHeap returns a new vault client heap with both the heap and a +// newVaultClientHeap returns a new vault client heap with both the heap and a // map which is a secondary index for heap elements, both initialized. -func NewVaultClientHeap() *vaultClientHeap { +func newVaultClientHeap() *vaultClientHeap { return &vaultClientHeap{ heapMap: make(map[string]*vaultClientHeapEntry), heap: make(vaultDataHeapImp, 0), } } -// IsTracked returns if a given identifier is already present in the heap and +// isTracked returns if a given identifier is already present in the heap and // hence is being renewed. Lock should be held before calling this method. -func (c *vaultClient) IsTracked(id string) bool { +func (c *vaultClient) isTracked(id string) bool { if id == "" { return false } @@ -240,7 +231,7 @@ func (c *vaultClient) Start() { return } - c.logger.Printf("[DEBUG] client/vaultclient: establishing connection to vault") + c.logger.Printf("[DEBUG] client.vault: establishing connection to vault") go c.establishConnection() } @@ -269,7 +260,7 @@ OUTER: case <-retryTimer.C: // Ensure the API is reachable if _, err := c.client.Sys().InitStatus(); err != nil { - c.logger.Printf("[WARN] client/vaultclient: failed to contact Vault API. Retrying in %v", + c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v", c.config.ConnectionRetryIntv) retryTimer.Reset(c.config.ConnectionRetryIntv) continue OUTER @@ -283,85 +274,9 @@ OUTER: c.connEstablished = true c.lock.Unlock() - // Retrieve our token, validate it and parse the lease duration - if err := c.parseSelfToken(); err != nil { - c.logger.Printf("[ERR] client/vaultclient: failed to lookup self token and not retrying: %v", err) - return - } - // Begin the renewal loop go c.run() - c.logger.Printf("[DEBUG] client/vaultclient: started") - - // If we are given a token that needs renewal, place it in the renewal - // loop. - - // Root tokens can also have a TTL - if c.token.Root && c.token.TTL == 0 { - c.logger.Printf("[DEBUG] client/vaultclient: not renewing token as it is a non-expiring root token") - } else { - c.logger.Printf("[DEBUG] client/vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second) - - // Renew the token and place it in renewal min-heap - errCh := c.RenewToken(c.config.Token, c.token.CreationTTL) - - // Catch the renewal error of VaultClient's token. - go func(errCh <-chan error) { - var err error - for { - select { - case err = <-errCh: - c.logger.Printf("[ERR] client/vaultclient: error while renewing the vault client's token: %v", err) - } - } - }(errCh) - } -} - -// parseSelfToken looks up the VaultClient's token in vault and parses its data -// storing it in the client. If the token is not valid for Nomads purposes an -// error is returned. -func (c *vaultClient) parseSelfToken() error { - // Get the initial lease duration - auth := c.client.Auth().Token() - self, err := auth.LookupSelf() - if err != nil { - return fmt.Errorf("failed to lookup VaultClient's token: %v", err) - } - - // Read and parse the fields - var data tokenData - if err := mapstructure.WeakDecode(self.Data, &data); err != nil { - return fmt.Errorf("failed to parse Vault token's data block: %v", err) - } - - root := false - for _, p := range data.Policies { - if p == "root" { - root = true - break - } - } - - if !data.Renewable && !root { - return fmt.Errorf("vault token is not renewable or root") - } - - if data.CreationTTL == 0 && !root { - return fmt.Errorf("invalid lease duration of zero") - } - - if data.TTL == 0 && !root { - return fmt.Errorf("token TTL is zero") - } - - if !root && data.Role == "" { - return fmt.Errorf("token role name must be set when not using a root token") - } - - data.Root = root - c.token = &data - return nil + c.logger.Printf("[DEBUG] client.vault: started") } // Stops the renewal loop of vault client @@ -382,42 +297,39 @@ func (c *vaultClient) Stop() { // The return value is a map containing all the unwrapped tokens indexed by the // task name. func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { - var result *multierror.Error - if !c.running { - result = multierror.Append(fmt.Errorf("vault client is not running")) - return nil, result + return nil, fmt.Errorf("vault client is not running") } if alloc == nil { - result = multierror.Append(fmt.Errorf("nil allocation")) - return nil, result - } - if taskNames == nil || len(taskNames) == 0 { - result = multierror.Append(fmt.Errorf("missing task names")) - return nil, result + return nil, fmt.Errorf("nil allocation") + } + + if taskNames == nil || len(taskNames) == 0 { + return nil, fmt.Errorf("missing task names") } - found := false verifiedTasks := []string{} - // Check if the given task names actually exist in the allocation - for _, taskName := range taskNames { - found = false - for _, group := range alloc.Job.TaskGroups { - for _, task := range group.Tasks { - if task.Name == taskName { - found = true + found := false + // Check if the given task names actually exist in the allocation under + // the correct group name + for _, group := range alloc.Job.TaskGroups { + // Refer only to the group belonging to the allocation + if group.Name == alloc.TaskGroup { + for _, taskName := range taskNames { + found = false + for _, task := range group.Tasks { + if task.Name == taskName { + found = true + } } + if !found { + c.logger.Printf("[ERR] task %q not found in the allocation", taskName) + return nil, fmt.Errorf("task %q not found in the allocaition", taskName) + } + verifiedTasks = append(verifiedTasks, taskName) } } - if found { - verifiedTasks = append(verifiedTasks, taskName) - } else { - // Append the error for an invalid task name, but don't - // break out of the loop. Continue to process other - // tasks. - result = multierror.Append(result, fmt.Errorf("task %s not found in the allocation", taskName)) - } } // DeriveVaultToken of nomad server can take in a set of tasks and @@ -436,38 +348,32 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) // Derive the tokens var resp structs.DeriveVaultTokenResponse if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil { - c.logger.Printf("[ERR] client/vaultclient: failed to derive vault tokens: %v", err) - result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: %v", err)) - return nil, result + c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err) + return nil, fmt.Errorf("failed to derive vault tokens: %v", err) } if resp.Tasks == nil { - c.logger.Printf("[ERR] client/vaultclient: failed to derive vault token: invalid response") - result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: invalid response")) - return nil, result + c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response") + return nil, fmt.Errorf("failed to derive vault tokens: invalid response") } unwrappedTokens := make(map[string]string) - // Retrieve the wrapped tokens from the response and unwrap it using - // the VaultClient's token, which is cached at the API client. + // Retrieve the wrapped tokens from the response and unwrap it for _, taskName := range verifiedTasks { // Get the wrapped token wrappedToken, ok := resp.Tasks[taskName] if !ok { - c.logger.Printf("[ERR] client/vaultclient: wrapped token missing for task %q", taskName) - result = multierror.Append(result, fmt.Errorf("wrapped token missing for task %q", taskName)) - return nil, result + c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName) + return nil, fmt.Errorf("wrapped token missing for task %q", taskName) } // Unwrap the vault token unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) if err != nil { - result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err)) - return nil, result + return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err) } if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { - result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q", taskName)) - return nil, result + return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName) } // Append the unwrapped token to the return value @@ -493,18 +399,19 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) // Use the token supplied to interact with vault c.client.SetToken(token) - // Restore the token in client to VaultClient's token - defer c.client.SetToken(c.config.Token) + // Reset the token before returning + defer c.client.SetToken("") // Read the consul ACL token and return the secret directly return c.client.Logical().Read(path) } -// RenewToken renews the supplied token and adds it to the min-heap so that it -// is renewed periodically by the renewal loop. Any error returned during -// renewal will be written to a buffered channel and the channel is returned -// instead of an actual error. This helps the caller be notified of a renewal -// failure asynchronously for appropriate actions to be taken. +// RenewToken renews the supplied token for a given duration (in seconds) and +// adds it to the min-heap so that it is renewed periodically by the renewal +// loop. Any error returned during renewal will be written to a buffered +// channel and the channel is returned instead of an actual error. This helps +// the caller be notified of a renewal failure asynchronously for appropriate +// actions to be taken. func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // Create a buffered error channel errCh := make(chan error, 1) @@ -531,19 +438,20 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // error channel. if err := c.renew(renewalReq); err != nil { errCh <- err + close(errCh) } return errCh } -// RenewLease renews the supplied lease identifier for a supplied duration and -// adds it to the min-heap so that it gets renewed periodically by the renewal -// loop. Any error returned during renewal will be written to a buffered -// channel and the channel is returned instead of an actual error. This helps -// the caller be notified of a renewal failure asynchronously for appropriate -// actions to be taken. +// RenewLease renews the supplied lease identifier for a supplied duration (in +// seconds) and adds it to the min-heap so that it gets renewed periodically by +// the renewal loop. Any error returned during renewal will be written to a +// buffered channel and the channel is returned instead of an actual error. +// This helps the caller be notified of a renewal failure asynchronously for +// appropriate actions to be taken. func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { - c.logger.Printf("[DEBUG] client/vaultclient: renewing lease %q", leaseId) + c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId) // Create a buffered error channel errCh := make(chan error, 1) @@ -567,6 +475,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { errCh <- err + close(errCh) } return errCh @@ -598,9 +507,8 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { var renewalErr error leaseDuration := req.increment if req.isToken { - // Reset the token in the API client to that of VaultClient - // before returning - defer c.client.SetToken(c.config.Token) + // Reset the token in the API client before returning + defer c.client.SetToken("") // Set the token in the API client to the one that needs // renewal @@ -643,13 +551,13 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { rand.Seed(time.Now().Unix()) duration = min + rand.Intn(max-min) } - c.logger.Printf("[DEBUG] client/vaultclient: req.increment: %d, leaseDuration: %d, duration: %d", + c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", req.increment, leaseDuration, duration) // Determine the next renewal time next := time.Now().Add(time.Duration(duration) * time.Second) - if c.IsTracked(req.id) { + if c.isTracked(req.id) { // If the identifier is already tracked, this indicates a // subsequest renewal. In this case, update the existing // element in the heap with the new renewal time. @@ -727,7 +635,7 @@ func (c *vaultClient) run() { case <-c.updateCh: continue case <-c.stopCh: - c.logger.Printf("[DEBUG] client/vaultclient: stopped") + c.logger.Printf("[DEBUG] client.vault: stopped") return } } @@ -751,7 +659,7 @@ func (c *vaultClient) stopRenew(id string) error { c.lock.Lock() defer c.lock.Unlock() - if !c.IsTracked(id) { + if !c.isTracked(id) { return nil } diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 065bce76e..38fdc613c 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -161,7 +161,7 @@ func TestVaultClient_Heap(t *testing.T) { if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { t.Fatal(err) } - if !c.IsTracked("id1") { + if !c.isTracked("id1") { t.Fatalf("id1 should have been tracked") } @@ -173,7 +173,7 @@ func TestVaultClient_Heap(t *testing.T) { if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { t.Fatal(err) } - if !c.IsTracked("id2") { + if !c.isTracked("id2") { t.Fatalf("id2 should have been tracked") } @@ -185,7 +185,7 @@ func TestVaultClient_Heap(t *testing.T) { if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { t.Fatal(err) } - if !c.IsTracked("id3") { + if !c.isTracked("id3") { t.Fatalf("id3 should have been tracked") } @@ -226,15 +226,15 @@ func TestVaultClient_Heap(t *testing.T) { t.Fatal(err) } - if c.IsTracked("id1") { + if c.isTracked("id1") { t.Fatalf("id1 should not have been tracked") } - if c.IsTracked("id1") { + if c.isTracked("id1") { t.Fatalf("id1 should not have been tracked") } - if c.IsTracked("id1") { + if c.isTracked("id1") { t.Fatalf("id1 should not have been tracked") } From 603d7b09d8ed004850b04908747f335816daa29f Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Mon, 29 Aug 2016 16:34:39 -0400 Subject: [PATCH 4/9] Use Job.LookupTaskGroup --- client/vaultclient/vaultclient.go | 33 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index adc45fded..d1e2f5f38 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -309,27 +309,26 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) return nil, fmt.Errorf("missing task names") } + group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + if group == nil { + return nil, fmt.Errorf("group name in allocation is not present in job") + } + verifiedTasks := []string{} found := false - // Check if the given task names actually exist in the allocation under - // the correct group name - for _, group := range alloc.Job.TaskGroups { - // Refer only to the group belonging to the allocation - if group.Name == alloc.TaskGroup { - for _, taskName := range taskNames { - found = false - for _, task := range group.Tasks { - if task.Name == taskName { - found = true - } - } - if !found { - c.logger.Printf("[ERR] task %q not found in the allocation", taskName) - return nil, fmt.Errorf("task %q not found in the allocaition", taskName) - } - verifiedTasks = append(verifiedTasks, taskName) + // Check if the given task names actually exist in the allocation + for _, taskName := range taskNames { + found = false + for _, task := range group.Tasks { + if task.Name == taskName { + found = true } } + if !found { + c.logger.Printf("[ERR] task %q not found in the allocation", taskName) + return nil, fmt.Errorf("task %q not found in the allocaition", taskName) + } + verifiedTasks = append(verifiedTasks, taskName) } // DeriveVaultToken of nomad server can take in a set of tasks and From 72d2e9d2dd002a95707b82636fec061b1d4e6093 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Mon, 29 Aug 2016 17:07:23 -0400 Subject: [PATCH 5/9] tokenDeriver function pointer to derive tokens. Remove rpc*, connPool, node and region from vaultclient. --- client/client.go | 87 ++++++++++++++++++- client/vaultclient/vaultclient.go | 138 ++---------------------------- 2 files changed, 91 insertions(+), 134 deletions(-) diff --git a/client/client.go b/client/client.go index 785b608f3..da4dd6469 100644 --- a/client/client.go +++ b/client/client.go @@ -27,6 +27,7 @@ import ( "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" + vaultapi "github.com/hashicorp/vault/api" "github.com/mitchellh/hashstructure" ) @@ -1310,9 +1311,7 @@ func (c *Client) setupVaultClient() error { } var err error - if c.vaultClient, err = vaultclient.NewVaultClient(c.Node(), c.Region(), - c.config.VaultConfig, c.logger, c.config.RPCHandler, c.connPool, - c.rpcProxy); err != nil { + if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.tokenDeriver); err != nil { return err } @@ -1324,6 +1323,88 @@ func (c *Client) setupVaultClient() error { return nil } +func (c *Client) tokenDeriver(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { + if alloc == nil { + return nil, fmt.Errorf("nil allocation") + } + + if taskNames == nil || len(taskNames) == 0 { + return nil, fmt.Errorf("missing task names") + } + + group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + if group == nil { + return nil, fmt.Errorf("group name in allocation is not present in job") + } + + verifiedTasks := []string{} + found := false + // Check if the given task names actually exist in the allocation + for _, taskName := range taskNames { + found = false + for _, task := range group.Tasks { + if task.Name == taskName { + found = true + } + } + if !found { + c.logger.Printf("[ERR] task %q not found in the allocation", taskName) + return nil, fmt.Errorf("task %q not found in the allocaition", taskName) + } + verifiedTasks = append(verifiedTasks, taskName) + } + + // DeriveVaultToken of nomad server can take in a set of tasks and + // creates tokens for all the tasks. + req := &structs.DeriveVaultTokenRequest{ + NodeID: c.Node().ID, + SecretID: c.Node().SecretID, + AllocID: alloc.ID, + Tasks: verifiedTasks, + QueryOptions: structs.QueryOptions{ + Region: c.Region(), + AllowStale: true, + }, + } + + // Derive the tokens + var resp structs.DeriveVaultTokenResponse + if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil { + c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err) + return nil, fmt.Errorf("failed to derive vault tokens: %v", err) + } + if resp.Tasks == nil { + c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response") + return nil, fmt.Errorf("failed to derive vault tokens: invalid response") + } + + unwrappedTokens := make(map[string]string) + + // Retrieve the wrapped tokens from the response and unwrap it + for _, taskName := range verifiedTasks { + // Get the wrapped token + wrappedToken, ok := resp.Tasks[taskName] + if !ok { + c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName) + return nil, fmt.Errorf("wrapped token missing for task %q", taskName) + } + + // Unwrap the vault token + unwrapResp, err := vclient.Logical().Unwrap(wrappedToken) + if err != nil { + return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err) + } + if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { + return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName) + } + + // Append the unwrapped token to the return value + unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken + } + + return unwrappedTokens, nil +} + // setupConsulSyncer creates Client-mode consul.Syncer which periodically // executes callbacks on a fixed interval. // diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index d1e2f5f38..f271e26fb 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -16,6 +16,8 @@ import ( vaultapi "github.com/hashicorp/vault/api" ) +type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map[string]string, error) + // The interface which nomad client uses to interact with vault and // periodically renews the tokens and secrets. type VaultClient interface { @@ -51,11 +53,7 @@ type VaultClient interface { // Implementation of VaultClient interface to interact with vault and perform // token and lease renewals periodically. type vaultClient struct { - // Client's region - region string - - // The node in which this vault client is running in - node *structs.Node + tokenDeriver TokenDeriverFunc // running indicates if the renewal loop is active or not running bool @@ -137,21 +135,11 @@ type vaultClientHeap struct { type vaultDataHeapImp []*vaultClientHeapEntry // NewVaultClient returns a new vault client from the given config. -func NewVaultClient(node *structs.Node, region string, config *config.VaultConfig, - logger *log.Logger, rpcHandler clientconfig.RPCHandler, connPool *nomad.ConnPool, - rpcProxy *rpcproxy.RPCProxy) (*vaultClient, error) { +func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver TokenDeriverFunc) (*vaultClient, error) { if !config.Enabled { return nil, nil } - if node == nil { - return nil, fmt.Errorf("nil node") - } - - if region == "" { - return nil, fmt.Errorf("missing region") - } - if config == nil { return nil, fmt.Errorf("nil vault config") } @@ -164,22 +152,9 @@ func NewVaultClient(node *structs.Node, region string, config *config.VaultConfi return nil, fmt.Errorf("nil logger") } - if connPool == nil { - return nil, fmt.Errorf("nil connection pool") - } - - if rpcProxy == nil { - return nil, fmt.Errorf("nil rpc proxy") - } - c := &vaultClient{ - rpcHandler: rpcHandler, - connPool: connPool, - rpcProxy: rpcProxy, - region: region, - node: node, - config: config, - stopCh: make(chan struct{}), + config: config, + stopCh: make(chan struct{}), // Update channel should be a buffered channel updateCh: make(chan struct{}, 1), heap: newVaultClientHeap(), @@ -301,85 +276,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) return nil, fmt.Errorf("vault client is not running") } - if alloc == nil { - return nil, fmt.Errorf("nil allocation") - } - - if taskNames == nil || len(taskNames) == 0 { - return nil, fmt.Errorf("missing task names") - } - - group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) - if group == nil { - return nil, fmt.Errorf("group name in allocation is not present in job") - } - - verifiedTasks := []string{} - found := false - // Check if the given task names actually exist in the allocation - for _, taskName := range taskNames { - found = false - for _, task := range group.Tasks { - if task.Name == taskName { - found = true - } - } - if !found { - c.logger.Printf("[ERR] task %q not found in the allocation", taskName) - return nil, fmt.Errorf("task %q not found in the allocaition", taskName) - } - verifiedTasks = append(verifiedTasks, taskName) - } - - // DeriveVaultToken of nomad server can take in a set of tasks and - // creates tokens for all the tasks. - req := &structs.DeriveVaultTokenRequest{ - NodeID: c.node.ID, - SecretID: c.node.SecretID, - AllocID: alloc.ID, - Tasks: verifiedTasks, - QueryOptions: structs.QueryOptions{ - Region: c.region, - AllowStale: true, - }, - } - - // Derive the tokens - var resp structs.DeriveVaultTokenResponse - if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil { - c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err) - return nil, fmt.Errorf("failed to derive vault tokens: %v", err) - } - if resp.Tasks == nil { - c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response") - return nil, fmt.Errorf("failed to derive vault tokens: invalid response") - } - - unwrappedTokens := make(map[string]string) - - // Retrieve the wrapped tokens from the response and unwrap it - for _, taskName := range verifiedTasks { - // Get the wrapped token - wrappedToken, ok := resp.Tasks[taskName] - if !ok { - c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName) - return nil, fmt.Errorf("wrapped token missing for task %q", taskName) - } - - // Unwrap the vault token - unwrapResp, err := c.client.Logical().Unwrap(wrappedToken) - if err != nil { - return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err) - } - if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" { - return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName) - } - - // Append the unwrapped token to the return value - unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken - } - - return unwrappedTokens, nil + return c.tokenDeriver(alloc, taskNames, c.client) } // GetConsulACL creates a vault API client and reads from vault a consul ACL @@ -702,27 +599,6 @@ func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) { return nextEntry.req, nextEntry.next } -// RPC is used to forward an RPC call to a nomad server, or fail if no servers -func (c *vaultClient) RPC(method string, args interface{}, reply interface{}) error { - // Invoke the RPCHandler if it exists - if c.rpcHandler != nil { - return c.rpcHandler.RPC(method, args, reply) - } - - // Pick a server to request from - server := c.rpcProxy.FindServer() - if server == nil { - return fmt.Errorf("no known servers") - } - - // Make the RPC request - if err := c.connPool.RPC(c.region, server.Addr, structs.ApiMajorVersion, method, args, reply); err != nil { - c.rpcProxy.NotifyFailedServer(server) - return fmt.Errorf("RPC failed to server %s: %v", server.Addr, err) - } - return nil -} - // Additional helper functions on top of interface methods // Length returns the number of elements in the heap From dd26f9b4bf0bc369ce0cac3718cb20f1ecff78da Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Mon, 29 Aug 2016 21:30:06 -0400 Subject: [PATCH 6/9] Fix tests --- client/client.go | 8 ++++++-- client/vaultclient/vaultclient_test.go | 24 +++--------------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/client/client.go b/client/client.go index da4dd6469..37224cbf4 100644 --- a/client/client.go +++ b/client/client.go @@ -1311,7 +1311,8 @@ func (c *Client) setupVaultClient() error { } var err error - if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.tokenDeriver); err != nil { + if c.vaultClient, err = + vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { return err } @@ -1323,7 +1324,10 @@ func (c *Client) setupVaultClient() error { return nil } -func (c *Client) tokenDeriver(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { +// deriveToken takes in an allocation and a set of tasks and derives vault +// tokens for each of the tasks, unwraps all of them using the supplied vault +// client and returns a map of unwrapped tokens, indexed by the task name. +func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { if alloc == nil { return nil, fmt.Errorf("nil allocation") } diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 38fdc613c..3ff1b128b 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -7,9 +7,6 @@ import ( "time" "github.com/hashicorp/nomad/client/config" - "github.com/hashicorp/nomad/client/rpcproxy" - "github.com/hashicorp/nomad/nomad" - "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" vaultapi "github.com/hashicorp/vault/api" ) @@ -20,12 +17,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) { logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -58,12 +50,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -138,12 +125,7 @@ func TestVaultClient_Heap(t *testing.T) { conf.VaultConfig.TaskTokenTTL = "10s" logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", conf.VaultConfig, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(conf.VaultConfig, logger, nil) if err != nil { t.Fatal(err) } From 082d5e58a4ce959d9c4b67bbae56887c23373a5e Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Tue, 30 Aug 2016 12:46:59 -0400 Subject: [PATCH 7/9] Return only fatal error to renewal error channel --- client/client.go | 2 +- client/vaultclient/vaultclient.go | 72 +++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/client/client.go b/client/client.go index 37224cbf4..fbd8a0ba5 100644 --- a/client/client.go +++ b/client/client.go @@ -250,7 +250,7 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg // Start renewing tokens and secrets if c.vaultClient != nil { - go c.vaultClient.Start() + c.vaultClient.Start() } return c, nil diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index f271e26fb..2cf398a7d 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -5,17 +5,18 @@ import ( "fmt" "log" "math/rand" + "strings" "sync" "time" - clientconfig "github.com/hashicorp/nomad/client/config" - "github.com/hashicorp/nomad/client/rpcproxy" - "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" vaultapi "github.com/hashicorp/vault/api" ) +// TokenDeriverFunc takes in an allocation and a set of tasks and derives a +// wrapped token for all the tasks, from the nomad server. All the derived +// wrapped tokens will be unwrapped using the vault API client. type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map[string]string, error) // The interface which nomad client uses to interact with vault and @@ -53,6 +54,10 @@ type VaultClient interface { // Implementation of VaultClient interface to interact with vault and perform // token and lease renewals periodically. type vaultClient struct { + // tokenDeriver is a function pointer passed in by the client to derive + // tokens by making RPC calls to the nomad server. The wrapped tokens + // returned by the nomad server will be unwrapped by this function + // using the vault API client. tokenDeriver TokenDeriverFunc // running indicates if the renewal loop is active or not @@ -82,10 +87,6 @@ type vaultClient struct { lock sync.RWMutex logger *log.Logger - - rpcHandler clientconfig.RPCHandler - rpcProxy *rpcproxy.RPCProxy - connPool *nomad.ConnPool } // tokenData holds the relevant information about the Vault token passed to the @@ -333,8 +334,7 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // Perform the renewal of the token and send any error to the dedicated // error channel. if err := c.renew(renewalReq); err != nil { - errCh <- err - close(errCh) + c.logger.Printf("[ERR] Renewal of token failed: %v", err) } return errCh @@ -370,8 +370,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { - errCh <- err - close(errCh) + c.logger.Printf("[ERR] Renewal of lease failed: %v", err) } return errCh @@ -453,17 +452,55 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { // Determine the next renewal time next := time.Now().Add(time.Duration(duration) * time.Second) + fatal := false + if renewalErr != nil && + (strings.Contains(renewalErr.Error(), "lease not found or lease is not renewable") || + strings.Contains(renewalErr.Error(), "token not found")) { + fatal = true + } else if renewalErr != nil { + c.logger.Printf("[ERR] renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String()) + } + if c.isTracked(req.id) { + if fatal { + // If encountered with an error where in a lease or a + // token is not valid at all with vault, and if that + // item is tracked by the renewal loop, stop renewing + // it by removing the corresponding heap entry. + if err := c.heap.Remove(req.id); err != nil { + return fmt.Errorf("failed to remove heap entry. err: %v", err) + } + delete(c.heap.heapMap, req.id) + + // Report the fatal error to the client + req.errCh <- renewalErr + close(req.errCh) + + return renewalErr + } + // If the identifier is already tracked, this indicates a // subsequest renewal. In this case, update the existing // element in the heap with the new renewal time. - - // There is no need to signal an update to the renewal loop - // here because this case is hit from the renewal loop itself. if err := c.heap.Update(req, next); err != nil { return fmt.Errorf("failed to update heap entry. err: %v", err) } + + // There is no need to signal an update to the renewal loop + // here because this case is hit from the renewal loop itself. } else { + if fatal { + // If encountered with an error where in a lease or a + // token is not valid at all with vault, and if that + // item is not tracked by renewal loop, don't add it. + + // Report the fatal error to the client + req.errCh <- renewalErr + close(req.errCh) + + return renewalErr + } + // If the identifier is not already tracked, this is a first // renewal request. In this case, add an entry into the heap // with the next renewal time. @@ -481,10 +518,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { } } - // Returning the renewal error here ensures that an entry is either - // added or updated in the min-heap. This is done to not starve other - // entries in heap. - return renewalErr + return nil } // run is the renewal loop which performs the periodic renewals of both the @@ -526,7 +560,7 @@ func (c *vaultClient) run() { select { case <-renewalCh: if err := c.renew(renewalReq); err != nil { - renewalReq.errCh <- err + c.logger.Printf("[ERR] Renewal of token failed: %v", err) } case <-c.updateCh: continue From 68b1b30bf5b6d6c76e5a97fb20063b07452e89fc Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Tue, 30 Aug 2016 13:08:13 -0400 Subject: [PATCH 8/9] Addressed review feedback --- client/client.go | 4 -- client/vaultclient/vaultclient.go | 72 ++++++++++++++++++------------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/client/client.go b/client/client.go index fbd8a0ba5..45c51fdbe 100644 --- a/client/client.go +++ b/client/client.go @@ -1306,10 +1306,6 @@ func (c *Client) setupVaultClient() error { return nil } - if c.config.VaultConfig.Token == "" { - return fmt.Errorf("vault token not set") - } - var err error if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 2cf398a7d..c552dc8d5 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -22,32 +22,34 @@ type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map // The interface which nomad client uses to interact with vault and // periodically renews the tokens and secrets. type VaultClient interface { - // Starts the renewal loop of tokens and secrets + // Start initiates the renewal loop of tokens and secrets Start() - // Stops the renewal loop for tokens and secrets + // Stop terminates the renewal loop for tokens and secrets Stop() - // Contacts the nomad server and fetches wrapped tokens for a set of - // tasks. The wrapped tokens will be unwrapped using vault and + // DeriveToken contacts the nomad server and fetches wrapped tokens for + // a set of tasks. The wrapped tokens will be unwrapped using vault and // returned. DeriveToken(*structs.Allocation, []string) (map[string]string, error) - // Fetch the Consul ACL token required for the task + // GetConsulACL fetches the Consul ACL token required for the task GetConsulACL(string, string) (*vaultapi.Secret, error) - // Renews a token with the given increment and adds it to the min-heap - // for periodic renewal. + // RenewToken renews a token with the given increment and adds it to + // the min-heap for periodic renewal. RenewToken(string, int) <-chan error - // Removes the token from the min-heap, stopping its renewal. + // StopRenewToken removes the token from the min-heap, stopping its + // renewal. StopRenewToken(string) error - // Renews a vault secret's lease and add the lease identifier to the - // min-heap for periodic renewal. + // RenewLease renews a vault secret's lease and adds the lease + // identifier to the min-heap for periodic renewal. RenewLease(string, int) <-chan error - // Removes a secret's lease ID from the min-heap, stopping its renewal. + // StopRenewLease removes a secret's lease ID from the min-heap, + // stopping its renewal. StopRenewLease(string) error } @@ -70,19 +72,20 @@ type vaultClient struct { // tokenData is the data of the passed VaultClient token token *tokenData - // API client to interact with vault + // client is the API client to interact with vault client *vaultapi.Client - // Channel to notify heap modifications to the renewal loop + // updateCh is the channel to notify heap modifications to the renewal + // loop updateCh chan struct{} - // Channel to trigger termination of renewal loop + // stopCh is the channel to trigger termination of renewal loop stopCh chan struct{} - // Min-Heap to keep track of both tokens and leases + // heap is the min-heap to keep track of both tokens and leases heap *vaultClientHeap - // Configuration to connect to vault + // config is the configuration to connect to vault config *config.VaultConfig lock sync.RWMutex @@ -100,19 +103,20 @@ type tokenData struct { Root bool } -// Request object for renewals. This can be used for both token renewals and -// secret's lease renewals. +// vaultClientRenewalRequest is a request object for renewal of both tokens and +// secret's leases. type vaultClientRenewalRequest struct { - // Channel into which any renewal error will be sent down to + // errCh is the channel into which any renewal error will be sent to errCh chan error - // This can either be a token or a lease identifier + // id is an identifier which represents either a token or a lease id string - // Duration for which the token or lease should be renewed for + // increment is the duration for which the token or lease should be + // renewed for increment int - // Indicates whether the 'id' field is a token or not + // isToken indicates whether the 'id' field is a token or not isToken bool } @@ -137,14 +141,14 @@ type vaultDataHeapImp []*vaultClientHeapEntry // NewVaultClient returns a new vault client from the given config. func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver TokenDeriverFunc) (*vaultClient, error) { - if !config.Enabled { - return nil, nil - } - if config == nil { return nil, fmt.Errorf("nil vault config") } + if !config.Enabled { + return nil, nil + } + if config.TaskTokenTTL == "" { return nil, fmt.Errorf("task_token_ttl not set") } @@ -290,6 +294,10 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) return nil, fmt.Errorf("missing consul ACL token vault path") } + if !c.ConnectionEstablished() { + return nil, fmt.Errorf("connection with vault is not yet established") + } + c.lock.Lock() defer c.lock.Unlock() @@ -308,17 +316,20 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) // loop. Any error returned during renewal will be written to a buffered // channel and the channel is returned instead of an actual error. This helps // the caller be notified of a renewal failure asynchronously for appropriate -// actions to be taken. +// actions to be taken. The caller of this function need not have to close the +// error channel. func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // Create a buffered error channel errCh := make(chan error, 1) if token == "" { errCh <- fmt.Errorf("missing token") + close(errCh) return errCh } if increment < 1 { errCh <- fmt.Errorf("increment cannot be less than 1") + close(errCh) return errCh } @@ -345,7 +356,8 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // the renewal loop. Any error returned during renewal will be written to a // buffered channel and the channel is returned instead of an actual error. // This helps the caller be notified of a renewal failure asynchronously for -// appropriate actions to be taken. +// appropriate actions to be taken. The caller of this function need not have +// to close the error channel. func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId) // Create a buffered error channel @@ -353,17 +365,19 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { if leaseId == "" { errCh <- fmt.Errorf("missing lease ID") + close(errCh) return errCh } if increment < 1 { errCh <- fmt.Errorf("increment cannot be less than 1") + close(errCh) return errCh } // Create a renewal request using the supplied lease and duration renewalReq := &vaultClientRenewalRequest{ - errCh: make(chan error, 1), + errCh: errCh, id: leaseId, increment: increment, } From 13d97f01bbe85b97c1b765fda35021a6e93edbc0 Mon Sep 17 00:00:00 2001 From: vishalnayak Date: Tue, 30 Aug 2016 13:14:34 -0400 Subject: [PATCH 9/9] Print debug message only when error is non-nil --- client/vaultclient/vaultclient.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index c552dc8d5..a4799dce4 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -345,7 +345,7 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error { // Perform the renewal of the token and send any error to the dedicated // error channel. if err := c.renew(renewalReq); err != nil { - c.logger.Printf("[ERR] Renewal of token failed: %v", err) + c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) } return errCh @@ -384,7 +384,7 @@ func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error { // Renew the secret and send any error to the dedicated error channel if err := c.renew(renewalReq); err != nil { - c.logger.Printf("[ERR] Renewal of lease failed: %v", err) + c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err) } return errCh @@ -460,8 +460,6 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { rand.Seed(time.Now().Unix()) duration = min + rand.Intn(max-min) } - c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", - req.increment, leaseDuration, duration) // Determine the next renewal time next := time.Now().Add(time.Duration(duration) * time.Second) @@ -472,7 +470,8 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { strings.Contains(renewalErr.Error(), "token not found")) { fatal = true } else if renewalErr != nil { - c.logger.Printf("[ERR] renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String()) + c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", req.increment, leaseDuration, duration) + c.logger.Printf("[ERR] client.vault: renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String()) } if c.isTracked(req.id) { @@ -574,7 +573,7 @@ func (c *vaultClient) run() { select { case <-renewalCh: if err := c.renew(renewalReq); err != nil { - c.logger.Printf("[ERR] Renewal of token failed: %v", err) + c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err) } case <-c.updateCh: continue