diff --git a/nomad/state/schema.go b/nomad/state/schema.go index e451c35a2..dc2adb171 100644 --- a/nomad/state/schema.go +++ b/nomad/state/schema.go @@ -49,6 +49,7 @@ func init() { siTokenAccessorTableSchema, aclPolicyTableSchema, aclTokenTableSchema, + oneTimeTokenTableSchema, autopilotConfigTableSchema, schedulerConfigTableSchema, clusterMetaTableSchema, @@ -651,6 +652,32 @@ func aclTokenTableSchema() *memdb.TableSchema { } } +// oneTimeTokenTableSchema returns the MemDB schema for the tokens table. +// This table is used to store one-time tokens for ACL tokens +func oneTimeTokenTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "one_time_token", + Indexes: map[string]*memdb.IndexSchema{ + "secret": { + Name: "secret", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "OneTimeSecretID", + }, + }, + "id": { + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.UUIDFieldIndex{ + Field: "AccessorID", + }, + }, + }, + } +} + // singletonRecord can be used to describe tables which should contain only 1 entry. // Example uses include storing node config or cluster metadata blobs. var singletonRecord = &memdb.ConditionalIndex{ diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 2317631e2..d604b3009 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -5333,6 +5333,102 @@ func (s *StateStore) BootstrapACLTokens(msgType structs.MessageType, index uint6 return txn.Commit() } +// UpsertOneTimeToken is used to create or update a set of ACL +// tokens. Validating that we're not upserting an already-expired token is +// made the responsibility of the caller to facilitate testing. +func (s *StateStore) UpsertOneTimeToken(msgType structs.MessageType, index uint64, token *structs.OneTimeToken) error { + txn := s.db.WriteTxnMsgT(msgType, index) + defer txn.Abort() + + // we expect the RPC call to set the ExpiresAt + if token.ExpiresAt.IsZero() { + return fmt.Errorf("one-time token must have an ExpiresAt time") + } + + // Update all the indexes + token.CreateIndex = index + token.ModifyIndex = index + + // Create the token + if err := txn.Insert("one_time_token", token); err != nil { + return fmt.Errorf("upserting one-time token failed: %v", err) + } + + // Update the indexes table + if err := txn.Insert("index", &IndexEntry{"one_time_token", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + return txn.Commit() +} + +// DeleteOneTimeTokens deletes the tokens with the given ACLToken Accessor IDs +func (s *StateStore) DeleteOneTimeTokens(msgType structs.MessageType, index uint64, ids []string) error { + txn := s.db.WriteTxnMsgT(msgType, index) + defer txn.Abort() + + var deleted int + for _, id := range ids { + d, err := txn.DeleteAll("one_time_token", "id", id) + if err != nil { + return fmt.Errorf("deleting one-time token failed: %v", err) + } + deleted += d + } + if deleted > 0 { + if err := txn.Insert("index", &IndexEntry{"one_time_token", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + } + return txn.Commit() +} + +// OneTimeTokenBySecret is used to lookup a token by secret +func (s *StateStore) OneTimeTokenBySecret(ws memdb.WatchSet, secret string) (*structs.OneTimeToken, error) { + if secret == "" { + return nil, fmt.Errorf("one-time token lookup failed: missing secret") + } + + txn := s.db.ReadTxn() + + watchCh, existing, err := txn.FirstWatch("one_time_token", "secret", secret) + if err != nil { + return nil, fmt.Errorf("one-time token lookup failed: %v", err) + } + ws.Add(watchCh) + + if existing != nil { + return existing.(*structs.OneTimeToken), nil + } + return nil, nil +} + +// OneTimeTokensExpired returns an iterator over all expired one-time tokens +func (s *StateStore) OneTimeTokensExpired(ws memdb.WatchSet) (memdb.ResultIterator, error) { + txn := s.db.ReadTxn() + + iter, err := txn.Get("one_time_token", "id") + if err != nil { + return nil, fmt.Errorf("one-time token lookup failed: %v", err) + } + + ws.Add(iter.WatchCh()) + iter = memdb.NewFilterIterator(iter, expiredOneTimeTokenFilter(time.Now())) + return iter, nil +} + +// expiredOneTimeTokenFilter returns a filter function that returns only +// expired one-time tokens +func expiredOneTimeTokenFilter(now time.Time) func(interface{}) bool { + return func(raw interface{}) bool { + ott, ok := raw.(*structs.OneTimeToken) + if !ok { + return true + } + + return ott.ExpiresAt.After(now) + } +} + // SchedulerConfig is used to get the current Scheduler configuration. func (s *StateStore) SchedulerConfig() (uint64, *structs.SchedulerConfiguration, error) { tx := s.db.ReadTxn() @@ -6178,6 +6274,14 @@ func (r *StateRestore) ACLTokenRestore(token *structs.ACLToken) error { return nil } +// OneTimeTokenRestore is used to restore a one-time token +func (r *StateRestore) OneTimeTokenRestore(token *structs.OneTimeToken) error { + if err := r.txn.Insert("one_time_token", token); err != nil { + return fmt.Errorf("inserting one-time token failed: %v", err) + } + return nil +} + func (r *StateRestore) SchedulerConfigRestore(schedConfig *structs.SchedulerConfiguration) error { if err := r.txn.Insert("scheduler_config", schedConfig); err != nil { return fmt.Errorf("inserting scheduler config failed: %s", err) diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 8f01cac78..2d8039d10 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -8493,6 +8493,114 @@ func TestStateStore_RestoreACLToken(t *testing.T) { assert.Equal(t, token, out) } +func TestStateStore_OneTimeTokens(t *testing.T) { + t.Parallel() + index := uint64(100) + state := testStateStore(t) + + // create some ACL tokens + + token1 := mock.ACLToken() + token2 := mock.ACLToken() + token3 := mock.ACLToken() + index++ + require.Nil(t, state.UpsertACLTokens( + structs.MsgTypeTestSetup, index, + []*structs.ACLToken{token1, token2, token3})) + + otts := []*structs.OneTimeToken{ + { + // expired OTT for token1 + OneTimeSecretID: uuid.Generate(), + AccessorID: token1.AccessorID, + ExpiresAt: time.Now().Add(-1 * time.Minute), + }, + { + // valid OTT for token2 + OneTimeSecretID: uuid.Generate(), + AccessorID: token2.AccessorID, + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + { + // new but expired OTT for token2; this will be accepted even + // though it's expired and overwrite the other one + OneTimeSecretID: uuid.Generate(), + AccessorID: token2.AccessorID, + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + { + // valid OTT for token3 + AccessorID: token3.AccessorID, + OneTimeSecretID: uuid.Generate(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + { + // new valid OTT for token3 + OneTimeSecretID: uuid.Generate(), + AccessorID: token3.AccessorID, + ExpiresAt: time.Now().Add(5 * time.Minute), + }, + } + + for _, ott := range otts { + index++ + require.NoError(t, state.UpsertOneTimeToken(structs.MsgTypeTestSetup, index, ott)) + } + + getExpiredTokens := func() []*structs.OneTimeToken { + // find all the expired tokens + iter, err := state.OneTimeTokensExpired(nil) + require.NoError(t, err) + + results := []*structs.OneTimeToken{} + for { + raw := iter.Next() + if raw == nil { + break + } + ott, ok := raw.(*structs.OneTimeToken) + require.True(t, ok) + results = append(results, ott) + } + return results + } + + results := getExpiredTokens() + require.Len(t, results, 2) + + // results aren't ordered + expiredAccessors := []string{results[0].AccessorID, results[1].AccessorID} + require.Contains(t, expiredAccessors, token1.AccessorID) + require.Contains(t, expiredAccessors, token2.AccessorID) + require.True(t, time.Now().After(results[0].ExpiresAt)) + require.True(t, time.Now().After(results[1].ExpiresAt)) + + // clear the expired tokens and verify they're gone + index++ + require.NoError(t, + state.DeleteOneTimeTokens(structs.MsgTypeTestSetup, index, + []string{results[0].AccessorID, results[1].AccessorID})) + + results = getExpiredTokens() + require.Len(t, results, 0) + + // query the unexpired token + ott, err := state.OneTimeTokenBySecret(nil, otts[len(otts)-1].OneTimeSecretID) + require.NoError(t, err) + require.Equal(t, token3.AccessorID, ott.AccessorID) + require.True(t, time.Now().Before(ott.ExpiresAt)) + + restore, err := state.Restore() + require.NoError(t, err) + err = restore.OneTimeTokenRestore(ott) + require.NoError(t, err) + require.NoError(t, restore.Commit()) + + ott, err = state.OneTimeTokenBySecret(nil, otts[len(otts)-1].OneTimeSecretID) + require.NoError(t, err) + require.Equal(t, token3.AccessorID, ott.AccessorID) +} + func TestStateStore_SchedulerConfig(t *testing.T) { t.Parallel() diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 8e574cee8..efc8d1b35 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -11040,6 +11040,16 @@ type ACLTokenUpsertResponse struct { WriteMeta } +// OneTimeToken is used to log into the web UI using a token provided by the +// command line. +type OneTimeToken struct { + OneTimeSecretID string + AccessorID string + ExpiresAt time.Time + CreateIndex uint64 + ModifyIndex uint64 +} + // RpcError is used for serializing errors with a potential error code type RpcError struct { Message string