mirror of
https://github.com/kemko/nomad.git
synced 2026-01-09 20:05:42 +03:00
state store updates for one-time tokens
The `OneTimeToken` struct is to support the `nomad ui -login` command. This changeset adds the struct to the Nomad state store.
This commit is contained in:
@@ -49,6 +49,7 @@ func init() {
|
||||
siTokenAccessorTableSchema,
|
||||
aclPolicyTableSchema,
|
||||
aclTokenTableSchema,
|
||||
oneTimeTokenTableSchema,
|
||||
autopilotConfigTableSchema,
|
||||
schedulerConfigTableSchema,
|
||||
clusterMetaTableSchema,
|
||||
@@ -651,6 +652,32 @@ func aclTokenTableSchema() *memdb.TableSchema {
|
||||
}
|
||||
}
|
||||
|
||||
// oneTimeTokenTableSchema returns the MemDB schema for the tokens table.
|
||||
// This table is used to store one-time tokens for ACL tokens
|
||||
func oneTimeTokenTableSchema() *memdb.TableSchema {
|
||||
return &memdb.TableSchema{
|
||||
Name: "one_time_token",
|
||||
Indexes: map[string]*memdb.IndexSchema{
|
||||
"secret": {
|
||||
Name: "secret",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.UUIDFieldIndex{
|
||||
Field: "OneTimeSecretID",
|
||||
},
|
||||
},
|
||||
"id": {
|
||||
Name: "id",
|
||||
AllowMissing: false,
|
||||
Unique: true,
|
||||
Indexer: &memdb.UUIDFieldIndex{
|
||||
Field: "AccessorID",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// singletonRecord can be used to describe tables which should contain only 1 entry.
|
||||
// Example uses include storing node config or cluster metadata blobs.
|
||||
var singletonRecord = &memdb.ConditionalIndex{
|
||||
|
||||
@@ -5333,6 +5333,102 @@ func (s *StateStore) BootstrapACLTokens(msgType structs.MessageType, index uint6
|
||||
return txn.Commit()
|
||||
}
|
||||
|
||||
// UpsertOneTimeToken is used to create or update a set of ACL
|
||||
// tokens. Validating that we're not upserting an already-expired token is
|
||||
// made the responsibility of the caller to facilitate testing.
|
||||
func (s *StateStore) UpsertOneTimeToken(msgType structs.MessageType, index uint64, token *structs.OneTimeToken) error {
|
||||
txn := s.db.WriteTxnMsgT(msgType, index)
|
||||
defer txn.Abort()
|
||||
|
||||
// we expect the RPC call to set the ExpiresAt
|
||||
if token.ExpiresAt.IsZero() {
|
||||
return fmt.Errorf("one-time token must have an ExpiresAt time")
|
||||
}
|
||||
|
||||
// Update all the indexes
|
||||
token.CreateIndex = index
|
||||
token.ModifyIndex = index
|
||||
|
||||
// Create the token
|
||||
if err := txn.Insert("one_time_token", token); err != nil {
|
||||
return fmt.Errorf("upserting one-time token failed: %v", err)
|
||||
}
|
||||
|
||||
// Update the indexes table
|
||||
if err := txn.Insert("index", &IndexEntry{"one_time_token", index}); err != nil {
|
||||
return fmt.Errorf("index update failed: %v", err)
|
||||
}
|
||||
return txn.Commit()
|
||||
}
|
||||
|
||||
// DeleteOneTimeTokens deletes the tokens with the given ACLToken Accessor IDs
|
||||
func (s *StateStore) DeleteOneTimeTokens(msgType structs.MessageType, index uint64, ids []string) error {
|
||||
txn := s.db.WriteTxnMsgT(msgType, index)
|
||||
defer txn.Abort()
|
||||
|
||||
var deleted int
|
||||
for _, id := range ids {
|
||||
d, err := txn.DeleteAll("one_time_token", "id", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting one-time token failed: %v", err)
|
||||
}
|
||||
deleted += d
|
||||
}
|
||||
if deleted > 0 {
|
||||
if err := txn.Insert("index", &IndexEntry{"one_time_token", index}); err != nil {
|
||||
return fmt.Errorf("index update failed: %v", err)
|
||||
}
|
||||
}
|
||||
return txn.Commit()
|
||||
}
|
||||
|
||||
// OneTimeTokenBySecret is used to lookup a token by secret
|
||||
func (s *StateStore) OneTimeTokenBySecret(ws memdb.WatchSet, secret string) (*structs.OneTimeToken, error) {
|
||||
if secret == "" {
|
||||
return nil, fmt.Errorf("one-time token lookup failed: missing secret")
|
||||
}
|
||||
|
||||
txn := s.db.ReadTxn()
|
||||
|
||||
watchCh, existing, err := txn.FirstWatch("one_time_token", "secret", secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("one-time token lookup failed: %v", err)
|
||||
}
|
||||
ws.Add(watchCh)
|
||||
|
||||
if existing != nil {
|
||||
return existing.(*structs.OneTimeToken), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// OneTimeTokensExpired returns an iterator over all expired one-time tokens
|
||||
func (s *StateStore) OneTimeTokensExpired(ws memdb.WatchSet) (memdb.ResultIterator, error) {
|
||||
txn := s.db.ReadTxn()
|
||||
|
||||
iter, err := txn.Get("one_time_token", "id")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("one-time token lookup failed: %v", err)
|
||||
}
|
||||
|
||||
ws.Add(iter.WatchCh())
|
||||
iter = memdb.NewFilterIterator(iter, expiredOneTimeTokenFilter(time.Now()))
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
// expiredOneTimeTokenFilter returns a filter function that returns only
|
||||
// expired one-time tokens
|
||||
func expiredOneTimeTokenFilter(now time.Time) func(interface{}) bool {
|
||||
return func(raw interface{}) bool {
|
||||
ott, ok := raw.(*structs.OneTimeToken)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return ott.ExpiresAt.After(now)
|
||||
}
|
||||
}
|
||||
|
||||
// SchedulerConfig is used to get the current Scheduler configuration.
|
||||
func (s *StateStore) SchedulerConfig() (uint64, *structs.SchedulerConfiguration, error) {
|
||||
tx := s.db.ReadTxn()
|
||||
@@ -6178,6 +6274,14 @@ func (r *StateRestore) ACLTokenRestore(token *structs.ACLToken) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OneTimeTokenRestore is used to restore a one-time token
|
||||
func (r *StateRestore) OneTimeTokenRestore(token *structs.OneTimeToken) error {
|
||||
if err := r.txn.Insert("one_time_token", token); err != nil {
|
||||
return fmt.Errorf("inserting one-time token failed: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StateRestore) SchedulerConfigRestore(schedConfig *structs.SchedulerConfiguration) error {
|
||||
if err := r.txn.Insert("scheduler_config", schedConfig); err != nil {
|
||||
return fmt.Errorf("inserting scheduler config failed: %s", err)
|
||||
|
||||
@@ -8493,6 +8493,114 @@ func TestStateStore_RestoreACLToken(t *testing.T) {
|
||||
assert.Equal(t, token, out)
|
||||
}
|
||||
|
||||
func TestStateStore_OneTimeTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
index := uint64(100)
|
||||
state := testStateStore(t)
|
||||
|
||||
// create some ACL tokens
|
||||
|
||||
token1 := mock.ACLToken()
|
||||
token2 := mock.ACLToken()
|
||||
token3 := mock.ACLToken()
|
||||
index++
|
||||
require.Nil(t, state.UpsertACLTokens(
|
||||
structs.MsgTypeTestSetup, index,
|
||||
[]*structs.ACLToken{token1, token2, token3}))
|
||||
|
||||
otts := []*structs.OneTimeToken{
|
||||
{
|
||||
// expired OTT for token1
|
||||
OneTimeSecretID: uuid.Generate(),
|
||||
AccessorID: token1.AccessorID,
|
||||
ExpiresAt: time.Now().Add(-1 * time.Minute),
|
||||
},
|
||||
{
|
||||
// valid OTT for token2
|
||||
OneTimeSecretID: uuid.Generate(),
|
||||
AccessorID: token2.AccessorID,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
},
|
||||
{
|
||||
// new but expired OTT for token2; this will be accepted even
|
||||
// though it's expired and overwrite the other one
|
||||
OneTimeSecretID: uuid.Generate(),
|
||||
AccessorID: token2.AccessorID,
|
||||
ExpiresAt: time.Now().Add(-10 * time.Minute),
|
||||
},
|
||||
{
|
||||
// valid OTT for token3
|
||||
AccessorID: token3.AccessorID,
|
||||
OneTimeSecretID: uuid.Generate(),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
},
|
||||
{
|
||||
// new valid OTT for token3
|
||||
OneTimeSecretID: uuid.Generate(),
|
||||
AccessorID: token3.AccessorID,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
for _, ott := range otts {
|
||||
index++
|
||||
require.NoError(t, state.UpsertOneTimeToken(structs.MsgTypeTestSetup, index, ott))
|
||||
}
|
||||
|
||||
getExpiredTokens := func() []*structs.OneTimeToken {
|
||||
// find all the expired tokens
|
||||
iter, err := state.OneTimeTokensExpired(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
results := []*structs.OneTimeToken{}
|
||||
for {
|
||||
raw := iter.Next()
|
||||
if raw == nil {
|
||||
break
|
||||
}
|
||||
ott, ok := raw.(*structs.OneTimeToken)
|
||||
require.True(t, ok)
|
||||
results = append(results, ott)
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
results := getExpiredTokens()
|
||||
require.Len(t, results, 2)
|
||||
|
||||
// results aren't ordered
|
||||
expiredAccessors := []string{results[0].AccessorID, results[1].AccessorID}
|
||||
require.Contains(t, expiredAccessors, token1.AccessorID)
|
||||
require.Contains(t, expiredAccessors, token2.AccessorID)
|
||||
require.True(t, time.Now().After(results[0].ExpiresAt))
|
||||
require.True(t, time.Now().After(results[1].ExpiresAt))
|
||||
|
||||
// clear the expired tokens and verify they're gone
|
||||
index++
|
||||
require.NoError(t,
|
||||
state.DeleteOneTimeTokens(structs.MsgTypeTestSetup, index,
|
||||
[]string{results[0].AccessorID, results[1].AccessorID}))
|
||||
|
||||
results = getExpiredTokens()
|
||||
require.Len(t, results, 0)
|
||||
|
||||
// query the unexpired token
|
||||
ott, err := state.OneTimeTokenBySecret(nil, otts[len(otts)-1].OneTimeSecretID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, token3.AccessorID, ott.AccessorID)
|
||||
require.True(t, time.Now().Before(ott.ExpiresAt))
|
||||
|
||||
restore, err := state.Restore()
|
||||
require.NoError(t, err)
|
||||
err = restore.OneTimeTokenRestore(ott)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, restore.Commit())
|
||||
|
||||
ott, err = state.OneTimeTokenBySecret(nil, otts[len(otts)-1].OneTimeSecretID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, token3.AccessorID, ott.AccessorID)
|
||||
}
|
||||
|
||||
func TestStateStore_SchedulerConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -11040,6 +11040,16 @@ type ACLTokenUpsertResponse struct {
|
||||
WriteMeta
|
||||
}
|
||||
|
||||
// OneTimeToken is used to log into the web UI using a token provided by the
|
||||
// command line.
|
||||
type OneTimeToken struct {
|
||||
OneTimeSecretID string
|
||||
AccessorID string
|
||||
ExpiresAt time.Time
|
||||
CreateIndex uint64
|
||||
ModifyIndex uint64
|
||||
}
|
||||
|
||||
// RpcError is used for serializing errors with a potential error code
|
||||
type RpcError struct {
|
||||
Message string
|
||||
|
||||
Reference in New Issue
Block a user