diff --git a/nomad/state/schema.go b/nomad/state/schema.go index 2dde4fc65..f76134873 100644 --- a/nomad/state/schema.go +++ b/nomad/state/schema.go @@ -25,6 +25,7 @@ const ( indexNodeID = "node_id" indexAllocID = "alloc_id" indexServiceName = "service_name" + indexKeyID = "key_id" ) var ( @@ -1230,10 +1231,71 @@ func secureVariablesTableSchema() *memdb.TableSchema { }, }, }, + indexKeyID: { + Name: indexKeyID, + AllowMissing: false, + Indexer: &secureVariableKeyIDFieldIndexer{}, + }, }, } } +type secureVariableKeyIDFieldIndexer struct{} + +// FromArgs implements go-memdb/Indexer and is used to build an exact +// index lookup based on arguments +func (s *secureVariableKeyIDFieldIndexer) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +// PrefixFromArgs implements go-memdb/PrefixIndexer and returns a +// prefix that should be used for scanning based on the arguments +func (s *secureVariableKeyIDFieldIndexer) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +// FromObject implements go-memdb/SingleIndexer and is used to extract +// an index value from an object or to indicate that the index value +// is missing. +func (s *secureVariableKeyIDFieldIndexer) FromObject(obj interface{}) (bool, []byte, error) { + variable, ok := obj.(*structs.SecureVariable) + if !ok { + return false, nil, fmt.Errorf("object %#v is not a SecureVariable", obj) + } + + if variable.EncryptedData == nil { + return false, nil, nil + } + + keyID := variable.EncryptedData.KeyID + if keyID == "" { + return false, nil, nil + } + + // Add the null character as a terminator + keyID += "\x00" + return true, []byte(keyID), nil +} + // secureVariablesQuotasTableSchema returns the MemDB schema for Nomad // secure variables quotas tracking func secureVariablesQuotasTableSchema() *memdb.TableSchema { diff --git a/nomad/state/state_store_secure_variables.go b/nomad/state/state_store_secure_variables.go index 9702a1280..67e0f9e55 100644 --- a/nomad/state/state_store_secure_variables.go +++ b/nomad/state/state_store_secure_variables.go @@ -54,6 +54,21 @@ func (s *StateStore) GetSecureVariablesByNamespaceAndPrefix( return iter, nil } +// GetSecureVariablesByKeyID returns an iterator that contains all +// variables that were encrypted with a particular key +func (s *StateStore) GetSecureVariablesByKeyID( + ws memdb.WatchSet, keyID string) (memdb.ResultIterator, error) { + txn := s.db.ReadTxn() + + iter, err := txn.Get(TableSecureVariables, indexKeyID, keyID) + if err != nil { + return nil, fmt.Errorf("secure variable lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + return iter, nil +} + // GetSecureVariable returns an single secure variable at a given namespace and // path. func (s *StateStore) GetSecureVariable( diff --git a/nomad/state/state_store_secure_variables_test.go b/nomad/state/state_store_secure_variables_test.go index bd3aad925..39cebc1ae 100644 --- a/nomad/state/state_store_secure_variables_test.go +++ b/nomad/state/state_store_secure_variables_test.go @@ -11,6 +11,7 @@ import ( memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/stretchr/testify/require" @@ -383,6 +384,42 @@ func TestStateStore_ListSecureVariablesByNamespaceAndPrefix(t *testing.T) { require.Equal(t, 0, count3) } +func TestStateStore_ListSecureVariablesByKeyID(t *testing.T) { + ci.Parallel(t) + testState := testStateStore(t) + + // Generate some test secure variables and upsert them. + svs, _ := mockSecureVariables(7, 7) + keyID := uuid.Generate() + + expectedForKey := []string{} + for i := 0; i < 5; i++ { + svs[i].EncryptedData.KeyID = keyID + expectedForKey = append(expectedForKey, svs[i].Path) + sort.Strings(expectedForKey) + } + + expectedOrphaned := []string{svs[5].Path, svs[6].Path} + + initialIndex := uint64(10) + require.NoError(t, testState.UpsertSecureVariables( + structs.MsgTypeTestSetup, initialIndex, svs)) + + ws := memdb.NewWatchSet() + iter, err := testState.GetSecureVariablesByKeyID(ws, keyID) + require.NoError(t, err) + + var count int + for raw := iter.Next(); raw != nil; raw = iter.Next() { + sv := raw.(*structs.SecureVariable) + require.Equal(t, keyID, sv.EncryptedData.KeyID) + require.Equal(t, expectedForKey[count], sv.Path) + require.NotContains(t, expectedOrphaned, sv.Path) + count++ + } + require.Equal(t, 5, count) +} + // mockSecureVariables returns a random number of secure variables between min // and max inclusive. func mockSecureVariables(minU, maxU uint8) (