Merge pull request #1606 from hashicorp/f-vault-client

VaultClient for Nomad client's interactions with Vault
This commit is contained in:
Vishal Nayak
2016-08-30 13:13:54 -04:00
committed by GitHub
6 changed files with 1118 additions and 1 deletions

View File

@@ -23,9 +23,11 @@ import (
"github.com/hashicorp/nomad/client/fingerprint"
"github.com/hashicorp/nomad/client/rpcproxy"
"github.com/hashicorp/nomad/client/stats"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/command/agent/consul"
"github.com/hashicorp/nomad/nomad"
"github.com/hashicorp/nomad/nomad/structs"
vaultapi "github.com/hashicorp/vault/api"
"github.com/mitchellh/hashstructure"
)
@@ -147,6 +149,9 @@ type Client struct {
shutdown bool
shutdownCh chan struct{}
shutdownLock sync.Mutex
// client to interact with vault for token and secret renewals
vaultClient vaultclient.VaultClient
}
// NewClient is used to create a new client from the given configuration
@@ -213,6 +218,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg
return nil, fmt.Errorf("failed to create client Consul syncer: %v", err)
}
// Setup the vault client for token and secret renewals
if err := c.setupVaultClient(); err != nil {
return nil, fmt.Errorf("failed to setup vault client: %v", err)
}
// Register and then start heartbeating to the servers.
go c.registerAndHeartbeat()
@@ -238,6 +248,11 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg
// populated by periodically polling Consul, if available.
go c.rpcProxy.Run()
// Start renewing tokens and secrets
if c.vaultClient != nil {
c.vaultClient.Start()
}
return c, nil
}
@@ -319,6 +334,11 @@ func (c *Client) Shutdown() error {
return nil
}
// Stop renewing tokens and secrets
if c.vaultClient != nil {
c.vaultClient.Stop()
}
// Destroy all the running allocations.
if c.config.DevMode {
c.allocLock.Lock()
@@ -1275,6 +1295,116 @@ func (c *Client) addAlloc(alloc *structs.Allocation) error {
return nil
}
// setupVaultClient creates an object to periodically renew tokens and secrets
// with vault.
func (c *Client) setupVaultClient() error {
if c.config.VaultConfig == nil {
return fmt.Errorf("nil vault config")
}
if !c.config.VaultConfig.Enabled {
return nil
}
var err error
if c.vaultClient, err =
vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil {
return err
}
if c.vaultClient == nil {
c.logger.Printf("[ERR] client: failed to create vault client")
return fmt.Errorf("failed to create vault client")
}
return nil
}
// deriveToken takes in an allocation and a set of tasks and derives vault
// tokens for each of the tasks, unwraps all of them using the supplied vault
// client and returns a map of unwrapped tokens, indexed by the task name.
func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) {
if alloc == nil {
return nil, fmt.Errorf("nil allocation")
}
if taskNames == nil || len(taskNames) == 0 {
return nil, fmt.Errorf("missing task names")
}
group := alloc.Job.LookupTaskGroup(alloc.TaskGroup)
if group == nil {
return nil, fmt.Errorf("group name in allocation is not present in job")
}
verifiedTasks := []string{}
found := false
// Check if the given task names actually exist in the allocation
for _, taskName := range taskNames {
found = false
for _, task := range group.Tasks {
if task.Name == taskName {
found = true
}
}
if !found {
c.logger.Printf("[ERR] task %q not found in the allocation", taskName)
return nil, fmt.Errorf("task %q not found in the allocaition", taskName)
}
verifiedTasks = append(verifiedTasks, taskName)
}
// DeriveVaultToken of nomad server can take in a set of tasks and
// creates tokens for all the tasks.
req := &structs.DeriveVaultTokenRequest{
NodeID: c.Node().ID,
SecretID: c.Node().SecretID,
AllocID: alloc.ID,
Tasks: verifiedTasks,
QueryOptions: structs.QueryOptions{
Region: c.Region(),
AllowStale: true,
},
}
// Derive the tokens
var resp structs.DeriveVaultTokenResponse
if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault tokens: %v", err)
return nil, fmt.Errorf("failed to derive vault tokens: %v", err)
}
if resp.Tasks == nil {
c.logger.Printf("[ERR] client.vault: failed to derive vault token: invalid response")
return nil, fmt.Errorf("failed to derive vault tokens: invalid response")
}
unwrappedTokens := make(map[string]string)
// Retrieve the wrapped tokens from the response and unwrap it
for _, taskName := range verifiedTasks {
// Get the wrapped token
wrappedToken, ok := resp.Tasks[taskName]
if !ok {
c.logger.Printf("[ERR] client.vault: wrapped token missing for task %q", taskName)
return nil, fmt.Errorf("wrapped token missing for task %q", taskName)
}
// Unwrap the vault token
unwrapResp, err := vclient.Logical().Unwrap(wrappedToken)
if err != nil {
return nil, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err)
}
if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" {
return nil, fmt.Errorf("failed to unwrap the token for task %q", taskName)
}
// Append the unwrapped token to the return value
unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken
}
return unwrappedTokens, nil
}
// setupConsulSyncer creates Client-mode consul.Syncer which periodically
// executes callbacks on a fixed interval.
//

View File

@@ -85,6 +85,7 @@ func testServer(t *testing.T, cb func(*nomad.Config)) (*nomad.Server, string) {
func testClient(t *testing.T, cb func(c *config.Config)) *Client {
conf := config.DefaultConfig()
conf.VaultConfig.Enabled = false
conf.DevMode = true
if cb != nil {
cb(conf)

View File

@@ -149,6 +149,7 @@ func (c *Config) Copy() *Config {
// DefaultConfig returns the default configuration
func DefaultConfig() *Config {
return &Config{
VaultConfig: config.DefaultVaultConfig(),
ConsulConfig: config.DefaultConsulConfig(),
LogOutput: os.Stderr,
Region: "global",

View File

@@ -12,7 +12,6 @@ import (
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/driver"

View File

@@ -0,0 +1,763 @@
package vaultclient
import (
"container/heap"
"fmt"
"log"
"math/rand"
"strings"
"sync"
"time"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
vaultapi "github.com/hashicorp/vault/api"
)
// TokenDeriverFunc takes in an allocation and a set of tasks and derives a
// wrapped token for all the tasks, from the nomad server. All the derived
// wrapped tokens will be unwrapped using the vault API client.
type TokenDeriverFunc func(*structs.Allocation, []string, *vaultapi.Client) (map[string]string, error)
// The interface which nomad client uses to interact with vault and
// periodically renews the tokens and secrets.
type VaultClient interface {
// Start initiates the renewal loop of tokens and secrets
Start()
// Stop terminates the renewal loop for tokens and secrets
Stop()
// DeriveToken contacts the nomad server and fetches wrapped tokens for
// a set of tasks. The wrapped tokens will be unwrapped using vault and
// returned.
DeriveToken(*structs.Allocation, []string) (map[string]string, error)
// GetConsulACL fetches the Consul ACL token required for the task
GetConsulACL(string, string) (*vaultapi.Secret, error)
// RenewToken renews a token with the given increment and adds it to
// the min-heap for periodic renewal.
RenewToken(string, int) <-chan error
// StopRenewToken removes the token from the min-heap, stopping its
// renewal.
StopRenewToken(string) error
// RenewLease renews a vault secret's lease and adds the lease
// identifier to the min-heap for periodic renewal.
RenewLease(string, int) <-chan error
// StopRenewLease removes a secret's lease ID from the min-heap,
// stopping its renewal.
StopRenewLease(string) error
}
// Implementation of VaultClient interface to interact with vault and perform
// token and lease renewals periodically.
type vaultClient struct {
// tokenDeriver is a function pointer passed in by the client to derive
// tokens by making RPC calls to the nomad server. The wrapped tokens
// returned by the nomad server will be unwrapped by this function
// using the vault API client.
tokenDeriver TokenDeriverFunc
// running indicates if the renewal loop is active or not
running bool
// connEstablished marks whether the connection to vault was successful
// or not
connEstablished bool
// tokenData is the data of the passed VaultClient token
token *tokenData
// client is the API client to interact with vault
client *vaultapi.Client
// updateCh is the channel to notify heap modifications to the renewal
// loop
updateCh chan struct{}
// stopCh is the channel to trigger termination of renewal loop
stopCh chan struct{}
// heap is the min-heap to keep track of both tokens and leases
heap *vaultClientHeap
// config is the configuration to connect to vault
config *config.VaultConfig
lock sync.RWMutex
logger *log.Logger
}
// tokenData holds the relevant information about the Vault token passed to the
// client.
type tokenData struct {
CreationTTL int `mapstructure:"creation_ttl"`
TTL int `mapstructure:"ttl"`
Renewable bool `mapstructure:"renewable"`
Policies []string `mapstructure:"policies"`
Role string `mapstructure:"role"`
Root bool
}
// vaultClientRenewalRequest is a request object for renewal of both tokens and
// secret's leases.
type vaultClientRenewalRequest struct {
// errCh is the channel into which any renewal error will be sent to
errCh chan error
// id is an identifier which represents either a token or a lease
id string
// increment is the duration for which the token or lease should be
// renewed for
increment int
// isToken indicates whether the 'id' field is a token or not
isToken bool
}
// Element representing an entry in the renewal heap
type vaultClientHeapEntry struct {
req *vaultClientRenewalRequest
next time.Time
index int
}
// Wrapper around the actual heap to provide additional symantics on top of
// functions provided by the heap interface. In order to achieve that, an
// additional map is placed beside the actual heap. This map can be used to
// check if an entry is already present in the heap.
type vaultClientHeap struct {
heapMap map[string]*vaultClientHeapEntry
heap vaultDataHeapImp
}
// Data type of the heap
type vaultDataHeapImp []*vaultClientHeapEntry
// NewVaultClient returns a new vault client from the given config.
func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver TokenDeriverFunc) (*vaultClient, error) {
if config == nil {
return nil, fmt.Errorf("nil vault config")
}
if !config.Enabled {
return nil, nil
}
if config.TaskTokenTTL == "" {
return nil, fmt.Errorf("task_token_ttl not set")
}
if logger == nil {
return nil, fmt.Errorf("nil logger")
}
c := &vaultClient{
config: config,
stopCh: make(chan struct{}),
// Update channel should be a buffered channel
updateCh: make(chan struct{}, 1),
heap: newVaultClientHeap(),
logger: logger,
}
// Get the Vault API configuration
apiConf, err := config.ApiConfig()
if err != nil {
logger.Printf("[ERR] client.vault: failed to create vault API config: %v", err)
return nil, err
}
// Create the Vault API client
client, err := vaultapi.NewClient(apiConf)
if err != nil {
logger.Printf("[ERR] client.vault: failed to create Vault client. Not retrying: %v", err)
return nil, err
}
c.client = client
return c, nil
}
// newVaultClientHeap returns a new vault client heap with both the heap and a
// map which is a secondary index for heap elements, both initialized.
func newVaultClientHeap() *vaultClientHeap {
return &vaultClientHeap{
heapMap: make(map[string]*vaultClientHeapEntry),
heap: make(vaultDataHeapImp, 0),
}
}
// isTracked returns if a given identifier is already present in the heap and
// hence is being renewed. Lock should be held before calling this method.
func (c *vaultClient) isTracked(id string) bool {
if id == "" {
return false
}
_, ok := c.heap.heapMap[id]
return ok
}
// Starts the renewal loop of vault client
func (c *vaultClient) Start() {
if !c.config.Enabled || c.running {
return
}
c.logger.Printf("[DEBUG] client.vault: establishing connection to vault")
go c.establishConnection()
}
// ConnectionEstablished indicates whether VaultClient successfully established
// connection to vault or not
func (c *vaultClient) ConnectionEstablished() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.connEstablished
}
// establishConnection is used to make first contact with Vault. This should be
// called in a go-routine since the connection is retried till the Vault Client
// is stopped or the connection is successfully made at which point the renew
// loop is started.
func (c *vaultClient) establishConnection() {
// Create the retry timer and set initial duration to zero so it fires
// immediately
retryTimer := time.NewTimer(0)
OUTER:
for {
select {
case <-c.stopCh:
return
case <-retryTimer.C:
// Ensure the API is reachable
if _, err := c.client.Sys().InitStatus(); err != nil {
c.logger.Printf("[WARN] client.vault: failed to contact Vault API. Retrying in %v",
c.config.ConnectionRetryIntv)
retryTimer.Reset(c.config.ConnectionRetryIntv)
continue OUTER
}
break OUTER
}
}
c.lock.Lock()
c.connEstablished = true
c.lock.Unlock()
// Begin the renewal loop
go c.run()
c.logger.Printf("[DEBUG] client.vault: started")
}
// Stops the renewal loop of vault client
func (c *vaultClient) Stop() {
if !c.config.Enabled || !c.running {
return
}
c.lock.Lock()
defer c.lock.Unlock()
c.running = false
close(c.stopCh)
}
// DeriveToken takes in an allocation and a set of tasks and for each of the
// task, it derives a vault token from nomad server and unwraps it using vault.
// The return value is a map containing all the unwrapped tokens indexed by the
// task name.
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
if !c.running {
return nil, fmt.Errorf("vault client is not running")
}
return c.tokenDeriver(alloc, taskNames, c.client)
}
// GetConsulACL creates a vault API client and reads from vault a consul ACL
// token used by the task.
func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) {
if token == "" {
return nil, fmt.Errorf("missing token")
}
if path == "" {
return nil, fmt.Errorf("missing consul ACL token vault path")
}
if !c.ConnectionEstablished() {
return nil, fmt.Errorf("connection with vault is not yet established")
}
c.lock.Lock()
defer c.lock.Unlock()
// Use the token supplied to interact with vault
c.client.SetToken(token)
// Reset the token before returning
defer c.client.SetToken("")
// Read the consul ACL token and return the secret directly
return c.client.Logical().Read(path)
}
// RenewToken renews the supplied token for a given duration (in seconds) and
// adds it to the min-heap so that it is renewed periodically by the renewal
// loop. Any error returned during renewal will be written to a buffered
// channel and the channel is returned instead of an actual error. This helps
// the caller be notified of a renewal failure asynchronously for appropriate
// actions to be taken. The caller of this function need not have to close the
// error channel.
func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
// Create a buffered error channel
errCh := make(chan error, 1)
if token == "" {
errCh <- fmt.Errorf("missing token")
close(errCh)
return errCh
}
if increment < 1 {
errCh <- fmt.Errorf("increment cannot be less than 1")
close(errCh)
return errCh
}
// Create a renewal request and indicate that the identifier in the
// request is a token and not a lease
renewalReq := &vaultClientRenewalRequest{
errCh: errCh,
id: token,
isToken: true,
increment: increment,
}
// Perform the renewal of the token and send any error to the dedicated
// error channel.
if err := c.renew(renewalReq); err != nil {
c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err)
}
return errCh
}
// RenewLease renews the supplied lease identifier for a supplied duration (in
// seconds) and adds it to the min-heap so that it gets renewed periodically by
// the renewal loop. Any error returned during renewal will be written to a
// buffered channel and the channel is returned instead of an actual error.
// This helps the caller be notified of a renewal failure asynchronously for
// appropriate actions to be taken. The caller of this function need not have
// to close the error channel.
func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
c.logger.Printf("[DEBUG] client.vault: renewing lease %q", leaseId)
// Create a buffered error channel
errCh := make(chan error, 1)
if leaseId == "" {
errCh <- fmt.Errorf("missing lease ID")
close(errCh)
return errCh
}
if increment < 1 {
errCh <- fmt.Errorf("increment cannot be less than 1")
close(errCh)
return errCh
}
// Create a renewal request using the supplied lease and duration
renewalReq := &vaultClientRenewalRequest{
errCh: errCh,
id: leaseId,
increment: increment,
}
// Renew the secret and send any error to the dedicated error channel
if err := c.renew(renewalReq); err != nil {
c.logger.Printf("[ERR] client.vault: renewal of lease failed: %v", err)
}
return errCh
}
// renew is a common method to handle renewal of both tokens and secret leases.
// It invokes a token renewal or a secret's lease renewal. If renewal is
// successful, min-heap is updated based on the duration after which it needs
// renewal again. The next renewal time is randomly selected to avoid spikes in
// the number of APIs periodically.
func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.running {
return fmt.Errorf("vault client is not running")
}
if req == nil {
return fmt.Errorf("nil renewal request")
}
if req.id == "" {
return fmt.Errorf("missing id in renewal request")
}
if req.increment < 1 {
return fmt.Errorf("increment cannot be less than 1")
}
var renewalErr error
leaseDuration := req.increment
if req.isToken {
// Reset the token in the API client before returning
defer c.client.SetToken("")
// Set the token in the API client to the one that needs
// renewal
c.client.SetToken(req.id)
// Renew the token
renewResp, err := c.client.Auth().Token().RenewSelf(req.increment)
if err != nil {
renewalErr = fmt.Errorf("failed to renew the vault token: %v", err)
}
if renewResp == nil || renewResp.Auth == nil {
renewalErr = fmt.Errorf("failed to renew the vault token")
} else {
// Don't set this if renewal fails
leaseDuration = renewResp.Auth.LeaseDuration
}
} else {
// Renew the secret
renewResp, err := c.client.Sys().Renew(req.id, req.increment)
if err != nil {
renewalErr = fmt.Errorf("failed to renew vault secret: %v", err)
}
if renewResp == nil {
renewalErr = fmt.Errorf("failed to renew vault secret")
} else {
// Don't set this if renewal fails
leaseDuration = renewResp.LeaseDuration
}
}
duration := leaseDuration / 2
switch {
case leaseDuration < 30:
// Don't bother about introducing randomness if the
// leaseDuration is too small.
default:
// Give a breathing space of 20 seconds
min := 10
max := leaseDuration - min
rand.Seed(time.Now().Unix())
duration = min + rand.Intn(max-min)
}
// Determine the next renewal time
next := time.Now().Add(time.Duration(duration) * time.Second)
fatal := false
if renewalErr != nil &&
(strings.Contains(renewalErr.Error(), "lease not found or lease is not renewable") ||
strings.Contains(renewalErr.Error(), "token not found")) {
fatal = true
} else if renewalErr != nil {
c.logger.Printf("[DEBUG] client.vault: req.increment: %d, leaseDuration: %d, duration: %d", req.increment, leaseDuration, duration)
c.logger.Printf("[ERR] client.vault: renewal of lease or token failed due to a non-fatal error. Retrying at %v", next.String())
}
if c.isTracked(req.id) {
if fatal {
// If encountered with an error where in a lease or a
// token is not valid at all with vault, and if that
// item is tracked by the renewal loop, stop renewing
// it by removing the corresponding heap entry.
if err := c.heap.Remove(req.id); err != nil {
return fmt.Errorf("failed to remove heap entry. err: %v", err)
}
delete(c.heap.heapMap, req.id)
// Report the fatal error to the client
req.errCh <- renewalErr
close(req.errCh)
return renewalErr
}
// If the identifier is already tracked, this indicates a
// subsequest renewal. In this case, update the existing
// element in the heap with the new renewal time.
if err := c.heap.Update(req, next); err != nil {
return fmt.Errorf("failed to update heap entry. err: %v", err)
}
// There is no need to signal an update to the renewal loop
// here because this case is hit from the renewal loop itself.
} else {
if fatal {
// If encountered with an error where in a lease or a
// token is not valid at all with vault, and if that
// item is not tracked by renewal loop, don't add it.
// Report the fatal error to the client
req.errCh <- renewalErr
close(req.errCh)
return renewalErr
}
// If the identifier is not already tracked, this is a first
// renewal request. In this case, add an entry into the heap
// with the next renewal time.
if err := c.heap.Push(req, next); err != nil {
return fmt.Errorf("failed to push an entry to heap. err: %v", err)
}
// Signal an update for the renewal loop to trigger a fresh
// computation for the next best candidate for renewal.
if c.running {
select {
case c.updateCh <- struct{}{}:
default:
}
}
}
return nil
}
// run is the renewal loop which performs the periodic renewals of both the
// tokens and the secret leases.
func (c *vaultClient) run() {
if !c.config.Enabled {
return
}
c.lock.Lock()
c.running = true
c.lock.Unlock()
var renewalCh <-chan time.Time
for c.config.Enabled && c.running {
// Fetches the candidate for next renewal
renewalReq, renewalTime := c.nextRenewal()
if renewalTime.IsZero() {
// If the heap is empty, don't do anything
renewalCh = nil
} else {
now := time.Now()
if renewalTime.After(now) {
// Compute the duration after which the item
// needs renewal and set the renewalCh to fire
// at that time.
renewalDuration := renewalTime.Sub(time.Now())
renewalCh = time.After(renewalDuration)
} else {
// If the renewals of multiple items are too
// close to each other and by the time the
// entry is fetched from heap it might be past
// the current time (by a small margin). In
// which case, fire immediately.
renewalCh = time.After(0)
}
}
select {
case <-renewalCh:
if err := c.renew(renewalReq); err != nil {
c.logger.Printf("[ERR] client.vault: renewal of token failed: %v", err)
}
case <-c.updateCh:
continue
case <-c.stopCh:
c.logger.Printf("[DEBUG] client.vault: stopped")
return
}
}
}
// StopRenewToken removes the item from the heap which represents the given
// token.
func (c *vaultClient) StopRenewToken(token string) error {
return c.stopRenew(token)
}
// StopRenewLease removes the item from the heap which represents the given
// lease identifier.
func (c *vaultClient) StopRenewLease(leaseId string) error {
return c.stopRenew(leaseId)
}
// stopRenew removes the given identifier from the heap and signals the renewal
// loop to compute the next best candidate for renewal.
func (c *vaultClient) stopRenew(id string) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.isTracked(id) {
return nil
}
// Remove the identifier from the heap
if err := c.heap.Remove(id); err != nil {
return fmt.Errorf("failed to remove heap entry: %v", err)
}
// Delete the identifier from the map only after the it is removed from
// the heap. Heap's remove method relies on the heap map.
delete(c.heap.heapMap, id)
// Signal an update to the renewal loop.
if c.running {
select {
case c.updateCh <- struct{}{}:
default:
}
}
return nil
}
// nextRenewal returns the root element of the min-heap, which represents the
// next element to be renewed and the time at which the renewal needs to be
// triggered.
func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) {
c.lock.RLock()
defer c.lock.RUnlock()
if c.heap.Length() == 0 {
return nil, time.Time{}
}
// Fetches the root element in the min-heap
nextEntry := c.heap.Peek()
if nextEntry == nil {
return nil, time.Time{}
}
return nextEntry.req, nextEntry.next
}
// Additional helper functions on top of interface methods
// Length returns the number of elements in the heap
func (h *vaultClientHeap) Length() int {
return len(h.heap)
}
// Returns the root node of the min-heap
func (h *vaultClientHeap) Peek() *vaultClientHeapEntry {
if len(h.heap) == 0 {
return nil
}
return h.heap[0]
}
// Push adds the secondary index and inserts an item into the heap
func (h *vaultClientHeap) Push(req *vaultClientRenewalRequest, next time.Time) error {
if req == nil {
return fmt.Errorf("nil request")
}
if _, ok := h.heapMap[req.id]; ok {
return fmt.Errorf("entry %v already exists", req.id)
}
heapEntry := &vaultClientHeapEntry{
req: req,
next: next,
}
h.heapMap[req.id] = heapEntry
heap.Push(&h.heap, heapEntry)
return nil
}
// Update will modify the existing item in the heap with the new data and the
// time, and fixes the heap.
func (h *vaultClientHeap) Update(req *vaultClientRenewalRequest, next time.Time) error {
if entry, ok := h.heapMap[req.id]; ok {
entry.req = req
entry.next = next
heap.Fix(&h.heap, entry.index)
return nil
}
return fmt.Errorf("heap doesn't contain %v", req.id)
}
// Remove will remove an identifier from the secondary index and deletes the
// corresponding node from the heap.
func (h *vaultClientHeap) Remove(id string) error {
if entry, ok := h.heapMap[id]; ok {
heap.Remove(&h.heap, entry.index)
delete(h.heapMap, id)
return nil
}
return fmt.Errorf("heap doesn't contain entry for %v", id)
}
// The heap interface requires the following methods to be implemented.
// * Push(x interface{}) // add x as element Len()
// * Pop() interface{} // remove and return element Len() - 1.
// * sort.Interface
//
// sort.Interface comprises of the following methods:
// * Len() int
// * Less(i, j int) bool
// * Swap(i, j int)
// Part of sort.Interface
func (h vaultDataHeapImp) Len() int { return len(h) }
// Part of sort.Interface
func (h vaultDataHeapImp) Less(i, j int) bool {
// Two zero times should return false.
// Otherwise, zero is "greater" than any other time.
// (To sort it at the end of the list.)
// Sort such that zero times are at the end of the list.
iZero, jZero := h[i].next.IsZero(), h[j].next.IsZero()
if iZero && jZero {
return false
} else if iZero {
return false
} else if jZero {
return true
}
return h[i].next.Before(h[j].next)
}
// Part of sort.Interface
func (h vaultDataHeapImp) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].index = i
h[j].index = j
}
// Part of heap.Interface
func (h *vaultDataHeapImp) Push(x interface{}) {
n := len(*h)
entry := x.(*vaultClientHeapEntry)
entry.index = n
*h = append(*h, entry)
}
// Part of heap.Interface
func (h *vaultDataHeapImp) Pop() interface{} {
old := *h
n := len(old)
entry := old[n-1]
entry.index = -1 // for safety
*h = old[0 : n-1]
return entry
}

View File

@@ -0,0 +1,223 @@
package vaultclient
import (
"log"
"os"
"testing"
"time"
"github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/testutil"
vaultapi "github.com/hashicorp/vault/api"
)
func TestVaultClient_EstablishConnection(t *testing.T) {
v := testutil.NewTestVault(t)
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "10s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while and check that no connection has been established.
time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond)
if c.ConnectionEstablished() {
t.Fatalf("ConnectionEstablished() returned true before Vault server started")
}
// Start Vault
v.Start()
defer v.Stop()
testutil.WaitForResult(func() (bool, error) {
return c.ConnectionEstablished(), nil
}, func(err error) {
t.Fatalf("Connection not established")
})
}
func TestVaultClient_TokenRenewals(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "10s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while to ensure that the renewal loop is active
time.Sleep(3 * time.Second)
tcr := &vaultapi.TokenCreateRequest{
Policies: []string{"foo", "bar"},
TTL: "2s",
DisplayName: "derived-for-task",
Renewable: new(bool),
}
*tcr.Renewable = true
num := 5
tokens := make([]string, num)
for i := 0; i < num; i++ {
c.client.SetToken(v.Config.Token)
if err := c.client.SetAddress(v.Config.Addr); err != nil {
t.Fatal(err)
}
secret, err := c.client.Auth().Token().Create(tcr)
if err != nil {
t.Fatalf("failed to create vault token: %v", err)
}
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatal("failed to derive a wrapped vault token")
}
tokens[i] = secret.Auth.ClientToken
errCh := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
go func(errCh <-chan error) {
var err error
for {
select {
case err = <-errCh:
t.Fatalf("error while renewing the token: %v", err)
}
}
}(errCh)
}
if c.heap.Length() != num {
t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
}
time.Sleep(5 * time.Second)
for i := 0; i < num; i++ {
if err := c.StopRenewToken(tokens[i]); err != nil {
t.Fatal(err)
}
}
if c.heap.Length() != 0 {
t.Fatal("bad: heap length: expected: 0, actual: %d", c.heap.Length())
}
}
func TestVaultClient_Heap(t *testing.T) {
conf := config.DefaultConfig()
conf.VaultConfig.Enabled = true
conf.VaultConfig.Token = "testvaulttoken"
conf.VaultConfig.TaskTokenTTL = "10s"
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
c, err := NewVaultClient(conf.VaultConfig, logger, nil)
if err != nil {
t.Fatal(err)
}
if c == nil {
t.Fatal("failed to create vault client")
}
now := time.Now()
renewalReq1 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id1",
increment: 10,
}
if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id1") {
t.Fatalf("id1 should have been tracked")
}
renewalReq2 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id2",
increment: 10,
}
if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id2") {
t.Fatalf("id2 should have been tracked")
}
renewalReq3 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id3",
increment: 10,
}
if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id3") {
t.Fatalf("id3 should have been tracked")
}
// Reading elements should yield id2, id1 and id3 in order
req, _ := c.nextRenewal()
if req != renewalReq2 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
}
if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
t.Fatal(err)
}
req, _ = c.nextRenewal()
if req != renewalReq1 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
}
if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
t.Fatal(err)
}
req, _ = c.nextRenewal()
if req != renewalReq3 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
}
if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id1"); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id2"); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id3"); err != nil {
t.Fatal(err)
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
}