diff --git a/client/client.go b/client/client.go index da4dd6469..37224cbf4 100644 --- a/client/client.go +++ b/client/client.go @@ -1311,7 +1311,8 @@ func (c *Client) setupVaultClient() error { } var err error - if c.vaultClient, err = vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.tokenDeriver); err != nil { + if c.vaultClient, err = + vaultclient.NewVaultClient(c.config.VaultConfig, c.logger, c.deriveToken); err != nil { return err } @@ -1323,7 +1324,10 @@ func (c *Client) setupVaultClient() error { return nil } -func (c *Client) tokenDeriver(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { +// deriveToken takes in an allocation and a set of tasks and derives vault +// tokens for each of the tasks, unwraps all of them using the supplied vault +// client and returns a map of unwrapped tokens, indexed by the task name. +func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { if alloc == nil { return nil, fmt.Errorf("nil allocation") } diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 38fdc613c..3ff1b128b 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -7,9 +7,6 @@ import ( "time" "github.com/hashicorp/nomad/client/config" - "github.com/hashicorp/nomad/client/rpcproxy" - "github.com/hashicorp/nomad/nomad" - "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" vaultapi "github.com/hashicorp/vault/api" ) @@ -20,12 +17,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) { logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -58,12 +50,7 @@ func TestVaultClient_TokenRenewals(t *testing.T) { logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) v.Config.ConnectionRetryIntv = 100 * time.Millisecond v.Config.TaskTokenTTL = "10s" - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", v.Config, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(v.Config, logger, nil) if err != nil { t.Fatalf("failed to build vault client: %v", err) } @@ -138,12 +125,7 @@ func TestVaultClient_Heap(t *testing.T) { conf.VaultConfig.TaskTokenTTL = "10s" logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) - node := &structs.Node{} - connPool := &nomad.ConnPool{} - rpcProxy := &rpcproxy.RPCProxy{} - var rpcHandler config.RPCHandler - c, err := NewVaultClient(node, "global", conf.VaultConfig, logger, rpcHandler, - connPool, rpcProxy) + c, err := NewVaultClient(conf.VaultConfig, logger, nil) if err != nil { t.Fatal(err) }