diff --git a/nomad/fsm.go b/nomad/fsm.go index 4b757b89a..754139a3a 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -35,6 +35,7 @@ const ( TimeTableSnapshot PeriodicLaunchSnapshot JobSummarySnapshot + VaultAccessorSnapshot ) // nomadFSM implements a finite state machine that is used @@ -137,6 +138,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyAllocClientUpdate(buf[1:], log.Index) case structs.ReconcileJobSummariesRequestType: return n.applyReconcileSummaries(buf[1:], log.Index) + case structs.VaultAccessorRegisterRequestType: + return n.applyUpsertVaultAccessor(buf[1:], log.Index) default: if ignoreUnknown { n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType) @@ -454,6 +457,22 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{} return n.reconcileQueuedAllocations(index) } +// applyUpsertVaultAccessor stores the Vault accessors for a given +func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} { + defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now()) + var req structs.VaultAccessorRegisterRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) + } + + if err := n.state.UpsertVaultAccessor(index, req.Accessors); err != nil { + n.logger.Printf("[ERR] nomad.fsm: UpsertVaultAccessor failed: %v", err) + return err + } + + return nil +} + func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) { // Create a new snapshot snap, err := n.state.Snapshot() @@ -583,6 +602,15 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { return err } + case VaultAccessorSnapshot: + accessor := new(structs.VaultAccessor) + if err := dec.Decode(accessor); err != nil { + return err + } + if err := restore.VaultAccessorRestore(accessor); err != nil { + return err + } + default: return fmt.Errorf("Unrecognized snapshot type: %v", msgType) } @@ -756,6 +784,10 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error { sink.Cancel() return err } + if err := s.persistVaultAccessors(sink, encoder); err != nil { + sink.Cancel() + return err + } return nil } @@ -945,6 +977,30 @@ func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink, return nil } +func (s *nomadSnapshot) persistVaultAccessors(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + + accessors, err := s.snap.VaultAccessors() + if err != nil { + return err + } + + for { + raw := accessors.Next() + if raw == nil { + break + } + + accessor := raw.(*structs.VaultAccessor) + + sink.Write([]byte{byte(VaultAccessorSnapshot)}) + if err := encoder.Encode(accessor); err != nil { + return err + } + } + return nil +} + // Release is a no-op, as we just need to GC the pointer // to the state store snapshot. There is nothing to explicitly // cleanup. diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 2274ea27d..fe6f58813 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -976,6 +976,27 @@ func TestFSM_SnapshotRestore_JobSummary(t *testing.T) { } } +func TestFSM_SnapshotRestore_VaultAccessors(t *testing.T) { + // Add some state + fsm := testFSM(t) + state := fsm.State() + a1 := mock.VaultAccessor() + a2 := mock.VaultAccessor() + state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a1, a2}) + + // Verify the contents + fsm2 := testSnapshotRestore(t, fsm) + state2 := fsm2.State() + out1, _ := state2.VaultAccessor(a1.Accessor) + out2, _ := state2.VaultAccessor(a2.Accessor) + if !reflect.DeepEqual(a1, out1) { + t.Fatalf("bad: \n%#v\n%#v", out1, a1) + } + if !reflect.DeepEqual(a2, out2) { + t.Fatalf("bad: \n%#v\n%#v", out2, a2) + } +} + func TestFSM_SnapshotRestore_AddMissingSummary(t *testing.T) { // Add some state fsm := testFSM(t) diff --git a/nomad/mock/mock.go b/nomad/mock/mock.go index 1cdec8d18..937d09783 100644 --- a/nomad/mock/mock.go +++ b/nomad/mock/mock.go @@ -290,6 +290,16 @@ func Alloc() *structs.Allocation { return alloc } +func VaultAccessor() *structs.VaultAccessor { + return &structs.VaultAccessor{ + Accessor: structs.GenerateUUID(), + NodeID: structs.GenerateUUID(), + AllocID: structs.GenerateUUID(), + CreationTTL: 86400, + Task: "foo", + } +} + func Plan() *structs.Plan { return &structs.Plan{ Priority: 50, diff --git a/nomad/state/schema.go b/nomad/state/schema.go index b88cf43cf..05ffd78cf 100644 --- a/nomad/state/schema.go +++ b/nomad/state/schema.go @@ -23,6 +23,7 @@ func stateStoreSchema() *memdb.DBSchema { periodicLaunchTableSchema, evalTableSchema, allocTableSchema, + vaultAccessorTableSchema, } // Add each of the tables @@ -291,3 +292,41 @@ func allocTableSchema() *memdb.TableSchema { }, } } + +// vaultAccessorTableSchema returns the MemDB schema for the Vault Accessor +// Table. This table tracks Vault accessors for tokens created on behalf of +// allocations required Vault tokens. +func vaultAccessorTableSchema() *memdb.TableSchema { + return &memdb.TableSchema{ + Name: "vault_accessors", + Indexes: map[string]*memdb.IndexSchema{ + // The primary index is the accessor id + "id": &memdb.IndexSchema{ + Name: "id", + AllowMissing: false, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Accessor", + }, + }, + + "alloc_id": &memdb.IndexSchema{ + Name: "alloc_id", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "AllocID", + }, + }, + + "node_id": &memdb.IndexSchema{ + Name: "node_id", + AllowMissing: false, + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "NodeID", + }, + }, + }, + } +} diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index dfe5945ff..cee5935f3 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -1113,6 +1113,124 @@ func (s *StateStore) Allocs() (memdb.ResultIterator, error) { return iter, nil } +// UpsertVaultAccessors is used to register a set of Vault Accessors +func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.VaultAccessor) error { + txn := s.db.Txn(true) + defer txn.Abort() + + for _, accessor := range accessors { + // Set the create index + accessor.CreateIndex = index + + // Insert the accessor + if err := txn.Insert("vault_accessors", accessor); err != nil { + return fmt.Errorf("accessor insert failed: %v", err) + } + } + + if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + + txn.Commit() + return nil +} + +// DeleteVaultAccessor is used to delete a Vault Accessor +func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error { + txn := s.db.Txn(true) + defer txn.Abort() + + // Lookup the accessor + existing, err := txn.First("vault_accessors", "id", accessor) + if err != nil { + return fmt.Errorf("accessor lookup failed: %v", err) + } + if existing == nil { + return fmt.Errorf("vault_accessor not found") + } + + // Delete the accessor + if err := txn.Delete("vault_accessors", existing); err != nil { + return fmt.Errorf("accessor delete failed: %v", err) + } + if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil { + return fmt.Errorf("index update failed: %v", err) + } + + txn.Commit() + return nil +} + +// VaultAccessor returns the given Vault accessor +func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + existing, err := txn.First("vault_accessors", "id", accessor) + if err != nil { + return nil, fmt.Errorf("accessor lookup failed: %v", err) + } + + if existing != nil { + return existing.(*structs.VaultAccessor), nil + } + + return nil, nil +} + +// VaultAccessors returns an iterator of Vault accessors. +func (s *StateStore) VaultAccessors() (memdb.ResultIterator, error) { + txn := s.db.Txn(false) + + iter, err := txn.Get("vault_accessors", "id") + if err != nil { + return nil, err + } + return iter, nil +} + +// VaultAccessorsByAlloc returns all the Vault accessors by alloc id +func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + // Get an iterator over the accessors + iter, err := txn.Get("vault_accessors", "alloc_id", allocID) + if err != nil { + return nil, err + } + + var out []*structs.VaultAccessor + for { + raw := iter.Next() + if raw == nil { + break + } + out = append(out, raw.(*structs.VaultAccessor)) + } + return out, nil +} + +// VaultAccessorsByNode returns all the Vault accessors by node id +func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccessor, error) { + txn := s.db.Txn(false) + + // Get an iterator over the accessors + iter, err := txn.Get("vault_accessors", "node_id", nodeID) + if err != nil { + return nil, err + } + + var out []*structs.VaultAccessor + for { + raw := iter.Next() + if raw == nil { + break + } + out = append(out, raw.(*structs.VaultAccessor)) + } + return out, nil +} + // LastIndex returns the greatest index value for all indexes func (s *StateStore) LatestIndex() (uint64, error) { indexes, err := s.Indexes() @@ -1627,6 +1745,14 @@ func (r *StateRestore) JobSummaryRestore(jobSummary *structs.JobSummary) error { return nil } +// VaultAccessorRestore is used to restore a vault accessor +func (r *StateRestore) VaultAccessorRestore(accessor *structs.VaultAccessor) error { + if err := r.txn.Insert("vault_accessors", accessor); err != nil { + return fmt.Errorf("vault accessor insert failed: %v", err) + } + return nil +} + // stateWatch holds shared state for watching updates. This is // outside of StateStore so it can be shared with snapshots. type stateWatch struct { diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 4ee119e68..f272b9fb1 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -2833,6 +2833,206 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) { } } +func TestStateStore_UpsertVaultAccessors(t *testing.T) { + state := testStateStore(t) + a := mock.VaultAccessor() + a2 := mock.VaultAccessor() + + err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a, a2}) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessor(a.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(a, out) { + t.Fatalf("bad: %#v %#v", a, out) + } + + out, err = state.VaultAccessor(a2.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(a2, out) { + t.Fatalf("bad: %#v %#v", a2, out) + } + + iter, err := state.VaultAccessors() + if err != nil { + t.Fatalf("err: %v", err) + } + + count := 0 + for { + raw := iter.Next() + if raw == nil { + break + } + + count++ + accessor := raw.(*structs.VaultAccessor) + + if !reflect.DeepEqual(accessor, a) && !reflect.DeepEqual(accessor, a2) { + t.Fatalf("bad: %#v", accessor) + } + } + + if count != 2 { + t.Fatalf("bad: %d", count) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_DeleteVaultAccessor(t *testing.T) { + state := testStateStore(t) + accessor := mock.VaultAccessor() + + err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor}) + if err != nil { + t.Fatalf("err: %v", err) + } + + err = state.DeleteVaultAccessor(1001, accessor.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessor(accessor.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if out != nil { + t.Fatalf("bad: %#v %#v", accessor, out) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1001 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_VaultAccessorsByAlloc(t *testing.T) { + state := testStateStore(t) + alloc := mock.Alloc() + var accessors []*structs.VaultAccessor + var expected []*structs.VaultAccessor + + for i := 0; i < 5; i++ { + accessor := mock.VaultAccessor() + accessor.AllocID = alloc.ID + expected = append(expected, accessor) + accessors = append(accessors, accessor) + } + + for i := 0; i < 10; i++ { + accessor := mock.VaultAccessor() + accessors = append(accessors, accessor) + } + + err := state.UpsertVaultAccessor(1000, accessors) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessorsByAlloc(alloc.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(expected) != len(out) { + t.Fatalf("bad: %#v %#v", len(expected), len(out)) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_VaultAccessorsByNode(t *testing.T) { + state := testStateStore(t) + node := mock.Node() + var accessors []*structs.VaultAccessor + var expected []*structs.VaultAccessor + + for i := 0; i < 5; i++ { + accessor := mock.VaultAccessor() + accessor.NodeID = node.ID + expected = append(expected, accessor) + accessors = append(accessors, accessor) + } + + for i := 0; i < 10; i++ { + accessor := mock.VaultAccessor() + accessors = append(accessors, accessor) + } + + err := state.UpsertVaultAccessor(1000, accessors) + if err != nil { + t.Fatalf("err: %v", err) + } + + out, err := state.VaultAccessorsByNode(node.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(expected) != len(out) { + t.Fatalf("bad: %#v %#v", len(expected), len(out)) + } + + index, err := state.Index("vault_accessors") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} + +func TestStateStore_RestoreVaultAccessor(t *testing.T) { + state := testStateStore(t) + a := mock.VaultAccessor() + + restore, err := state.Restore() + if err != nil { + t.Fatalf("err: %v", err) + } + + err = restore.VaultAccessorRestore(a) + if err != nil { + t.Fatalf("err: %v", err) + } + restore.Commit() + + out, err := state.VaultAccessor(a.Accessor) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !reflect.DeepEqual(out, a) { + t.Fatalf("Bad: %#v %#v", out, a) + } +} + // setupNotifyTest takes a state store and a set of watch items, then creates // and subscribes a notification channel for each item. func setupNotifyTest(state *StateStore, items ...watch.Item) notifyTest { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 3071beb79..7f1079372 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -47,6 +47,7 @@ const ( AllocUpdateRequestType AllocClientUpdateRequestType ReconcileJobSummariesRequestType + VaultAccessorRegisterRequestType ) const ( @@ -364,6 +365,24 @@ type DeriveVaultTokenRequest struct { QueryOptions } +// VaultAccessorRegisterRequest is used to register a set of Vault accessors +type VaultAccessorRegisterRequest struct { + Accessors []*VaultAccessor +} + +// VaultAccessor is a reference to a created Vault token on behalf of +// an allocation's task. +type VaultAccessor struct { + AllocID string + Task string + NodeID string + Accessor string + CreationTTL int64 + + // Raft Indexes + CreateIndex uint64 +} + // DeriveVaultTokenResponse returns the wrapped tokens for each requested task type DeriveVaultTokenResponse struct { Tasks map[string]string