diff --git a/nomad/fsm.go b/nomad/fsm.go index 752270321..5247223bd 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -19,6 +19,15 @@ var ( } ) +// SnapshotType is prefixed to a record in the FSM snapshot +// so that we can determine the type for restore +type SnapshotType byte + +const ( + NodeSnapshot SnapshotType = iota + IndexSnapshot +) + // nomadFSM implements a finite state machine that is used // along with Raft to provide strong consistency. We implement // this outside the Server to avoid exposing this outside the package. @@ -187,18 +196,27 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { } // Decode - switch structs.MessageType(msgType[0]) { - case structs.RegisterRequestType: - var req structs.RegisterRequest - if err := dec.Decode(&req); err != nil { + switch SnapshotType(msgType[0]) { + case NodeSnapshot: + node := new(structs.Node) + if err := dec.Decode(node); err != nil { return err } - if err := restore.NodeRestore(req.Node); err != nil { + if err := restore.NodeRestore(node); err != nil { + return err + } + + case IndexSnapshot: + idx := new(IndexEntry) + if err := dec.Decode(idx); err != nil { + return err + } + if err := restore.IndexRestore(idx); err != nil { return err } default: - return fmt.Errorf("Unrecognized msg type: %v", msgType) + return fmt.Errorf("Unrecognized snapshot type: %v", msgType) } } @@ -220,6 +238,10 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error { } // Write all the data out + if err := s.persistIndexes(sink, encoder); err != nil { + sink.Cancel() + return err + } if err := s.persistNodes(sink, encoder); err != nil { sink.Cancel() return err @@ -227,6 +249,33 @@ func (s *nomadSnapshot) Persist(sink raft.SnapshotSink) error { return nil } +func (s *nomadSnapshot) persistIndexes(sink raft.SnapshotSink, + encoder *codec.Encoder) error { + // Get all the indexes + iter, err := s.snap.Indexes() + if err != nil { + return err + } + + for { + // Get the next item + raw := iter.Next() + if raw == nil { + break + } + + // Prepare the request struct + idx := raw.(*IndexEntry) + + // Write out a node registration + sink.Write([]byte{byte(IndexSnapshot)}) + if err := encoder.Encode(idx); err != nil { + return err + } + } + return nil +} + func (s *nomadSnapshot) persistNodes(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the nodes @@ -235,7 +284,6 @@ func (s *nomadSnapshot) persistNodes(sink raft.SnapshotSink, return err } - var req structs.RegisterRequest for { // Get the next item raw := nodes.Next() @@ -245,11 +293,10 @@ func (s *nomadSnapshot) persistNodes(sink raft.SnapshotSink, // Prepare the request struct node := raw.(*structs.Node) - req = structs.RegisterRequest{Node: node} // Write out a node registration - sink.Write([]byte{byte(structs.RegisterRequestType)}) - if err := encoder.Encode(&req); err != nil { + sink.Write([]byte{byte(NodeSnapshot)}) + if err := encoder.Encode(node); err != nil { return err } } diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 2d8e70396..82db0fcc8 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -204,3 +204,23 @@ func TestFSM_SnapshotRestore_Nodes(t *testing.T) { t.Fatalf("bad: \n%#v\n%#v", out2, node2) } } + +func TestFSM_SnapshotRestore_Indexes(t *testing.T) { + // Add some state + fsm := testFSM(t) + state := fsm.State() + node1 := mockNode() + state.RegisterNode(1000, node1) + + // Verify the contents + fsm2 := testSnapshotRestore(t, fsm) + state2 := fsm2.State() + + index, err := state2.GetIndex("nodes") + if err != nil { + t.Fatalf("err: %v", err) + } + if index != 1000 { + t.Fatalf("bad: %d", index) + } +} diff --git a/nomad/state_store.go b/nomad/state_store.go index be7a29c7d..425f7d30d 100644 --- a/nomad/state_store.go +++ b/nomad/state_store.go @@ -245,3 +245,10 @@ func (r *StateRestore) NodeRestore(node *structs.Node) error { } return nil } + +func (r *StateRestore) IndexRestore(idx *IndexEntry) error { + if err := r.txn.Insert("index", idx); err != nil { + return fmt.Errorf("index insert failed: %v", err) + } + return nil +}