diff --git a/demo/vagrant/server.hcl b/demo/vagrant/server.hcl index 70ad75d7e..653b2c037 100644 --- a/demo/vagrant/server.hcl +++ b/demo/vagrant/server.hcl @@ -4,13 +4,6 @@ log_level = "DEBUG" # Setup data dir data_dir = "/tmp/server1" -vault { - enabled = true - # address = "127.0.0.1:8200" - token_role_name = "foobar" - # periodic_token = "09e54c4d-a9b6-f1b8-fb41-e87a263d4da9" -} - # Enable the server server { enabled = true diff --git a/nomad/server_test.go b/nomad/server_test.go index 0312ae0a4..b28058bab 100644 --- a/nomad/server_test.go +++ b/nomad/server_test.go @@ -57,6 +57,9 @@ func testServer(t *testing.T, cb func(*Config)) *Server { config.RaftConfig.ElectionTimeout = 50 * time.Millisecond config.RaftTimeout = 500 * time.Millisecond + // Disable Vault + config.VaultConfig.Enabled = false + // Invoke the callback if any if cb != nil { cb(config) diff --git a/nomad/structs/config/vault.go b/nomad/structs/config/vault.go index f2eae626a..c93de23d4 100644 --- a/nomad/structs/config/vault.go +++ b/nomad/structs/config/vault.go @@ -1,6 +1,16 @@ package config -import vault "github.com/hashicorp/vault/api" +import ( + "time" + + vault "github.com/hashicorp/vault/api" +) + +const ( + // DefaultVaultConnectRetryIntv is the retry interval between trying to + // connect to Vault + DefaultVaultConnectRetryIntv = 30 * time.Second +) // VaultConfig contains the configuration information necessary to // communicate with Vault in order to: @@ -30,9 +40,14 @@ type VaultConfig struct { // does not have to renew with Vault at a very high frequency TaskTokenTTL string `mapstructure:"task_token_ttl"` - // Addr is the address of the local Vault agent + // Addr is the address of the local Vault agent. This should be a complete + // URL such as "http://vault.example.com" Addr string `mapstructure:"address"` + // ConnectionRetryIntv is the interval to wait before re-attempting to + // connect to Vault. + ConnectionRetryIntv time.Duration + // TLSCaFile is the path to a PEM-encoded CA cert file to use to verify the // Vault server SSL certificate. TLSCaFile string `mapstructure:"tls_ca_file"` @@ -58,9 +73,10 @@ type VaultConfig struct { // `vault` configuration. func DefaultVaultConfig() *VaultConfig { return &VaultConfig{ - Enabled: false, + Enabled: true, AllowUnauthenticated: false, - Addr: "vault.service.consul:8200", + Addr: "https://vault.service.consul:8200", + ConnectionRetryIntv: DefaultVaultConnectRetryIntv, } } @@ -77,6 +93,9 @@ func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig { if b.Addr != "" { result.Addr = b.Addr } + if b.ConnectionRetryIntv.Nanoseconds() != 0 { + result.ConnectionRetryIntv = b.ConnectionRetryIntv + } if b.TLSCaFile != "" { result.TLSCaFile = b.TLSCaFile } @@ -100,16 +119,9 @@ func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig { } // ApiConfig() returns a usable Vault config that can be passed directly to -// hashicorp/vault/api. If readEnv is true, the environment is read for -// appropriate Vault variables. -func (c *VaultConfig) ApiConfig(readEnv bool) (*vault.Config, error) { +// hashicorp/vault/api. +func (c *VaultConfig) ApiConfig() (*vault.Config, error) { conf := vault.DefaultConfig() - if readEnv { - if err := conf.ReadEnvironment(); err != nil { - return nil, err - } - } - tlsConf := &vault.TLSConfig{ CACert: c.TLSCaFile, CAPath: c.TLSCaPath, @@ -122,6 +134,7 @@ func (c *VaultConfig) ApiConfig(readEnv bool) (*vault.Config, error) { return nil, err } + conf.Address = c.Addr return conf, nil } diff --git a/nomad/vault.go b/nomad/vault.go index 48324449f..04d0cd939 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -36,6 +36,17 @@ type VaultClient interface { Stop() } +// tokenData holds the relevant information about the Vault token passed to the +// client. +type tokenData struct { + CreationTTL int `mapstructure:"creation_ttl"` + TTL int `mapstructure:"ttl"` + Renewable bool `mapstructure:"renewable"` + Policies []string `mapstructure:"policies"` + Role string `mapstructure:"role"` + Root bool +} + // vaultClient is the Servers implementation of the VaultClient interface. The // client renews the PeriodicToken given in the Vault configuration and provides // the Server with the ability to create child tokens and lookup the permissions @@ -43,31 +54,43 @@ type VaultClient interface { type vaultClient struct { // client is the Vault API client client *vapi.Client - auth *vapi.TokenAuth - logger *log.Logger - // running returns whether the renewal goroutine is running - running bool - shutdownCh chan struct{} - l sync.Mutex + // auth is the Vault token auth API client + auth *vapi.TokenAuth + + // config is the user passed Vault config + config *config.VaultConfig + + // renewalRunning marks whether the renewal goroutine is running + renewalRunning bool + + // establishingConn marks whether we are trying to establishe a connection to Vault + establishingConn bool + + // connEstablished marks whether we have an established connection to Vault + connEstablished bool + + // tokenData is the data of the passed Vault token + token *tokenData // enabled indicates whether the vaultClient is enabled. If it is not the // token lookup and create methods will return errors. enabled bool - // tokenRole is the role in which child tokens will be created from. - tokenRole string - // childTTL is the TTL for child tokens. childTTL string - // leaseDuration is the lease duration of our token in seconds - leaseDuration int - // lastRenewed is the time the token was last renewed lastRenewed time.Time + + shutdownCh chan struct{} + l sync.Mutex + logger *log.Logger } +// NewVaultClient returns a Vault client from the given config. If the client +// couldn't be made an error is returned. If an error is not returned, Shutdown +// is expected to be called to clean up any created goroutine func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, error) { if c == nil { return nil, fmt.Errorf("must pass valid VaultConfig") @@ -78,9 +101,9 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er } v := &vaultClient{ - enabled: c.Enabled, - tokenRole: c.TokenRoleName, - logger: logger, + enabled: c.Enabled, + config: c, + logger: logger, } // If vault is not enabled do not configure an API client or start any token @@ -89,23 +112,29 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er return v, nil } + // Validate we have the required fields. + if c.Token == "" { + return nil, errors.New("Vault token must be set") + } else if c.Addr == "" { + return nil, errors.New("Vault address must be set") + } + // Parse the TTL if it is set - if c.ChildTokenTTL != "" { - d, err := time.ParseDuration(c.ChildTokenTTL) + if c.TaskTokenTTL != "" { + d, err := time.ParseDuration(c.TaskTokenTTL) if err != nil { - return nil, fmt.Errorf("failed to parse ChildTokenTTL %q: %v", c.ChildTokenTTL, err) + return nil, fmt.Errorf("failed to parse TaskTokenTTL %q: %v", c.TaskTokenTTL, err) } - // TODO this should be a config validation problem as well if d.Nanoseconds() < minimumTokenTTL.Nanoseconds() { return nil, fmt.Errorf("ChildTokenTTL is less than minimum allowed of %v", minimumTokenTTL) } - v.childTTL = c.ChildTokenTTL + v.childTTL = c.TaskTokenTTL } // Get the Vault API configuration - apiConf, err := c.ApiConfig(true) + apiConf, err := c.ApiConfig() if err != nil { return nil, fmt.Errorf("Failed to create Vault API config: %v", err) } @@ -113,96 +142,84 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger) (*vaultClient, er // Create the Vault API client client, err := vapi.NewClient(apiConf) if err != nil { - return nil, fmt.Errorf("Failed to create Vault API client: %v", err) + v.logger.Printf("[ERR] vault: failed to create Vault client. Not retrying: %v", err) + return nil, err } - // Set the wrapping function such that token creation is wrapped - client.SetWrappingLookupFunc(v.getWrappingFn()) - // Set the token and store the client - if client.Token() == "" { - client.SetToken(c.PeriodicToken) - } - + client.SetToken(v.config.Token) v.client = client v.auth = client.Auth().Token() - // Validate we have the required fields. This is done after we create the - // client since these fields can be read from environment variables - if c.TokenRoleName == "" { - return nil, errors.New("Vault token role name must be set in config") - } else if client.Token() == "" { - return nil, errors.New("Vault periodic token must be set in config or in $VAULT_TOKEN") - } else if apiConf.Address == "" { - return nil, errors.New("Vault address must be set in config or in $VAULT_ADDR") - } - - // Retrieve our token, validate it and parse the lease duration - leaseDuration, err := v.parseSelfToken() - if err != nil { - return nil, err - } - v.leaseDuration = leaseDuration - - v.logger.Printf("[DEBUG] vault: token lease duration is %v", - time.Duration(v.leaseDuration)*time.Second) - // Prepare and launch the token renewal goroutine v.shutdownCh = make(chan struct{}) - v.running = true - go v.run() - + go v.establishConnection() return v, nil } -// getWrappingFn returns an appropriate wrapping function for Nomad Servers -func (v *vaultClient) getWrappingFn() func(operation, path string) string { - createPath := fmt.Sprintf("auth/token/create/%s", v.tokenRole) - return func(operation, path string) string { - // Only wrap the token create operation - if operation != "POST" || path != createPath { - return "" +// establishConnection is used to make first contact with Vault. This should be +// called in a go-routine since the connection is retried til the Vault Client +// is stopped or the connection is successfully made at which point the renew +// loop is started. +func (v *vaultClient) establishConnection() { + v.l.Lock() + v.establishingConn = true + v.l.Unlock() + + // Create the retry timer and set initial duration to zero so it fires + // immediately + retryTimer := time.NewTimer(0) + +OUTER: + for { + select { + case <-v.shutdownCh: + return + case <-retryTimer.C: + // Ensure the API is reachable + if _, err := v.client.Sys().InitStatus(); err != nil { + v.logger.Printf("[WARN] vault: failed to contact Vault API. Retrying in %v", + v.config.ConnectionRetryIntv) + retryTimer.Reset(v.config.ConnectionRetryIntv) + continue OUTER + } + + break OUTER } + } - return vaultTokenCreateTTL + v.l.Lock() + v.connEstablished = true + v.establishingConn = false + v.l.Unlock() + + // Retrieve our token, validate it and parse the lease duration + if err := v.parseSelfToken(); err != nil { + v.logger.Printf("[ERR] vault: failed to lookup self token and not retrying: %v", err) + return + } + + // Set the wrapping function such that token creation is wrapped now + // that we know our role + v.client.SetWrappingLookupFunc(v.getWrappingFn()) + + // If we are given a non-root token, start renewing it + if v.token.Root { + v.logger.Printf("[DEBUG] vault: not renewing token as it is root") + } else { + v.logger.Printf("[DEBUG] vault: token lease duration is %v", + time.Duration(v.token.CreationTTL)*time.Second) + go v.renewalLoop() } } -func (v *vaultClient) parseSelfToken() (int, error) { - // Get the initial lease duration - auth := v.client.Auth().Token() - self, err := auth.LookupSelf() - if err != nil { - return 0, fmt.Errorf("failed to lookup Vault periodic token: %v", err) - } +// renewalLoop runs the renew loop. This should only be called if we are given a +// non-root token. +func (v *vaultClient) renewalLoop() { + v.l.Lock() + v.renewalRunning = true + v.l.Unlock() - // Read and parse the fields - var data struct { - CreationTTL int `mapstructure:"creation_ttl"` - TTL int `mapstructure:"ttl"` - Renewable bool `mapstructure:"renewable"` - } - if err := mapstructure.WeakDecode(self.Data, &data); err != nil { - return 0, fmt.Errorf("failed to parse Vault token's data block: %v", err) - } - - if !data.Renewable { - return 0, fmt.Errorf("Vault token is not renewable") - } - - if data.CreationTTL == 0 { - return 0, fmt.Errorf("invalid lease duration of zero") - } - - if data.TTL == 0 { - return 0, fmt.Errorf("token TTL is zero") - } - - return data.CreationTTL, nil -} - -// run runs the renew loop -func (v *vaultClient) run() { // Create the renewal timer and set initial duration to zero so it fires // immediately authRenewTimer := time.NewTimer(0) @@ -216,8 +233,9 @@ func (v *vaultClient) run() { case <-v.shutdownCh: return case <-authRenewTimer.C: + // Renew the token and determine the new expiration err := v.renew() - currentExpiration := v.lastRenewed.Add(time.Duration(v.leaseDuration) * time.Second) + currentExpiration := v.lastRenewed.Add(time.Duration(v.token.CreationTTL) * time.Second) // Successfully renewed if err == nil { @@ -266,8 +284,8 @@ func (v *vaultClient) run() { v.l.Lock() defer v.l.Unlock() v.logger.Printf("[ERR] vault: failed to renew Vault token before lease expiration. Renew loop exiting") - if v.running { - v.running = false + if v.renewalRunning { + v.renewalRunning = false close(v.shutdownCh) } @@ -285,9 +303,11 @@ func (v *vaultClient) run() { } } +// renew attempts to renew our Vault token. If the renewal fails, an error is +// returned. This method updates the lastRenewed time func (v *vaultClient) renew() error { // Attempt to renew the token - secret, err := v.auth.RenewSelf(v.leaseDuration) + secret, err := v.auth.RenewSelf(v.token.CreationTTL) if err != nil { return err } @@ -304,7 +324,71 @@ func (v *vaultClient) renew() error { return nil } -// Stop stops token renewal. +// getWrappingFn returns an appropriate wrapping function for Nomad Servers +func (v *vaultClient) getWrappingFn() func(operation, path string) string { + createPath := "auth/token/create" + if !v.token.Root { + createPath = fmt.Sprintf("auth/token/create/%s", v.token.Role) + } + + return func(operation, path string) string { + // Only wrap the token create operation + if operation != "POST" || path != createPath { + return "" + } + + return vaultTokenCreateTTL + } +} + +// parseSelfToken looks up the Vault token in Vault and parses its data storing +// it in the client. If the token is not valid for Nomads purposes an error is +// returned. +func (v *vaultClient) parseSelfToken() error { + // Get the initial lease duration + auth := v.client.Auth().Token() + self, err := auth.LookupSelf() + if err != nil { + return fmt.Errorf("failed to lookup Vault periodic token: %v", err) + } + + // Read and parse the fields + var data tokenData + if err := mapstructure.WeakDecode(self.Data, &data); err != nil { + return fmt.Errorf("failed to parse Vault token's data block: %v", err) + } + + root := false + for _, p := range data.Policies { + if p == "root" { + root = true + break + } + } + + if !data.Renewable && !root { + return fmt.Errorf("Vault token is not renewable or root") + } + + if data.CreationTTL == 0 && !root { + return fmt.Errorf("invalid lease duration of zero") + } + + if data.TTL == 0 && !root { + return fmt.Errorf("token TTL is zero") + } + + if !root && data.Role == "" { + return fmt.Errorf("token role name must be set when not using a root token") + } + + data.Root = root + v.token = &data + return nil +} + +// Stop stops any goroutine that may be running, either for establishing a Vault +// connection or token renewal. func (v *vaultClient) Stop() { // Nothing to do if !v.enabled { @@ -313,12 +397,21 @@ func (v *vaultClient) Stop() { v.l.Lock() defer v.l.Unlock() - if !v.running { + if !v.renewalRunning || !v.establishingConn { return } close(v.shutdownCh) - v.running = false + v.renewalRunning = false + v.establishingConn = false +} + +// ConnectionEstablished returns whether a connection to Vault has been +// established. +func (v *vaultClient) ConnectionEstablished() bool { + v.l.Lock() + defer v.l.Unlock() + return v.connEstablished } func (v *vaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) { diff --git a/nomad/vault_test.go b/nomad/vault_test.go new file mode 100644 index 000000000..4b8abdd9e --- /dev/null +++ b/nomad/vault_test.go @@ -0,0 +1,81 @@ +package nomad + +import ( + "log" + "os" + "strings" + "testing" + "time" + + "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/hashicorp/nomad/testutil" +) + +func TestVaultClient_BadConfig(t *testing.T) { + conf := &config.VaultConfig{} + logger := log.New(os.Stderr, "", log.LstdFlags) + + // Should be no error since Vault is not enabled + client, err := NewVaultClient(conf, logger) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if client.ConnectionEstablished() { + t.Fatalf("bad") + } + + conf.Enabled = true + _, err = NewVaultClient(conf, logger) + if err == nil || !strings.Contains(err.Error(), "token must be set") { + t.Fatalf("Expected token unset error: %v", err) + } + + conf.Token = "123" + _, err = NewVaultClient(conf, logger) + if err == nil || !strings.Contains(err.Error(), "address must be set") { + t.Fatalf("Expected address unset error: %v", err) + } +} + +// Test that the Vault Client can establish a connection even if it is started +// before Vault is available. +func TestVaultClient_EstablishConnection(t *testing.T) { + v := testutil.NewTestVault(t) + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + client, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + // Sleep a little while and check that no connection has been established. + time.Sleep(100 * time.Duration(testutil.TestMultiplier()) * time.Millisecond) + + if client.ConnectionEstablished() { + t.Fatalf("ConnectionEstablished() returned true before Vault server started") + } + + // Start Vault + v.Start() + + testutil.WaitForResult(func() (bool, error) { + return client.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) + + // Ensure that since we are using a root token that we haven started the + // renewal loop. + if client.renewalRunning { + t.Fatalf("No renewal loop should be running") + } +} + +func TestVaultClient_RenewalLoop(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + +} diff --git a/testutil/vault.go b/testutil/vault.go new file mode 100644 index 000000000..1899c87ea --- /dev/null +++ b/testutil/vault.go @@ -0,0 +1,120 @@ +package testutil + +import ( + "fmt" + "os" + "os/exec" + "testing" + + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" + vapi "github.com/hashicorp/vault/api" +) + +// TestVault is a test helper. It uses a fork/exec model to create a test Vault +// server instance in the background and can be initialized with policies, roles +// and backends mounted. The test Vault instances can be used to run a unit test +// and offers and easy API to tear itself down on test end. The only +// prerequisite is that the Vault binary is on the $PATH. + +const ( + // vaultStartPort is the starting port we use to bind Vault servers to + vaultStartPort uint64 = 40000 +) + +// vaultPortOffset is used to atomically increment the port numbers. +var vaultPortOffset uint64 + +// TestVault wraps a test Vault server launched in dev mode, suitable for +// testing. +type TestVault struct { + cmd *exec.Cmd + t *testing.T + + Addr string + HTTPAddr string + RootToken string + Config *config.VaultConfig + Client *vapi.Client +} + +// NewTestVault returns a new TestVault instance that has yet to be started +func NewTestVault(t *testing.T) *TestVault { + port := getPort() + token := structs.GenerateUUID() + bind := fmt.Sprintf("-dev-listen-address=127.0.0.1:%d", port) + http := fmt.Sprintf("http://127.0.0.1:%d", port) + root := fmt.Sprintf("-dev-root-token-id=%s", token) + + cmd := exec.Command("vault", "server", "-dev", bind, root) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Build the config + conf := vapi.DefaultConfig() + conf.Address = http + + // Make the client and set the token to the root token + client, err := vapi.NewClient(conf) + if err != nil { + t.Fatalf("failed to build Vault API client: %v", err) + } + client.SetToken(root) + + tv := &TestVault{ + cmd: cmd, + t: t, + Addr: bind, + HTTPAddr: http, + RootToken: root, + Client: client, + Config: &config.VaultConfig{ + Enabled: true, + Token: token, + Addr: http, + }, + } + + return tv +} + +// Start starts the test Vault server and waits for it to respond to its HTTP +// API +func (tv *TestVault) Start() *TestVault { + if err := tv.cmd.Start(); err != nil { + tv.t.Fatalf("failed to start vault: %v", err) + } + + tv.waitForAPI() + return tv +} + +// Stop stops the test Vault server +func (tv *TestVault) Stop() { + if err := tv.cmd.Process.Kill(); err != nil { + tv.t.Errorf("err: %s", err) + } + tv.cmd.Wait() +} + +// waitForAPI waits for the Vault HTTP endpoint to start +// responding. This is an indication that the agent has started. +func (tv *TestVault) waitForAPI() { + WaitForResult(func() (bool, error) { + inited, err := tv.Client.Sys().InitStatus() + if err != nil { + return false, err + } + return inited, nil + }, func(err error) { + defer tv.Stop() + tv.t.Fatalf("err: %s", err) + }) +} + +// getPort returns the next available port to bind Vault against +func getPort() uint64 { + p := vaultStartPort + vaultPortOffset + offset += 1 + return p +}