mirror of
https://github.com/kemko/nomad.git
synced 2026-01-06 18:35:44 +03:00
Employ DeriveVaultToken API and flesh-up DeriveToken
This commit is contained in:
@@ -248,7 +248,9 @@ func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logg
|
||||
go c.rpcProxy.Run()
|
||||
|
||||
// Start renewing tokens and secrets
|
||||
go c.vaultClient.Start()
|
||||
if c.vaultClient != nil {
|
||||
go c.vaultClient.Start()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
@@ -1298,15 +1300,27 @@ func (c *Client) setupVaultClient() error {
|
||||
if c.config.VaultConfig == nil {
|
||||
return fmt.Errorf("nil vault config")
|
||||
}
|
||||
|
||||
if !c.config.VaultConfig.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.config.VaultConfig.Token == "" {
|
||||
return fmt.Errorf("vault token not set")
|
||||
}
|
||||
|
||||
var err error
|
||||
if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger); err != nil {
|
||||
if c.vaultClient, err = vaultclient.NewVaultClient(c.Node(), c.Region(),
|
||||
c.config.VaultConfig, c.logger, c.config.RPCHandler, c.connPool,
|
||||
c.rpcProxy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.vaultClient == nil {
|
||||
c.logger.Printf("[ERR] client: failed to create vault client")
|
||||
return fmt.Errorf("failed to create vault client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/driver"
|
||||
|
||||
@@ -4,9 +4,15 @@ import (
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
clientconfig "github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/rpcproxy"
|
||||
"github.com/hashicorp/nomad/nomad"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
@@ -21,9 +27,10 @@ type VaultClient interface {
|
||||
// Stops the renewal loop for tokens and secrets
|
||||
Stop()
|
||||
|
||||
// Contacts the nomad server and fetches a wrapped token. The wrapped
|
||||
// token will be unwrapped by contacting vault and returned.
|
||||
DeriveToken() (string, error)
|
||||
// Contacts the nomad server and fetches wrapped tokens for a set of
|
||||
// tasks. The wrapped tokens will be unwrapped using vault and
|
||||
// returned.
|
||||
DeriveToken(*structs.Allocation, []string) (map[string]string, error)
|
||||
|
||||
// Fetch the Consul ACL token required for the task
|
||||
GetConsulACL(string, string) (*vaultapi.Secret, error)
|
||||
@@ -39,13 +46,19 @@ type VaultClient interface {
|
||||
// min-heap for periodic renewal.
|
||||
RenewLease(string, int) <-chan error
|
||||
|
||||
// Removes a secret's lease id from the min-heap, stopping its renewal.
|
||||
// Removes a secret's lease ID from the min-heap, stopping its renewal.
|
||||
StopRenewLease(string) error
|
||||
}
|
||||
|
||||
// Implementation of VaultClient interface to interact with vault and perform
|
||||
// token and lease renewals periodically.
|
||||
type vaultClient struct {
|
||||
// Client's region
|
||||
region string
|
||||
|
||||
// The node in which this vault client is running in
|
||||
node *structs.Node
|
||||
|
||||
// running indicates if the renewal loop is active or not
|
||||
running bool
|
||||
|
||||
@@ -73,6 +86,10 @@ type vaultClient struct {
|
||||
|
||||
lock sync.RWMutex
|
||||
logger *log.Logger
|
||||
|
||||
rpcHandler clientconfig.RPCHandler
|
||||
rpcProxy *rpcproxy.RPCProxy
|
||||
connPool *nomad.ConnPool
|
||||
}
|
||||
|
||||
// tokenData holds the relevant information about the Vault token passed to the
|
||||
@@ -122,13 +139,26 @@ type vaultClientHeap struct {
|
||||
type vaultDataHeapImp []*vaultClientHeapEntry
|
||||
|
||||
// NewVaultClient returns a new vault client from the given config.
|
||||
func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClient, error) {
|
||||
func NewVaultClient(node *structs.Node, region string, config *config.VaultConfig,
|
||||
logger *log.Logger, rpcHandler clientconfig.RPCHandler, connPool *nomad.ConnPool,
|
||||
rpcProxy *rpcproxy.RPCProxy) (*vaultClient, error) {
|
||||
if !config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if node == nil {
|
||||
return nil, fmt.Errorf("nil node")
|
||||
}
|
||||
|
||||
if region == "" {
|
||||
return nil, fmt.Errorf("missing region")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("nil vault config")
|
||||
}
|
||||
|
||||
// Creation of a vault client requires that the token is supplied via
|
||||
// config.
|
||||
// Creation of a vault client requires a token
|
||||
if config.Token == "" {
|
||||
return nil, fmt.Errorf("vault token not set")
|
||||
}
|
||||
@@ -141,34 +171,44 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger) (*vaultClien
|
||||
return nil, fmt.Errorf("nil logger")
|
||||
}
|
||||
|
||||
if connPool == nil {
|
||||
return nil, fmt.Errorf("nil connection pool")
|
||||
}
|
||||
|
||||
if rpcProxy == nil {
|
||||
return nil, fmt.Errorf("nil rpc proxy")
|
||||
}
|
||||
|
||||
c := &vaultClient{
|
||||
config: config,
|
||||
stopCh: make(chan struct{}),
|
||||
rpcHandler: rpcHandler,
|
||||
connPool: connPool,
|
||||
rpcProxy: rpcProxy,
|
||||
region: region,
|
||||
node: node,
|
||||
config: config,
|
||||
stopCh: make(chan struct{}),
|
||||
// Update channel should be a buffered channel
|
||||
updateCh: make(chan struct{}, 1),
|
||||
heap: NewVaultClientHeap(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if !c.config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Get the Vault API configuration
|
||||
apiConf, err := config.ApiConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create vault API config: %v", err)
|
||||
logger.Printf("[ERR] client/vaultclient: failed to create vault API config: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the Vault API client
|
||||
client, err := vaultapi.NewClient(apiConf)
|
||||
if err != nil {
|
||||
logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err)
|
||||
logger.Printf("[ERR] client/vaultclient: failed to create Vault client. Not retrying: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the token and store the client
|
||||
client.SetToken(c.config.Token)
|
||||
|
||||
c.client = client
|
||||
|
||||
return c, nil
|
||||
@@ -184,7 +224,7 @@ func NewVaultClientHeap() *vaultClientHeap {
|
||||
}
|
||||
|
||||
// IsTracked returns if a given identifier is already present in the heap and
|
||||
// hence is being renewed.
|
||||
// hence is being renewed. Lock should be held before calling this method.
|
||||
func (c *vaultClient) IsTracked(id string) bool {
|
||||
if id == "" {
|
||||
return false
|
||||
@@ -200,7 +240,7 @@ func (c *vaultClient) Start() {
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Printf("[INFO] vaultclient: establishing connection to vault")
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: establishing connection to vault")
|
||||
go c.establishConnection()
|
||||
}
|
||||
|
||||
@@ -213,7 +253,7 @@ func (c *vaultClient) ConnectionEstablished() bool {
|
||||
}
|
||||
|
||||
// establishConnection is used to make first contact with Vault. This should be
|
||||
// called in a go-routine since the connection is retried til the Vault Client
|
||||
// called in a go-routine since the connection is retried till the Vault Client
|
||||
// is stopped or the connection is successfully made at which point the renew
|
||||
// loop is started.
|
||||
func (c *vaultClient) establishConnection() {
|
||||
@@ -229,7 +269,7 @@ OUTER:
|
||||
case <-retryTimer.C:
|
||||
// Ensure the API is reachable
|
||||
if _, err := c.client.Sys().InitStatus(); err != nil {
|
||||
c.logger.Printf("[WARN] vaultclient: failed to contact Vault API. Retrying in %v",
|
||||
c.logger.Printf("[WARN] client/vaultclient: failed to contact Vault API. Retrying in %v",
|
||||
c.config.ConnectionRetryIntv)
|
||||
retryTimer.Reset(c.config.ConnectionRetryIntv)
|
||||
continue OUTER
|
||||
@@ -245,38 +285,42 @@ OUTER:
|
||||
|
||||
// Retrieve our token, validate it and parse the lease duration
|
||||
if err := c.parseSelfToken(); err != nil {
|
||||
c.logger.Printf("[ERR] vaultclient: failed to lookup self token and not retrying: %v", err)
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to lookup self token and not retrying: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Begin the renewal loop
|
||||
go c.run()
|
||||
c.logger.Printf("[INFO] vaultclient: started")
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: started")
|
||||
|
||||
// If we are given a non-root token, start renewing it
|
||||
if c.token.Renewable {
|
||||
c.logger.Printf("[INFO] vaultclient: not renewing token as it is not renewable")
|
||||
// If we are given a token that needs renewal, place it in the renewal
|
||||
// loop.
|
||||
|
||||
// Root tokens can also have a TTL
|
||||
if c.token.Root && c.token.TTL == 0 {
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: not renewing token as it is a non-expiring root token")
|
||||
} else {
|
||||
c.logger.Printf("[INFO] vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second)
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: token lease duration is %v", time.Duration(c.token.CreationTTL)*time.Second)
|
||||
|
||||
// Add the VaultClient's token to the renewal loop
|
||||
// Renew the token and place it in renewal min-heap
|
||||
errCh := c.RenewToken(c.config.Token, c.token.CreationTTL)
|
||||
|
||||
// Catch the renewal error of VaultClient's token.
|
||||
go func(errCh <-chan error) {
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case err = <-errCh:
|
||||
c.logger.Printf("[ERR] vaultclient: error while renewing the vault client's token: %v", err)
|
||||
c.logger.Printf("[ERR] client/vaultclient: error while renewing the vault client's token: %v", err)
|
||||
}
|
||||
}
|
||||
}(errCh)
|
||||
}
|
||||
}
|
||||
|
||||
// parseSelfToken looks up the Vault token in Vault and parses its data storing
|
||||
// it in the client. If the token is not valid for Nomads purposes an error is
|
||||
// returned.
|
||||
// parseSelfToken looks up the VaultClient's token in vault and parses its data
|
||||
// storing it in the client. If the token is not valid for Nomads purposes an
|
||||
// error is returned.
|
||||
func (c *vaultClient) parseSelfToken() error {
|
||||
// Get the initial lease duration
|
||||
auth := c.client.Auth().Token()
|
||||
@@ -333,69 +377,127 @@ func (c *vaultClient) Stop() {
|
||||
close(c.stopCh)
|
||||
}
|
||||
|
||||
// DeriveToken contacts the nomad server and fetches a wrapped token. Then it
|
||||
// contacts vault to unwrap the token and returns the unwrapped token.
|
||||
func (c *vaultClient) DeriveToken() (string, error) {
|
||||
// TODO: Replace this code with an actual call to the nomad server.
|
||||
// This is a sample code which directly fetches a wrapped token from
|
||||
// vault and unwraps it for time being.
|
||||
tcr := &vaultapi.TokenCreateRequest{
|
||||
Policies: []string{"foo", "bar"},
|
||||
TTL: "10s",
|
||||
DisplayName: "derived-token",
|
||||
Renewable: new(bool),
|
||||
}
|
||||
*tcr.Renewable = true
|
||||
// DeriveToken takes in an allocation and a set of tasks and for each of the
|
||||
// task, it derives a vault token from nomad server and unwraps it using vault.
|
||||
// The return value is a map containing all the unwrapped tokens indexed by the
|
||||
// task name.
|
||||
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
|
||||
var result *multierror.Error
|
||||
|
||||
// Set the TTL for the wrapped token
|
||||
wrapLookupFunc := func(method, path string) string {
|
||||
if method == "POST" && path == "auth/token/create" {
|
||||
return "60s"
|
||||
if !c.running {
|
||||
result = multierror.Append(fmt.Errorf("vault client is not running"))
|
||||
return nil, result
|
||||
}
|
||||
|
||||
if alloc == nil {
|
||||
result = multierror.Append(fmt.Errorf("nil allocation"))
|
||||
return nil, result
|
||||
}
|
||||
if taskNames == nil || len(taskNames) == 0 {
|
||||
result = multierror.Append(fmt.Errorf("missing task names"))
|
||||
return nil, result
|
||||
}
|
||||
|
||||
found := false
|
||||
verifiedTasks := []string{}
|
||||
// Check if the given task names actually exist in the allocation
|
||||
for _, taskName := range taskNames {
|
||||
found = false
|
||||
for _, group := range alloc.Job.TaskGroups {
|
||||
for _, task := range group.Tasks {
|
||||
if task.Name == taskName {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if found {
|
||||
verifiedTasks = append(verifiedTasks, taskName)
|
||||
} else {
|
||||
// Append the error for an invalid task name, but don't
|
||||
// break out of the loop. Continue to process other
|
||||
// tasks.
|
||||
result = multierror.Append(result, fmt.Errorf("task %s not found in the allocation", taskName))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
c.client.SetWrappingLookupFunc(wrapLookupFunc)
|
||||
|
||||
// Create a wrapped token
|
||||
secret, err := c.client.Auth().Token().Create(tcr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create vault token: %v", err)
|
||||
}
|
||||
if secret == nil || secret.WrapInfo == nil || secret.WrapInfo.Token == "" ||
|
||||
secret.WrapInfo.WrappedAccessor == "" {
|
||||
return "", fmt.Errorf("failed to derive a wrapped vault token")
|
||||
}
|
||||
|
||||
wrappedToken := secret.WrapInfo.Token
|
||||
|
||||
// Unwrap the vault token
|
||||
unwrapResp, err := c.client.Logical().Unwrap(wrappedToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unwrap the token: %v", err)
|
||||
}
|
||||
if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" {
|
||||
return "", fmt.Errorf("failed to unwrap the token")
|
||||
// DeriveVaultToken of nomad server can take in a set of tasks and
|
||||
// creates tokens for all the tasks.
|
||||
req := &structs.DeriveVaultTokenRequest{
|
||||
NodeID: c.node.ID,
|
||||
SecretID: c.node.SecretID,
|
||||
AllocID: alloc.ID,
|
||||
Tasks: verifiedTasks,
|
||||
QueryOptions: structs.QueryOptions{
|
||||
Region: c.region,
|
||||
AllowStale: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Return the unwrapped token
|
||||
return unwrapResp.Auth.ClientToken, nil
|
||||
// Derive the tokens
|
||||
var resp structs.DeriveVaultTokenResponse
|
||||
if err := c.RPC("Node.DeriveVaultToken", &req, &resp); err != nil {
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault tokens: %v", err)
|
||||
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: %v", err))
|
||||
return nil, result
|
||||
}
|
||||
if resp.Tasks == nil {
|
||||
c.logger.Printf("[ERR] client/vaultclient: failed to derive vault token: invalid response")
|
||||
result = multierror.Append(result, fmt.Errorf("failed to derive vault tokens: invalid response"))
|
||||
return nil, result
|
||||
}
|
||||
|
||||
unwrappedTokens := make(map[string]string)
|
||||
|
||||
// Retrieve the wrapped tokens from the response and unwrap it using
|
||||
// the VaultClient's token, which is cached at the API client.
|
||||
for _, taskName := range verifiedTasks {
|
||||
// Get the wrapped token
|
||||
wrappedToken, ok := resp.Tasks[taskName]
|
||||
if !ok {
|
||||
c.logger.Printf("[ERR] client/vaultclient: wrapped token missing for task %q", taskName)
|
||||
result = multierror.Append(result, fmt.Errorf("wrapped token missing for task %q", taskName))
|
||||
return nil, result
|
||||
}
|
||||
|
||||
// Unwrap the vault token
|
||||
unwrapResp, err := c.client.Logical().Unwrap(wrappedToken)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q: %v", taskName, err))
|
||||
return nil, result
|
||||
}
|
||||
if unwrapResp == nil || unwrapResp.Auth == nil || unwrapResp.Auth.ClientToken == "" {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to unwrap the token for task %q", taskName))
|
||||
return nil, result
|
||||
}
|
||||
|
||||
// Append the unwrapped token to the return value
|
||||
unwrappedTokens[taskName] = unwrapResp.Auth.ClientToken
|
||||
}
|
||||
|
||||
return unwrappedTokens, nil
|
||||
}
|
||||
|
||||
// GetConsulACL creates a vault API client and reads from vault a consul ACL
|
||||
// token used by the task.
|
||||
func (c *vaultClient) GetConsulACL(token, vaultPath string) (*vaultapi.Secret, error) {
|
||||
func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) {
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("missing token")
|
||||
}
|
||||
if vaultPath == "" {
|
||||
return nil, fmt.Errorf("missing vault path")
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("missing consul ACL token vault path")
|
||||
}
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
// Use the token supplied to interact with vault
|
||||
c.client.SetToken(token)
|
||||
|
||||
// Restore the token in client to VaultClient's token
|
||||
defer c.client.SetToken(c.config.Token)
|
||||
|
||||
// Read the consul ACL token and return the secret directly
|
||||
return c.client.Logical().Read(vaultPath)
|
||||
return c.client.Logical().Read(path)
|
||||
}
|
||||
|
||||
// RenewToken renews the supplied token and adds it to the min-heap so that it
|
||||
@@ -411,6 +513,10 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
|
||||
errCh <- fmt.Errorf("missing token")
|
||||
return errCh
|
||||
}
|
||||
if increment < 1 {
|
||||
errCh <- fmt.Errorf("increment cannot be less than 1")
|
||||
return errCh
|
||||
}
|
||||
|
||||
// Create a renewal request and indicate that the identifier in the
|
||||
// request is a token and not a lease
|
||||
@@ -436,8 +542,8 @@ func (c *vaultClient) RenewToken(token string, increment int) <-chan error {
|
||||
// channel and the channel is returned instead of an actual error. This helps
|
||||
// the caller be notified of a renewal failure asynchronously for appropriate
|
||||
// actions to be taken.
|
||||
func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error {
|
||||
c.logger.Printf("[INFO] vaultclient: renewing lease %q", leaseId)
|
||||
func (c *vaultClient) RenewLease(leaseId string, increment int) <-chan error {
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: renewing lease %q", leaseId)
|
||||
// Create a buffered error channel
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
@@ -446,8 +552,8 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error
|
||||
return errCh
|
||||
}
|
||||
|
||||
if leaseDuration == 0 {
|
||||
errCh <- fmt.Errorf("missing lease duration")
|
||||
if increment < 1 {
|
||||
errCh <- fmt.Errorf("increment cannot be less than 1")
|
||||
return errCh
|
||||
}
|
||||
|
||||
@@ -455,7 +561,7 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error
|
||||
renewalReq := &vaultClientRenewalRequest{
|
||||
errCh: make(chan error, 1),
|
||||
id: leaseId,
|
||||
increment: leaseDuration,
|
||||
increment: increment,
|
||||
}
|
||||
|
||||
// Renew the secret and send any error to the dedicated error channel
|
||||
@@ -467,12 +573,11 @@ func (c *vaultClient) RenewLease(leaseId string, leaseDuration int) <-chan error
|
||||
}
|
||||
|
||||
// renew is a common method to handle renewal of both tokens and secret leases.
|
||||
// It creates a vault API client and invokes either a token renewal request or
|
||||
// a secret renewal request. If renewal is successful, min-heap is updated
|
||||
// based on the duration after which it needs its renewal again. The duration
|
||||
// is set to half the lease duration present in the renewal response.
|
||||
// It invokes a token renewal or a secret's lease renewal. If renewal is
|
||||
// successful, min-heap is updated based on the duration after which it needs
|
||||
// renewal again. The next renewal time is randomly selected to avoid spikes in
|
||||
// the number of APIs periodically.
|
||||
func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
||||
c.logger.Printf("[INFO] vaultclient: ~~~~~~~Renewing %s~~~~~~~~", req.id)
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
@@ -486,11 +591,12 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
||||
if req.id == "" {
|
||||
return fmt.Errorf("missing id in renewal request")
|
||||
}
|
||||
if req.increment == 0 {
|
||||
return fmt.Errorf("missing increment in renewal request")
|
||||
if req.increment < 1 {
|
||||
return fmt.Errorf("increment cannot be less than 1")
|
||||
}
|
||||
|
||||
var duration time.Duration
|
||||
var renewalErr error
|
||||
leaseDuration := req.increment
|
||||
if req.isToken {
|
||||
// Reset the token in the API client to that of VaultClient
|
||||
// before returning
|
||||
@@ -503,30 +609,45 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
||||
// Renew the token
|
||||
renewResp, err := c.client.Auth().Token().RenewSelf(req.increment)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to renew the vault token: %v", err)
|
||||
renewalErr = fmt.Errorf("failed to renew the vault token: %v", err)
|
||||
}
|
||||
if renewResp == nil || renewResp.Auth == nil {
|
||||
return fmt.Errorf("failed to renew the vault token")
|
||||
renewalErr = fmt.Errorf("failed to renew the vault token")
|
||||
} else {
|
||||
// Don't set this if renewal fails
|
||||
leaseDuration = renewResp.Auth.LeaseDuration
|
||||
}
|
||||
|
||||
// Set the next renewal time to half the lease duration
|
||||
duration = time.Duration(renewResp.Auth.LeaseDuration) * time.Second / 2
|
||||
} else {
|
||||
// Renew the secret
|
||||
renewResp, err := c.client.Sys().Renew(req.id, req.increment)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to renew vault secret: %v", err)
|
||||
renewalErr = fmt.Errorf("failed to renew vault secret: %v", err)
|
||||
}
|
||||
if renewResp == nil {
|
||||
return fmt.Errorf("failed to renew vault secret")
|
||||
renewalErr = fmt.Errorf("failed to renew vault secret")
|
||||
} else {
|
||||
// Don't set this if renewal fails
|
||||
leaseDuration = renewResp.LeaseDuration
|
||||
}
|
||||
|
||||
// Set the next renewal time to half the lease duration
|
||||
duration = time.Duration(renewResp.LeaseDuration) * time.Second / 2
|
||||
}
|
||||
|
||||
duration := leaseDuration / 2
|
||||
switch {
|
||||
case leaseDuration < 30:
|
||||
// Don't bother about introducing randomness if the
|
||||
// leaseDuration is too small.
|
||||
default:
|
||||
// Give a breathing space of 20 seconds
|
||||
min := 10
|
||||
max := leaseDuration - min
|
||||
rand.Seed(time.Now().Unix())
|
||||
duration = min + rand.Intn(max-min)
|
||||
}
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: req.increment: %d, leaseDuration: %d, duration: %d",
|
||||
req.increment, leaseDuration, duration)
|
||||
|
||||
// Determine the next renewal time
|
||||
next := time.Now().Add(duration)
|
||||
next := time.Now().Add(time.Duration(duration) * time.Second)
|
||||
|
||||
if c.IsTracked(req.id) {
|
||||
// If the identifier is already tracked, this indicates a
|
||||
@@ -556,14 +677,15 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// Returning the renewal error here ensures that an entry is either
|
||||
// added or updated in the min-heap. This is done to not starve other
|
||||
// entries in heap.
|
||||
return renewalErr
|
||||
}
|
||||
|
||||
// run is the renewal loop which performs the periodic renewals of both the
|
||||
// tokens and the secret leases.
|
||||
func (c *vaultClient) run() {
|
||||
var renewalCh <-chan time.Time
|
||||
|
||||
if !c.config.Enabled {
|
||||
return
|
||||
}
|
||||
@@ -572,6 +694,7 @@ func (c *vaultClient) run() {
|
||||
c.running = true
|
||||
c.lock.Unlock()
|
||||
|
||||
var renewalCh <-chan time.Time
|
||||
for c.config.Enabled && c.running {
|
||||
// Fetches the candidate for next renewal
|
||||
renewalReq, renewalTime := c.nextRenewal()
|
||||
@@ -604,7 +727,7 @@ func (c *vaultClient) run() {
|
||||
case <-c.updateCh:
|
||||
continue
|
||||
case <-c.stopCh:
|
||||
c.logger.Printf("[INFO] vaultclient: stopped")
|
||||
c.logger.Printf("[DEBUG] client/vaultclient: stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -672,6 +795,27 @@ func (c *vaultClient) nextRenewal() (*vaultClientRenewalRequest, time.Time) {
|
||||
return nextEntry.req, nextEntry.next
|
||||
}
|
||||
|
||||
// RPC is used to forward an RPC call to a nomad server, or fail if no servers
|
||||
func (c *vaultClient) RPC(method string, args interface{}, reply interface{}) error {
|
||||
// Invoke the RPCHandler if it exists
|
||||
if c.rpcHandler != nil {
|
||||
return c.rpcHandler.RPC(method, args, reply)
|
||||
}
|
||||
|
||||
// Pick a server to request from
|
||||
server := c.rpcProxy.FindServer()
|
||||
if server == nil {
|
||||
return fmt.Errorf("no known servers")
|
||||
}
|
||||
|
||||
// Make the RPC request
|
||||
if err := c.connPool.RPC(c.region, server.Addr, structs.ApiMajorVersion, method, args, reply); err != nil {
|
||||
c.rpcProxy.NotifyFailedServer(server)
|
||||
return fmt.Errorf("RPC failed to server %s: %v", server.Addr, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Additional helper functions on top of interface methods
|
||||
|
||||
// Length returns the number of elements in the heap
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/rpcproxy"
|
||||
"github.com/hashicorp/nomad/nomad"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
"github.com/hashicorp/nomad/testutil"
|
||||
vaultapi "github.com/hashicorp/vault/api"
|
||||
)
|
||||
@@ -17,8 +20,12 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
|
||||
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
|
||||
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
|
||||
v.Config.TaskTokenTTL = "10s"
|
||||
|
||||
c, err := NewVaultClient(v.Config, logger)
|
||||
node := &structs.Node{}
|
||||
connPool := &nomad.ConnPool{}
|
||||
rpcProxy := &rpcproxy.RPCProxy{}
|
||||
var rpcHandler config.RPCHandler
|
||||
c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler,
|
||||
connPool, rpcProxy)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
@@ -48,11 +55,15 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
|
||||
v := testutil.NewTestVault(t).Start()
|
||||
defer v.Stop()
|
||||
|
||||
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
|
||||
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
|
||||
v.Config.TaskTokenTTL = "10s"
|
||||
|
||||
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
|
||||
c, err := NewVaultClient(v.Config, logger)
|
||||
node := &structs.Node{}
|
||||
connPool := &nomad.ConnPool{}
|
||||
rpcProxy := &rpcproxy.RPCProxy{}
|
||||
var rpcHandler config.RPCHandler
|
||||
c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler,
|
||||
connPool, rpcProxy)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build vault client: %v", err)
|
||||
}
|
||||
@@ -61,7 +72,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
|
||||
defer c.Stop()
|
||||
|
||||
// Sleep a little while to ensure that the renewal loop is active
|
||||
time.Sleep(2 * time.Second)
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
tcr := &vaultapi.TokenCreateRequest{
|
||||
Policies: []string{"foo", "bar"},
|
||||
@@ -71,7 +82,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
|
||||
}
|
||||
*tcr.Renewable = true
|
||||
|
||||
num := 10
|
||||
num := 5
|
||||
tokens := make([]string, num)
|
||||
for i := 0; i < num; i++ {
|
||||
c.client.SetToken(v.Config.Token)
|
||||
@@ -107,7 +118,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
|
||||
t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length())
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
for i := 0; i < num; i++ {
|
||||
if err := c.StopRenewToken(tokens[i]); err != nil {
|
||||
@@ -122,14 +133,23 @@ func TestVaultClient_TokenRenewals(t *testing.T) {
|
||||
|
||||
func TestVaultClient_Heap(t *testing.T) {
|
||||
conf := config.DefaultConfig()
|
||||
conf.VaultConfig.Enabled = true
|
||||
conf.VaultConfig.Token = "testvaulttoken"
|
||||
conf.VaultConfig.TaskTokenTTL = "10s"
|
||||
|
||||
logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags)
|
||||
c, err := NewVaultClient(conf.VaultConfig, logger)
|
||||
node := &structs.Node{}
|
||||
connPool := &nomad.ConnPool{}
|
||||
rpcProxy := &rpcproxy.RPCProxy{}
|
||||
var rpcHandler config.RPCHandler
|
||||
c, err := NewVaultClient(node, "global", conf.VaultConfig, logger, rpcHandler,
|
||||
connPool, rpcProxy)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c == nil {
|
||||
t.Fatal("failed to create vault client")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user