diff --git a/client/consul_template.go b/client/consul_template.go index f306306ba..bc5741520 100644 --- a/client/consul_template.go +++ b/client/consul_template.go @@ -424,7 +424,7 @@ func runnerConfig(config *config.Config, vaultToken string) (*ctconf.Config, err } // Setup the Vault config - if config.VaultConfig != nil && config.VaultConfig.Enabled { + if config.VaultConfig != nil && config.VaultConfig.IsEnabled() { conf.Vault = &ctconf.VaultConfig{ Address: config.VaultConfig.Addr, Token: vaultToken, @@ -433,9 +433,10 @@ func runnerConfig(config *config.Config, vaultToken string) (*ctconf.Config, err set([]string{"vault", "vault.address", "vault.token", "vault.renew_token"}) if strings.HasPrefix(config.VaultConfig.Addr, "https") || config.VaultConfig.TLSCertFile != "" { + verify := config.VaultConfig.TLSSkipVerify == nil || !*config.VaultConfig.TLSSkipVerify conf.Vault.SSL = &ctconf.SSLConfig{ Enabled: true, - Verify: !config.VaultConfig.TLSSkipVerify, + Verify: !verify, Cert: config.VaultConfig.TLSCertFile, Key: config.VaultConfig.TLSKeyFile, CaCert: config.VaultConfig.TLSCaFile, diff --git a/client/fingerprint/vault.go b/client/fingerprint/vault.go index 4a98ee336..acc331c0f 100644 --- a/client/fingerprint/vault.go +++ b/client/fingerprint/vault.go @@ -30,7 +30,7 @@ func NewVaultFingerprint(logger *log.Logger) Fingerprint { } func (f *VaultFingerprint) Fingerprint(config *client.Config, node *structs.Node) (bool, error) { - if config.VaultConfig == nil || !config.VaultConfig.Enabled { + if config.VaultConfig == nil || !config.VaultConfig.IsEnabled() { return false, nil } diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 8f24b8273..2f9f50be2 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -155,7 +155,7 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver tokenDeriver: tokenDeriver, } - if !config.Enabled { + if !config.IsEnabled() { return c, nil } @@ -200,7 +200,7 @@ func (c *vaultClient) isTracked(id string) bool { // Starts the renewal loop of vault client func (c *vaultClient) Start() { - if !c.config.Enabled || c.running { + if !c.config.IsEnabled() || c.running { return } @@ -213,7 +213,7 @@ func (c *vaultClient) Start() { // Stops the renewal loop of vault client func (c *vaultClient) Stop() { - if !c.config.Enabled || !c.running { + if !c.config.IsEnabled() || !c.running { return } @@ -229,7 +229,7 @@ func (c *vaultClient) Stop() { // The return value is a map containing all the unwrapped tokens indexed by the // task name. func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { - if !c.config.Enabled { + if !c.config.IsEnabled() { return nil, fmt.Errorf("vault client not enabled") } if !c.running { @@ -242,7 +242,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) // GetConsulACL creates a vault API client and reads from vault a consul ACL // token used by the task. func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) { - if !c.config.Enabled { + if !c.config.IsEnabled() { return nil, fmt.Errorf("vault client not enabled") } if token == "" { @@ -350,7 +350,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { c.lock.Lock() defer c.lock.Unlock() - if !c.config.Enabled { + if !c.config.IsEnabled() { return fmt.Errorf("vault client not enabled") } if !c.running { @@ -495,12 +495,12 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { // run is the renewal loop which performs the periodic renewals of both the // tokens and the secret leases. func (c *vaultClient) run() { - if !c.config.Enabled { + if !c.config.IsEnabled() { return } var renewalCh <-chan time.Time - for c.config.Enabled && c.running { + for c.config.IsEnabled() && c.running { // Fetches the candidate for next renewal renewalReq, renewalTime := c.nextRenewal() if renewalTime.IsZero() { diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index a42dbbab5..aeb445d0a 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -52,7 +52,7 @@ func makeAgent(t testing.TB, cb func(*Config)) (string, *Agent) { } conf.NodeName = fmt.Sprintf("Node %d", conf.Ports.RPC) conf.Consul = sconfig.DefaultConsulConfig() - conf.Vault.Enabled = false + conf.Vault.Enabled = new(bool) // Tighten the Serf timing config.SerfConfig.MemberlistConfig.SuspicionMult = 2 diff --git a/command/agent/command.go b/command/agent/command.go index 8e999c60a..e16fec80b 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -21,7 +21,7 @@ import ( "github.com/hashicorp/go-checkpoint" "github.com/hashicorp/go-syslog" "github.com/hashicorp/logutils" - "github.com/hashicorp/nomad/helper/flag-slice" + "github.com/hashicorp/nomad/helper/flag-helpers" "github.com/hashicorp/nomad/helper/gated-writer" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/scada-client/scada" @@ -79,8 +79,8 @@ func (c *Command) readConfig() *Config { // Server-only options flags.IntVar(&cmdConfig.Server.BootstrapExpect, "bootstrap-expect", 0, "") flags.BoolVar(&cmdConfig.Server.RejoinAfterLeave, "rejoin", false, "") - flags.Var((*sliceflag.StringFlag)(&cmdConfig.Server.StartJoin), "join", "") - flags.Var((*sliceflag.StringFlag)(&cmdConfig.Server.RetryJoin), "retry-join", "") + flags.Var((*flaghelper.StringFlag)(&cmdConfig.Server.StartJoin), "join", "") + flags.Var((*flaghelper.StringFlag)(&cmdConfig.Server.RetryJoin), "retry-join", "") flags.IntVar(&cmdConfig.Server.RetryMaxAttempts, "retry-max", 0, "") flags.StringVar(&cmdConfig.Server.RetryInterval, "retry-interval", "", "") @@ -89,12 +89,12 @@ func (c *Command) readConfig() *Config { flags.StringVar(&cmdConfig.Client.AllocDir, "alloc-dir", "", "") flags.StringVar(&cmdConfig.Client.NodeClass, "node-class", "", "") flags.StringVar(&servers, "servers", "", "") - flags.Var((*sliceflag.StringFlag)(&meta), "meta", "") + flags.Var((*flaghelper.StringFlag)(&meta), "meta", "") flags.StringVar(&cmdConfig.Client.NetworkInterface, "network-interface", "", "") flags.IntVar(&cmdConfig.Client.NetworkSpeed, "network-speed", 0, "") // General options - flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config") + flags.Var((*flaghelper.StringFlag)(&configPath), "config", "config") flags.StringVar(&cmdConfig.BindAddr, "bind", "", "") flags.StringVar(&cmdConfig.Region, "region", "", "") flags.StringVar(&cmdConfig.DataDir, "data-dir", "", "") @@ -108,8 +108,14 @@ func (c *Command) readConfig() *Config { flags.StringVar(&cmdConfig.Atlas.Token, "atlas-token", "", "") // Vault options - flags.BoolVar(&cmdConfig.Vault.Enabled, "vault-enabled", false, "") - flags.BoolVar(&cmdConfig.Vault.AllowUnauthenticated, "vault-allow-unauthenticated", false, "") + flags.Var((flaghelper.FuncBoolVar)(func(b bool) error { + cmdConfig.Vault.Enabled = &b + return nil + }), "vault-enabled", "") + flags.Var((flaghelper.FuncBoolVar)(func(b bool) error { + cmdConfig.Vault.AllowUnauthenticated = &b + return nil + }), "vault-allow-unauthenticated", "") flags.StringVar(&cmdConfig.Vault.Token, "vault-token", "", "") flags.StringVar(&cmdConfig.Vault.Addr, "vault-address", "", "") diff --git a/command/agent/config_parse_test.go b/command/agent/config_parse_test.go index 1809fffee..951c60ee7 100644 --- a/command/agent/config_parse_test.go +++ b/command/agent/config_parse_test.go @@ -124,14 +124,14 @@ func TestConfig_Parse(t *testing.T) { }, Vault: &config.VaultConfig{ Addr: "127.0.0.1:9500", - AllowUnauthenticated: true, - Enabled: false, + AllowUnauthenticated: &trueValue, + Enabled: &falseValue, TLSCaFile: "/path/to/ca/file", TLSCaPath: "/path/to/ca", TLSCertFile: "/path/to/cert/file", TLSKeyFile: "/path/to/key/file", TLSServerName: "foobar", - TLSSkipVerify: true, + TLSSkipVerify: &trueValue, TaskTokenTTL: "1s", Token: "12345", }, diff --git a/command/agent/config_test.go b/command/agent/config_test.go index d83ad140a..ef3ab9a14 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -12,6 +12,12 @@ import ( "github.com/hashicorp/nomad/nomad/structs/config" ) +var ( + // trueValue/falseValue are used to get a pointer to a boolean + trueValue = true + falseValue = false +) + func TestConfig_Merge(t *testing.T) { c1 := &Config{ Region: "global", @@ -97,14 +103,14 @@ func TestConfig_Merge(t *testing.T) { }, Vault: &config.VaultConfig{ Token: "1", - AllowUnauthenticated: false, + AllowUnauthenticated: &falseValue, TaskTokenTTL: "1", Addr: "1", TLSCaFile: "1", TLSCaPath: "1", TLSCertFile: "1", TLSKeyFile: "1", - TLSSkipVerify: false, + TLSSkipVerify: &falseValue, TLSServerName: "1", }, Consul: &config.ConsulConfig{ @@ -225,14 +231,14 @@ func TestConfig_Merge(t *testing.T) { }, Vault: &config.VaultConfig{ Token: "2", - AllowUnauthenticated: true, + AllowUnauthenticated: &trueValue, TaskTokenTTL: "2", Addr: "2", TLSCaFile: "2", TLSCaPath: "2", TLSCertFile: "2", TLSKeyFile: "2", - TLSSkipVerify: true, + TLSSkipVerify: &trueValue, TLSServerName: "2", }, Consul: &config.ConsulConfig{ diff --git a/demo/vagrant/server.hcl b/demo/vagrant/server.hcl index 653b2c037..cc098c035 100644 --- a/demo/vagrant/server.hcl +++ b/demo/vagrant/server.hcl @@ -11,3 +11,17 @@ server { # Self-elect, should be 3 or 5 for production bootstrap_expect = 1 } + +vault { + address = "https://10.0.0.231:8200" + token = "6e073f4b-4a6d-1fde-812e-7ff65dd3f4fa" + #allow_unauthenticated = true + task_token_ttl = "5m" + #enabled = true + #tls_ca_file = "/etc/ssl/cluster/ca.pem" + #tls_ca_path = "/etc/ssl/cluster" + #tls_cert_file = "/etc/ssl/cluster/cert.pem" + #tls_key_file = "/etc/ssl/cluster/key.pem" + tls_server_name = "vault" + tls_skip_verify = true +} diff --git a/helper/flag-helpers/flag.go b/helper/flag-helpers/flag.go new file mode 100644 index 000000000..10a5644e2 --- /dev/null +++ b/helper/flag-helpers/flag.go @@ -0,0 +1,60 @@ +package flaghelper + +import ( + "strconv" + "strings" + "time" +) + +// StringFlag implements the flag.Value interface and allows multiple +// calls to the same variable to append a list. +type StringFlag []string + +func (s *StringFlag) String() string { + return strings.Join(*s, ",") +} + +func (s *StringFlag) Set(value string) error { + *s = append(*s, value) + return nil +} + +// FuncVar is a type of flag that accepts a function that is the string +// given +// by the user. +type FuncVar func(s string) error + +func (f FuncVar) Set(s string) error { return f(s) } +func (f FuncVar) String() string { return "" } +func (f FuncVar) IsBoolFlag() bool { return false } + +// FuncBoolVar is a type of flag that accepts a function, converts the +// user's +// value to a bool, and then calls the given function. +type FuncBoolVar func(b bool) error + +func (f FuncBoolVar) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + return f(v) +} +func (f FuncBoolVar) String() string { return "" } +func (f FuncBoolVar) IsBoolFlag() bool { return true } + +// FuncDurationVar is a type of flag that +// accepts a function, converts the +// user's value to a duration, and then +// calls the given function. +type FuncDurationVar func(d time.Duration) error + +func (f FuncDurationVar) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + return err + } + return f(v) +} +func (f FuncDurationVar) String() string { return "" } +func (f FuncDurationVar) IsBoolFlag() bool { return false } diff --git a/helper/flag-slice/flag_test.go b/helper/flag-helpers/flag_test.go similarity index 96% rename from helper/flag-slice/flag_test.go rename to helper/flag-helpers/flag_test.go index f72e1d960..7893c0e15 100644 --- a/helper/flag-slice/flag_test.go +++ b/helper/flag-helpers/flag_test.go @@ -1,4 +1,4 @@ -package sliceflag +package flaghelper import ( "flag" diff --git a/helper/flag-slice/flag.go b/helper/flag-slice/flag.go deleted file mode 100644 index da75149dc..000000000 --- a/helper/flag-slice/flag.go +++ /dev/null @@ -1,16 +0,0 @@ -package sliceflag - -import "strings" - -// StringFlag implements the flag.Value interface and allows multiple -// calls to the same variable to append a list. -type StringFlag []string - -func (s *StringFlag) String() string { - return strings.Join(*s, ",") -} - -func (s *StringFlag) Set(value string) error { - *s = append(*s, value) - return nil -} diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 874af7a05..fd27eb606 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -84,12 +84,12 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis policies := args.Job.VaultPolicies() if len(policies) != 0 { vconf := j.srv.config.VaultConfig - if !vconf.Enabled { + if !vconf.IsEnabled() { return fmt.Errorf("Vault not enabled and Vault policies requested") } // Have to check if the user has permissions - if !vconf.AllowUnauthenticated { + if !vconf.AllowsUnauthenticated() { if args.Job.VaultToken == "" { return fmt.Errorf("Vault policies requested but missing Vault Token") } diff --git a/nomad/structs/config/vault.go b/nomad/structs/config/vault.go index d97909f18..2ad10c4e9 100644 --- a/nomad/structs/config/vault.go +++ b/nomad/structs/config/vault.go @@ -23,7 +23,7 @@ const ( type VaultConfig struct { // Enabled enables or disables Vault support. - Enabled bool `mapstructure:"enabled"` + Enabled *bool `mapstructure:"enabled"` // Token is the Vault token given to Nomad such that it can // derive child tokens. Nomad will renew this token at half its lease @@ -33,7 +33,7 @@ type VaultConfig struct { // AllowUnauthenticated allows users to submit jobs requiring Vault tokens // without providing a Vault token proving they have access to these // policies. - AllowUnauthenticated bool `mapstructure:"allow_unauthenticated"` + AllowUnauthenticated *bool `mapstructure:"allow_unauthenticated"` // TaskTokenTTL is the TTL of the tokens created by Nomad Servers and used // by the client. There should be a minimum time value such that the client @@ -63,7 +63,7 @@ type VaultConfig struct { TLSKeyFile string `mapstructure:"tls_key_file"` // TLSSkipVerify enables or disables SSL verification - TLSSkipVerify bool `mapstructure:"tls_skip_verify"` + TLSSkipVerify *bool `mapstructure:"tls_skip_verify"` // TLSServerName, if set, is used to set the SNI host when connecting via TLS. TLSServerName string `mapstructure:"tls_server_name"` @@ -73,12 +73,22 @@ type VaultConfig struct { // `vault` configuration. func DefaultVaultConfig() *VaultConfig { return &VaultConfig{ - AllowUnauthenticated: false, - Addr: "https://vault.service.consul:8200", - ConnectionRetryIntv: DefaultVaultConnectRetryIntv, + Addr: "https://vault.service.consul:8200", + ConnectionRetryIntv: DefaultVaultConnectRetryIntv, } } +// IsEnabled returns whether the config enables Vault integration +func (a *VaultConfig) IsEnabled() bool { + return a.Enabled != nil && *a.Enabled +} + +// AllowsUnauthenticated returns whether the config allows unauthenticated +// access to Vault +func (a *VaultConfig) AllowsUnauthenticated() bool { + return a.AllowUnauthenticated != nil && *a.AllowUnauthenticated +} + // Merge merges two Vault configurations together. func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig { result := *a @@ -110,10 +120,16 @@ func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig { if b.TLSServerName != "" { result.TLSServerName = b.TLSServerName } + if b.AllowUnauthenticated != nil { + result.AllowUnauthenticated = b.AllowUnauthenticated + } + if b.TLSSkipVerify != nil { + result.TLSSkipVerify = b.TLSSkipVerify + } + if b.Enabled != nil { + result.Enabled = b.Enabled + } - result.AllowUnauthenticated = b.AllowUnauthenticated - result.TLSSkipVerify = b.TLSSkipVerify - result.Enabled = b.Enabled return &result } @@ -127,8 +143,13 @@ func (c *VaultConfig) ApiConfig() (*vault.Config, error) { ClientCert: c.TLSCertFile, ClientKey: c.TLSKeyFile, TLSServerName: c.TLSServerName, - Insecure: c.TLSSkipVerify, } + if c.TLSSkipVerify != nil { + tlsConf.Insecure = *c.TLSSkipVerify + } else { + tlsConf.Insecure = false + } + if err := conf.ConfigureTLS(tlsConf); err != nil { return nil, err } diff --git a/nomad/structs/config/vault_test.go b/nomad/structs/config/vault_test.go new file mode 100644 index 000000000..29f7ea590 --- /dev/null +++ b/nomad/structs/config/vault_test.go @@ -0,0 +1,56 @@ +package config + +import ( + "reflect" + "testing" +) + +func TestVaultConfig_Merge(t *testing.T) { + trueValue, falseValue := true, false + c1 := &VaultConfig{ + Enabled: &falseValue, + Token: "1", + AllowUnauthenticated: &trueValue, + TaskTokenTTL: "1", + Addr: "1", + TLSCaFile: "1", + TLSCaPath: "1", + TLSCertFile: "1", + TLSKeyFile: "1", + TLSSkipVerify: &trueValue, + TLSServerName: "1", + } + + c2 := &VaultConfig{ + Enabled: &trueValue, + Token: "2", + AllowUnauthenticated: &falseValue, + TaskTokenTTL: "2", + Addr: "2", + TLSCaFile: "2", + TLSCaPath: "2", + TLSCertFile: "2", + TLSKeyFile: "2", + TLSSkipVerify: nil, + TLSServerName: "2", + } + + e := &VaultConfig{ + Enabled: &trueValue, + Token: "2", + AllowUnauthenticated: &falseValue, + TaskTokenTTL: "2", + Addr: "2", + TLSCaFile: "2", + TLSCaPath: "2", + TLSCertFile: "2", + TLSKeyFile: "2", + TLSSkipVerify: &trueValue, + TLSServerName: "2", + } + + result := c1.Merge(c2) + if !reflect.DeepEqual(result, e) { + t.Fatalf("bad:\n%#v\n%#v", result, e) + } +} diff --git a/nomad/vault.go b/nomad/vault.go index 73b71a02c..026e8343d 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -159,7 +159,7 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger, purgeFn PurgeVaul tomb: &tomb.Tomb{}, } - if v.config.Enabled { + if v.config.IsEnabled() { if err := v.buildClient(); err != nil { return nil, err } @@ -223,7 +223,7 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error { // Store the new config v.config = config - if v.config.Enabled { + if v.config.IsEnabled() { // Stop accepting any new request atomic.StoreInt32(&v.connEstablished, 0) @@ -529,7 +529,7 @@ func (v *vaultClient) ConnectionEstablished() bool { func (v *vaultClient) Enabled() bool { v.l.Lock() defer v.l.Unlock() - return v.config.Enabled + return v.config.IsEnabled() } // diff --git a/testutil/vault.go b/testutil/vault.go index 9b52bf118..23241194d 100644 --- a/testutil/vault.go +++ b/testutil/vault.go @@ -61,6 +61,7 @@ func NewTestVault(t *testing.T) *TestVault { } client.SetToken(token) + enable := true tv := &TestVault{ cmd: cmd, t: t, @@ -69,7 +70,7 @@ func NewTestVault(t *testing.T) *TestVault { RootToken: token, Client: client, Config: &config.VaultConfig{ - Enabled: true, + Enabled: &enable, Token: token, Addr: http, },