config: Validate keyring config to catch invalid provider types. (#26673)

This commit is contained in:
James Rasell
2025-09-02 11:07:49 +01:00
committed by GitHub
parent 267dc72f4e
commit cddc1b0127
5 changed files with 113 additions and 1 deletions

3
.changelog/26673.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
config: Validate the `keyring` configuration block label against supported values on agent startup
```

View File

@@ -546,6 +546,13 @@ func (c *Command) IsValidConfig(config, cmdConfig *Config) bool {
c.Ui.Warn("Please remove deprecated protocol_version field from config.") c.Ui.Warn("Please remove deprecated protocol_version field from config.")
} }
for _, keyring := range config.KEKProviders {
if err := keyring.Validate(); err != nil {
c.Ui.Error(fmt.Sprintf("keyring %q invalid: %v", keyring.Name, err))
return false
}
}
return true return true
} }
@@ -1580,7 +1587,7 @@ Client Options:
-network-interface -network-interface
Forces the network fingerprinter to use the specified network interface. Forces the network fingerprinter to use the specified network interface.
-preferred-address-family -preferred-address-family
Specify which IP family to prefer when selecting an IP address of the Specify which IP family to prefer when selecting an IP address of the
network interface. Valid values are "ipv4" and "ipv6". When not specified, network interface. Valid values are "ipv4" and "ipv6". When not specified,

View File

@@ -485,6 +485,23 @@ func TestIsValidConfig(t *testing.T) {
}, },
err: "missing protocol scheme", err: "missing protocol scheme",
}, },
{
name: "invalidate keyring provider",
conf: Config{
DataDir: "/tmp",
Server: &ServerConfig{
BootstrapExpect: 1,
Enabled: true,
},
KEKProviders: []*structs.KEKProviderConfig{
{
Name: "invalid",
Provider: "foo",
},
},
},
err: "unknown keyring provider",
},
} }
for _, tc := range cases { for _, tc := range cases {

View File

@@ -297,6 +297,22 @@ type KEKProviderConfig struct {
ExtraKeysHCL []string `hcl:",unusedKeys" json:"-"` ExtraKeysHCL []string `hcl:",unusedKeys" json:"-"`
} }
// Validate checks that the KEKProviderConfig is valid.
func (c *KEKProviderConfig) Validate() error {
if c == nil {
return nil
}
switch KEKProviderName(c.Provider) {
case KEKProviderAEAD, KEKProviderAWSKMS, KEKProviderAzureKeyVault,
KEKProviderGCPCloudKMS, KEKProviderVaultTransit:
return nil
default:
return fmt.Errorf("unknown keyring provider: %q", c.Provider)
}
}
func (c *KEKProviderConfig) Copy() *KEKProviderConfig { func (c *KEKProviderConfig) Copy() *KEKProviderConfig {
return &KEKProviderConfig{ return &KEKProviderConfig{
Provider: c.Provider, Provider: c.Provider,

View File

@@ -10,6 +10,75 @@ import (
"github.com/shoenig/test/must" "github.com/shoenig/test/must"
) )
func TestKEKProviderConfig_Validate(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
inputKeyringConfig *KEKProviderConfig
expectedErrorContains string
}{
{
name: "nil",
inputKeyringConfig: nil,
expectedErrorContains: "",
},
{
name: "aead",
inputKeyringConfig: &KEKProviderConfig{
Provider: "aead",
},
expectedErrorContains: "",
},
{
name: "awskms",
inputKeyringConfig: &KEKProviderConfig{
Provider: "awskms",
},
expectedErrorContains: "",
},
{
name: "azurekeyvault",
inputKeyringConfig: &KEKProviderConfig{
Provider: "azurekeyvault",
},
expectedErrorContains: "",
},
{
name: "gcpckms",
inputKeyringConfig: &KEKProviderConfig{
Provider: "gcpckms",
},
expectedErrorContains: "",
},
{
name: "transit",
inputKeyringConfig: &KEKProviderConfig{
Provider: "transit",
},
expectedErrorContains: "",
},
{
name: "unknown",
inputKeyringConfig: &KEKProviderConfig{
Provider: "unknown",
},
expectedErrorContains: "unknown keyring provider",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualError := tc.inputKeyringConfig.Validate()
if tc.expectedErrorContains == "" {
must.NoError(t, actualError)
} else {
must.ErrorContains(t, actualError, tc.expectedErrorContains)
}
})
}
}
func TestKeyring_OIDCDiscoveryConfig(t *testing.T) { func TestKeyring_OIDCDiscoveryConfig(t *testing.T) {
ci.Parallel(t) ci.Parallel(t)