Fix Vault parsing of booleans

This commit is contained in:
Alex Dadgar
2016-10-10 18:04:39 -07:00
parent bf112a51c3
commit 9ff2bf0bff
16 changed files with 208 additions and 59 deletions

View File

@@ -424,7 +424,7 @@ func runnerConfig(config *config.Config, vaultToken string) (*ctconf.Config, err
}
// Setup the Vault config
if config.VaultConfig != nil && config.VaultConfig.Enabled {
if config.VaultConfig != nil && config.VaultConfig.IsEnabled() {
conf.Vault = &ctconf.VaultConfig{
Address: config.VaultConfig.Addr,
Token: vaultToken,
@@ -433,9 +433,10 @@ func runnerConfig(config *config.Config, vaultToken string) (*ctconf.Config, err
set([]string{"vault", "vault.address", "vault.token", "vault.renew_token"})
if strings.HasPrefix(config.VaultConfig.Addr, "https") || config.VaultConfig.TLSCertFile != "" {
verify := config.VaultConfig.TLSSkipVerify == nil || !*config.VaultConfig.TLSSkipVerify
conf.Vault.SSL = &ctconf.SSLConfig{
Enabled: true,
Verify: !config.VaultConfig.TLSSkipVerify,
Verify: !verify,
Cert: config.VaultConfig.TLSCertFile,
Key: config.VaultConfig.TLSKeyFile,
CaCert: config.VaultConfig.TLSCaFile,

View File

@@ -30,7 +30,7 @@ func NewVaultFingerprint(logger *log.Logger) Fingerprint {
}
func (f *VaultFingerprint) Fingerprint(config *client.Config, node *structs.Node) (bool, error) {
if config.VaultConfig == nil || !config.VaultConfig.Enabled {
if config.VaultConfig == nil || !config.VaultConfig.IsEnabled() {
return false, nil
}

View File

@@ -155,7 +155,7 @@ func NewVaultClient(config *config.VaultConfig, logger *log.Logger, tokenDeriver
tokenDeriver: tokenDeriver,
}
if !config.Enabled {
if !config.IsEnabled() {
return c, nil
}
@@ -200,7 +200,7 @@ func (c *vaultClient) isTracked(id string) bool {
// Starts the renewal loop of vault client
func (c *vaultClient) Start() {
if !c.config.Enabled || c.running {
if !c.config.IsEnabled() || c.running {
return
}
@@ -213,7 +213,7 @@ func (c *vaultClient) Start() {
// Stops the renewal loop of vault client
func (c *vaultClient) Stop() {
if !c.config.Enabled || !c.running {
if !c.config.IsEnabled() || !c.running {
return
}
@@ -229,7 +229,7 @@ func (c *vaultClient) Stop() {
// The return value is a map containing all the unwrapped tokens indexed by the
// task name.
func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) {
if !c.config.Enabled {
if !c.config.IsEnabled() {
return nil, fmt.Errorf("vault client not enabled")
}
if !c.running {
@@ -242,7 +242,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
// GetConsulACL creates a vault API client and reads from vault a consul ACL
// token used by the task.
func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error) {
if !c.config.Enabled {
if !c.config.IsEnabled() {
return nil, fmt.Errorf("vault client not enabled")
}
if token == "" {
@@ -350,7 +350,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.config.Enabled {
if !c.config.IsEnabled() {
return fmt.Errorf("vault client not enabled")
}
if !c.running {
@@ -495,12 +495,12 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
// run is the renewal loop which performs the periodic renewals of both the
// tokens and the secret leases.
func (c *vaultClient) run() {
if !c.config.Enabled {
if !c.config.IsEnabled() {
return
}
var renewalCh <-chan time.Time
for c.config.Enabled && c.running {
for c.config.IsEnabled() && c.running {
// Fetches the candidate for next renewal
renewalReq, renewalTime := c.nextRenewal()
if renewalTime.IsZero() {

View File

@@ -52,7 +52,7 @@ func makeAgent(t testing.TB, cb func(*Config)) (string, *Agent) {
}
conf.NodeName = fmt.Sprintf("Node %d", conf.Ports.RPC)
conf.Consul = sconfig.DefaultConsulConfig()
conf.Vault.Enabled = false
conf.Vault.Enabled = new(bool)
// Tighten the Serf timing
config.SerfConfig.MemberlistConfig.SuspicionMult = 2

View File

@@ -21,7 +21,7 @@ import (
"github.com/hashicorp/go-checkpoint"
"github.com/hashicorp/go-syslog"
"github.com/hashicorp/logutils"
"github.com/hashicorp/nomad/helper/flag-slice"
"github.com/hashicorp/nomad/helper/flag-helpers"
"github.com/hashicorp/nomad/helper/gated-writer"
"github.com/hashicorp/nomad/nomad/structs/config"
"github.com/hashicorp/scada-client/scada"
@@ -79,8 +79,8 @@ func (c *Command) readConfig() *Config {
// Server-only options
flags.IntVar(&cmdConfig.Server.BootstrapExpect, "bootstrap-expect", 0, "")
flags.BoolVar(&cmdConfig.Server.RejoinAfterLeave, "rejoin", false, "")
flags.Var((*sliceflag.StringFlag)(&cmdConfig.Server.StartJoin), "join", "")
flags.Var((*sliceflag.StringFlag)(&cmdConfig.Server.RetryJoin), "retry-join", "")
flags.Var((*flaghelper.StringFlag)(&cmdConfig.Server.StartJoin), "join", "")
flags.Var((*flaghelper.StringFlag)(&cmdConfig.Server.RetryJoin), "retry-join", "")
flags.IntVar(&cmdConfig.Server.RetryMaxAttempts, "retry-max", 0, "")
flags.StringVar(&cmdConfig.Server.RetryInterval, "retry-interval", "", "")
@@ -89,12 +89,12 @@ func (c *Command) readConfig() *Config {
flags.StringVar(&cmdConfig.Client.AllocDir, "alloc-dir", "", "")
flags.StringVar(&cmdConfig.Client.NodeClass, "node-class", "", "")
flags.StringVar(&servers, "servers", "", "")
flags.Var((*sliceflag.StringFlag)(&meta), "meta", "")
flags.Var((*flaghelper.StringFlag)(&meta), "meta", "")
flags.StringVar(&cmdConfig.Client.NetworkInterface, "network-interface", "", "")
flags.IntVar(&cmdConfig.Client.NetworkSpeed, "network-speed", 0, "")
// General options
flags.Var((*sliceflag.StringFlag)(&configPath), "config", "config")
flags.Var((*flaghelper.StringFlag)(&configPath), "config", "config")
flags.StringVar(&cmdConfig.BindAddr, "bind", "", "")
flags.StringVar(&cmdConfig.Region, "region", "", "")
flags.StringVar(&cmdConfig.DataDir, "data-dir", "", "")
@@ -108,8 +108,14 @@ func (c *Command) readConfig() *Config {
flags.StringVar(&cmdConfig.Atlas.Token, "atlas-token", "", "")
// Vault options
flags.BoolVar(&cmdConfig.Vault.Enabled, "vault-enabled", false, "")
flags.BoolVar(&cmdConfig.Vault.AllowUnauthenticated, "vault-allow-unauthenticated", false, "")
flags.Var((flaghelper.FuncBoolVar)(func(b bool) error {
cmdConfig.Vault.Enabled = &b
return nil
}), "vault-enabled", "")
flags.Var((flaghelper.FuncBoolVar)(func(b bool) error {
cmdConfig.Vault.AllowUnauthenticated = &b
return nil
}), "vault-allow-unauthenticated", "")
flags.StringVar(&cmdConfig.Vault.Token, "vault-token", "", "")
flags.StringVar(&cmdConfig.Vault.Addr, "vault-address", "", "")

View File

@@ -124,14 +124,14 @@ func TestConfig_Parse(t *testing.T) {
},
Vault: &config.VaultConfig{
Addr: "127.0.0.1:9500",
AllowUnauthenticated: true,
Enabled: false,
AllowUnauthenticated: &trueValue,
Enabled: &falseValue,
TLSCaFile: "/path/to/ca/file",
TLSCaPath: "/path/to/ca",
TLSCertFile: "/path/to/cert/file",
TLSKeyFile: "/path/to/key/file",
TLSServerName: "foobar",
TLSSkipVerify: true,
TLSSkipVerify: &trueValue,
TaskTokenTTL: "1s",
Token: "12345",
},

View File

@@ -12,6 +12,12 @@ import (
"github.com/hashicorp/nomad/nomad/structs/config"
)
var (
// trueValue/falseValue are used to get a pointer to a boolean
trueValue = true
falseValue = false
)
func TestConfig_Merge(t *testing.T) {
c1 := &Config{
Region: "global",
@@ -97,14 +103,14 @@ func TestConfig_Merge(t *testing.T) {
},
Vault: &config.VaultConfig{
Token: "1",
AllowUnauthenticated: false,
AllowUnauthenticated: &falseValue,
TaskTokenTTL: "1",
Addr: "1",
TLSCaFile: "1",
TLSCaPath: "1",
TLSCertFile: "1",
TLSKeyFile: "1",
TLSSkipVerify: false,
TLSSkipVerify: &falseValue,
TLSServerName: "1",
},
Consul: &config.ConsulConfig{
@@ -225,14 +231,14 @@ func TestConfig_Merge(t *testing.T) {
},
Vault: &config.VaultConfig{
Token: "2",
AllowUnauthenticated: true,
AllowUnauthenticated: &trueValue,
TaskTokenTTL: "2",
Addr: "2",
TLSCaFile: "2",
TLSCaPath: "2",
TLSCertFile: "2",
TLSKeyFile: "2",
TLSSkipVerify: true,
TLSSkipVerify: &trueValue,
TLSServerName: "2",
},
Consul: &config.ConsulConfig{

View File

@@ -11,3 +11,17 @@ server {
# Self-elect, should be 3 or 5 for production
bootstrap_expect = 1
}
vault {
address = "https://10.0.0.231:8200"
token = "6e073f4b-4a6d-1fde-812e-7ff65dd3f4fa"
#allow_unauthenticated = true
task_token_ttl = "5m"
#enabled = true
#tls_ca_file = "/etc/ssl/cluster/ca.pem"
#tls_ca_path = "/etc/ssl/cluster"
#tls_cert_file = "/etc/ssl/cluster/cert.pem"
#tls_key_file = "/etc/ssl/cluster/key.pem"
tls_server_name = "vault"
tls_skip_verify = true
}

View File

@@ -0,0 +1,60 @@
package flaghelper
import (
"strconv"
"strings"
"time"
)
// StringFlag implements the flag.Value interface and allows multiple
// calls to the same variable to append a list.
type StringFlag []string
func (s *StringFlag) String() string {
return strings.Join(*s, ",")
}
func (s *StringFlag) Set(value string) error {
*s = append(*s, value)
return nil
}
// FuncVar is a type of flag that accepts a function that is the string
// given
// by the user.
type FuncVar func(s string) error
func (f FuncVar) Set(s string) error { return f(s) }
func (f FuncVar) String() string { return "" }
func (f FuncVar) IsBoolFlag() bool { return false }
// FuncBoolVar is a type of flag that accepts a function, converts the
// user's
// value to a bool, and then calls the given function.
type FuncBoolVar func(b bool) error
func (f FuncBoolVar) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
return err
}
return f(v)
}
func (f FuncBoolVar) String() string { return "" }
func (f FuncBoolVar) IsBoolFlag() bool { return true }
// FuncDurationVar is a type of flag that
// accepts a function, converts the
// user's value to a duration, and then
// calls the given function.
type FuncDurationVar func(d time.Duration) error
func (f FuncDurationVar) Set(s string) error {
v, err := time.ParseDuration(s)
if err != nil {
return err
}
return f(v)
}
func (f FuncDurationVar) String() string { return "" }
func (f FuncDurationVar) IsBoolFlag() bool { return false }

View File

@@ -1,4 +1,4 @@
package sliceflag
package flaghelper
import (
"flag"

View File

@@ -1,16 +0,0 @@
package sliceflag
import "strings"
// StringFlag implements the flag.Value interface and allows multiple
// calls to the same variable to append a list.
type StringFlag []string
func (s *StringFlag) String() string {
return strings.Join(*s, ",")
}
func (s *StringFlag) Set(value string) error {
*s = append(*s, value)
return nil
}

View File

@@ -84,12 +84,12 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis
policies := args.Job.VaultPolicies()
if len(policies) != 0 {
vconf := j.srv.config.VaultConfig
if !vconf.Enabled {
if !vconf.IsEnabled() {
return fmt.Errorf("Vault not enabled and Vault policies requested")
}
// Have to check if the user has permissions
if !vconf.AllowUnauthenticated {
if !vconf.AllowsUnauthenticated() {
if args.Job.VaultToken == "" {
return fmt.Errorf("Vault policies requested but missing Vault Token")
}

View File

@@ -23,7 +23,7 @@ const (
type VaultConfig struct {
// Enabled enables or disables Vault support.
Enabled bool `mapstructure:"enabled"`
Enabled *bool `mapstructure:"enabled"`
// Token is the Vault token given to Nomad such that it can
// derive child tokens. Nomad will renew this token at half its lease
@@ -33,7 +33,7 @@ type VaultConfig struct {
// AllowUnauthenticated allows users to submit jobs requiring Vault tokens
// without providing a Vault token proving they have access to these
// policies.
AllowUnauthenticated bool `mapstructure:"allow_unauthenticated"`
AllowUnauthenticated *bool `mapstructure:"allow_unauthenticated"`
// TaskTokenTTL is the TTL of the tokens created by Nomad Servers and used
// by the client. There should be a minimum time value such that the client
@@ -63,7 +63,7 @@ type VaultConfig struct {
TLSKeyFile string `mapstructure:"tls_key_file"`
// TLSSkipVerify enables or disables SSL verification
TLSSkipVerify bool `mapstructure:"tls_skip_verify"`
TLSSkipVerify *bool `mapstructure:"tls_skip_verify"`
// TLSServerName, if set, is used to set the SNI host when connecting via TLS.
TLSServerName string `mapstructure:"tls_server_name"`
@@ -73,12 +73,22 @@ type VaultConfig struct {
// `vault` configuration.
func DefaultVaultConfig() *VaultConfig {
return &VaultConfig{
AllowUnauthenticated: false,
Addr: "https://vault.service.consul:8200",
ConnectionRetryIntv: DefaultVaultConnectRetryIntv,
Addr: "https://vault.service.consul:8200",
ConnectionRetryIntv: DefaultVaultConnectRetryIntv,
}
}
// IsEnabled returns whether the config enables Vault integration
func (a *VaultConfig) IsEnabled() bool {
return a.Enabled != nil && *a.Enabled
}
// AllowsUnauthenticated returns whether the config allows unauthenticated
// access to Vault
func (a *VaultConfig) AllowsUnauthenticated() bool {
return a.AllowUnauthenticated != nil && *a.AllowUnauthenticated
}
// Merge merges two Vault configurations together.
func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig {
result := *a
@@ -110,10 +120,16 @@ func (a *VaultConfig) Merge(b *VaultConfig) *VaultConfig {
if b.TLSServerName != "" {
result.TLSServerName = b.TLSServerName
}
if b.AllowUnauthenticated != nil {
result.AllowUnauthenticated = b.AllowUnauthenticated
}
if b.TLSSkipVerify != nil {
result.TLSSkipVerify = b.TLSSkipVerify
}
if b.Enabled != nil {
result.Enabled = b.Enabled
}
result.AllowUnauthenticated = b.AllowUnauthenticated
result.TLSSkipVerify = b.TLSSkipVerify
result.Enabled = b.Enabled
return &result
}
@@ -127,8 +143,13 @@ func (c *VaultConfig) ApiConfig() (*vault.Config, error) {
ClientCert: c.TLSCertFile,
ClientKey: c.TLSKeyFile,
TLSServerName: c.TLSServerName,
Insecure: c.TLSSkipVerify,
}
if c.TLSSkipVerify != nil {
tlsConf.Insecure = *c.TLSSkipVerify
} else {
tlsConf.Insecure = false
}
if err := conf.ConfigureTLS(tlsConf); err != nil {
return nil, err
}

View File

@@ -0,0 +1,56 @@
package config
import (
"reflect"
"testing"
)
func TestVaultConfig_Merge(t *testing.T) {
trueValue, falseValue := true, false
c1 := &VaultConfig{
Enabled: &falseValue,
Token: "1",
AllowUnauthenticated: &trueValue,
TaskTokenTTL: "1",
Addr: "1",
TLSCaFile: "1",
TLSCaPath: "1",
TLSCertFile: "1",
TLSKeyFile: "1",
TLSSkipVerify: &trueValue,
TLSServerName: "1",
}
c2 := &VaultConfig{
Enabled: &trueValue,
Token: "2",
AllowUnauthenticated: &falseValue,
TaskTokenTTL: "2",
Addr: "2",
TLSCaFile: "2",
TLSCaPath: "2",
TLSCertFile: "2",
TLSKeyFile: "2",
TLSSkipVerify: nil,
TLSServerName: "2",
}
e := &VaultConfig{
Enabled: &trueValue,
Token: "2",
AllowUnauthenticated: &falseValue,
TaskTokenTTL: "2",
Addr: "2",
TLSCaFile: "2",
TLSCaPath: "2",
TLSCertFile: "2",
TLSKeyFile: "2",
TLSSkipVerify: &trueValue,
TLSServerName: "2",
}
result := c1.Merge(c2)
if !reflect.DeepEqual(result, e) {
t.Fatalf("bad:\n%#v\n%#v", result, e)
}
}

View File

@@ -159,7 +159,7 @@ func NewVaultClient(c *config.VaultConfig, logger *log.Logger, purgeFn PurgeVaul
tomb: &tomb.Tomb{},
}
if v.config.Enabled {
if v.config.IsEnabled() {
if err := v.buildClient(); err != nil {
return nil, err
}
@@ -223,7 +223,7 @@ func (v *vaultClient) SetConfig(config *config.VaultConfig) error {
// Store the new config
v.config = config
if v.config.Enabled {
if v.config.IsEnabled() {
// Stop accepting any new request
atomic.StoreInt32(&v.connEstablished, 0)
@@ -529,7 +529,7 @@ func (v *vaultClient) ConnectionEstablished() bool {
func (v *vaultClient) Enabled() bool {
v.l.Lock()
defer v.l.Unlock()
return v.config.Enabled
return v.config.IsEnabled()
}
//

View File

@@ -61,6 +61,7 @@ func NewTestVault(t *testing.T) *TestVault {
}
client.SetToken(token)
enable := true
tv := &TestVault{
cmd: cmd,
t: t,
@@ -69,7 +70,7 @@ func NewTestVault(t *testing.T) *TestVault {
RootToken: token,
Client: client,
Config: &config.VaultConfig{
Enabled: true,
Enabled: &enable,
Token: token,
Addr: http,
},