diff --git a/client/client_test.go b/client/client_test.go index 0455ca3fa..975517245 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad" @@ -477,7 +478,8 @@ func TestClient_UpdateAllocStatus(t *testing.T) { state.UpsertAllocs(101, []*structs.Allocation{alloc}) testutil.WaitForResult(func() (bool, error) { - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { return false, err } @@ -724,7 +726,8 @@ func TestClient_BlockedAllocations(t *testing.T) { // Wait for the node to be ready state := s1.State() testutil.WaitForResult(func() (bool, error) { - out, err := state.NodeByID(c1.Node().ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, c1.Node().ID) if err != nil { return false, err } @@ -753,7 +756,8 @@ func TestClient_BlockedAllocations(t *testing.T) { // Wait until the client downloads and starts the allocation testutil.WaitForResult(func() (bool, error) { - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { return false, err } diff --git a/nomad/alloc_endpoint.go b/nomad/alloc_endpoint.go index 9826a1547..47514b023 100644 --- a/nomad/alloc_endpoint.go +++ b/nomad/alloc_endpoint.go @@ -5,8 +5,8 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" ) // Alloc endpoint is used for manipulating allocations @@ -25,18 +25,14 @@ func (a *Alloc) List(args *structs.AllocListRequest, reply *structs.AllocListRes opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Table: "allocs"}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture all the allocations - snap, err := a.srv.fsm.State().Snapshot() - if err != nil { - return err - } + var err error var iter memdb.ResultIterator if prefix := args.QueryOptions.Prefix; prefix != "" { - iter, err = snap.AllocsByIDPrefix(prefix) + iter, err = state.AllocsByIDPrefix(ws, prefix) } else { - iter, err = snap.Allocs() + iter, err = state.Allocs(ws) } if err != nil { return err @@ -54,7 +50,7 @@ func (a *Alloc) List(args *structs.AllocListRequest, reply *structs.AllocListRes reply.Allocations = allocs // Use the last index that affected the jobs table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } @@ -79,14 +75,9 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Alloc: args.AllocID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Lookup the allocation - snap, err := a.srv.fsm.State().Snapshot() - if err != nil { - return err - } - out, err := snap.AllocByID(args.AllocID) + out, err := state.AllocByID(ws, args.AllocID) if err != nil { return err } @@ -97,7 +88,7 @@ func (a *Alloc) GetAlloc(args *structs.AllocSpecificRequest, reply.Index = out.ModifyIndex } else { // Use the last index that affected the allocs table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } @@ -119,12 +110,6 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, } defer metrics.MeasureSince([]string{"nomad", "alloc", "get_allocs"}, time.Now()) - // Build the watch - items := make([]watch.Item, 0, len(args.AllocIDs)) - for _, allocID := range args.AllocIDs { - items = append(items, watch.Item{Alloc: allocID}) - } - allocs := make([]*structs.Allocation, len(args.AllocIDs)) // Setup the blocking query. We wait for at least one of the requested @@ -133,18 +118,12 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(items...), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Lookup the allocation - snap, err := a.srv.fsm.State().Snapshot() - if err != nil { - return err - } - thresholdMet := false maxIndex := uint64(0) for i, alloc := range args.AllocIDs { - out, err := snap.AllocByID(alloc) + out, err := state.AllocByID(ws, alloc) if err != nil { return err } @@ -173,7 +152,7 @@ func (a *Alloc) GetAllocs(args *structs.AllocsGetRequest, reply.Index = maxIndex } else { // Use the last index that affected the nodes table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } diff --git a/nomad/alloc_endpoint_test.go b/nomad/alloc_endpoint_test.go index 854a66a6e..8a3909d3c 100644 --- a/nomad/alloc_endpoint_test.go +++ b/nomad/alloc_endpoint_test.go @@ -197,7 +197,7 @@ func TestAllocEndpoint_GetAlloc_Blocking(t *testing.T) { // Create the alloc we are watching later time.AfterFunc(200*time.Millisecond, func() { - state.UpsertJobSummary(999, mock.JobSummary(alloc2.JobID)) + state.UpsertJobSummary(199, mock.JobSummary(alloc2.JobID)) err := state.UpsertAllocs(200, []*structs.Allocation{alloc2}) if err != nil { t.Fatalf("err: %v", err) @@ -209,7 +209,7 @@ func TestAllocEndpoint_GetAlloc_Blocking(t *testing.T) { AllocID: alloc2.ID, QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.SingleAllocResponse diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 8cb170045..e5dbea5a7 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -5,6 +5,7 @@ import ( "math" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" @@ -67,7 +68,8 @@ func (c *CoreScheduler) forceGC(eval *structs.Evaluation) error { // jobGC is used to garbage collect eligible jobs. func (c *CoreScheduler) jobGC(eval *structs.Evaluation) error { // Get all the jobs eligible for garbage collection. - iter, err := c.snap.JobsByGC(true) + ws := memdb.NewWatchSet() + iter, err := c.snap.JobsByGC(ws, true) if err != nil { return err } @@ -99,7 +101,8 @@ OUTER: continue } - evals, err := c.snap.EvalsByJob(job.ID) + ws := memdb.NewWatchSet() + evals, err := c.snap.EvalsByJob(ws, job.ID) if err != nil { c.srv.logger.Printf("[ERR] sched.core: failed to get evals for job %s: %v", job.ID, err) continue @@ -163,7 +166,8 @@ OUTER: // evalGC is used to garbage collect old evaluations func (c *CoreScheduler) evalGC(eval *structs.Evaluation) error { // Iterate over the evaluations - iter, err := c.snap.Evals() + ws := memdb.NewWatchSet() + iter, err := c.snap.Evals(ws) if err != nil { return err } @@ -227,6 +231,9 @@ func (c *CoreScheduler) gcEval(eval *structs.Evaluation, thresholdIndex uint64, return false, nil, nil } + // Create a watchset + ws := memdb.NewWatchSet() + // If the eval is from a running "batch" job we don't want to garbage // collect its allocations. If there is a long running batch job and its // terminal allocations get GC'd the scheduler would re-run the @@ -237,7 +244,7 @@ func (c *CoreScheduler) gcEval(eval *structs.Evaluation, thresholdIndex uint64, } // Check if the job is running - job, err := c.snap.JobByID(eval.JobID) + job, err := c.snap.JobByID(ws, eval.JobID) if err != nil { return false, nil, err } @@ -249,7 +256,7 @@ func (c *CoreScheduler) gcEval(eval *structs.Evaluation, thresholdIndex uint64, } // Get the allocations by eval - allocs, err := c.snap.AllocsByEval(eval.ID) + allocs, err := c.snap.AllocsByEval(ws, eval.ID) if err != nil { c.srv.logger.Printf("[ERR] sched.core: failed to get allocs for eval %s: %v", eval.ID, err) @@ -336,7 +343,8 @@ func (c *CoreScheduler) partitionReap(evals, allocs []string) []*structs.EvalDel // nodeGC is used to garbage collect old nodes func (c *CoreScheduler) nodeGC(eval *structs.Evaluation) error { // Iterate over the evaluations - iter, err := c.snap.Nodes() + ws := memdb.NewWatchSet() + iter, err := c.snap.Nodes(ws) if err != nil { return err } @@ -374,7 +382,8 @@ OUTER: } // Get the allocations by node - allocs, err := c.snap.AllocsByNode(node.ID) + ws := memdb.NewWatchSet() + allocs, err := c.snap.AllocsByNode(ws, node.ID) if err != nil { c.srv.logger.Printf("[ERR] sched.core: failed to get allocs for node %s: %v", eval.ID, err) diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index 72bd4bf66..3f1c6d247 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -63,7 +64,8 @@ func TestCoreScheduler_EvalGC(t *testing.T) { } // Should be gone - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -71,7 +73,7 @@ func TestCoreScheduler_EvalGC(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err := state.AllocByID(alloc.ID) + outA, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -79,7 +81,7 @@ func TestCoreScheduler_EvalGC(t *testing.T) { t.Fatalf("bad: %v", outA) } - outA2, err := state.AllocByID(alloc2.ID) + outA2, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -154,7 +156,8 @@ func TestCoreScheduler_EvalGC_Batch(t *testing.T) { } // Nothing should be gone - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -162,7 +165,7 @@ func TestCoreScheduler_EvalGC_Batch(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err := state.AllocByID(alloc.ID) + outA, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -170,7 +173,7 @@ func TestCoreScheduler_EvalGC_Batch(t *testing.T) { t.Fatalf("bad: %v", outA) } - outA2, err := state.AllocByID(alloc2.ID) + outA2, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -178,7 +181,7 @@ func TestCoreScheduler_EvalGC_Batch(t *testing.T) { t.Fatalf("bad: %v", outA2) } - outB, err := state.JobByID(job.ID) + outB, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -251,7 +254,8 @@ func TestCoreScheduler_EvalGC_Partial(t *testing.T) { } // Should not be gone - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -259,7 +263,7 @@ func TestCoreScheduler_EvalGC_Partial(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err := state.AllocByID(alloc3.ID) + outA, err := state.AllocByID(ws, alloc3.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -268,7 +272,7 @@ func TestCoreScheduler_EvalGC_Partial(t *testing.T) { } // Should be gone - outB, err := state.AllocByID(alloc.ID) + outB, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -276,7 +280,7 @@ func TestCoreScheduler_EvalGC_Partial(t *testing.T) { t.Fatalf("bad: %v", outB) } - outC, err := state.AllocByID(alloc2.ID) + outC, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -328,7 +332,8 @@ func TestCoreScheduler_EvalGC_Force(t *testing.T) { } // Should be gone - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -336,7 +341,7 @@ func TestCoreScheduler_EvalGC_Force(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err := state.AllocByID(alloc.ID) + outA, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -381,7 +386,8 @@ func TestCoreScheduler_NodeGC(t *testing.T) { } // Should be gone - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -434,7 +440,8 @@ func TestCoreScheduler_NodeGC_TerminalAllocs(t *testing.T) { } // Should be gone - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -489,7 +496,8 @@ func TestCoreScheduler_NodeGC_RunningAllocs(t *testing.T) { } // Should still be here - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -530,7 +538,8 @@ func TestCoreScheduler_NodeGC_Force(t *testing.T) { } // Should be gone - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -589,7 +598,8 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { } // Should still exist - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -597,7 +607,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { t.Fatalf("bad: %v", out) } - outE, err := state.EvalByID(eval.ID) + outE, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -605,7 +615,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { t.Fatalf("bad: %v", outE) } - outE2, err := state.EvalByID(eval2.ID) + outE2, err := state.EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -635,7 +645,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { } // Should not still exist - out, err = state.JobByID(job.ID) + out, err = state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -643,7 +653,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { t.Fatalf("bad: %v", out) } - outE, err = state.EvalByID(eval.ID) + outE, err = state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -651,7 +661,7 @@ func TestCoreScheduler_JobGC_OutstandingEvals(t *testing.T) { t.Fatalf("bad: %v", outE) } - outE2, err = state.EvalByID(eval2.ID) + outE2, err = state.EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -724,7 +734,8 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { } // Should still exist - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -732,7 +743,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err := state.AllocByID(alloc.ID) + outA, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -740,7 +751,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { t.Fatalf("bad: %v", outA) } - outA2, err := state.AllocByID(alloc2.ID) + outA2, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -770,7 +781,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { } // Should not still exist - out, err = state.JobByID(job.ID) + out, err = state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -778,7 +789,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { t.Fatalf("bad: %v", out) } - outA, err = state.AllocByID(alloc.ID) + outA, err = state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -786,7 +797,7 @@ func TestCoreScheduler_JobGC_OutstandingAllocs(t *testing.T) { t.Fatalf("bad: %v", outA) } - outA2, err = state.AllocByID(alloc2.ID) + outA2, err = state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -866,7 +877,8 @@ func TestCoreScheduler_JobGC_OneShot(t *testing.T) { } // Should still exist - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -874,7 +886,7 @@ func TestCoreScheduler_JobGC_OneShot(t *testing.T) { t.Fatalf("bad: %v", out) } - outE, err := state.EvalByID(eval.ID) + outE, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -882,7 +894,7 @@ func TestCoreScheduler_JobGC_OneShot(t *testing.T) { t.Fatalf("bad: %v", outE) } - outE2, err := state.EvalByID(eval2.ID) + outE2, err := state.EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -890,14 +902,14 @@ func TestCoreScheduler_JobGC_OneShot(t *testing.T) { t.Fatalf("bad: %v", outE2) } - outA, err := state.AllocByID(alloc.ID) + outA, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } if outA == nil { t.Fatalf("bad: %v", outA) } - outA2, err := state.AllocByID(alloc2.ID) + outA2, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -948,7 +960,8 @@ func TestCoreScheduler_JobGC_Force(t *testing.T) { } // Shouldn't still exist - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -956,7 +969,7 @@ func TestCoreScheduler_JobGC_Force(t *testing.T) { t.Fatalf("bad: %v", out) } - outE, err := state.EvalByID(eval.ID) + outE, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1008,7 +1021,8 @@ func TestCoreScheduler_JobGC_NonGCable(t *testing.T) { } // Should still exist - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1016,7 +1030,7 @@ func TestCoreScheduler_JobGC_NonGCable(t *testing.T) { t.Fatalf("bad: %v", out) } - outE, err := state.JobByID(job2.ID) + outE, err := state.JobByID(ws, job2.ID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/eval_endpoint.go b/nomad/eval_endpoint.go index 32ea14faa..6f8f404f0 100644 --- a/nomad/eval_endpoint.go +++ b/nomad/eval_endpoint.go @@ -6,8 +6,8 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" "github.com/hashicorp/nomad/scheduler" ) @@ -33,14 +33,9 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Eval: args.EvalID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Look for the job - snap, err := e.srv.fsm.State().Snapshot() - if err != nil { - return err - } - out, err := snap.EvalByID(args.EvalID) + out, err := state.EvalByID(ws, args.EvalID) if err != nil { return err } @@ -51,7 +46,7 @@ func (e *Eval) GetEval(args *structs.EvalSpecificRequest, reply.Index = out.ModifyIndex } else { // Use the last index that affected the nodes table - index, err := snap.Index("evals") + index, err := state.Index("evals") if err != nil { return err } @@ -190,7 +185,9 @@ func (e *Eval) Create(args *structs.EvalUpdateRequest, if err != nil { return err } - out, err := snap.EvalByID(eval.ID) + + ws := memdb.NewWatchSet() + out, err := snap.EvalByID(ws, eval.ID) if err != nil { return err } @@ -233,7 +230,9 @@ func (e *Eval) Reblock(args *structs.EvalUpdateRequest, reply *structs.GenericRe if err != nil { return err } - out, err := snap.EvalByID(eval.ID) + + ws := memdb.NewWatchSet() + out, err := snap.EvalByID(ws, eval.ID) if err != nil { return err } @@ -280,18 +279,14 @@ func (e *Eval) List(args *structs.EvalListRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Table: "evals"}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Scan all the evaluations - snap, err := e.srv.fsm.State().Snapshot() - if err != nil { - return err - } + var err error var iter memdb.ResultIterator if prefix := args.QueryOptions.Prefix; prefix != "" { - iter, err = snap.EvalsByIDPrefix(prefix) + iter, err = state.EvalsByIDPrefix(ws, prefix) } else { - iter, err = snap.Evals() + iter, err = state.Evals(ws) } if err != nil { return err @@ -309,7 +304,7 @@ func (e *Eval) List(args *structs.EvalListRequest, reply.Evaluations = evals // Use the last index that affected the jobs table - index, err := snap.Index("evals") + index, err := state.Index("evals") if err != nil { return err } @@ -334,14 +329,9 @@ func (e *Eval) Allocations(args *structs.EvalSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{AllocEval: args.EvalID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture the allocations - snap, err := e.srv.fsm.State().Snapshot() - if err != nil { - return err - } - allocs, err := snap.AllocsByEval(args.EvalID) + allocs, err := state.AllocsByEval(ws, args.EvalID) if err != nil { return err } @@ -355,7 +345,7 @@ func (e *Eval) Allocations(args *structs.EvalSpecificRequest, } // Use the last index that affected the allocs table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } diff --git a/nomad/eval_endpoint_test.go b/nomad/eval_endpoint_test.go index cf5473ca7..63000a286 100644 --- a/nomad/eval_endpoint_test.go +++ b/nomad/eval_endpoint_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -85,7 +86,7 @@ func TestEvalEndpoint_GetEval_Blocking(t *testing.T) { EvalID: eval2.ID, QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.SingleEvalResponse @@ -314,7 +315,8 @@ func TestEvalEndpoint_Update(t *testing.T) { } // Ensure updated - outE, err := s1.fsm.State().EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + outE, err := s1.fsm.State().EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -361,7 +363,8 @@ func TestEvalEndpoint_Create(t *testing.T) { } // Ensure created - outE, err := s1.fsm.State().EvalByID(eval1.ID) + ws := memdb.NewWatchSet() + outE, err := s1.fsm.State().EvalByID(ws, eval1.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -397,7 +400,8 @@ func TestEvalEndpoint_Reap(t *testing.T) { } // Ensure deleted - outE, err := s1.fsm.State().EvalByID(eval1.ID) + ws := memdb.NewWatchSet() + outE, err := s1.fsm.State().EvalByID(ws, eval1.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -588,7 +592,7 @@ func TestEvalEndpoint_Allocations_Blocking(t *testing.T) { EvalID: alloc2.EvalID, QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.EvalAllocationsResponse diff --git a/nomad/fsm.go b/nomad/fsm.go index a5c6d86e6..25221ca4f 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -5,9 +5,11 @@ import ( "io" "log" "reflect" + "sync" "time" "github.com/armon/go-metrics" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" @@ -50,6 +52,12 @@ type nomadFSM struct { logger *log.Logger state *state.StateStore timetable *TimeTable + + // stateLock is only used to protect outside callers to State() from + // racing with Restore(), which is called by Raft (it puts in a totally + // new state store). Everything internal here is synchronized by the + // Raft side, so doesn't need to lock this. + stateLock sync.RWMutex } // nomadSnapshot is used to provide a snapshot of the current @@ -92,6 +100,8 @@ func (n *nomadFSM) Close() error { // State is used to return a handle to the current state func (n *nomadFSM) State() *state.StateStore { + n.stateLock.RLock() + defer n.stateLock.RUnlock() return n.state } @@ -203,7 +213,8 @@ func (n *nomadFSM) applyStatusUpdate(buf []byte, index uint64) interface{} { // Unblock evals for the nodes computed node class if it is in a ready // state. if req.Status == structs.NodeStatusReady { - node, err := n.state.NodeByID(req.NodeID) + ws := memdb.NewWatchSet() + node, err := n.state.NodeByID(ws, req.NodeID) if err != nil { n.logger.Printf("[ERR] nomad.fsm: looking up node %q failed: %v", req.NodeID, err) return err @@ -256,13 +267,16 @@ func (n *nomadFSM) applyUpsertJob(buf []byte, index uint64) interface{} { return err } + // Create a watch set + ws := memdb.NewWatchSet() + // If it is periodic, record the time it was inserted. This is necessary for // recovering during leader election. It is possible that from the time it // is added to when it was suppose to launch, leader election occurs and the // job was not launched. In this case, we use the insertion time to // determine if a launch was missed. if req.Job.IsPeriodic() { - prevLaunch, err := n.state.PeriodicLaunchByID(req.Job.ID) + prevLaunch, err := n.state.PeriodicLaunchByID(ws, req.Job.ID) if err != nil { n.logger.Printf("[ERR] nomad.fsm: PeriodicLaunchByID failed: %v", err) return err @@ -282,7 +296,7 @@ func (n *nomadFSM) applyUpsertJob(buf []byte, index uint64) interface{} { // Check if the parent job is periodic and mark the launch time. parentID := req.Job.ParentID if parentID != "" { - parent, err := n.state.JobByID(parentID) + parent, err := n.state.JobByID(ws, parentID) if err != nil { n.logger.Printf("[ERR] nomad.fsm: JobByID(%v) lookup for parent failed: %v", parentID, err) return err @@ -435,9 +449,12 @@ func (n *nomadFSM) applyAllocClientUpdate(buf []byte, index uint64) interface{} return nil } + // Create a watch set + ws := memdb.NewWatchSet() + // Updating the allocs with the job id and task group name for _, alloc := range req.Alloc { - if existing, _ := n.state.AllocByID(alloc.ID); existing != nil { + if existing, _ := n.state.AllocByID(ws, alloc.ID); existing != nil { alloc.JobID = existing.JobID alloc.TaskGroup = existing.TaskGroup } @@ -455,7 +472,7 @@ func (n *nomadFSM) applyAllocClientUpdate(buf []byte, index uint64) interface{} if alloc.ClientStatus == structs.AllocClientStatusComplete || alloc.ClientStatus == structs.AllocClientStatusFailed { nodeID := alloc.NodeID - node, err := n.state.NodeByID(nodeID) + node, err := n.state.NodeByID(ws, nodeID) if err != nil || node == nil { n.logger.Printf("[ERR] nomad.fsm: looking up node %q failed: %v", nodeID, err) return err @@ -531,7 +548,6 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { if err != nil { return err } - n.state = newState // Start the state restore restore, err := newState.Restore() @@ -660,7 +676,7 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { // summaries if they were not present previously. When users upgrade to 0.5 // from 0.4.1, the snapshot will contain job summaries so it will be safe to // remove this block. - index, err := n.state.Index("job_summary") + index, err := newState.Index("job_summary") if err != nil { return fmt.Errorf("couldn't fetch index of job summary table: %v", err) } @@ -669,15 +685,27 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { // we will have to create them if index == 0 { // query the latest index - latestIndex, err := n.state.LatestIndex() + latestIndex, err := newState.LatestIndex() if err != nil { return fmt.Errorf("unable to query latest index: %v", index) } - if err := n.state.ReconcileJobSummaries(latestIndex); err != nil { + if err := newState.ReconcileJobSummaries(latestIndex); err != nil { return fmt.Errorf("error reconciling summaries: %v", err) } } + // External code might be calling State(), so we need to synchronize + // here to make sure we swap in the new state store atomically. + n.stateLock.Lock() + stateOld := n.state + n.state = newState + n.stateLock.Unlock() + + // Signal that the old state store has been abandoned. This is required + // because we don't operate on it any more, we just throw it away, so + // blocking queries won't see any changes and need to be woken up. + stateOld.Abandon() + return nil } @@ -685,7 +713,8 @@ func (n *nomadFSM) Restore(old io.ReadCloser) error { // created a Job Summary during the snap shot restore func (n *nomadFSM) reconcileQueuedAllocations(index uint64) error { // Get all the jobs - iter, err := n.state.Jobs() + ws := memdb.NewWatchSet() + iter, err := n.state.Jobs(ws) if err != nil { return err } @@ -729,7 +758,7 @@ func (n *nomadFSM) reconcileQueuedAllocations(index uint64) error { } // Get the job summary from the fsm state store - originalSummary, err := n.state.JobSummaryByID(job.ID) + originalSummary, err := n.state.JobSummaryByID(ws, job.ID) if err != nil { return err } @@ -865,7 +894,8 @@ func (s *nomadSnapshot) persistIndexes(sink raft.SnapshotSink, func (s *nomadSnapshot) persistNodes(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the nodes - nodes, err := s.snap.Nodes() + ws := memdb.NewWatchSet() + nodes, err := s.snap.Nodes(ws) if err != nil { return err } @@ -892,7 +922,8 @@ func (s *nomadSnapshot) persistNodes(sink raft.SnapshotSink, func (s *nomadSnapshot) persistJobs(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the jobs - jobs, err := s.snap.Jobs() + ws := memdb.NewWatchSet() + jobs, err := s.snap.Jobs(ws) if err != nil { return err } @@ -919,7 +950,8 @@ func (s *nomadSnapshot) persistJobs(sink raft.SnapshotSink, func (s *nomadSnapshot) persistEvals(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the evaluations - evals, err := s.snap.Evals() + ws := memdb.NewWatchSet() + evals, err := s.snap.Evals(ws) if err != nil { return err } @@ -946,7 +978,8 @@ func (s *nomadSnapshot) persistEvals(sink raft.SnapshotSink, func (s *nomadSnapshot) persistAllocs(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the allocations - allocs, err := s.snap.Allocs() + ws := memdb.NewWatchSet() + allocs, err := s.snap.Allocs(ws) if err != nil { return err } @@ -973,7 +1006,8 @@ func (s *nomadSnapshot) persistAllocs(sink raft.SnapshotSink, func (s *nomadSnapshot) persistPeriodicLaunches(sink raft.SnapshotSink, encoder *codec.Encoder) error { // Get all the jobs - launches, err := s.snap.PeriodicLaunches() + ws := memdb.NewWatchSet() + launches, err := s.snap.PeriodicLaunches(ws) if err != nil { return err } @@ -1000,7 +1034,8 @@ func (s *nomadSnapshot) persistPeriodicLaunches(sink raft.SnapshotSink, func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink, encoder *codec.Encoder) error { - summaries, err := s.snap.JobSummaries() + ws := memdb.NewWatchSet() + summaries, err := s.snap.JobSummaries(ws) if err != nil { return err } @@ -1024,7 +1059,8 @@ func (s *nomadSnapshot) persistJobSummaries(sink raft.SnapshotSink, func (s *nomadSnapshot) persistVaultAccessors(sink raft.SnapshotSink, encoder *codec.Encoder) error { - accessors, err := s.snap.VaultAccessors() + ws := memdb.NewWatchSet() + accessors, err := s.snap.VaultAccessors(ws) if err != nil { return err } diff --git a/nomad/fsm_test.go b/nomad/fsm_test.go index 1883a6c97..80e48faee 100644 --- a/nomad/fsm_test.go +++ b/nomad/fsm_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" @@ -92,7 +93,8 @@ func TestFSM_UpsertNode(t *testing.T) { } // Verify we are registered - n, err := fsm.State().NodeByID(req.Node.ID) + ws := memdb.NewWatchSet() + n, err := fsm.State().NodeByID(ws, req.Node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -153,7 +155,8 @@ func TestFSM_DeregisterNode(t *testing.T) { } // Verify we are NOT registered - node, err = fsm.State().NodeByID(req.Node.ID) + ws := memdb.NewWatchSet() + node, err = fsm.State().NodeByID(ws, req.Node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -200,7 +203,8 @@ func TestFSM_UpdateNodeStatus(t *testing.T) { } // Verify the status is ready. - node, err = fsm.State().NodeByID(req.Node.ID) + ws := memdb.NewWatchSet() + node, err = fsm.State().NodeByID(ws, req.Node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -252,7 +256,8 @@ func TestFSM_UpdateNodeDrain(t *testing.T) { } // Verify we are NOT registered - node, err = fsm.State().NodeByID(req.Node.ID) + ws := memdb.NewWatchSet() + node, err = fsm.State().NodeByID(ws, req.Node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -279,7 +284,8 @@ func TestFSM_RegisterJob(t *testing.T) { } // Verify we are registered - jobOut, err := fsm.State().JobByID(req.Job.ID) + ws := memdb.NewWatchSet() + jobOut, err := fsm.State().JobByID(ws, req.Job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -296,7 +302,7 @@ func TestFSM_RegisterJob(t *testing.T) { } // Verify the launch time was tracked. - launchOut, err := fsm.State().PeriodicLaunchByID(req.Job.ID) + launchOut, err := fsm.State().PeriodicLaunchByID(ws, req.Job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -339,7 +345,8 @@ func TestFSM_DeregisterJob(t *testing.T) { } // Verify we are NOT registered - jobOut, err := fsm.State().JobByID(req.Job.ID) + ws := memdb.NewWatchSet() + jobOut, err := fsm.State().JobByID(ws, req.Job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -353,7 +360,7 @@ func TestFSM_DeregisterJob(t *testing.T) { } // Verify it was removed from the periodic launch table. - launchOut, err := fsm.State().PeriodicLaunchByID(req.Job.ID) + launchOut, err := fsm.State().PeriodicLaunchByID(ws, req.Job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -380,7 +387,8 @@ func TestFSM_UpdateEval(t *testing.T) { } // Verify we are registered - eval, err := fsm.State().EvalByID(req.Evals[0].ID) + ws := memdb.NewWatchSet() + eval, err := fsm.State().EvalByID(ws, req.Evals[0].ID) if err != nil { t.Fatalf("err: %v", err) } @@ -421,7 +429,8 @@ func TestFSM_UpdateEval_Blocked(t *testing.T) { } // Verify we are registered - out, err := fsm.State().EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -474,7 +483,8 @@ func TestFSM_UpdateEval_Untrack(t *testing.T) { } // Verify we are registered - out, err := fsm.State().EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -529,7 +539,8 @@ func TestFSM_UpdateEval_NoUntrack(t *testing.T) { } // Verify we are registered - out, err := fsm.State().EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -584,7 +595,8 @@ func TestFSM_DeleteEval(t *testing.T) { } // Verify we are NOT registered - eval, err = fsm.State().EvalByID(req.Evals[0].ID) + ws := memdb.NewWatchSet() + eval, err = fsm.State().EvalByID(ws, req.Evals[0].ID) if err != nil { t.Fatalf("err: %v", err) } @@ -612,7 +624,8 @@ func TestFSM_UpsertAllocs(t *testing.T) { } // Verify we are registered - out, err := fsm.State().AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -640,7 +653,7 @@ func TestFSM_UpsertAllocs(t *testing.T) { } // Verify we are evicted - out, err = fsm.State().AllocByID(alloc.ID) + out, err = fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -671,7 +684,8 @@ func TestFSM_UpsertAllocs_SharedJob(t *testing.T) { } // Verify we are registered - out, err := fsm.State().AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -708,7 +722,7 @@ func TestFSM_UpsertAllocs_SharedJob(t *testing.T) { } // Verify we are evicted - out, err = fsm.State().AllocByID(alloc.ID) + out, err = fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -743,7 +757,8 @@ func TestFSM_UpsertAllocs_StrippedResources(t *testing.T) { } // Verify we are registered - out, err := fsm.State().AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -808,7 +823,8 @@ func TestFSM_UpdateAllocFromClient_Unblock(t *testing.T) { } // Verify we are updated - out, err := fsm.State().AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -818,7 +834,7 @@ func TestFSM_UpdateAllocFromClient_Unblock(t *testing.T) { t.Fatalf("bad: %#v %#v", clientAlloc, out) } - out, err = fsm.State().AllocByID(alloc2.ID) + out, err = fsm.State().AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -868,7 +884,8 @@ func TestFSM_UpdateAllocFromClient(t *testing.T) { } // Verify we are registered - out, err := fsm.State().AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -899,7 +916,8 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) { } // Verify we are registered - out1, err := fsm.State().VaultAccessor(va.Accessor) + ws := memdb.NewWatchSet() + out1, err := fsm.State().VaultAccessor(ws, va.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -909,7 +927,7 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) { if out1.CreateIndex != 1 { t.Fatalf("bad index: %d", out1.CreateIndex) } - out2, err := fsm.State().VaultAccessor(va2.Accessor) + out2, err := fsm.State().VaultAccessor(ws, va2.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -953,7 +971,8 @@ func TestFSM_DeregisterVaultAccessor(t *testing.T) { t.Fatalf("resp: %v", resp) } - out1, err := fsm.State().VaultAccessor(va.Accessor) + ws := memdb.NewWatchSet() + out1, err := fsm.State().VaultAccessor(ws, va.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -985,11 +1004,25 @@ func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM { // Try to restore on a new FSM fsm2 := testFSM(t) + snap, err = fsm2.Snapshot() + if err != nil { + t.Fatalf("err: %v", err) + } + defer snap.Release() + + abandonCh := fsm2.State().AbandonCh() // Do a restore if err := fsm2.Restore(sink); err != nil { t.Fatalf("err: %v", err) } + + select { + case <-abandonCh: + default: + t.Fatalf("bad") + } + return fsm2 } @@ -1005,8 +1038,9 @@ func TestFSM_SnapshotRestore_Nodes(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.NodeByID(node1.ID) - out2, _ := state2.NodeByID(node2.ID) + ws := memdb.NewWatchSet() + out1, _ := state2.NodeByID(ws, node1.ID) + out2, _ := state2.NodeByID(ws, node2.ID) if !reflect.DeepEqual(node1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, node1) } @@ -1025,10 +1059,11 @@ func TestFSM_SnapshotRestore_Jobs(t *testing.T) { state.UpsertJob(1001, job2) // Verify the contents + ws := memdb.NewWatchSet() fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.JobByID(job1.ID) - out2, _ := state2.JobByID(job2.ID) + out1, _ := state2.JobByID(ws, job1.ID) + out2, _ := state2.JobByID(ws, job2.ID) if !reflect.DeepEqual(job1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, job1) } @@ -1049,8 +1084,9 @@ func TestFSM_SnapshotRestore_Evals(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.EvalByID(eval1.ID) - out2, _ := state2.EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + out1, _ := state2.EvalByID(ws, eval1.ID) + out2, _ := state2.EvalByID(ws, eval2.ID) if !reflect.DeepEqual(eval1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, eval1) } @@ -1073,8 +1109,9 @@ func TestFSM_SnapshotRestore_Allocs(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.AllocByID(alloc1.ID) - out2, _ := state2.AllocByID(alloc2.ID) + ws := memdb.NewWatchSet() + out1, _ := state2.AllocByID(ws, alloc1.ID) + out2, _ := state2.AllocByID(ws, alloc2.ID) if !reflect.DeepEqual(alloc1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, alloc1) } @@ -1099,8 +1136,9 @@ func TestFSM_SnapshotRestore_Allocs_NoSharedResources(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.AllocByID(alloc1.ID) - out2, _ := state2.AllocByID(alloc2.ID) + ws := memdb.NewWatchSet() + out1, _ := state2.AllocByID(ws, alloc1.ID) + out2, _ := state2.AllocByID(ws, alloc2.ID) alloc1.SharedResources = &structs.Resources{DiskMB: 150} alloc2.SharedResources = &structs.Resources{DiskMB: 150} @@ -1167,8 +1205,9 @@ func TestFSM_SnapshotRestore_PeriodicLaunches(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.PeriodicLaunchByID(launch1.ID) - out2, _ := state2.PeriodicLaunchByID(launch2.ID) + ws := memdb.NewWatchSet() + out1, _ := state2.PeriodicLaunchByID(ws, launch1.ID) + out2, _ := state2.PeriodicLaunchByID(ws, launch2.ID) if !reflect.DeepEqual(launch1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, job1) } @@ -1184,17 +1223,18 @@ func TestFSM_SnapshotRestore_JobSummary(t *testing.T) { job1 := mock.Job() state.UpsertJob(1000, job1) - js1, _ := state.JobSummaryByID(job1.ID) + ws := memdb.NewWatchSet() + js1, _ := state.JobSummaryByID(ws, job1.ID) job2 := mock.Job() state.UpsertJob(1001, job2) - js2, _ := state.JobSummaryByID(job2.ID) + js2, _ := state.JobSummaryByID(ws, job2.ID) // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.JobSummaryByID(job1.ID) - out2, _ := state2.JobSummaryByID(job2.ID) + out1, _ := state2.JobSummaryByID(ws, job1.ID) + out2, _ := state2.JobSummaryByID(ws, job2.ID) if !reflect.DeepEqual(js1, out1) { t.Fatalf("bad: \n%#v\n%#v", js1, out1) } @@ -1214,8 +1254,9 @@ func TestFSM_SnapshotRestore_VaultAccessors(t *testing.T) { // Verify the contents fsm2 := testSnapshotRestore(t, fsm) state2 := fsm2.State() - out1, _ := state2.VaultAccessor(a1.Accessor) - out2, _ := state2.VaultAccessor(a2.Accessor) + ws := memdb.NewWatchSet() + out1, _ := state2.VaultAccessor(ws, a1.Accessor) + out2, _ := state2.VaultAccessor(ws, a2.Accessor) if !reflect.DeepEqual(a1, out1) { t.Fatalf("bad: \n%#v\n%#v", out1, a1) } @@ -1246,7 +1287,8 @@ func TestFSM_SnapshotRestore_AddMissingSummary(t *testing.T) { state2 := fsm2.State() latestIndex, _ := state.LatestIndex() - out, _ := state2.JobSummaryByID(alloc.Job.ID) + ws := memdb.NewWatchSet() + out, _ := state2.JobSummaryByID(ws, alloc.Job.ID) expected := structs.JobSummary{ JobID: alloc.Job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -1297,7 +1339,8 @@ func TestFSM_ReconcileSummaries(t *testing.T) { t.Fatalf("resp: %v", resp) } - out1, _ := state.JobSummaryByID(job1.ID) + ws := memdb.NewWatchSet() + out1, _ := state.JobSummaryByID(ws, job1.ID) expected := structs.JobSummary{ JobID: job1.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -1315,7 +1358,7 @@ func TestFSM_ReconcileSummaries(t *testing.T) { // This exercises the code path which adds the allocations made by the // planner and the number of unplaced allocations in the reconcile summaries // codepath - out2, _ := state.JobSummaryByID(alloc.Job.ID) + out2, _ := state.JobSummaryByID(ws, alloc.Job.ID) expected = structs.JobSummary{ JobID: alloc.Job.ID, Summary: map[string]structs.TaskGroupSummary{ diff --git a/nomad/heartbeat.go b/nomad/heartbeat.go index 9b2867eca..89bc86010 100644 --- a/nomad/heartbeat.go +++ b/nomad/heartbeat.go @@ -5,6 +5,7 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/lib" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" ) @@ -19,7 +20,8 @@ func (s *Server) initializeHeartbeatTimers() error { } // Get an iterator over nodes - iter, err := snap.Nodes() + ws := memdb.NewWatchSet() + iter, err := snap.Nodes(ws) if err != nil { return err } diff --git a/nomad/heartbeat_test.go b/nomad/heartbeat_test.go index 7ab5495a8..dc5b29c4c 100644 --- a/nomad/heartbeat_test.go +++ b/nomad/heartbeat_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -132,7 +133,8 @@ func TestInvalidateHeartbeat(t *testing.T) { s1.invalidateHeartbeat(node.ID) // Check it is updated - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 992a06538..ce25d1ab2 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -13,8 +13,8 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/client/driver" "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" "github.com/hashicorp/nomad/scheduler" ) @@ -72,7 +72,8 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis if err != nil { return err } - job, err := snap.JobByID(args.Job.ID) + ws := memdb.NewWatchSet() + job, err := snap.JobByID(ws, args.Job.ID) if err != nil { return err } @@ -257,15 +258,9 @@ func (j *Job) Summary(args *structs.JobSummaryRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{JobSummary: args.JobID}), - run: func() error { - snap, err := j.srv.fsm.State().Snapshot() - if err != nil { - return err - } - + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Look for job summary - out, err := snap.JobSummaryByID(args.JobID) + out, err := state.JobSummaryByID(ws, args.JobID) if err != nil { return err } @@ -276,7 +271,7 @@ func (j *Job) Summary(args *structs.JobSummaryRequest, reply.Index = out.ModifyIndex } else { // Use the last index that affected the job_summary table - index, err := snap.Index("job_summary") + index, err := state.Index("job_summary") if err != nil { return err } @@ -307,7 +302,8 @@ func (j *Job) Evaluate(args *structs.JobEvaluateRequest, reply *structs.JobRegis if err != nil { return err } - job, err := snap.JobByID(args.JobID) + ws := memdb.NewWatchSet() + job, err := snap.JobByID(ws, args.JobID) if err != nil { return err } @@ -368,7 +364,8 @@ func (j *Job) Deregister(args *structs.JobDeregisterRequest, reply *structs.JobD if err != nil { return err } - job, err := snap.JobByID(args.JobID) + ws := memdb.NewWatchSet() + job, err := snap.JobByID(ws, args.JobID) if err != nil { return err } @@ -432,15 +429,9 @@ func (j *Job) GetJob(args *structs.JobSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Job: args.JobID}), - run: func() error { - + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Look for the job - snap, err := j.srv.fsm.State().Snapshot() - if err != nil { - return err - } - out, err := snap.JobByID(args.JobID) + out, err := state.JobByID(ws, args.JobID) if err != nil { return err } @@ -451,7 +442,7 @@ func (j *Job) GetJob(args *structs.JobSpecificRequest, reply.Index = out.ModifyIndex } else { // Use the last index that affected the nodes table - index, err := snap.Index("jobs") + index, err := state.Index("jobs") if err != nil { return err } @@ -477,18 +468,14 @@ func (j *Job) List(args *structs.JobListRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Table: "jobs"}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture all the jobs - snap, err := j.srv.fsm.State().Snapshot() - if err != nil { - return err - } + var err error var iter memdb.ResultIterator if prefix := args.QueryOptions.Prefix; prefix != "" { - iter, err = snap.JobsByIDPrefix(prefix) + iter, err = state.JobsByIDPrefix(ws, prefix) } else { - iter, err = snap.Jobs() + iter, err = state.Jobs(ws) } if err != nil { return err @@ -501,7 +488,7 @@ func (j *Job) List(args *structs.JobListRequest, break } job := raw.(*structs.Job) - summary, err := snap.JobSummaryByID(job.ID) + summary, err := state.JobSummaryByID(ws, job.ID) if err != nil { return fmt.Errorf("unable to look up summary for job: %v", job.ID) } @@ -510,7 +497,7 @@ func (j *Job) List(args *structs.JobListRequest, reply.Jobs = jobs // Use the last index that affected the jobs table - index, err := snap.Index("jobs") + index, err := state.Index("jobs") if err != nil { return err } @@ -535,14 +522,9 @@ func (j *Job) Allocations(args *structs.JobSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{AllocJob: args.JobID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture the allocations - snap, err := j.srv.fsm.State().Snapshot() - if err != nil { - return err - } - allocs, err := snap.AllocsByJob(args.JobID, args.AllAllocs) + allocs, err := state.AllocsByJob(ws, args.JobID, args.AllAllocs) if err != nil { return err } @@ -556,7 +538,7 @@ func (j *Job) Allocations(args *structs.JobSpecificRequest, } // Use the last index that affected the allocs table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } @@ -582,21 +564,16 @@ func (j *Job) Evaluations(args *structs.JobSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{EvalJob: args.JobID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture the evals - snap, err := j.srv.fsm.State().Snapshot() - if err != nil { - return err - } - - reply.Evaluations, err = snap.EvalsByJob(args.JobID) + var err error + reply.Evaluations, err = state.EvalsByJob(ws, args.JobID) if err != nil { return err } // Use the last index that affected the evals table - index, err := snap.Index("evals") + index, err := state.Index("evals") if err != nil { return err } @@ -641,7 +618,8 @@ func (j *Job) Plan(args *structs.JobPlanRequest, reply *structs.JobPlanResponse) } // Get the original job - oldJob, err := snap.JobByID(args.Job.ID) + ws := memdb.NewWatchSet() + oldJob, err := snap.JobByID(ws, args.Job.ID) if err != nil { return err } @@ -797,7 +775,8 @@ func (j *Job) Dispatch(args *structs.JobDispatchRequest, reply *structs.JobDispa if err != nil { return err } - parameterizedJob, err := snap.JobByID(args.JobID) + ws := memdb.NewWatchSet() + parameterizedJob, err := snap.JobByID(ws, args.JobID) if err != nil { return err } diff --git a/nomad/job_endpoint_test.go b/nomad/job_endpoint_test.go index e2f17166b..4a9af00e2 100644 --- a/nomad/job_endpoint_test.go +++ b/nomad/job_endpoint_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -39,7 +40,8 @@ func TestJobEndpoint_Register(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -56,7 +58,7 @@ func TestJobEndpoint_Register(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(resp.EvalID) + eval, err := state.EvalByID(ws, resp.EvalID) if err != nil { t.Fatalf("err: %v", err) } @@ -185,7 +187,8 @@ func TestJobEndpoint_Register_Existing(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -200,7 +203,7 @@ func TestJobEndpoint_Register_Existing(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(resp.EvalID) + eval, err := state.EvalByID(ws, resp.EvalID) if err != nil { t.Fatalf("err: %v", err) } @@ -257,7 +260,8 @@ func TestJobEndpoint_Register_Periodic(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -306,7 +310,8 @@ func TestJobEndpoint_Register_ParameterizedJob(t *testing.T) { // Check for the job in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -365,7 +370,8 @@ func TestJobEndpoint_Register_EnforceIndex(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -421,7 +427,7 @@ func TestJobEndpoint_Register_EnforceIndex(t *testing.T) { t.Fatalf("bad index: %d", resp.Index) } - out, err = state.JobByID(job.ID) + out, err = state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -498,7 +504,8 @@ func TestJobEndpoint_Register_Vault_AllowUnauthenticated(t *testing.T) { // Check for the job in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -621,7 +628,8 @@ func TestJobEndpoint_Register_Vault_Policies(t *testing.T) { // Check for the job in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -664,7 +672,7 @@ func TestJobEndpoint_Register_Vault_Policies(t *testing.T) { } // Check for the job in the FSM - out, err = state.JobByID(job2.ID) + out, err = state.JobByID(ws, job2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -719,7 +727,8 @@ func TestJobEndpoint_Evaluate(t *testing.T) { // Lookup the evaluation state := s1.fsm.State() - eval, err := state.EvalByID(resp.EvalID) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, resp.EvalID) if err != nil { t.Fatalf("err: %v", err) } @@ -859,8 +868,9 @@ func TestJobEndpoint_Deregister(t *testing.T) { } // Check for the node in the FSM + ws := memdb.NewWatchSet() state := s1.fsm.State() - out, err := state.JobByID(job.ID) + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -869,7 +879,7 @@ func TestJobEndpoint_Deregister(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(resp2.EvalID) + eval, err := state.EvalByID(ws, resp2.EvalID) if err != nil { t.Fatalf("err: %v", err) } @@ -924,7 +934,8 @@ func TestJobEndpoint_Deregister_NonExistent(t *testing.T) { // Lookup the evaluation state := s1.fsm.State() - eval, err := state.EvalByID(resp2.EvalID) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, resp2.EvalID) if err != nil { t.Fatalf("err: %v", err) } @@ -991,7 +1002,8 @@ func TestJobEndpoint_Deregister_Periodic(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1042,7 +1054,8 @@ func TestJobEndpoint_Deregister_ParameterizedJob(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1294,7 +1307,7 @@ func TestJobEndpoint_GetJob_Blocking(t *testing.T) { JobID: job2.ID, QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } start := time.Now() @@ -1428,7 +1441,7 @@ func TestJobEndpoint_ListJobs_Blocking(t *testing.T) { t.Fatalf("Bad index: %d %d", resp.Index, 100) } if len(resp.Jobs) != 1 || resp.Jobs[0].ID != job.ID { - t.Fatalf("bad: %#v", resp.Jobs) + t.Fatalf("bad: %#v", resp) } // Job deletion triggers watches @@ -1452,7 +1465,7 @@ func TestJobEndpoint_ListJobs_Blocking(t *testing.T) { t.Fatalf("Bad index: %d %d", resp2.Index, 200) } if len(resp2.Jobs) != 0 { - t.Fatalf("bad: %#v", resp2.Jobs) + t.Fatalf("bad: %#v", resp2) } } @@ -1528,7 +1541,7 @@ func TestJobEndpoint_Allocations_Blocking(t *testing.T) { JobID: "job1", QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.JobAllocationsResponse @@ -1616,7 +1629,7 @@ func TestJobEndpoint_Evaluations_Blocking(t *testing.T) { JobID: "job1", QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.JobEvaluationsResponse @@ -1782,7 +1795,8 @@ func TestJobEndpoint_ImplicitConstraints_Vault(t *testing.T) { // Check for the job in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1837,7 +1851,8 @@ func TestJobEndpoint_ImplicitConstraints_Signals(t *testing.T) { // Check for the job in the FSM state := s1.fsm.State() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2078,7 +2093,8 @@ func TestJobEndpoint_Dispatch(t *testing.T) { } state := s1.fsm.State() - out, err := state.JobByID(dispatchResp.DispatchedJobID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, dispatchResp.DispatchedJobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2093,7 +2109,7 @@ func TestJobEndpoint_Dispatch(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(dispatchResp.EvalID) + eval, err := state.EvalByID(ws, dispatchResp.EvalID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/leader.go b/nomad/leader.go index 3307608e6..dc9cd4231 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -7,6 +7,7 @@ import ( "time" "github.com/armon/go-metrics" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/raft" "github.com/hashicorp/serf/serf" @@ -191,7 +192,8 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error { // a leadership transition takes place. func (s *Server) restoreEvals() error { // Get an iterator over every evaluation - iter, err := s.fsm.State().Evals() + ws := memdb.NewWatchSet() + iter, err := s.fsm.State().Evals(ws) if err != nil { return fmt.Errorf("failed to get evaluations: %v", err) } @@ -216,8 +218,9 @@ func (s *Server) restoreEvals() error { // revoked. func (s *Server) restoreRevokingAccessors() error { // An accessor should be revoked if its allocation or node is terminal + ws := memdb.NewWatchSet() state := s.fsm.State() - iter, err := state.VaultAccessors() + iter, err := state.VaultAccessors(ws) if err != nil { return fmt.Errorf("failed to get vault accessors: %v", err) } @@ -232,7 +235,7 @@ func (s *Server) restoreRevokingAccessors() error { va := raw.(*structs.VaultAccessor) // Check the allocation - alloc, err := state.AllocByID(va.AllocID) + alloc, err := state.AllocByID(ws, va.AllocID) if err != nil { return fmt.Errorf("failed to lookup allocation: %v", va.AllocID, err) } @@ -243,7 +246,7 @@ func (s *Server) restoreRevokingAccessors() error { } // Check the node - node, err := state.NodeByID(va.NodeID) + node, err := state.NodeByID(ws, va.NodeID) if err != nil { return fmt.Errorf("failed to lookup node %q: %v", va.NodeID, err) } @@ -269,7 +272,8 @@ func (s *Server) restoreRevokingAccessors() error { // dispatcher is maintained only by the leader, so it must be restored anytime a // leadership transition takes place. func (s *Server) restorePeriodicDispatcher() error { - iter, err := s.fsm.State().JobsByPeriodic(true) + ws := memdb.NewWatchSet() + iter, err := s.fsm.State().JobsByPeriodic(ws, true) if err != nil { return fmt.Errorf("failed to get periodic jobs: %v", err) } @@ -282,7 +286,7 @@ func (s *Server) restorePeriodicDispatcher() error { // If the periodic job has never been launched before, launch will hold // the time the periodic job was added. Otherwise it has the last launch // time of the periodic job. - launch, err := s.fsm.State().PeriodicLaunchByID(job.ID) + launch, err := s.fsm.State().PeriodicLaunchByID(ws, job.ID) if err != nil || launch == nil { return fmt.Errorf("failed to get periodic launch time: %v", err) } diff --git a/nomad/leader_test.go b/nomad/leader_test.go index 71b4e7878..987e716bc 100644 --- a/nomad/leader_test.go +++ b/nomad/leader_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -406,7 +407,8 @@ func TestLeader_PeriodicDispatcher_Restore_NoEvals(t *testing.T) { } // Check that an eval was made. - last, err := s1.fsm.State().PeriodicLaunchByID(job.ID) + ws := memdb.NewWatchSet() + last, err := s1.fsm.State().PeriodicLaunchByID(ws, job.ID) if err != nil || last == nil { t.Fatalf("failed to get periodic launch time: %v", err) } @@ -457,7 +459,8 @@ func TestLeader_PeriodicDispatcher_Restore_Evals(t *testing.T) { } // Check that an eval was made. - last, err := s1.fsm.State().PeriodicLaunchByID(job.ID) + ws := memdb.NewWatchSet() + last, err := s1.fsm.State().PeriodicLaunchByID(ws, job.ID) if err != nil || last == nil { t.Fatalf("failed to get periodic launch time: %v", err) } @@ -508,7 +511,8 @@ func TestLeader_ReapFailedEval(t *testing.T) { // Wait updated evaluation state := s1.fsm.State() testutil.WaitForResult(func() (bool, error) { - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { return false, err } @@ -535,7 +539,8 @@ func TestLeader_ReapDuplicateEval(t *testing.T) { // Wait for the evaluation to marked as cancelled state := s1.fsm.State() testutil.WaitForResult(func() (bool, error) { - out, err := state.EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval2.ID) if err != nil { return false, err } diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index 0646bfc7d..916bd7921 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -14,7 +14,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" "github.com/hashicorp/raft" vapi "github.com/hashicorp/vault/api" ) @@ -103,7 +102,9 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp if err != nil { return err } - originalNode, err := snap.NodeByID(args.Node.ID) + + ws := memdb.NewWatchSet() + originalNode, err := snap.NodeByID(ws, args.Node.ID) if err != nil { return err } @@ -203,7 +204,8 @@ func (n *Node) constructNodeServerInfoResponse(snap *state.StateSnapshot, reply // Snapshot is used only to iterate over all nodes to create a node // count to send back to Nomad Clients in their heartbeat so Clients // can estimate the size of the cluster. - iter, err := snap.Nodes() + ws := memdb.NewWatchSet() + iter, err := snap.Nodes(ws) if err == nil { for { raw := iter.Next() @@ -248,7 +250,8 @@ func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.No } // Determine if there are any Vault accessors on the node - accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID) + ws := memdb.NewWatchSet() + accessors, err := n.srv.State().VaultAccessorsByNode(ws, args.NodeID) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err) return err @@ -289,7 +292,9 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct if err != nil { return err } - node, err := snap.NodeByID(args.NodeID) + + ws := memdb.NewWatchSet() + node, err := snap.NodeByID(ws, args.NodeID) if err != nil { return err } @@ -330,7 +335,7 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct switch args.Status { case structs.NodeStatusDown: // Determine if there are any Vault accessors on the node - accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID) + accessors, err := n.srv.State().VaultAccessorsByNode(ws, args.NodeID) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err) return err @@ -389,7 +394,8 @@ func (n *Node) UpdateDrain(args *structs.NodeUpdateDrainRequest, if err != nil { return err } - node, err := snap.NodeByID(args.NodeID) + ws := memdb.NewWatchSet() + node, err := snap.NodeByID(ws, args.NodeID) if err != nil { return err } @@ -443,7 +449,8 @@ func (n *Node) Evaluate(args *structs.NodeEvaluateRequest, reply *structs.NodeUp if err != nil { return err } - node, err := snap.NodeByID(args.NodeID) + ws := memdb.NewWatchSet() + node, err := snap.NodeByID(ws, args.NodeID) if err != nil { return err } @@ -484,19 +491,14 @@ func (n *Node) GetNode(args *structs.NodeSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Node: args.NodeID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Verify the arguments if args.NodeID == "" { return fmt.Errorf("missing node ID") } // Look for the node - snap, err := n.srv.fsm.State().Snapshot() - if err != nil { - return err - } - out, err := snap.NodeByID(args.NodeID) + out, err := state.NodeByID(ws, args.NodeID) if err != nil { return err } @@ -509,7 +511,7 @@ func (n *Node) GetNode(args *structs.NodeSpecificRequest, reply.Index = out.ModifyIndex } else { // Use the last index that affected the nodes table - index, err := snap.Index("nodes") + index, err := state.Index("nodes") if err != nil { return err } @@ -541,14 +543,9 @@ func (n *Node) GetAllocs(args *structs.NodeSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{AllocNode: args.NodeID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Look for the node - snap, err := n.srv.fsm.State().Snapshot() - if err != nil { - return err - } - allocs, err := snap.AllocsByNode(args.NodeID) + allocs, err := state.AllocsByNode(ws, args.NodeID) if err != nil { return err } @@ -563,7 +560,7 @@ func (n *Node) GetAllocs(args *structs.NodeSpecificRequest, reply.Allocs = nil // Use the last index that affected the nodes table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } @@ -599,16 +596,9 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{AllocNode: args.NodeID}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Look for the node - snap, err := n.srv.fsm.State().Snapshot() - if err != nil { - return err - } - - // Look for the node - node, err := snap.NodeByID(args.NodeID) + node, err := state.NodeByID(ws, args.NodeID) if err != nil { return err } @@ -628,7 +618,7 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, } var err error - allocs, err = snap.AllocsByNode(args.NodeID) + allocs, err = state.AllocsByNode(ws, args.NodeID) if err != nil { return err } @@ -643,7 +633,7 @@ func (n *Node) GetClientAllocs(args *structs.NodeSpecificRequest, } } else { // Use the last index that affected the nodes table - index, err := snap.Index("allocs") + index, err := state.Index("allocs") if err != nil { return err } @@ -734,7 +724,8 @@ func (n *Node) batchUpdate(future *batchFuture, updates []*structs.Allocation) { } // Determine if there are any Vault accessors for the allocation - accessors, err := n.srv.State().VaultAccessorsByAlloc(alloc.ID) + ws := memdb.NewWatchSet() + accessors, err := n.srv.State().VaultAccessorsByAlloc(ws, alloc.ID) if err != nil { n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for alloc %q failed: %v", alloc.ID, err) mErr.Errors = append(mErr.Errors, err) @@ -766,18 +757,14 @@ func (n *Node) List(args *structs.NodeListRequest, opts := blockingOptions{ queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, - watch: watch.NewItems(watch.Item{Table: "nodes"}), - run: func() error { + run: func(ws memdb.WatchSet, state *state.StateStore) error { // Capture all the nodes - snap, err := n.srv.fsm.State().Snapshot() - if err != nil { - return err - } + var err error var iter memdb.ResultIterator if prefix := args.QueryOptions.Prefix; prefix != "" { - iter, err = snap.NodesByIDPrefix(prefix) + iter, err = state.NodesByIDPrefix(ws, prefix) } else { - iter, err = snap.Nodes() + iter, err = state.Nodes(ws) } if err != nil { return err @@ -795,7 +782,7 @@ func (n *Node) List(args *structs.NodeListRequest, reply.Nodes = nodes // Use the last index that affected the jobs table - index, err := snap.Index("nodes") + index, err := state.Index("nodes") if err != nil { return err } @@ -818,12 +805,13 @@ func (n *Node) createNodeEvals(nodeID string, nodeIndex uint64) ([]string, uint6 } // Find all the allocations for this node - allocs, err := snap.AllocsByNode(nodeID) + ws := memdb.NewWatchSet() + allocs, err := snap.AllocsByNode(ws, nodeID) if err != nil { return nil, 0, fmt.Errorf("failed to find allocs for '%s': %v", nodeID, err) } - sysJobsIter, err := snap.JobsByScheduler("system") + sysJobsIter, err := snap.JobsByScheduler(ws, "system") if err != nil { return nil, 0, fmt.Errorf("failed to find system jobs for '%s': %v", nodeID, err) } @@ -985,7 +973,8 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, setErr(err, false) return nil } - node, err := snap.NodeByID(args.NodeID) + ws := memdb.NewWatchSet() + node, err := snap.NodeByID(ws, args.NodeID) if err != nil { setErr(err, false) return nil @@ -999,7 +988,7 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, return nil } - alloc, err := snap.AllocByID(args.AllocID) + alloc, err := snap.AllocByID(ws, args.AllocID) if err != nil { setErr(err, false) return nil diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 749e48b83..df75b8978 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -38,7 +39,8 @@ func TestClientEndpoint_Register(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -85,7 +87,8 @@ func TestClientEndpoint_Register_NoSecret(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -161,7 +164,8 @@ func TestClientEndpoint_Deregister(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -215,7 +219,8 @@ func TestClientEndpoint_Deregister_Vault(t *testing.T) { } // Check for the node in the FSM - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -276,7 +281,8 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -384,7 +390,8 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { } evalID := resp.EvalIDs[0] - eval, err := state.EvalByID(evalID) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, evalID) if err != nil { t.Fatalf("could not get eval %v", evalID) } @@ -394,7 +401,7 @@ func TestClientEndpoint_Register_GetEvals(t *testing.T) { } // Check for the node in the FSM - out, err := state.NodeByID(node.ID) + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -490,7 +497,8 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { } evalID := resp2.EvalIDs[0] - eval, err := state.EvalByID(evalID) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, evalID) if err != nil { t.Fatalf("could not get eval %v", evalID) } @@ -506,7 +514,7 @@ func TestClientEndpoint_UpdateStatus_GetEvals(t *testing.T) { } // Check for the node in the FSM - out, err := state.NodeByID(node.ID) + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -627,7 +635,8 @@ func TestClientEndpoint_UpdateDrain(t *testing.T) { // Check for the node in the FSM state := s1.fsm.State() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -683,11 +692,12 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { // Wait for the scheduler to create an allocation testutil.WaitForResult(func() (bool, error) { - allocs, err := s1.fsm.state.AllocsByJob(job.ID, true) + ws := memdb.NewWatchSet() + allocs, err := s1.fsm.state.AllocsByJob(ws, job.ID, true) if err != nil { return false, err } - allocs1, err := s1.fsm.state.AllocsByJob(job1.ID, true) + allocs1, err := s1.fsm.state.AllocsByJob(ws, job1.ID, true) if err != nil { return false, err } @@ -719,7 +729,8 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { // Ensure that the allocation has transitioned to lost testutil.WaitForResult(func() (bool, error) { - summary, err := s1.fsm.state.JobSummaryByID(job.ID) + ws := memdb.NewWatchSet() + summary, err := s1.fsm.state.JobSummaryByID(ws, job.ID) if err != nil { return false, err } @@ -739,7 +750,7 @@ func TestClientEndpoint_Drain_Down(t *testing.T) { return false, fmt.Errorf("expected: %#v, actual: %#v", expectedSummary, summary) } - summary1, err := s1.fsm.state.JobSummaryByID(job1.ID) + summary1, err := s1.fsm.state.JobSummaryByID(ws, job1.ID) if err != nil { return false, err } @@ -851,7 +862,7 @@ func TestClientEndpoint_GetNode_Blocking(t *testing.T) { NodeID: node2.ID, QueryOptions: structs.QueryOptions{ Region: "global", - MinQueryIndex: 50, + MinQueryIndex: 150, }, } var resp structs.SingleNodeResponse @@ -1289,7 +1300,8 @@ func TestClientEndpoint_UpdateAlloc(t *testing.T) { } // Lookup the alloc - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1344,7 +1356,8 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) { } // Lookup the alloc - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1415,7 +1428,8 @@ func TestClientEndpoint_UpdateAlloc_Vault(t *testing.T) { } // Lookup the alloc - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1460,9 +1474,10 @@ func TestClientEndpoint_CreateNodeEvals(t *testing.T) { } // Lookup the evaluations + ws := memdb.NewWatchSet() evalByType := make(map[string]*structs.Evaluation, 2) for _, id := range ids { - eval, err := state.EvalByID(id) + eval, err := state.EvalByID(ws, id) if err != nil { t.Fatalf("err: %v", err) } @@ -1559,7 +1574,8 @@ func TestClientEndpoint_Evaluate(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(ids[0]) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, ids[0]) if err != nil { t.Fatalf("err: %v", err) } @@ -1936,7 +1952,8 @@ func TestClientEndpoint_DeriveVaultToken(t *testing.T) { } // Check the state store and ensure that we created a VaultAccessor - va, err := state.VaultAccessor(accessor) + ws := memdb.NewWatchSet() + va, err := state.VaultAccessor(ws, accessor) if err != nil { t.Fatalf("bad: %v", err) } diff --git a/nomad/periodic.go b/nomad/periodic.go index 09dadfdff..b06267585 100644 --- a/nomad/periodic.go +++ b/nomad/periodic.go @@ -9,6 +9,7 @@ import ( "sync" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" ) @@ -86,8 +87,9 @@ func (s *Server) RunningChildren(job *structs.Job) (bool, error) { return false, err } + ws := memdb.NewWatchSet() prefix := fmt.Sprintf("%s%s", job.ID, structs.PeriodicLaunchSuffix) - iter, err := state.JobsByIDPrefix(prefix) + iter, err := state.JobsByIDPrefix(ws, prefix) if err != nil { return false, err } @@ -102,7 +104,7 @@ func (s *Server) RunningChildren(job *structs.Job) (bool, error) { } // Get the childs evaluations. - evals, err := state.EvalsByJob(child.ID) + evals, err := state.EvalsByJob(ws, child.ID) if err != nil { return false, err } @@ -113,7 +115,7 @@ func (s *Server) RunningChildren(job *structs.Job) (bool, error) { return true, nil } - allocs, err := state.AllocsByEval(eval.ID) + allocs, err := state.AllocsByEval(ws, eval.ID) if err != nil { return false, err } diff --git a/nomad/periodic_endpoint.go b/nomad/periodic_endpoint.go index f8de4ae00..b172c4658 100644 --- a/nomad/periodic_endpoint.go +++ b/nomad/periodic_endpoint.go @@ -5,6 +5,7 @@ import ( "time" "github.com/armon/go-metrics" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" ) @@ -30,7 +31,9 @@ func (p *Periodic) Force(args *structs.PeriodicForceRequest, reply *structs.Peri if err != nil { return err } - job, err := snap.JobByID(args.JobID) + + ws := memdb.NewWatchSet() + job, err := snap.JobByID(ws, args.JobID) if err != nil { return err } diff --git a/nomad/periodic_endpoint_test.go b/nomad/periodic_endpoint_test.go index 295070162..fb6cbcec2 100644 --- a/nomad/periodic_endpoint_test.go +++ b/nomad/periodic_endpoint_test.go @@ -3,6 +3,7 @@ package nomad import ( "testing" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -42,7 +43,8 @@ func TestPeriodicEndpoint_Force(t *testing.T) { } // Lookup the evaluation - eval, err := state.EvalByID(resp.EvalID) + ws := memdb.NewWatchSet() + eval, err := state.EvalByID(ws, resp.EvalID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/plan_apply.go b/nomad/plan_apply.go index c094f16e8..5262eb94e 100644 --- a/nomad/plan_apply.go +++ b/nomad/plan_apply.go @@ -6,6 +6,7 @@ import ( "time" "github.com/armon/go-metrics" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" @@ -322,7 +323,8 @@ func evaluateNodePlan(snap *state.StateSnapshot, plan *structs.Plan, nodeID stri } // Get the node itself - node, err := snap.NodeByID(nodeID) + ws := memdb.NewWatchSet() + node, err := snap.NodeByID(ws, nodeID) if err != nil { return false, fmt.Errorf("failed to get node '%s': %v", nodeID, err) } @@ -335,7 +337,7 @@ func evaluateNodePlan(snap *state.StateSnapshot, plan *structs.Plan, nodeID stri } // Get the existing allocations that are non-terminal - existingAlloc, err := snap.AllocsByNodeTerminal(nodeID, false) + existingAlloc, err := snap.AllocsByNodeTerminal(ws, nodeID, false) if err != nil { return false, fmt.Errorf("failed to get existing allocations for '%s': %v", nodeID, err) } diff --git a/nomad/plan_apply_test.go b/nomad/plan_apply_test.go index 2584556b0..7b0e3e659 100644 --- a/nomad/plan_apply_test.go +++ b/nomad/plan_apply_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -88,7 +89,8 @@ func TestPlanApply_applyPlan(t *testing.T) { } // Verify our optimistic snapshot is updated - if out, err := snap.AllocByID(alloc.ID); err != nil || out == nil { + ws := memdb.NewWatchSet() + if out, err := snap.AllocByID(ws, alloc.ID); err != nil || out == nil { t.Fatalf("bad: %v %v", out, err) } @@ -102,7 +104,7 @@ func TestPlanApply_applyPlan(t *testing.T) { } // Lookup the allocation - out, err := s1.fsm.State().AllocByID(alloc.ID) + out, err := s1.fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -141,7 +143,7 @@ func TestPlanApply_applyPlan(t *testing.T) { } // Check that our optimistic view is updated - if out, _ := snap.AllocByID(allocEvict.ID); out.DesiredStatus != structs.AllocDesiredStatusEvict { + if out, _ := snap.AllocByID(ws, allocEvict.ID); out.DesiredStatus != structs.AllocDesiredStatusEvict { t.Fatalf("bad: %#v", out) } @@ -155,7 +157,7 @@ func TestPlanApply_applyPlan(t *testing.T) { } // Lookup the allocation - out, err = s1.fsm.State().AllocByID(alloc.ID) + out, err = s1.fsm.State().AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -167,7 +169,7 @@ func TestPlanApply_applyPlan(t *testing.T) { } // Lookup the allocation - out, err = s1.fsm.State().AllocByID(alloc2.ID) + out, err = s1.fsm.State().AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } diff --git a/nomad/rpc.go b/nomad/rpc.go index f9a0ebd17..7aaf225c0 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -12,10 +12,10 @@ import ( "github.com/armon/go-metrics" "github.com/hashicorp/consul/lib" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" "github.com/hashicorp/raft" "github.com/hashicorp/yamux" ) @@ -321,19 +321,24 @@ func (s *Server) setQueryMeta(m *structs.QueryMeta) { } } +// queryFn is used to perform a query operation. If a re-query is needed, the +// passed-in watch set will be used to block for changes. The passed-in state +// store should be used (vs. calling fsm.State()) since the given state store +// will be correctly watched for changes if the state store is restored from +// a snapshot. +type queryFn func(memdb.WatchSet, *state.StateStore) error + // blockingOptions is used to parameterize blockingRPC type blockingOptions struct { queryOpts *structs.QueryOptions queryMeta *structs.QueryMeta - watch watch.Items - run func() error + run queryFn } // blockingRPC is used for queries that need to wait for a // minimum index. This is used to block and wait for changes. func (s *Server) blockingRPC(opts *blockingOptions) error { var timeout *time.Timer - var notifyCh chan struct{} var state *state.StateStore // Fast path non-blocking @@ -353,36 +358,40 @@ func (s *Server) blockingRPC(opts *blockingOptions) error { // Setup a query timeout timeout = time.NewTimer(opts.queryOpts.MaxQueryTime) - - // Setup the notify channel - notifyCh = make(chan struct{}, 1) - - // Ensure we tear down any watchers on return - state = s.fsm.State() - defer func() { - timeout.Stop() - state.StopWatch(opts.watch, notifyCh) - }() - -REGISTER_NOTIFY: - // Register the notification channel. This may be done - // multiple times if we have not reached the target wait index. - state.Watch(opts.watch, notifyCh) + defer timeout.Stop() RUN_QUERY: // Update the query meta data s.setQueryMeta(opts.queryMeta) - // Run the query function + // Increment the rpc query counter metrics.IncrCounter([]string{"nomad", "rpc", "query"}, 1) - err := opts.run() + + // We capture the state store and its abandon channel but pass a snapshot to + // the blocking query function. We operate on the snapshot to allow separate + // calls to the state store not all wrapped within the same transaction. + state = s.fsm.State() + abandonCh := state.AbandonCh() + snap, _ := state.Snapshot() + stateSnap := &snap.StateStore + + // We can skip all watch tracking if this isn't a blocking query. + var ws memdb.WatchSet + if opts.queryOpts.MinQueryIndex > 0 { + ws = memdb.NewWatchSet() + + // This channel will be closed if a snapshot is restored and the + // whole state store is abandoned. + ws.Add(abandonCh) + } + + // Block up to the timeout if we didn't see anything fresh. + err := opts.run(ws, stateSnap) // Check for minimum query time if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { - select { - case <-notifyCh: - goto REGISTER_NOTIFY - case <-timeout.C: + if expired := ws.Watch(timeout.C); !expired { + goto RUN_QUERY } } return err diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 65ae22dea..a973678ce 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -4,11 +4,9 @@ import ( "fmt" "io" "log" - "sync" "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" ) // IndexEntry is used with the "index" table @@ -28,7 +26,10 @@ type IndexEntry struct { type StateStore struct { logger *log.Logger db *memdb.MemDB - watch *stateWatch + + // abandonCh is used to signal watchers that this state store has been + // abandoned (usually during a restore). This is only ever closed. + abandonCh chan struct{} } // NewStateStore is used to create a new state store @@ -41,9 +42,9 @@ func NewStateStore(logOutput io.Writer) (*StateStore, error) { // Create the state store s := &StateStore{ - logger: log.New(logOutput, "", log.LstdFlags), - db: db, - watch: newStateWatch(), + logger: log.New(logOutput, "", log.LstdFlags), + db: db, + abandonCh: make(chan struct{}), } return s, nil } @@ -56,7 +57,6 @@ func (s *StateStore) Snapshot() (*StateSnapshot, error) { StateStore: StateStore{ logger: s.logger, db: s.db.Snapshot(), - watch: s.watch, }, } return snap, nil @@ -68,21 +68,21 @@ func (s *StateStore) Snapshot() (*StateSnapshot, error) { func (s *StateStore) Restore() (*StateRestore, error) { txn := s.db.Txn(true) r := &StateRestore{ - txn: txn, - watch: s.watch, - items: watch.NewItems(), + txn: txn, } return r, nil } -// Watch subscribes a channel to a set of watch items. -func (s *StateStore) Watch(items watch.Items, notify chan struct{}) { - s.watch.watch(items, notify) +// AbandonCh returns a channel you can wait on to know if the state store was +// abandoned. +func (s *StateStore) AbandonCh() <-chan struct{} { + return s.abandonCh } -// StopWatch unsubscribes a channel from a set of watch items. -func (s *StateStore) StopWatch(items watch.Items, notify chan struct{}) { - s.watch.stopWatch(items, notify) +// Abandon is used to signal that the given state store has been abandoned. +// Calling this more than one time will panic. +func (s *StateStore) Abandon() { + close(s.abandonCh) } // UpsertJobSummary upserts a job summary into the state store. @@ -128,10 +128,6 @@ func (s *StateStore) UpsertNode(index uint64, node *structs.Node) error { txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "nodes"}) - watcher.Add(watch.Item{Node: node.ID}) - // Check if the node already exists existing, err := txn.First("nodes", "id", node.ID) if err != nil { @@ -157,7 +153,6 @@ func (s *StateStore) UpsertNode(index uint64, node *structs.Node) error { return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } @@ -176,10 +171,6 @@ func (s *StateStore) DeleteNode(index uint64, nodeID string) error { return fmt.Errorf("node not found") } - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "nodes"}) - watcher.Add(watch.Item{Node: nodeID}) - // Delete the node if err := txn.Delete("nodes", existing); err != nil { return fmt.Errorf("node delete failed: %v", err) @@ -188,7 +179,6 @@ func (s *StateStore) DeleteNode(index uint64, nodeID string) error { return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } @@ -198,10 +188,6 @@ func (s *StateStore) UpdateNodeStatus(index uint64, nodeID, status string) error txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "nodes"}) - watcher.Add(watch.Item{Node: nodeID}) - // Lookup the node existing, err := txn.First("nodes", "id", nodeID) if err != nil { @@ -228,7 +214,6 @@ func (s *StateStore) UpdateNodeStatus(index uint64, nodeID, status string) error return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } @@ -238,10 +223,6 @@ func (s *StateStore) UpdateNodeDrain(index uint64, nodeID string, drain bool) er txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "nodes"}) - watcher.Add(watch.Item{Node: nodeID}) - // Lookup the node existing, err := txn.First("nodes", "id", nodeID) if err != nil { @@ -268,19 +249,19 @@ func (s *StateStore) UpdateNodeDrain(index uint64, nodeID string, drain bool) er return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // NodeByID is used to lookup a node by ID -func (s *StateStore) NodeByID(nodeID string) (*structs.Node, error) { +func (s *StateStore) NodeByID(ws memdb.WatchSet, nodeID string) (*structs.Node, error) { txn := s.db.Txn(false) - existing, err := txn.First("nodes", "id", nodeID) + watchCh, existing, err := txn.FirstWatch("nodes", "id", nodeID) if err != nil { return nil, fmt.Errorf("node lookup failed: %v", err) } + ws.Add(watchCh) if existing != nil { return existing.(*structs.Node), nil @@ -289,19 +270,20 @@ func (s *StateStore) NodeByID(nodeID string) (*structs.Node, error) { } // NodesByIDPrefix is used to lookup nodes by prefix -func (s *StateStore) NodesByIDPrefix(nodeID string) (memdb.ResultIterator, error) { +func (s *StateStore) NodesByIDPrefix(ws memdb.WatchSet, nodeID string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("nodes", "id_prefix", nodeID) if err != nil { return nil, fmt.Errorf("node lookup failed: %v", err) } + ws.Add(iter.WatchCh()) return iter, nil } // Nodes returns an iterator over all the nodes -func (s *StateStore) Nodes() (memdb.ResultIterator, error) { +func (s *StateStore) Nodes(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Walk the entire nodes table @@ -309,6 +291,7 @@ func (s *StateStore) Nodes() (memdb.ResultIterator, error) { if err != nil { return nil, err } + ws.Add(iter.WatchCh()) return iter, nil } @@ -317,10 +300,6 @@ func (s *StateStore) UpsertJob(index uint64, job *structs.Job) error { txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "jobs"}) - watcher.Add(watch.Item{Job: job.ID}) - // Check if the job already exists existing, err := txn.First("jobs", "id", job.ID) if err != nil { @@ -344,7 +323,7 @@ func (s *StateStore) UpsertJob(index uint64, job *structs.Job) error { job.ModifyIndex = index job.JobModifyIndex = index - if err := s.setJobStatus(index, watcher, txn, job, false, ""); err != nil { + if err := s.setJobStatus(index, txn, job, false, ""); err != nil { return fmt.Errorf("setting job status for %q failed: %v", job.ID, err) } @@ -358,7 +337,7 @@ func (s *StateStore) UpsertJob(index uint64, job *structs.Job) error { } } - if err := s.updateSummaryWithJob(index, job, watcher, txn); err != nil { + if err := s.updateSummaryWithJob(index, job, txn); err != nil { return fmt.Errorf("unable to create job summary: %v", err) } @@ -374,7 +353,6 @@ func (s *StateStore) UpsertJob(index uint64, job *structs.Job) error { return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } @@ -393,12 +371,6 @@ func (s *StateStore) DeleteJob(index uint64, jobID string) error { return fmt.Errorf("job not found") } - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "jobs"}) - watcher.Add(watch.Item{Job: jobID}) - watcher.Add(watch.Item{Table: "job_summary"}) - watcher.Add(watch.Item{JobSummary: jobID}) - // Check if we should update a parent job summary job := existing.(*structs.Job) if job.ParentID != "" { @@ -433,9 +405,6 @@ func (s *StateStore) DeleteJob(index uint64, jobID string) error { // Update the modify index pSummary.ModifyIndex = index - watcher.Add(watch.Item{Table: "job_summary"}) - watcher.Add(watch.Item{JobSummary: job.ParentID}) - // Insert the summary if err := txn.Insert("job_summary", pSummary); err != nil { return fmt.Errorf("job summary insert failed: %v", err) @@ -464,19 +433,19 @@ func (s *StateStore) DeleteJob(index uint64, jobID string) error { return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // JobByID is used to lookup a job by its ID -func (s *StateStore) JobByID(id string) (*structs.Job, error) { +func (s *StateStore) JobByID(ws memdb.WatchSet, id string) (*structs.Job, error) { txn := s.db.Txn(false) - existing, err := txn.First("jobs", "id", id) + watchCh, existing, err := txn.FirstWatch("jobs", "id", id) if err != nil { return nil, fmt.Errorf("job lookup failed: %v", err) } + ws.Add(watchCh) if existing != nil { return existing.(*structs.Job), nil @@ -485,7 +454,7 @@ func (s *StateStore) JobByID(id string) (*structs.Job, error) { } // JobsByIDPrefix is used to lookup a job by prefix -func (s *StateStore) JobsByIDPrefix(id string) (memdb.ResultIterator, error) { +func (s *StateStore) JobsByIDPrefix(ws memdb.WatchSet, id string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("jobs", "id_prefix", id) @@ -493,11 +462,13 @@ func (s *StateStore) JobsByIDPrefix(id string) (memdb.ResultIterator, error) { return nil, fmt.Errorf("job lookup failed: %v", err) } + ws.Add(iter.WatchCh()) + return iter, nil } // Jobs returns an iterator over all the jobs -func (s *StateStore) Jobs() (memdb.ResultIterator, error) { +func (s *StateStore) Jobs(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Walk the entire jobs table @@ -505,23 +476,29 @@ func (s *StateStore) Jobs() (memdb.ResultIterator, error) { if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // JobsByPeriodic returns an iterator over all the periodic or non-periodic jobs. -func (s *StateStore) JobsByPeriodic(periodic bool) (memdb.ResultIterator, error) { +func (s *StateStore) JobsByPeriodic(ws memdb.WatchSet, periodic bool) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("jobs", "periodic", periodic) if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // JobsByScheduler returns an iterator over all the jobs with the specific // scheduler type. -func (s *StateStore) JobsByScheduler(schedulerType string) (memdb.ResultIterator, error) { +func (s *StateStore) JobsByScheduler(ws memdb.WatchSet, schedulerType string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Return an iterator for jobs with the specific type. @@ -529,29 +506,38 @@ func (s *StateStore) JobsByScheduler(schedulerType string) (memdb.ResultIterator if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // JobsByGC returns an iterator over all jobs eligible or uneligible for garbage // collection. -func (s *StateStore) JobsByGC(gc bool) (memdb.ResultIterator, error) { +func (s *StateStore) JobsByGC(ws memdb.WatchSet, gc bool) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("jobs", "gc", gc) if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // JobSummary returns a job summary object which matches a specific id. -func (s *StateStore) JobSummaryByID(jobID string) (*structs.JobSummary, error) { +func (s *StateStore) JobSummaryByID(ws memdb.WatchSet, jobID string) (*structs.JobSummary, error) { txn := s.db.Txn(false) - existing, err := txn.First("job_summary", "id", jobID) + watchCh, existing, err := txn.FirstWatch("job_summary", "id", jobID) if err != nil { return nil, err } + + ws.Add(watchCh) + if existing != nil { summary := existing.(*structs.JobSummary) return summary, nil @@ -562,18 +548,21 @@ func (s *StateStore) JobSummaryByID(jobID string) (*structs.JobSummary, error) { // JobSummaries walks the entire job summary table and returns all the job // summary objects -func (s *StateStore) JobSummaries() (memdb.ResultIterator, error) { +func (s *StateStore) JobSummaries(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("job_summary", "id") if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // JobSummaryByPrefix is used to look up Job Summary by id prefix -func (s *StateStore) JobSummaryByPrefix(id string) (memdb.ResultIterator, error) { +func (s *StateStore) JobSummaryByPrefix(ws memdb.WatchSet, id string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("job_summary", "id_prefix", id) @@ -581,6 +570,8 @@ func (s *StateStore) JobSummaryByPrefix(id string) (memdb.ResultIterator, error) return nil, fmt.Errorf("eval lookup failed: %v", err) } + ws.Add(iter.WatchCh()) + return iter, nil } @@ -589,10 +580,6 @@ func (s *StateStore) UpsertPeriodicLaunch(index uint64, launch *structs.Periodic txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "periodic_launch"}) - watcher.Add(watch.Item{Job: launch.ID}) - // Check if the job already exists existing, err := txn.First("periodic_launch", "id", launch.ID) if err != nil { @@ -616,7 +603,6 @@ func (s *StateStore) UpsertPeriodicLaunch(index uint64, launch *structs.Periodic return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } @@ -635,10 +621,6 @@ func (s *StateStore) DeletePeriodicLaunch(index uint64, jobID string) error { return fmt.Errorf("launch not found") } - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "periodic_launch"}) - watcher.Add(watch.Item{Job: jobID}) - // Delete the launch if err := txn.Delete("periodic_launch", existing); err != nil { return fmt.Errorf("launch delete failed: %v", err) @@ -647,21 +629,22 @@ func (s *StateStore) DeletePeriodicLaunch(index uint64, jobID string) error { return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // PeriodicLaunchByID is used to lookup a periodic launch by the periodic job // ID. -func (s *StateStore) PeriodicLaunchByID(id string) (*structs.PeriodicLaunch, error) { +func (s *StateStore) PeriodicLaunchByID(ws memdb.WatchSet, id string) (*structs.PeriodicLaunch, error) { txn := s.db.Txn(false) - existing, err := txn.First("periodic_launch", "id", id) + watchCh, existing, err := txn.FirstWatch("periodic_launch", "id", id) if err != nil { return nil, fmt.Errorf("periodic launch lookup failed: %v", err) } + ws.Add(watchCh) + if existing != nil { return existing.(*structs.PeriodicLaunch), nil } @@ -669,7 +652,7 @@ func (s *StateStore) PeriodicLaunchByID(id string) (*structs.PeriodicLaunch, err } // PeriodicLaunches returns an iterator over all the periodic launches -func (s *StateStore) PeriodicLaunches() (memdb.ResultIterator, error) { +func (s *StateStore) PeriodicLaunches(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Walk the entire table @@ -677,6 +660,9 @@ func (s *StateStore) PeriodicLaunches() (memdb.ResultIterator, error) { if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } @@ -685,15 +671,10 @@ func (s *StateStore) UpsertEvals(index uint64, evals []*structs.Evaluation) erro txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "evals"}) - // Do a nested upsert jobs := make(map[string]string, len(evals)) for _, eval := range evals { - watcher.Add(watch.Item{Eval: eval.ID}) - watcher.Add(watch.Item{EvalJob: eval.JobID}) - if err := s.nestedUpsertEval(txn, watcher, index, eval); err != nil { + if err := s.nestedUpsertEval(txn, index, eval); err != nil { return err } @@ -701,17 +682,16 @@ func (s *StateStore) UpsertEvals(index uint64, evals []*structs.Evaluation) erro } // Set the job's status - if err := s.setJobStatuses(index, watcher, txn, jobs, false); err != nil { + if err := s.setJobStatuses(index, txn, jobs, false); err != nil { return fmt.Errorf("setting job status failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // nestedUpsertEvaluation is used to nest an evaluation upsert within a transaction -func (s *StateStore) nestedUpsertEval(txn *memdb.Txn, watcher watch.Items, index uint64, eval *structs.Evaluation) error { +func (s *StateStore) nestedUpsertEval(txn *memdb.Txn, index uint64, eval *structs.Evaluation) error { // Lookup the evaluation existing, err := txn.First("evals", "id", eval.ID) if err != nil { @@ -785,8 +765,6 @@ func (s *StateStore) nestedUpsertEval(txn *memdb.Txn, watcher watch.Items, index if err := txn.Insert("evals", newEval); err != nil { return fmt.Errorf("eval insert failed: %v", err) } - - watcher.Add(watch.Item{Eval: newEval.ID}) } } @@ -804,9 +782,6 @@ func (s *StateStore) nestedUpsertEval(txn *memdb.Txn, watcher watch.Items, index func (s *StateStore) DeleteEval(index uint64, evals []string, allocs []string) error { txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "evals"}) - watcher.Add(watch.Item{Table: "allocs"}) jobs := make(map[string]string, len(evals)) for _, eval := range evals { @@ -821,8 +796,6 @@ func (s *StateStore) DeleteEval(index uint64, evals []string, allocs []string) e return fmt.Errorf("eval delete failed: %v", err) } jobID := existing.(*structs.Evaluation).JobID - watcher.Add(watch.Item{Eval: eval}) - watcher.Add(watch.Item{EvalJob: jobID}) jobs[jobID] = "" } @@ -837,11 +810,6 @@ func (s *StateStore) DeleteEval(index uint64, evals []string, allocs []string) e if err := txn.Delete("allocs", existing); err != nil { return fmt.Errorf("alloc delete failed: %v", err) } - realAlloc := existing.(*structs.Allocation) - watcher.Add(watch.Item{Alloc: realAlloc.ID}) - watcher.Add(watch.Item{AllocEval: realAlloc.EvalID}) - watcher.Add(watch.Item{AllocJob: realAlloc.JobID}) - watcher.Add(watch.Item{AllocNode: realAlloc.NodeID}) } // Update the indexes @@ -853,24 +821,25 @@ func (s *StateStore) DeleteEval(index uint64, evals []string, allocs []string) e } // Set the job's status - if err := s.setJobStatuses(index, watcher, txn, jobs, true); err != nil { + if err := s.setJobStatuses(index, txn, jobs, true); err != nil { return fmt.Errorf("setting job status failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // EvalByID is used to lookup an eval by its ID -func (s *StateStore) EvalByID(id string) (*structs.Evaluation, error) { +func (s *StateStore) EvalByID(ws memdb.WatchSet, id string) (*structs.Evaluation, error) { txn := s.db.Txn(false) - existing, err := txn.First("evals", "id", id) + watchCh, existing, err := txn.FirstWatch("evals", "id", id) if err != nil { return nil, fmt.Errorf("eval lookup failed: %v", err) } + ws.Add(watchCh) + if existing != nil { return existing.(*structs.Evaluation), nil } @@ -878,7 +847,7 @@ func (s *StateStore) EvalByID(id string) (*structs.Evaluation, error) { } // EvalsByIDPrefix is used to lookup evaluations by prefix -func (s *StateStore) EvalsByIDPrefix(id string) (memdb.ResultIterator, error) { +func (s *StateStore) EvalsByIDPrefix(ws memdb.WatchSet, id string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("evals", "id_prefix", id) @@ -886,11 +855,13 @@ func (s *StateStore) EvalsByIDPrefix(id string) (memdb.ResultIterator, error) { return nil, fmt.Errorf("eval lookup failed: %v", err) } + ws.Add(iter.WatchCh()) + return iter, nil } // EvalsByJob returns all the evaluations by job id -func (s *StateStore) EvalsByJob(jobID string) ([]*structs.Evaluation, error) { +func (s *StateStore) EvalsByJob(ws memdb.WatchSet, jobID string) ([]*structs.Evaluation, error) { txn := s.db.Txn(false) // Get an iterator over the node allocations @@ -899,6 +870,8 @@ func (s *StateStore) EvalsByJob(jobID string) ([]*structs.Evaluation, error) { return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.Evaluation for { raw := iter.Next() @@ -911,7 +884,7 @@ func (s *StateStore) EvalsByJob(jobID string) ([]*structs.Evaluation, error) { } // Evals returns an iterator over all the evaluations -func (s *StateStore) Evals() (memdb.ResultIterator, error) { +func (s *StateStore) Evals(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Walk the entire table @@ -919,11 +892,13 @@ func (s *StateStore) Evals() (memdb.ResultIterator, error) { if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // UpdateAllocsFromClient is used to update an allocation based on input - // from a client. While the schedulers are the authority on the allocation for // most things, some updates are authoritative from the client. Specifically, // the desired state comes from the schedulers, while the actual state comes @@ -932,13 +907,9 @@ func (s *StateStore) UpdateAllocsFromClient(index uint64, allocs []*structs.Allo txn := s.db.Txn(true) defer txn.Abort() - // Setup the watcher - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "allocs"}) - // Handle each of the updated allocations for _, alloc := range allocs { - if err := s.nestedUpdateAllocFromClient(txn, watcher, index, alloc); err != nil { + if err := s.nestedUpdateAllocFromClient(txn, index, alloc); err != nil { return err } } @@ -948,13 +919,12 @@ func (s *StateStore) UpdateAllocsFromClient(index uint64, allocs []*structs.Allo return fmt.Errorf("index update failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // nestedUpdateAllocFromClient is used to nest an update of an allocation with client status -func (s *StateStore) nestedUpdateAllocFromClient(txn *memdb.Txn, watcher watch.Items, index uint64, alloc *structs.Allocation) error { +func (s *StateStore) nestedUpdateAllocFromClient(txn *memdb.Txn, index uint64, alloc *structs.Allocation) error { // Look for existing alloc existing, err := txn.First("allocs", "id", alloc.ID) if err != nil { @@ -967,15 +937,8 @@ func (s *StateStore) nestedUpdateAllocFromClient(txn *memdb.Txn, watcher watch.I } exist := existing.(*structs.Allocation) - // Trigger the watcher - watcher.Add(watch.Item{Alloc: alloc.ID}) - watcher.Add(watch.Item{AllocEval: exist.EvalID}) - watcher.Add(watch.Item{AllocJob: exist.JobID}) - watcher.Add(watch.Item{AllocNode: exist.NodeID}) - // Copy everything from the existing allocation - copyAlloc := new(structs.Allocation) - *copyAlloc = *exist + copyAlloc := exist.Copy() // Pull in anything the client is the authority on copyAlloc.ClientStatus = alloc.ClientStatus @@ -985,7 +948,7 @@ func (s *StateStore) nestedUpdateAllocFromClient(txn *memdb.Txn, watcher watch.I // Update the modify index copyAlloc.ModifyIndex = index - if err := s.updateSummaryWithAlloc(index, copyAlloc, exist, watcher, txn); err != nil { + if err := s.updateSummaryWithAlloc(index, copyAlloc, exist, txn); err != nil { return fmt.Errorf("error updating job summary: %v", err) } @@ -1000,7 +963,7 @@ func (s *StateStore) nestedUpdateAllocFromClient(txn *memdb.Txn, watcher watch.I forceStatus = structs.JobStatusRunning } jobs := map[string]string{exist.JobID: forceStatus} - if err := s.setJobStatuses(index, watcher, txn, jobs, false); err != nil { + if err := s.setJobStatuses(index, txn, jobs, false); err != nil { return fmt.Errorf("setting job status failed: %v", err) } return nil @@ -1012,9 +975,6 @@ func (s *StateStore) UpsertAllocs(index uint64, allocs []*structs.Allocation) er txn := s.db.Txn(true) defer txn.Abort() - watcher := watch.NewItems() - watcher.Add(watch.Item{Table: "allocs"}) - // Handle the allocations jobs := make(map[string]string, 1) for _, alloc := range allocs { @@ -1046,7 +1006,7 @@ func (s *StateStore) UpsertAllocs(index uint64, allocs []*structs.Allocation) er } } - if err := s.updateSummaryWithAlloc(index, alloc, exist, watcher, txn); err != nil { + if err := s.updateSummaryWithAlloc(index, alloc, exist, txn); err != nil { return fmt.Errorf("error updating job summary: %v", err) } @@ -1066,11 +1026,6 @@ func (s *StateStore) UpsertAllocs(index uint64, allocs []*structs.Allocation) er forceStatus = structs.JobStatusRunning } jobs[alloc.JobID] = forceStatus - - watcher.Add(watch.Item{Alloc: alloc.ID}) - watcher.Add(watch.Item{AllocEval: alloc.EvalID}) - watcher.Add(watch.Item{AllocJob: alloc.JobID}) - watcher.Add(watch.Item{AllocNode: alloc.NodeID}) } // Update the indexes @@ -1079,24 +1034,25 @@ func (s *StateStore) UpsertAllocs(index uint64, allocs []*structs.Allocation) er } // Set the job's status - if err := s.setJobStatuses(index, watcher, txn, jobs, false); err != nil { + if err := s.setJobStatuses(index, txn, jobs, false); err != nil { return fmt.Errorf("setting job status failed: %v", err) } - txn.Defer(func() { s.watch.notify(watcher) }) txn.Commit() return nil } // AllocByID is used to lookup an allocation by its ID -func (s *StateStore) AllocByID(id string) (*structs.Allocation, error) { +func (s *StateStore) AllocByID(ws memdb.WatchSet, id string) (*structs.Allocation, error) { txn := s.db.Txn(false) - existing, err := txn.First("allocs", "id", id) + watchCh, existing, err := txn.FirstWatch("allocs", "id", id) if err != nil { return nil, fmt.Errorf("alloc lookup failed: %v", err) } + ws.Add(watchCh) + if existing != nil { return existing.(*structs.Allocation), nil } @@ -1104,7 +1060,7 @@ func (s *StateStore) AllocByID(id string) (*structs.Allocation, error) { } // AllocsByIDPrefix is used to lookup allocs by prefix -func (s *StateStore) AllocsByIDPrefix(id string) (memdb.ResultIterator, error) { +func (s *StateStore) AllocsByIDPrefix(ws memdb.WatchSet, id string) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("allocs", "id_prefix", id) @@ -1112,11 +1068,13 @@ func (s *StateStore) AllocsByIDPrefix(id string) (memdb.ResultIterator, error) { return nil, fmt.Errorf("alloc lookup failed: %v", err) } + ws.Add(iter.WatchCh()) + return iter, nil } // AllocsByNode returns all the allocations by node -func (s *StateStore) AllocsByNode(node string) ([]*structs.Allocation, error) { +func (s *StateStore) AllocsByNode(ws memdb.WatchSet, node string) ([]*structs.Allocation, error) { txn := s.db.Txn(false) // Get an iterator over the node allocations, using only the @@ -1126,6 +1084,8 @@ func (s *StateStore) AllocsByNode(node string) ([]*structs.Allocation, error) { return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.Allocation for { raw := iter.Next() @@ -1138,7 +1098,7 @@ func (s *StateStore) AllocsByNode(node string) ([]*structs.Allocation, error) { } // AllocsByNode returns all the allocations by node and terminal status -func (s *StateStore) AllocsByNodeTerminal(node string, terminal bool) ([]*structs.Allocation, error) { +func (s *StateStore) AllocsByNodeTerminal(ws memdb.WatchSet, node string, terminal bool) ([]*structs.Allocation, error) { txn := s.db.Txn(false) // Get an iterator over the node allocations @@ -1147,6 +1107,8 @@ func (s *StateStore) AllocsByNodeTerminal(node string, terminal bool) ([]*struct return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.Allocation for { raw := iter.Next() @@ -1159,7 +1121,7 @@ func (s *StateStore) AllocsByNodeTerminal(node string, terminal bool) ([]*struct } // AllocsByJob returns all the allocations by job id -func (s *StateStore) AllocsByJob(jobID string, all bool) ([]*structs.Allocation, error) { +func (s *StateStore) AllocsByJob(ws memdb.WatchSet, jobID string, all bool) ([]*structs.Allocation, error) { txn := s.db.Txn(false) // Get the job @@ -1178,6 +1140,8 @@ func (s *StateStore) AllocsByJob(jobID string, all bool) ([]*structs.Allocation, return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.Allocation for { raw := iter.Next() @@ -1198,7 +1162,7 @@ func (s *StateStore) AllocsByJob(jobID string, all bool) ([]*structs.Allocation, } // AllocsByEval returns all the allocations by eval id -func (s *StateStore) AllocsByEval(evalID string) ([]*structs.Allocation, error) { +func (s *StateStore) AllocsByEval(ws memdb.WatchSet, evalID string) ([]*structs.Allocation, error) { txn := s.db.Txn(false) // Get an iterator over the eval allocations @@ -1207,6 +1171,8 @@ func (s *StateStore) AllocsByEval(evalID string) ([]*structs.Allocation, error) return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.Allocation for { raw := iter.Next() @@ -1219,7 +1185,7 @@ func (s *StateStore) AllocsByEval(evalID string) ([]*structs.Allocation, error) } // Allocs returns an iterator over all the evaluations -func (s *StateStore) Allocs() (memdb.ResultIterator, error) { +func (s *StateStore) Allocs(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) // Walk the entire table @@ -1227,6 +1193,9 @@ func (s *StateStore) Allocs() (memdb.ResultIterator, error) { if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } @@ -1275,14 +1244,16 @@ func (s *StateStore) DeleteVaultAccessors(index uint64, accessors []*structs.Vau } // VaultAccessor returns the given Vault accessor -func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, error) { +func (s *StateStore) VaultAccessor(ws memdb.WatchSet, accessor string) (*structs.VaultAccessor, error) { txn := s.db.Txn(false) - existing, err := txn.First("vault_accessors", "id", accessor) + watchCh, existing, err := txn.FirstWatch("vault_accessors", "id", accessor) if err != nil { return nil, fmt.Errorf("accessor lookup failed: %v", err) } + ws.Add(watchCh) + if existing != nil { return existing.(*structs.VaultAccessor), nil } @@ -1291,18 +1262,21 @@ func (s *StateStore) VaultAccessor(accessor string) (*structs.VaultAccessor, err } // VaultAccessors returns an iterator of Vault accessors. -func (s *StateStore) VaultAccessors() (memdb.ResultIterator, error) { +func (s *StateStore) VaultAccessors(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.Txn(false) iter, err := txn.Get("vault_accessors", "id") if err != nil { return nil, err } + + ws.Add(iter.WatchCh()) + return iter, nil } // VaultAccessorsByAlloc returns all the Vault accessors by alloc id -func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAccessor, error) { +func (s *StateStore) VaultAccessorsByAlloc(ws memdb.WatchSet, allocID string) ([]*structs.VaultAccessor, error) { txn := s.db.Txn(false) // Get an iterator over the accessors @@ -1311,6 +1285,8 @@ func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAcce return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.VaultAccessor for { raw := iter.Next() @@ -1323,7 +1299,7 @@ func (s *StateStore) VaultAccessorsByAlloc(allocID string) ([]*structs.VaultAcce } // VaultAccessorsByNode returns all the Vault accessors by node id -func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccessor, error) { +func (s *StateStore) VaultAccessorsByNode(ws memdb.WatchSet, nodeID string) ([]*structs.VaultAccessor, error) { txn := s.db.Txn(false) // Get an iterator over the accessors @@ -1332,6 +1308,8 @@ func (s *StateStore) VaultAccessorsByNode(nodeID string) ([]*structs.VaultAccess return nil, err } + ws.Add(iter.WatchCh()) + var out []*structs.VaultAccessor for { raw := iter.Next() @@ -1496,7 +1474,7 @@ func (s *StateStore) ReconcileJobSummaries(index uint64) error { // setJobStatuses is a helper for calling setJobStatus on multiple jobs by ID. // It takes a map of job IDs to an optional forceStatus string. It returns an // error if the job doesn't exist or setJobStatus fails. -func (s *StateStore) setJobStatuses(index uint64, watcher watch.Items, txn *memdb.Txn, +func (s *StateStore) setJobStatuses(index uint64, txn *memdb.Txn, jobs map[string]string, evalDelete bool) error { for job, forceStatus := range jobs { existing, err := txn.First("jobs", "id", job) @@ -1508,7 +1486,7 @@ func (s *StateStore) setJobStatuses(index uint64, watcher watch.Items, txn *memd continue } - if err := s.setJobStatus(index, watcher, txn, existing.(*structs.Job), evalDelete, forceStatus); err != nil { + if err := s.setJobStatus(index, txn, existing.(*structs.Job), evalDelete, forceStatus); err != nil { return err } } @@ -1521,7 +1499,7 @@ func (s *StateStore) setJobStatuses(index uint64, watcher watch.Items, txn *memd // called because an evaluation is being deleted (potentially because of garbage // collection). If forceStatus is non-empty, the job's status will be set to the // passed status. -func (s *StateStore) setJobStatus(index uint64, watcher watch.Items, txn *memdb.Txn, +func (s *StateStore) setJobStatus(index uint64, txn *memdb.Txn, job *structs.Job, evalDelete bool, forceStatus string) error { // Capture the current status so we can check if there is a change @@ -1545,10 +1523,6 @@ func (s *StateStore) setJobStatus(index uint64, watcher watch.Items, txn *memdb. return nil } - // The job has changed, so add to watcher. - watcher.Add(watch.Item{Table: "jobs"}) - watcher.Add(watch.Item{Job: job.ID}) - // Copy and update the existing job updated := job.Copy() updated.Status = newStatus @@ -1611,9 +1585,6 @@ func (s *StateStore) setJobStatus(index uint64, watcher watch.Items, txn *memdb. // Update the index pSummary.ModifyIndex = index - watcher.Add(watch.Item{Table: "job_summary"}) - watcher.Add(watch.Item{JobSummary: updated.ParentID}) - // Insert the summary if err := txn.Insert("job_summary", pSummary); err != nil { return fmt.Errorf("job summary insert failed: %v", err) @@ -1673,7 +1644,7 @@ func (s *StateStore) getJobStatus(txn *memdb.Txn, job *structs.Job, evalDelete b // updateSummaryWithJob creates or updates job summaries when new jobs are // upserted or existing ones are updated func (s *StateStore) updateSummaryWithJob(index uint64, job *structs.Job, - watcher watch.Items, txn *memdb.Txn) error { + txn *memdb.Txn) error { // Update the job summary summaryRaw, err := txn.First("job_summary", "id", job.ID) @@ -1709,12 +1680,9 @@ func (s *StateStore) updateSummaryWithJob(index uint64, job *structs.Job, } } - // The job summary has changed, so add to watcher and update the modify - // index. + // The job summary has changed, so update the modify index. if hasSummaryChanged { summary.ModifyIndex = index - watcher.Add(watch.Item{Table: "job_summary"}) - watcher.Add(watch.Item{JobSummary: job.ID}) // Update the indexes table for job summary if err := txn.Insert("index", &IndexEntry{"job_summary", index}); err != nil { @@ -1731,7 +1699,7 @@ func (s *StateStore) updateSummaryWithJob(index uint64, job *structs.Job, // updateSummaryWithAlloc updates the job summary when allocations are updated // or inserted func (s *StateStore) updateSummaryWithAlloc(index uint64, alloc *structs.Allocation, - existingAlloc *structs.Allocation, watcher watch.Items, txn *memdb.Txn) error { + existingAlloc *structs.Allocation, txn *memdb.Txn) error { // We don't have to update the summary if the job is missing if alloc.Job == nil { @@ -1825,8 +1793,6 @@ func (s *StateStore) updateSummaryWithAlloc(index uint64, alloc *structs.Allocat if summaryChanged { jobSummary.ModifyIndex = index - watcher.Add(watch.Item{Table: "job_summary"}) - watcher.Add(watch.Item{JobSummary: alloc.JobID}) // Update the indexes table for job summary if err := txn.Insert("index", &IndexEntry{"job_summary", index}); err != nil { @@ -1869,9 +1835,7 @@ type StateSnapshot struct { // restoring state by only using a single large transaction // instead of thousands of sub transactions type StateRestore struct { - txn *memdb.Txn - watch *stateWatch - items watch.Items + txn *memdb.Txn } // Abort is used to abort the restore operation @@ -1881,14 +1845,11 @@ func (s *StateRestore) Abort() { // Commit is used to commit the restore operation func (s *StateRestore) Commit() { - s.txn.Defer(func() { s.watch.notify(s.items) }) s.txn.Commit() } // NodeRestore is used to restore a node func (r *StateRestore) NodeRestore(node *structs.Node) error { - r.items.Add(watch.Item{Table: "nodes"}) - r.items.Add(watch.Item{Node: node.ID}) if err := r.txn.Insert("nodes", node); err != nil { return fmt.Errorf("node insert failed: %v", err) } @@ -1897,9 +1858,6 @@ func (r *StateRestore) NodeRestore(node *structs.Node) error { // JobRestore is used to restore a job func (r *StateRestore) JobRestore(job *structs.Job) error { - r.items.Add(watch.Item{Table: "jobs"}) - r.items.Add(watch.Item{Job: job.ID}) - // Create the EphemeralDisk if it's nil by adding up DiskMB from task resources. // COMPAT 0.4.1 -> 0.5 r.addEphemeralDiskToTaskGroups(job) @@ -1912,9 +1870,6 @@ func (r *StateRestore) JobRestore(job *structs.Job) error { // EvalRestore is used to restore an evaluation func (r *StateRestore) EvalRestore(eval *structs.Evaluation) error { - r.items.Add(watch.Item{Table: "evals"}) - r.items.Add(watch.Item{Eval: eval.ID}) - r.items.Add(watch.Item{EvalJob: eval.JobID}) if err := r.txn.Insert("evals", eval); err != nil { return fmt.Errorf("eval insert failed: %v", err) } @@ -1923,12 +1878,6 @@ func (r *StateRestore) EvalRestore(eval *structs.Evaluation) error { // AllocRestore is used to restore an allocation func (r *StateRestore) AllocRestore(alloc *structs.Allocation) error { - r.items.Add(watch.Item{Table: "allocs"}) - r.items.Add(watch.Item{Alloc: alloc.ID}) - r.items.Add(watch.Item{AllocEval: alloc.EvalID}) - r.items.Add(watch.Item{AllocJob: alloc.JobID}) - r.items.Add(watch.Item{AllocNode: alloc.NodeID}) - // Set the shared resources if it's not present // COMPAT 0.4.1 -> 0.5 if alloc.SharedResources == nil { @@ -1958,8 +1907,6 @@ func (r *StateRestore) IndexRestore(idx *IndexEntry) error { // PeriodicLaunchRestore is used to restore a periodic launch. func (r *StateRestore) PeriodicLaunchRestore(launch *structs.PeriodicLaunch) error { - r.items.Add(watch.Item{Table: "periodic_launch"}) - r.items.Add(watch.Item{Job: launch.ID}) if err := r.txn.Insert("periodic_launch", launch); err != nil { return fmt.Errorf("periodic launch insert failed: %v", err) } @@ -1968,8 +1915,6 @@ func (r *StateRestore) PeriodicLaunchRestore(launch *structs.PeriodicLaunch) err // JobSummaryRestore is used to restore a job summary func (r *StateRestore) JobSummaryRestore(jobSummary *structs.JobSummary) error { - r.items.Add(watch.Item{Table: "job_summary"}) - r.items.Add(watch.Item{JobSummary: jobSummary.JobID}) if err := r.txn.Insert("job_summary", jobSummary); err != nil { return fmt.Errorf("job summary insert failed: %v", err) } @@ -2002,59 +1947,3 @@ func (r *StateRestore) addEphemeralDiskToTaskGroups(job *structs.Job) { } } } - -// stateWatch holds shared state for watching updates. This is -// outside of StateStore so it can be shared with snapshots. -type stateWatch struct { - items map[watch.Item]*NotifyGroup - l sync.Mutex -} - -// newStateWatch creates a new stateWatch for change notification. -func newStateWatch() *stateWatch { - return &stateWatch{ - items: make(map[watch.Item]*NotifyGroup), - } -} - -// watch subscribes a channel to the given watch items. -func (w *stateWatch) watch(items watch.Items, ch chan struct{}) { - w.l.Lock() - defer w.l.Unlock() - - for item, _ := range items { - grp, ok := w.items[item] - if !ok { - grp = new(NotifyGroup) - w.items[item] = grp - } - grp.Wait(ch) - } -} - -// stopWatch unsubscribes a channel from the given watch items. -func (w *stateWatch) stopWatch(items watch.Items, ch chan struct{}) { - w.l.Lock() - defer w.l.Unlock() - - for item, _ := range items { - if grp, ok := w.items[item]; ok { - grp.Clear(ch) - if grp.Empty() { - delete(w.items, item) - } - } - } -} - -// notify is used to fire notifications on the given watch items. -func (w *stateWatch) notify(items watch.Items) { - w.l.Lock() - defer w.l.Unlock() - - for wi, _ := range items { - if grp, ok := w.items[wi]; ok { - grp.Notify() - } - } -} diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index b73f24e4f..5162e8f24 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -7,10 +7,9 @@ import ( "testing" "time" - "github.com/hashicorp/go-memdb" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/nomad/watch" ) func testStateStore(t *testing.T) *StateStore { @@ -28,17 +27,23 @@ func TestStateStore_UpsertNode_Node(t *testing.T) { state := testStateStore(t) node := mock.Node() - notify := setupNotifyTest( - state, - watch.Item{Table: "nodes"}, - watch.Item{Node: node.ID}) + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + _, err := state.NodeByID(ws, node.ID) + if err != nil { + t.Fatalf("bad: %v", err) + } - err := state.UpsertNode(1000, node) + err = state.UpsertNode(1000, node) if err != nil { t.Fatalf("err: %v", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - out, err := state.NodeByID(node.ID) + ws = memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -55,29 +60,37 @@ func TestStateStore_UpsertNode_Node(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteNode_Node(t *testing.T) { state := testStateStore(t) node := mock.Node() - notify := setupNotifyTest( - state, - watch.Item{Table: "nodes"}, - watch.Item{Node: node.ID}) - err := state.UpsertNode(1000, node) if err != nil { t.Fatalf("err: %v", err) } + // Create a watchset so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.NodeByID(ws, node.ID); err != nil { + t.Fatalf("bad: %v", err) + } + err = state.DeleteNode(1001, node.ID) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.NodeByID(node.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -94,29 +107,37 @@ func TestStateStore_DeleteNode_Node(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateNodeStatus_Node(t *testing.T) { state := testStateStore(t) node := mock.Node() - notify := setupNotifyTest( - state, - watch.Item{Table: "nodes"}, - watch.Item{Node: node.ID}) - err := state.UpsertNode(800, node) if err != nil { t.Fatalf("err: %v", err) } + // Create a watchset so we can test that update node status fires the watch + ws := memdb.NewWatchSet() + if _, err := state.NodeByID(ws, node.ID); err != nil { + t.Fatalf("bad: %v", err) + } + err = state.UpdateNodeStatus(801, node.ID, structs.NodeStatusReady) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.NodeByID(node.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -136,29 +157,37 @@ func TestStateStore_UpdateNodeStatus_Node(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateNodeDrain_Node(t *testing.T) { state := testStateStore(t) node := mock.Node() - notify := setupNotifyTest( - state, - watch.Item{Table: "nodes"}, - watch.Item{Node: node.ID}) - err := state.UpsertNode(1000, node) if err != nil { t.Fatalf("err: %v", err) } + // Create a watchset so we can test that update node drain fires the watch + ws := memdb.NewWatchSet() + if _, err := state.NodeByID(ws, node.ID); err != nil { + t.Fatalf("bad: %v", err) + } + err = state.UpdateNodeDrain(1001, node.ID, true) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.NodeByID(node.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -178,7 +207,9 @@ func TestStateStore_UpdateNodeDrain_Node(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Nodes(t *testing.T) { @@ -195,9 +226,11 @@ func TestStateStore_Nodes(t *testing.T) { } } - iter, err := state.Nodes() + // Create a watchset so we can test that getters don't cause it to fire + ws := memdb.NewWatchSet() + iter, err := state.Nodes(ws) if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("bad: %v", err) } var out []*structs.Node @@ -215,6 +248,10 @@ func TestStateStore_Nodes(t *testing.T) { if !reflect.DeepEqual(nodes, out) { t.Fatalf("bad: %#v %#v", nodes, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_NodesByIDPrefix(t *testing.T) { @@ -227,7 +264,9 @@ func TestStateStore_NodesByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err := state.NodesByIDPrefix(node.ID) + // Create a watchset so we can test that getters don't cause it to fire + ws := memdb.NewWatchSet() + iter, err := state.NodesByIDPrefix(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -250,7 +289,11 @@ func TestStateStore_NodesByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.NodesByIDPrefix("11") + if watchFired(ws) { + t.Fatalf("bad") + } + + iter, err = state.NodesByIDPrefix(ws, "11") if err != nil { t.Fatalf("err: %v", err) } @@ -267,7 +310,12 @@ func TestStateStore_NodesByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.NodesByIDPrefix("11") + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + iter, err = state.NodesByIDPrefix(ws, "11") if err != nil { t.Fatalf("err: %v", err) } @@ -277,7 +325,7 @@ func TestStateStore_NodesByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.NodesByIDPrefix("1111") + iter, err = state.NodesByIDPrefix(ws, "1111") if err != nil { t.Fatalf("err: %v", err) } @@ -286,17 +334,16 @@ func TestStateStore_NodesByIDPrefix(t *testing.T) { if len(nodes) != 1 { t.Fatalf("err: %v", err) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreNode(t *testing.T) { state := testStateStore(t) node := mock.Node() - notify := setupNotifyTest( - state, - watch.Item{Table: "nodes"}, - watch.Item{Node: node.ID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -308,7 +355,8 @@ func TestStateStore_RestoreNode(t *testing.T) { } restore.Commit() - out, err := state.NodeByID(node.ID) + ws := memdb.NewWatchSet() + out, err := state.NodeByID(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -316,25 +364,28 @@ func TestStateStore_RestoreNode(t *testing.T) { if !reflect.DeepEqual(out, node) { t.Fatalf("Bad: %#v %#v", out, node) } - - notify.verify(t) } func TestStateStore_UpsertJob_Job(t *testing.T) { state := testStateStore(t) job := mock.Job() - notify := setupNotifyTest( - state, - watch.Item{Table: "jobs"}, - watch.Item{Job: job.ID}) - - err := state.UpsertJob(1000, job) + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + _, err := state.JobByID(ws, job.ID) if err != nil { - t.Fatalf("err: %v", err) + t.Fatalf("bad: %v", err) } - out, err := state.JobByID(job.ID) + if err := state.UpsertJob(1000, job); err != nil { + t.Fatalf("err: %v", err) + } + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -351,7 +402,7 @@ func TestStateStore_UpsertJob_Job(t *testing.T) { t.Fatalf("bad: %d", index) } - summary, err := state.JobSummaryByID(job.ID) + summary, err := state.JobSummaryByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -365,20 +416,23 @@ func TestStateStore_UpsertJob_Job(t *testing.T) { if !ok { t.Fatalf("nil summary for task group") } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateUpsertJob_Job(t *testing.T) { state := testStateStore(t) job := mock.Job() - notify := setupNotifyTest( - state, - watch.Item{Table: "jobs"}, - watch.Item{Job: job.ID}) - - err := state.UpsertJob(1000, job) + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + _, err := state.JobByID(ws, job.ID) if err != nil { + t.Fatalf("bad: %v", err) + } + + if err := state.UpsertJob(1000, job); err != nil { t.Fatalf("err: %v", err) } @@ -389,7 +443,12 @@ func TestStateStore_UpdateUpsertJob_Job(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.JobByID(job.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -415,7 +474,7 @@ func TestStateStore_UpdateUpsertJob_Job(t *testing.T) { // Test that the job summary remains the same if the job is updated but // count remains same - summary, err := state.JobSummaryByID(job.ID) + summary, err := state.JobSummaryByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -430,7 +489,9 @@ func TestStateStore_UpdateUpsertJob_Job(t *testing.T) { t.Fatalf("nil summary for task group") } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } // This test ensures that UpsertJob creates the EphemeralDisk is a job doesn't have @@ -449,7 +510,8 @@ func TestStateStore_UpsertJob_NoEphemeralDisk(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -471,25 +533,26 @@ func TestStateStore_UpsertJob_NoEphemeralDisk(t *testing.T) { // updated. func TestStateStore_UpsertJob_ChildJob(t *testing.T) { state := testStateStore(t) + + // Create a watchset so we can test that upsert fires the watch parent := mock.Job() + ws := memdb.NewWatchSet() + _, err := state.JobByID(ws, parent.ID) + if err != nil { + t.Fatalf("bad: %v", err) + } + if err := state.UpsertJob(1000, parent); err != nil { t.Fatalf("err: %v", err) } child := mock.Job() child.ParentID = parent.ID - - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}) - - err := state.UpsertJob(1001, child) - if err != nil { + if err := state.UpsertJob(1001, child); err != nil { t.Fatalf("err: %v", err) } - summary, err := state.JobSummaryByID(parent.ID) + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -505,29 +568,37 @@ func TestStateStore_UpsertJob_ChildJob(t *testing.T) { if summary.Children.Pending != 1 || summary.Children.Running != 0 || summary.Children.Dead != 0 { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + if !watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteJob_Job(t *testing.T) { state := testStateStore(t) job := mock.Job() - notify := setupNotifyTest( - state, - watch.Item{Table: "jobs"}, - watch.Item{Job: job.ID}) - err := state.UpsertJob(1000, job) if err != nil { t.Fatalf("err: %v", err) } + // Create a watchset so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.JobByID(ws, job.ID); err != nil { + t.Fatalf("bad: %v", err) + } + err = state.DeleteJob(1001, job.ID) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.JobByID(job.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -544,7 +615,7 @@ func TestStateStore_DeleteJob_Job(t *testing.T) { t.Fatalf("bad: %d", index) } - summary, err := state.JobSummaryByID(job.ID) + summary, err := state.JobSummaryByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -552,7 +623,9 @@ func TestStateStore_DeleteJob_Job(t *testing.T) { t.Fatalf("expected summary to be nil, but got: %v", summary) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteJob_ChildJob(t *testing.T) { @@ -570,17 +643,22 @@ func TestStateStore_DeleteJob_ChildJob(t *testing.T) { t.Fatalf("err: %v", err) } - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}) + // Create a watchset so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, parent.ID); err != nil { + t.Fatalf("bad: %v", err) + } err := state.DeleteJob(1001, child.ID) if err != nil { t.Fatalf("err: %v", err) } + if !watchFired(ws) { + t.Fatalf("bad") + } - summary, err := state.JobSummaryByID(parent.ID) + ws = memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -596,7 +674,9 @@ func TestStateStore_DeleteJob_ChildJob(t *testing.T) { if summary.Children.Pending != 0 || summary.Children.Running != 0 || summary.Children.Dead != 1 { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Jobs(t *testing.T) { @@ -613,7 +693,8 @@ func TestStateStore_Jobs(t *testing.T) { } } - iter, err := state.Jobs() + ws := memdb.NewWatchSet() + iter, err := state.Jobs(ws) if err != nil { t.Fatalf("err: %v", err) } @@ -633,6 +714,9 @@ func TestStateStore_Jobs(t *testing.T) { if !reflect.DeepEqual(jobs, out) { t.Fatalf("bad: %#v %#v", jobs, out) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_JobsByIDPrefix(t *testing.T) { @@ -645,7 +729,8 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err := state.JobsByIDPrefix(job.ID) + ws := memdb.NewWatchSet() + iter, err := state.JobsByIDPrefix(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -667,7 +752,7 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.JobsByIDPrefix("re") + iter, err = state.JobsByIDPrefix(ws, "re") if err != nil { t.Fatalf("err: %v", err) } @@ -676,6 +761,9 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { if len(jobs) != 1 { t.Fatalf("err: %v", err) } + if watchFired(ws) { + t.Fatalf("bad") + } job = mock.Job() job.ID = "riak" @@ -684,7 +772,12 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.JobsByIDPrefix("r") + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + iter, err = state.JobsByIDPrefix(ws, "r") if err != nil { t.Fatalf("err: %v", err) } @@ -694,7 +787,7 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err = state.JobsByIDPrefix("ri") + iter, err = state.JobsByIDPrefix(ws, "ri") if err != nil { t.Fatalf("err: %v", err) } @@ -703,6 +796,9 @@ func TestStateStore_JobsByIDPrefix(t *testing.T) { if len(jobs) != 1 { t.Fatalf("err: %v", err) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_JobsByPeriodic(t *testing.T) { @@ -729,7 +825,8 @@ func TestStateStore_JobsByPeriodic(t *testing.T) { } } - iter, err := state.JobsByPeriodic(true) + ws := memdb.NewWatchSet() + iter, err := state.JobsByPeriodic(ws, true) if err != nil { t.Fatalf("err: %v", err) } @@ -743,7 +840,7 @@ func TestStateStore_JobsByPeriodic(t *testing.T) { outPeriodic = append(outPeriodic, raw.(*structs.Job)) } - iter, err = state.JobsByPeriodic(false) + iter, err = state.JobsByPeriodic(ws, false) if err != nil { t.Fatalf("err: %v", err) } @@ -769,6 +866,9 @@ func TestStateStore_JobsByPeriodic(t *testing.T) { if !reflect.DeepEqual(nonPeriodic, outNonPeriodic) { t.Fatalf("bad: %#v %#v", nonPeriodic, outNonPeriodic) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_JobsByScheduler(t *testing.T) { @@ -796,7 +896,8 @@ func TestStateStore_JobsByScheduler(t *testing.T) { } } - iter, err := state.JobsByScheduler("service") + ws := memdb.NewWatchSet() + iter, err := state.JobsByScheduler(ws, "service") if err != nil { t.Fatalf("err: %v", err) } @@ -810,7 +911,7 @@ func TestStateStore_JobsByScheduler(t *testing.T) { outService = append(outService, raw.(*structs.Job)) } - iter, err = state.JobsByScheduler("system") + iter, err = state.JobsByScheduler(ws, "system") if err != nil { t.Fatalf("err: %v", err) } @@ -836,6 +937,9 @@ func TestStateStore_JobsByScheduler(t *testing.T) { if !reflect.DeepEqual(sysJobs, outSystem) { t.Fatalf("bad: %#v %#v", sysJobs, outSystem) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_JobsByGC(t *testing.T) { @@ -866,7 +970,8 @@ func TestStateStore_JobsByGC(t *testing.T) { } } - iter, err := state.JobsByGC(true) + ws := memdb.NewWatchSet() + iter, err := state.JobsByGC(ws, true) if err != nil { t.Fatalf("err: %v", err) } @@ -876,7 +981,7 @@ func TestStateStore_JobsByGC(t *testing.T) { outGc = append(outGc, i.(*structs.Job)) } - iter, err = state.JobsByGC(false) + iter, err = state.JobsByGC(ws, false) if err != nil { t.Fatalf("err: %v", err) } @@ -898,17 +1003,15 @@ func TestStateStore_JobsByGC(t *testing.T) { if !reflect.DeepEqual(nonGc, outNonGc) { t.Fatalf("bad: %#v %#v", nonGc, outNonGc) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreJob(t *testing.T) { state := testStateStore(t) job := mock.Job() - notify := setupNotifyTest( - state, - watch.Item{Table: "jobs"}, - watch.Item{Job: job.ID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -920,7 +1023,8 @@ func TestStateStore_RestoreJob(t *testing.T) { } restore.Commit() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -928,8 +1032,6 @@ func TestStateStore_RestoreJob(t *testing.T) { if !reflect.DeepEqual(out, job) { t.Fatalf("Bad: %#v %#v", out, job) } - - notify.verify(t) } // This test ensures that the state restore creates the EphemeralDisk for a job if @@ -943,11 +1045,6 @@ func TestStateStore_Jobs_NoEphemeralDisk(t *testing.T) { job.TaskGroups[0].EphemeralDisk = nil job.TaskGroups[0].Tasks[0].Resources.DiskMB = 150 - notify := setupNotifyTest( - state, - watch.Item{Table: "jobs"}, - watch.Item{Job: job.ID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -959,7 +1056,8 @@ func TestStateStore_Jobs_NoEphemeralDisk(t *testing.T) { } restore.Commit() - out, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -973,8 +1071,6 @@ func TestStateStore_Jobs_NoEphemeralDisk(t *testing.T) { if !reflect.DeepEqual(out, expected) { t.Fatalf("Bad: %#v %#v", out, job) } - - notify.verify(t) } func TestStateStore_UpsertPeriodicLaunch(t *testing.T) { @@ -982,17 +1078,23 @@ func TestStateStore_UpsertPeriodicLaunch(t *testing.T) { job := mock.Job() launch := &structs.PeriodicLaunch{ID: job.ID, Launch: time.Now()} - notify := setupNotifyTest( - state, - watch.Item{Table: "periodic_launch"}, - watch.Item{Job: job.ID}) + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + if _, err := state.PeriodicLaunchByID(ws, launch.ID); err != nil { + t.Fatalf("bad: %v", err) + } err := state.UpsertPeriodicLaunch(1000, launch) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.PeriodicLaunchByID(job.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.PeriodicLaunchByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1015,7 +1117,9 @@ func TestStateStore_UpsertPeriodicLaunch(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateUpsertPeriodicLaunch(t *testing.T) { @@ -1023,16 +1127,17 @@ func TestStateStore_UpdateUpsertPeriodicLaunch(t *testing.T) { job := mock.Job() launch := &structs.PeriodicLaunch{ID: job.ID, Launch: time.Now()} - notify := setupNotifyTest( - state, - watch.Item{Table: "periodic_launch"}, - watch.Item{Job: job.ID}) - err := state.UpsertPeriodicLaunch(1000, launch) if err != nil { t.Fatalf("err: %v", err) } + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + if _, err := state.PeriodicLaunchByID(ws, launch.ID); err != nil { + t.Fatalf("bad: %v", err) + } + launch2 := &structs.PeriodicLaunch{ ID: job.ID, Launch: launch.Launch.Add(1 * time.Second), @@ -1042,7 +1147,12 @@ func TestStateStore_UpdateUpsertPeriodicLaunch(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.PeriodicLaunchByID(job.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.PeriodicLaunchByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1065,7 +1175,9 @@ func TestStateStore_UpdateUpsertPeriodicLaunch(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeletePeriodicLaunch(t *testing.T) { @@ -1073,22 +1185,28 @@ func TestStateStore_DeletePeriodicLaunch(t *testing.T) { job := mock.Job() launch := &structs.PeriodicLaunch{ID: job.ID, Launch: time.Now()} - notify := setupNotifyTest( - state, - watch.Item{Table: "periodic_launch"}, - watch.Item{Job: job.ID}) - err := state.UpsertPeriodicLaunch(1000, launch) if err != nil { t.Fatalf("err: %v", err) } - err = state.DeletePeriodicLaunch(1001, job.ID) + // Create a watchset so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.PeriodicLaunchByID(ws, launch.ID); err != nil { + t.Fatalf("bad: %v", err) + } + + err = state.DeletePeriodicLaunch(1001, launch.ID) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.PeriodicLaunchByID(job.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.PeriodicLaunchByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1105,7 +1223,9 @@ func TestStateStore_DeletePeriodicLaunch(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_PeriodicLaunches(t *testing.T) { @@ -1123,7 +1243,8 @@ func TestStateStore_PeriodicLaunches(t *testing.T) { } } - iter, err := state.PeriodicLaunches() + ws := memdb.NewWatchSet() + iter, err := state.PeriodicLaunches(ws) if err != nil { t.Fatalf("err: %v", err) } @@ -1158,6 +1279,10 @@ func TestStateStore_PeriodicLaunches(t *testing.T) { if len(out) != 0 { t.Fatalf("leftover: %#v", out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestorePeriodicLaunch(t *testing.T) { @@ -1165,11 +1290,6 @@ func TestStateStore_RestorePeriodicLaunch(t *testing.T) { job := mock.Job() launch := &structs.PeriodicLaunch{ID: job.ID, Launch: time.Now()} - notify := setupNotifyTest( - state, - watch.Item{Table: "periodic_launch"}, - watch.Item{Job: job.ID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -1181,7 +1301,8 @@ func TestStateStore_RestorePeriodicLaunch(t *testing.T) { } restore.Commit() - out, err := state.PeriodicLaunchByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.PeriodicLaunchByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1190,7 +1311,9 @@ func TestStateStore_RestorePeriodicLaunch(t *testing.T) { t.Fatalf("Bad: %#v %#v", out, job) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreJobSummary(t *testing.T) { @@ -1215,7 +1338,8 @@ func TestStateStore_RestoreJobSummary(t *testing.T) { } restore.Commit() - out, err := state.JobSummaryByID(job.ID) + ws := memdb.NewWatchSet() + out, err := state.JobSummaryByID(ws, job.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1309,18 +1433,23 @@ func TestStateStore_UpsertEvals_Eval(t *testing.T) { state := testStateStore(t) eval := mock.Eval() - notify := setupNotifyTest( - state, - watch.Item{Table: "evals"}, - watch.Item{Eval: eval.ID}, - watch.Item{EvalJob: eval.JobID}) + // Create a watchset so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + if _, err := state.EvalByID(ws, eval.ID); err != nil { + t.Fatalf("bad: %v", err) + } err := state.UpsertEvals(1000, []*structs.Evaluation{eval}) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.EvalByID(eval.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1337,7 +1466,9 @@ func TestStateStore_UpsertEvals_Eval(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpsertEvals_CancelBlocked(t *testing.T) { @@ -1361,19 +1492,26 @@ func TestStateStore_UpsertEvals_CancelBlocked(t *testing.T) { eval.JobID = j eval.Status = structs.EvalStatusComplete - notify := setupNotifyTest( - state, - watch.Item{Table: "evals"}, - watch.Item{Eval: b1.ID}, - watch.Item{Eval: b2.ID}, - watch.Item{Eval: eval.ID}, - watch.Item{EvalJob: eval.JobID}) + // Create a watchset so we can test that the upsert of the complete eval + // fires the watch + ws := memdb.NewWatchSet() + if _, err := state.EvalByID(ws, b1.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalByID(ws, b2.ID); err != nil { + t.Fatalf("bad: %v", err) + } if err := state.UpsertEvals(1000, []*structs.Evaluation{eval}); err != nil { t.Fatalf("err: %v", err) } - out, err := state.EvalByID(eval.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1391,12 +1529,12 @@ func TestStateStore_UpsertEvals_CancelBlocked(t *testing.T) { } // Get b1/b2 and check they are cancelled - out1, err := state.EvalByID(b1.ID) + out1, err := state.EvalByID(ws, b1.ID) if err != nil { t.Fatalf("err: %v", err) } - out2, err := state.EvalByID(b2.ID) + out2, err := state.EvalByID(ws, b2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1405,7 +1543,9 @@ func TestStateStore_UpsertEvals_CancelBlocked(t *testing.T) { t.Fatalf("bad: %#v %#v", out1, out2) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Update_UpsertEvals_Eval(t *testing.T) { @@ -1417,11 +1557,16 @@ func TestStateStore_Update_UpsertEvals_Eval(t *testing.T) { t.Fatalf("err: %v", err) } - notify := setupNotifyTest( - state, - watch.Item{Table: "evals"}, - watch.Item{Eval: eval.ID}, - watch.Item{EvalJob: eval.JobID}) + // Create a watchset so we can test that delete fires the watch + ws := memdb.NewWatchSet() + ws2 := memdb.NewWatchSet() + if _, err := state.EvalByID(ws, eval.ID); err != nil { + t.Fatalf("bad: %v", err) + } + + if _, err := state.EvalsByJob(ws2, eval.JobID); err != nil { + t.Fatalf("bad: %v", err) + } eval2 := mock.Eval() eval2.ID = eval.ID @@ -1431,7 +1576,15 @@ func TestStateStore_Update_UpsertEvals_Eval(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.EvalByID(eval.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + if !watchFired(ws2) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1455,7 +1608,9 @@ func TestStateStore_Update_UpsertEvals_Eval(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpsertEvals_Eval_ChildJob(t *testing.T) { @@ -1477,20 +1632,37 @@ func TestStateStore_UpsertEvals_Eval_ChildJob(t *testing.T) { eval.Status = structs.EvalStatusComplete eval.JobID = child.ID - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}, - watch.Item{Table: "evals"}, - watch.Item{Eval: eval.ID}, - watch.Item{EvalJob: eval.JobID}) + // Create watchsets so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + ws2 := memdb.NewWatchSet() + ws3 := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, parent.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalByID(ws2, eval.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalsByJob(ws3, eval.JobID); err != nil { + t.Fatalf("bad: %v", err) + } err := state.UpsertEvals(1000, []*structs.Evaluation{eval}) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.EvalByID(eval.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + if !watchFired(ws2) { + t.Fatalf("bad") + } + if !watchFired(ws3) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1507,7 +1679,7 @@ func TestStateStore_UpsertEvals_Eval_ChildJob(t *testing.T) { t.Fatalf("bad: %d", index) } - summary, err := state.JobSummaryByID(parent.ID) + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1524,7 +1696,9 @@ func TestStateStore_UpsertEvals_Eval_ChildJob(t *testing.T) { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteEval_Eval(t *testing.T) { @@ -1534,22 +1708,47 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { alloc1 := mock.Alloc() alloc2 := mock.Alloc() - notify := setupNotifyTest( - state, - watch.Item{Table: "evals"}, - watch.Item{Table: "allocs"}, - watch.Item{Eval: eval1.ID}, - watch.Item{Eval: eval2.ID}, - watch.Item{EvalJob: eval1.JobID}, - watch.Item{EvalJob: eval2.JobID}, - watch.Item{Alloc: alloc1.ID}, - watch.Item{Alloc: alloc2.ID}, - watch.Item{AllocEval: alloc1.EvalID}, - watch.Item{AllocEval: alloc2.EvalID}, - watch.Item{AllocJob: alloc1.JobID}, - watch.Item{AllocJob: alloc2.JobID}, - watch.Item{AllocNode: alloc1.NodeID}, - watch.Item{AllocNode: alloc2.NodeID}) + // Create watchsets so we can test that upsert fires the watch + watches := make([]memdb.WatchSet, 12) + for i := 0; i < 12; i++ { + watches[i] = memdb.NewWatchSet() + } + if _, err := state.EvalByID(watches[0], eval1.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalByID(watches[1], eval2.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalsByJob(watches[2], eval1.JobID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.EvalsByJob(watches[3], eval2.JobID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocByID(watches[4], alloc1.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocByID(watches[5], alloc2.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[6], alloc1.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[7], alloc2.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[8], alloc1.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[9], alloc2.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[10], alloc1.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[11], alloc2.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } state.UpsertJobSummary(900, mock.JobSummary(eval1.JobID)) state.UpsertJobSummary(901, mock.JobSummary(eval2.JobID)) @@ -1570,7 +1769,14 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.EvalByID(eval1.ID) + for i, ws := range watches { + if !watchFired(ws) { + t.Fatalf("bad %d", i) + } + } + + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval1.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1579,7 +1785,7 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { t.Fatalf("bad: %#v %#v", eval1, out) } - out, err = state.EvalByID(eval2.ID) + out, err = state.EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1588,7 +1794,7 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { t.Fatalf("bad: %#v %#v", eval1, out) } - outA, err := state.AllocByID(alloc1.ID) + outA, err := state.AllocByID(ws, alloc1.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1597,7 +1803,7 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { t.Fatalf("bad: %#v %#v", alloc1, outA) } - outA, err = state.AllocByID(alloc2.ID) + outA, err = state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1622,7 +1828,9 @@ func TestStateStore_DeleteEval_Eval(t *testing.T) { t.Fatalf("bad: %d", index) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteEval_ChildJob(t *testing.T) { @@ -1655,17 +1863,23 @@ func TestStateStore_DeleteEval_ChildJob(t *testing.T) { t.Fatalf("err: %v", err) } - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}) + // Create watchsets so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, parent.ID); err != nil { + t.Fatalf("bad: %v", err) + } err = state.DeleteEval(1002, []string{eval1.ID}, []string{alloc1.ID}) if err != nil { t.Fatalf("err: %v", err) } - summary, err := state.JobSummaryByID(parent.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1682,7 +1896,9 @@ func TestStateStore_DeleteEval_ChildJob(t *testing.T) { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_EvalsByJob(t *testing.T) { @@ -1703,7 +1919,8 @@ func TestStateStore_EvalsByJob(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.EvalsByJob(eval1.JobID) + ws := memdb.NewWatchSet() + out, err := state.EvalsByJob(ws, eval1.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -1714,6 +1931,10 @@ func TestStateStore_EvalsByJob(t *testing.T) { if !reflect.DeepEqual(evals, out) { t.Fatalf("bad: %#v %#v", evals, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Evals(t *testing.T) { @@ -1730,7 +1951,8 @@ func TestStateStore_Evals(t *testing.T) { } } - iter, err := state.Evals() + ws := memdb.NewWatchSet() + iter, err := state.Evals(ws) if err != nil { t.Fatalf("err: %v", err) } @@ -1750,6 +1972,10 @@ func TestStateStore_Evals(t *testing.T) { if !reflect.DeepEqual(evals, out) { t.Fatalf("bad: %#v %#v", evals, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_EvalsByIDPrefix(t *testing.T) { @@ -1778,7 +2004,8 @@ func TestStateStore_EvalsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err := state.EvalsByIDPrefix("aaaa") + ws := memdb.NewWatchSet() + iter, err := state.EvalsByIDPrefix(ws, "aaaa") if err != nil { t.Fatalf("err: %v", err) } @@ -1808,7 +2035,7 @@ func TestStateStore_EvalsByIDPrefix(t *testing.T) { } } - iter, err = state.EvalsByIDPrefix("b-a7bfb") + iter, err = state.EvalsByIDPrefix(ws, "b-a7bfb") if err != nil { t.Fatalf("err: %v", err) } @@ -1818,17 +2045,15 @@ func TestStateStore_EvalsByIDPrefix(t *testing.T) { t.Fatalf("bad: unexpected zero evaluations, got: %#v", out) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreEval(t *testing.T) { state := testStateStore(t) eval := mock.Eval() - notify := setupNotifyTest( - state, - watch.Item{Table: "evals"}, - watch.Item{Eval: eval.ID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -1840,7 +2065,8 @@ func TestStateStore_RestoreEval(t *testing.T) { } restore.Commit() - out, err := state.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + out, err := state.EvalByID(ws, eval.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1848,13 +2074,10 @@ func TestStateStore_RestoreEval(t *testing.T) { if !reflect.DeepEqual(out, eval) { t.Fatalf("Bad: %#v %#v", out, eval) } - - notify.verify(t) } func TestStateStore_UpdateAllocsFromClient(t *testing.T) { state := testStateStore(t) - parent := mock.Job() if err := state.UpsertJob(998, parent); err != nil { t.Fatalf("err: %v", err) @@ -1870,31 +2093,13 @@ func TestStateStore_UpdateAllocsFromClient(t *testing.T) { alloc.JobID = child.ID alloc.Job = child - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}) - err := state.UpsertAllocs(1000, []*structs.Allocation{alloc}) if err != nil { t.Fatalf("err: %v", err) } - // Create the delta updates - ts := map[string]*structs.TaskState{"web": &structs.TaskState{State: structs.TaskStatePending}} - update := &structs.Allocation{ - ID: alloc.ID, - ClientStatus: structs.AllocClientStatusRunning, - TaskStates: ts, - JobID: alloc.JobID, - TaskGroup: alloc.TaskGroup, - } - err = state.UpdateAllocsFromClient(1001, []*structs.Allocation{update}) - if err != nil { - t.Fatalf("err: %v", err) - } - - summary, err := state.JobSummaryByID(parent.ID) + ws := memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -1911,46 +2116,108 @@ func TestStateStore_UpdateAllocsFromClient(t *testing.T) { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + // Create watchsets so we can test that update fires the watch + ws = memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, parent.ID); err != nil { + t.Fatalf("bad: %v", err) + } + + // Create the delta updates + ts := map[string]*structs.TaskState{"web": &structs.TaskState{State: structs.TaskStateRunning}} + update := &structs.Allocation{ + ID: alloc.ID, + ClientStatus: structs.AllocClientStatusComplete, + TaskStates: ts, + JobID: alloc.JobID, + TaskGroup: alloc.TaskGroup, + } + err = state.UpdateAllocsFromClient(1001, []*structs.Allocation{update}) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + summary, err = state.JobSummaryByID(ws, parent.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if summary == nil { + t.Fatalf("nil summary") + } + if summary.JobID != parent.ID { + t.Fatalf("bad summary id: %v", parent.ID) + } + if summary.Children == nil { + t.Fatalf("nil children summary") + } + if summary.Children.Pending != 0 || summary.Children.Running != 0 || summary.Children.Dead != 1 { + t.Fatalf("bad children summary: %v", summary.Children) + } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateAllocsFromClient_ChildJob(t *testing.T) { state := testStateStore(t) - alloc := mock.Alloc() + alloc1 := mock.Alloc() alloc2 := mock.Alloc() - notify := setupNotifyTest( - state, - watch.Item{Table: "allocs"}, - watch.Item{Alloc: alloc.ID}, - watch.Item{AllocEval: alloc.EvalID}, - watch.Item{AllocJob: alloc.JobID}, - watch.Item{AllocNode: alloc.NodeID}, - watch.Item{Alloc: alloc2.ID}, - watch.Item{AllocEval: alloc2.EvalID}, - watch.Item{AllocJob: alloc2.JobID}, - watch.Item{AllocNode: alloc2.NodeID}) - - if err := state.UpsertJob(999, alloc.Job); err != nil { + if err := state.UpsertJob(999, alloc1.Job); err != nil { t.Fatalf("err: %v", err) } if err := state.UpsertJob(999, alloc2.Job); err != nil { t.Fatalf("err: %v", err) } - err := state.UpsertAllocs(1000, []*structs.Allocation{alloc, alloc2}) + err := state.UpsertAllocs(1000, []*structs.Allocation{alloc1, alloc2}) if err != nil { t.Fatalf("err: %v", err) } + // Create watchsets so we can test that update fires the watch + watches := make([]memdb.WatchSet, 8) + for i := 0; i < 8; i++ { + watches[i] = memdb.NewWatchSet() + } + if _, err := state.AllocByID(watches[0], alloc1.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocByID(watches[1], alloc2.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[2], alloc1.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[3], alloc2.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[4], alloc1.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[5], alloc2.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[6], alloc1.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[7], alloc2.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } + // Create the delta updates ts := map[string]*structs.TaskState{"web": &structs.TaskState{State: structs.TaskStatePending}} update := &structs.Allocation{ - ID: alloc.ID, + ID: alloc1.ID, ClientStatus: structs.AllocClientStatusFailed, TaskStates: ts, - JobID: alloc.JobID, - TaskGroup: alloc.TaskGroup, + JobID: alloc1.JobID, + TaskGroup: alloc1.TaskGroup, } update2 := &structs.Allocation{ ID: alloc2.ID, @@ -1965,20 +2232,27 @@ func TestStateStore_UpdateAllocsFromClient_ChildJob(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + for i, ws := range watches { + if !watchFired(ws) { + t.Fatalf("bad %d", i) + } + } + + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc1.ID) if err != nil { t.Fatalf("err: %v", err) } - alloc.CreateIndex = 1000 - alloc.ModifyIndex = 1001 - alloc.TaskStates = ts - alloc.ClientStatus = structs.AllocClientStatusFailed - if !reflect.DeepEqual(alloc, out) { - t.Fatalf("bad: %#v %#v", alloc, out) + alloc1.CreateIndex = 1000 + alloc1.ModifyIndex = 1001 + alloc1.TaskStates = ts + alloc1.ClientStatus = structs.AllocClientStatusFailed + if !reflect.DeepEqual(alloc1, out) { + t.Fatalf("bad: %#v %#v", alloc1, out) } - out, err = state.AllocByID(alloc2.ID) + out, err = state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2000,7 +2274,7 @@ func TestStateStore_UpdateAllocsFromClient_ChildJob(t *testing.T) { } // Ensure summaries have been updated - summary, err := state.JobSummaryByID(alloc.JobID) + summary, err := state.JobSummaryByID(ws, alloc1.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2009,7 +2283,7 @@ func TestStateStore_UpdateAllocsFromClient_ChildJob(t *testing.T) { t.Fatalf("expected failed: %v, actual: %v, summary: %#v", 1, tgSummary.Failed, tgSummary) } - summary2, err := state.JobSummaryByID(alloc2.JobID) + summary2, err := state.JobSummaryByID(ws, alloc2.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2018,7 +2292,9 @@ func TestStateStore_UpdateAllocsFromClient_ChildJob(t *testing.T) { t.Fatalf("expected running: %v, actual: %v", 1, tgSummary2.Running) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateMultipleAllocsFromClient(t *testing.T) { @@ -2055,7 +2331,8 @@ func TestStateStore_UpdateMultipleAllocsFromClient(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2068,7 +2345,7 @@ func TestStateStore_UpdateMultipleAllocsFromClient(t *testing.T) { t.Fatalf("bad: %#v , actual:%#v", alloc, out) } - summary, err := state.JobSummaryByID(alloc.JobID) + summary, err := state.JobSummaryByID(ws, alloc.JobID) expectedSummary := &structs.JobSummary{ JobID: alloc.JobID, Summary: map[string]structs.TaskGroupSummary{ @@ -2092,24 +2369,41 @@ func TestStateStore_UpsertAlloc_Alloc(t *testing.T) { state := testStateStore(t) alloc := mock.Alloc() - notify := setupNotifyTest( - state, - watch.Item{Table: "allocs"}, - watch.Item{Alloc: alloc.ID}, - watch.Item{AllocEval: alloc.EvalID}, - watch.Item{AllocJob: alloc.JobID}, - watch.Item{AllocNode: alloc.NodeID}) - if err := state.UpsertJob(999, alloc.Job); err != nil { t.Fatalf("err: %v", err) } + // Create watchsets so we can test that update fires the watch + watches := make([]memdb.WatchSet, 4) + for i := 0; i < 4; i++ { + watches[i] = memdb.NewWatchSet() + } + if _, err := state.AllocByID(watches[0], alloc.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[1], alloc.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[2], alloc.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[3], alloc.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } + err := state.UpsertAllocs(1000, []*structs.Allocation{alloc}) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + for i, ws := range watches { + if !watchFired(ws) { + t.Fatalf("bad %d", i) + } + } + + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2126,7 +2420,7 @@ func TestStateStore_UpsertAlloc_Alloc(t *testing.T) { t.Fatalf("bad: %d", index) } - summary, err := state.JobSummaryByID(alloc.JobID) + summary, err := state.JobSummaryByID(ws, alloc.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2139,7 +2433,9 @@ func TestStateStore_UpsertAlloc_Alloc(t *testing.T) { t.Fatalf("expected queued: %v, actual: %v", 1, tgSummary.Starting) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpsertAlloc_NoEphemeralDisk(t *testing.T) { @@ -2157,7 +2453,8 @@ func TestStateStore_UpsertAlloc_NoEphemeralDisk(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2188,17 +2485,23 @@ func TestStateStore_UpsertAlloc_ChildJob(t *testing.T) { alloc.JobID = child.ID alloc.Job = child - notify := setupNotifyTest( - state, - watch.Item{Table: "job_summary"}, - watch.Item{JobSummary: parent.ID}) + // Create watchsets so we can test that delete fires the watch + ws := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, parent.ID); err != nil { + t.Fatalf("bad: %v", err) + } err := state.UpsertAllocs(1000, []*structs.Allocation{alloc}) if err != nil { t.Fatalf("err: %v", err) } - summary, err := state.JobSummaryByID(parent.ID) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, parent.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2215,7 +2518,9 @@ func TestStateStore_UpsertAlloc_ChildJob(t *testing.T) { t.Fatalf("bad children summary: %v", summary.Children) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_UpdateAlloc_Alloc(t *testing.T) { @@ -2231,7 +2536,8 @@ func TestStateStore_UpdateAlloc_Alloc(t *testing.T) { t.Fatalf("err: %v", err) } - summary, err := state.JobSummaryByID(alloc.JobID) + ws := memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, alloc.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2245,20 +2551,37 @@ func TestStateStore_UpdateAlloc_Alloc(t *testing.T) { alloc2.NodeID = alloc.NodeID + ".new" state.UpsertJobSummary(1001, mock.JobSummary(alloc2.JobID)) - notify := setupNotifyTest( - state, - watch.Item{Table: "allocs"}, - watch.Item{Alloc: alloc2.ID}, - watch.Item{AllocEval: alloc2.EvalID}, - watch.Item{AllocJob: alloc2.JobID}, - watch.Item{AllocNode: alloc2.NodeID}) + // Create watchsets so we can test that update fires the watch + watches := make([]memdb.WatchSet, 4) + for i := 0; i < 4; i++ { + watches[i] = memdb.NewWatchSet() + } + if _, err := state.AllocByID(watches[0], alloc2.ID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByEval(watches[1], alloc2.EvalID); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByJob(watches[2], alloc2.JobID, false); err != nil { + t.Fatalf("bad: %v", err) + } + if _, err := state.AllocsByNode(watches[3], alloc2.NodeID); err != nil { + t.Fatalf("bad: %v", err) + } err = state.UpsertAllocs(1002, []*structs.Allocation{alloc2}) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + for i, ws := range watches { + if !watchFired(ws) { + t.Fatalf("bad %d", i) + } + } + + ws = memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2283,7 +2606,7 @@ func TestStateStore_UpdateAlloc_Alloc(t *testing.T) { } // Ensure that summary hasb't changed - summary, err = state.JobSummaryByID(alloc.JobID) + summary, err = state.JobSummaryByID(ws, alloc.JobID) if err != nil { t.Fatalf("err: %v", err) } @@ -2292,7 +2615,9 @@ func TestStateStore_UpdateAlloc_Alloc(t *testing.T) { t.Fatalf("expected starting: %v, actual: %v", 1, tgSummary.Starting) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } // This test ensures that the state store will mark the clients status as lost @@ -2318,7 +2643,8 @@ func TestStateStore_UpdateAlloc_Lost(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc2.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2364,7 +2690,8 @@ func TestStateStore_UpdateAlloc_NoJob(t *testing.T) { t.Fatalf("err: %v", err) } - out, _ := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, _ := state.AllocByID(ws, alloc.ID) // Update the modify index of the alloc before comparing allocCopy1.ModifyIndex = 1003 if !reflect.DeepEqual(out, allocCopy1) { @@ -2380,16 +2707,17 @@ func TestStateStore_JobSummary(t *testing.T) { state.UpsertJob(900, job) // Get the job back - outJob, _ := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + outJob, _ := state.JobByID(ws, job.ID) if outJob.CreateIndex != 900 { t.Fatalf("bad create index: %v", outJob.CreateIndex) } - summary, _ := state.JobSummaryByID(job.ID) + summary, _ := state.JobSummaryByID(ws, job.ID) if summary.CreateIndex != 900 { t.Fatalf("bad create index: %v", summary.CreateIndex) } - // Upser an allocation + // Upsert an allocation alloc := mock.Alloc() alloc.JobID = job.ID alloc.Job = job @@ -2418,6 +2746,10 @@ func TestStateStore_JobSummary(t *testing.T) { alloc5.DesiredStatus = structs.AllocDesiredStatusRun state.UpsertAllocs(970, []*structs.Allocation{alloc5}) + if !watchFired(ws) { + t.Fatalf("bad") + } + expectedSummary := structs.JobSummary{ JobID: job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -2430,7 +2762,7 @@ func TestStateStore_JobSummary(t *testing.T) { ModifyIndex: 930, } - summary, _ = state.JobSummaryByID(job.ID) + summary, _ = state.JobSummaryByID(ws, job.ID) if !reflect.DeepEqual(&expectedSummary, summary) { t.Fatalf("expected: %#v, actual: %v", expectedSummary, summary) } @@ -2445,7 +2777,7 @@ func TestStateStore_JobSummary(t *testing.T) { state.UpdateAllocsFromClient(990, []*structs.Allocation{alloc6}) // We shouldn't have any summary at this point - summary, _ = state.JobSummaryByID(job.ID) + summary, _ = state.JobSummaryByID(ws, job.ID) if summary != nil { t.Fatalf("expected nil, actual: %#v", summary) } @@ -2454,11 +2786,11 @@ func TestStateStore_JobSummary(t *testing.T) { job1 := mock.Job() job1.ID = job.ID state.UpsertJob(1000, job1) - outJob2, _ := state.JobByID(job1.ID) + outJob2, _ := state.JobByID(ws, job1.ID) if outJob2.CreateIndex != 1000 { t.Fatalf("bad create index: %v", outJob2.CreateIndex) } - summary, _ = state.JobSummaryByID(job1.ID) + summary, _ = state.JobSummaryByID(ws, job1.ID) if summary.CreateIndex != 1000 { t.Fatalf("bad create index: %v", summary.CreateIndex) } @@ -2481,7 +2813,7 @@ func TestStateStore_JobSummary(t *testing.T) { ModifyIndex: 1000, } - summary, _ = state.JobSummaryByID(job1.ID) + summary, _ = state.JobSummaryByID(ws, job1.ID) if !reflect.DeepEqual(&expectedSummary, summary) { t.Fatalf("expected: %#v, actual: %#v", expectedSummary, summary) } @@ -2551,7 +2883,8 @@ func TestStateStore_ReconcileJobSummary(t *testing.T) { state.ReconcileJobSummaries(120) - summary, _ := state.JobSummaryByID(alloc.Job.ID) + ws := memdb.NewWatchSet() + summary, _ := state.JobSummaryByID(ws, alloc.Job.ID) expectedSummary := structs.JobSummary{ JobID: alloc.Job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -2614,7 +2947,9 @@ func TestStateStore_UpdateAlloc_JobNotPresent(t *testing.T) { CreateIndex: 500, ModifyIndex: 500, } - summary, _ := state.JobSummaryByID(alloc.Job.ID) + + ws := memdb.NewWatchSet() + summary, _ := state.JobSummaryByID(ws, alloc.Job.ID) if !reflect.DeepEqual(&expectedSummary, summary) { t.Fatalf("expected: %v, actual: %v", expectedSummary, summary) } @@ -2638,7 +2973,8 @@ func TestStateStore_EvictAlloc_Alloc(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2675,7 +3011,8 @@ func TestStateStore_AllocsByNode(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocsByNode("foo") + ws := memdb.NewWatchSet() + out, err := state.AllocsByNode(ws, "foo") if err != nil { t.Fatalf("err: %v", err) } @@ -2686,6 +3023,10 @@ func TestStateStore_AllocsByNode(t *testing.T) { if !reflect.DeepEqual(allocs, out) { t.Fatalf("bad: %#v %#v", allocs, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_AllocsByNodeTerminal(t *testing.T) { @@ -2714,7 +3055,8 @@ func TestStateStore_AllocsByNodeTerminal(t *testing.T) { } // Verify the terminal allocs - out, err := state.AllocsByNodeTerminal("foo", true) + ws := memdb.NewWatchSet() + out, err := state.AllocsByNodeTerminal(ws, "foo", true) if err != nil { t.Fatalf("err: %v", err) } @@ -2727,7 +3069,7 @@ func TestStateStore_AllocsByNodeTerminal(t *testing.T) { } // Verify the non-terminal allocs - out, err = state.AllocsByNodeTerminal("foo", false) + out, err = state.AllocsByNodeTerminal(ws, "foo", false) if err != nil { t.Fatalf("err: %v", err) } @@ -2738,6 +3080,10 @@ func TestStateStore_AllocsByNodeTerminal(t *testing.T) { if !reflect.DeepEqual(nonterm, out) { t.Fatalf("bad: %#v %#v", nonterm, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_AllocsByJob(t *testing.T) { @@ -2759,7 +3105,8 @@ func TestStateStore_AllocsByJob(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocsByJob("foo", false) + ws := memdb.NewWatchSet() + out, err := state.AllocsByJob(ws, "foo", false) if err != nil { t.Fatalf("err: %v", err) } @@ -2770,6 +3117,10 @@ func TestStateStore_AllocsByJob(t *testing.T) { if !reflect.DeepEqual(allocs, out) { t.Fatalf("bad: %#v %#v", allocs, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_AllocsForRegisteredJob(t *testing.T) { @@ -2809,7 +3160,8 @@ func TestStateStore_AllocsForRegisteredJob(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.AllocsByJob(job1.ID, true) + ws := memdb.NewWatchSet() + out, err := state.AllocsByJob(ws, job1.ID, true) if err != nil { t.Fatalf("err: %v", err) } @@ -2819,12 +3171,15 @@ func TestStateStore_AllocsForRegisteredJob(t *testing.T) { t.Fatalf("expected: %v, actual: %v", expected, len(out)) } - out1, err := state.AllocsByJob(job1.ID, false) + out1, err := state.AllocsByJob(ws, job1.ID, false) expected = len(allocs1) if len(out1) != expected { t.Fatalf("expected: %v, actual: %v", expected, len(out1)) } + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_AllocsByIDPrefix(t *testing.T) { @@ -2857,7 +3212,8 @@ func TestStateStore_AllocsByIDPrefix(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err := state.AllocsByIDPrefix("aaaa") + ws := memdb.NewWatchSet() + iter, err := state.AllocsByIDPrefix(ws, "aaaa") if err != nil { t.Fatalf("err: %v", err) } @@ -2887,7 +3243,7 @@ func TestStateStore_AllocsByIDPrefix(t *testing.T) { } } - iter, err = state.AllocsByIDPrefix("b-a7bfb") + iter, err = state.AllocsByIDPrefix(ws, "b-a7bfb") if err != nil { t.Fatalf("err: %v", err) } @@ -2896,6 +3252,10 @@ func TestStateStore_AllocsByIDPrefix(t *testing.T) { if len(out) != 0 { t.Fatalf("bad: unexpected zero allocations, got: %#v", out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_Allocs(t *testing.T) { @@ -2915,7 +3275,8 @@ func TestStateStore_Allocs(t *testing.T) { t.Fatalf("err: %v", err) } - iter, err := state.Allocs() + ws := memdb.NewWatchSet() + iter, err := state.Allocs(ws) if err != nil { t.Fatalf("err: %v", err) } @@ -2935,20 +3296,16 @@ func TestStateStore_Allocs(t *testing.T) { if !reflect.DeepEqual(allocs, out) { t.Fatalf("bad: %#v %#v", allocs, out) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreAlloc(t *testing.T) { state := testStateStore(t) alloc := mock.Alloc() - notify := setupNotifyTest( - state, - watch.Item{Table: "allocs"}, - watch.Item{Alloc: alloc.ID}, - watch.Item{AllocEval: alloc.EvalID}, - watch.Item{AllocJob: alloc.JobID}, - watch.Item{AllocNode: alloc.NodeID}) - restore, err := state.Restore() if err != nil { t.Fatalf("err: %v", err) @@ -2961,7 +3318,8 @@ func TestStateStore_RestoreAlloc(t *testing.T) { restore.Commit() - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -2970,7 +3328,9 @@ func TestStateStore_RestoreAlloc(t *testing.T) { t.Fatalf("Bad: %#v %#v", out, alloc) } - notify.verify(t) + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreAlloc_NoEphemeralDisk(t *testing.T) { @@ -2991,7 +3351,8 @@ func TestStateStore_RestoreAlloc_NoEphemeralDisk(t *testing.T) { restore.Commit() - out, err := state.AllocByID(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.AllocByID(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -3003,11 +3364,14 @@ func TestStateStore_RestoreAlloc_NoEphemeralDisk(t *testing.T) { if !reflect.DeepEqual(out, expected) { t.Fatalf("Bad: %#v %#v", out, expected) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_SetJobStatus_ForceStatus(t *testing.T) { state := testStateStore(t) - watcher := watch.NewItems() txn := state.db.Txn(true) // Create and insert a mock job. @@ -3020,7 +3384,7 @@ func TestStateStore_SetJobStatus_ForceStatus(t *testing.T) { exp := "foobar" index := uint64(1000) - if err := state.setJobStatus(index, watcher, txn, job, false, exp); err != nil { + if err := state.setJobStatus(index, txn, job, false, exp); err != nil { t.Fatalf("setJobStatus() failed: %v", err) } @@ -3041,7 +3405,6 @@ func TestStateStore_SetJobStatus_ForceStatus(t *testing.T) { func TestStateStore_SetJobStatus_NoOp(t *testing.T) { state := testStateStore(t) - watcher := watch.NewItems() txn := state.db.Txn(true) // Create and insert a mock job that should be pending. @@ -3053,7 +3416,7 @@ func TestStateStore_SetJobStatus_NoOp(t *testing.T) { } index := uint64(1000) - if err := state.setJobStatus(index, watcher, txn, job, false, ""); err != nil { + if err := state.setJobStatus(index, txn, job, false, ""); err != nil { t.Fatalf("setJobStatus() failed: %v", err) } @@ -3070,7 +3433,6 @@ func TestStateStore_SetJobStatus_NoOp(t *testing.T) { func TestStateStore_SetJobStatus(t *testing.T) { state := testStateStore(t) - watcher := watch.NewItems() txn := state.db.Txn(true) // Create and insert a mock job that should be pending but has an incorrect @@ -3083,7 +3445,7 @@ func TestStateStore_SetJobStatus(t *testing.T) { } index := uint64(1000) - if err := state.setJobStatus(index, watcher, txn, job, false, ""); err != nil { + if err := state.setJobStatus(index, txn, job, false, ""); err != nil { t.Fatalf("setJobStatus() failed: %v", err) } @@ -3223,69 +3585,32 @@ func TestStateStore_SetJobStatus_PendingEval(t *testing.T) { } } -func TestStateWatch_watch(t *testing.T) { - sw := newStateWatch() - notify1 := make(chan struct{}, 1) - notify2 := make(chan struct{}, 1) - notify3 := make(chan struct{}, 1) - - // Notifications trigger subscribed channels - sw.watch(watch.NewItems(watch.Item{Table: "foo"}), notify1) - sw.watch(watch.NewItems(watch.Item{Table: "bar"}), notify2) - sw.watch(watch.NewItems(watch.Item{Table: "baz"}), notify3) - - items := watch.NewItems() - items.Add(watch.Item{Table: "foo"}) - items.Add(watch.Item{Table: "bar"}) - - sw.notify(items) - if len(notify1) != 1 { - t.Fatalf("should notify") - } - if len(notify2) != 1 { - t.Fatalf("should notify") - } - if len(notify3) != 0 { - t.Fatalf("should not notify") - } -} - -func TestStateWatch_stopWatch(t *testing.T) { - sw := newStateWatch() - notify := make(chan struct{}) - - // First subscribe - sw.watch(watch.NewItems(watch.Item{Table: "foo"}), notify) - - // Unsubscribe stop notifications - sw.stopWatch(watch.NewItems(watch.Item{Table: "foo"}), notify) - - // Check that the group was removed - if _, ok := sw.items[watch.Item{Table: "foo"}]; ok { - t.Fatalf("should remove group") - } - - // Check that we are not notified - sw.notify(watch.NewItems(watch.Item{Table: "foo"})) - if len(notify) != 0 { - t.Fatalf("should not notify") - } -} - func TestStateJobSummary_UpdateJobCount(t *testing.T) { state := testStateStore(t) alloc := mock.Alloc() job := alloc.Job job.TaskGroups[0].Count = 3 - err := state.UpsertJob(1000, job) - if err != nil { + + // Create watchsets so we can test that upsert fires the watch + ws := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws, job.ID); err != nil { + t.Fatalf("bad: %v", err) + } + + if err := state.UpsertJob(1000, job); err != nil { t.Fatalf("err: %v", err) } if err := state.UpsertAllocs(1001, []*structs.Allocation{alloc}); err != nil { t.Fatalf("err: %v", err) } - summary, _ := state.JobSummaryByID(job.ID) + + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + summary, _ := state.JobSummaryByID(ws, job.ID) expectedSummary := structs.JobSummary{ JobID: job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -3301,6 +3626,12 @@ func TestStateJobSummary_UpdateJobCount(t *testing.T) { t.Fatalf("expected: %v, actual: %v", expectedSummary, summary) } + // Create watchsets so we can test that upsert fires the watch + ws2 := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws2, job.ID); err != nil { + t.Fatalf("bad: %v", err) + } + alloc2 := mock.Alloc() alloc2.Job = job alloc2.JobID = job.ID @@ -3313,9 +3644,13 @@ func TestStateJobSummary_UpdateJobCount(t *testing.T) { t.Fatalf("err: %v", err) } - outA, _ := state.AllocByID(alloc3.ID) + if !watchFired(ws2) { + t.Fatalf("bad") + } - summary, _ = state.JobSummaryByID(job.ID) + outA, _ := state.AllocByID(ws, alloc3.ID) + + summary, _ = state.JobSummaryByID(ws, job.ID) expectedSummary = structs.JobSummary{ JobID: job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -3331,6 +3666,12 @@ func TestStateJobSummary_UpdateJobCount(t *testing.T) { t.Fatalf("expected summary: %v, actual: %v", expectedSummary, summary) } + // Create watchsets so we can test that upsert fires the watch + ws3 := memdb.NewWatchSet() + if _, err := state.JobSummaryByID(ws3, job.ID); err != nil { + t.Fatalf("bad: %v", err) + } + alloc4 := mock.Alloc() alloc4.ID = alloc2.ID alloc4.Job = alloc2.Job @@ -3346,8 +3687,13 @@ func TestStateJobSummary_UpdateJobCount(t *testing.T) { if err := state.UpdateAllocsFromClient(1004, []*structs.Allocation{alloc4, alloc5}); err != nil { t.Fatalf("err: %v", err) } - outA, _ = state.AllocByID(alloc5.ID) - summary, _ = state.JobSummaryByID(job.ID) + + if !watchFired(ws2) { + t.Fatalf("bad") + } + + outA, _ = state.AllocByID(ws, alloc5.ID) + summary, _ = state.JobSummaryByID(ws, job.ID) expectedSummary = structs.JobSummary{ JobID: job.ID, Summary: map[string]structs.TaskGroupSummary{ @@ -3387,7 +3733,9 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) { if err := state.UpsertAllocs(1001, []*structs.Allocation{alloc, alloc2, alloc3}); err != nil { t.Fatalf("err: %v", err) } - summary, _ := state.JobSummaryByID(job.ID) + + ws := memdb.NewWatchSet() + summary, _ := state.JobSummaryByID(ws, job.ID) if summary.Summary["web"].Starting != 3 { t.Fatalf("bad job summary: %v", summary) } @@ -3413,7 +3761,12 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) { if err := state.UpdateAllocsFromClient(1002, []*structs.Allocation{alloc4, alloc5, alloc6}); err != nil { t.Fatalf("err: %v", err) } - summary, _ = state.JobSummaryByID(job.ID) + + if !watchFired(ws) { + t.Fatalf("bad") + } + + summary, _ = state.JobSummaryByID(ws, job.ID) if summary.Summary["web"].Running != 1 || summary.Summary["web"].Failed != 1 || summary.Summary["web"].Complete != 1 { t.Fatalf("bad job summary: %v", summary) } @@ -3425,7 +3778,7 @@ func TestJobSummary_UpdateClientStatus(t *testing.T) { if err := state.UpsertAllocs(1003, []*structs.Allocation{alloc7}); err != nil { t.Fatalf("err: %v", err) } - summary, _ = state.JobSummaryByID(job.ID) + summary, _ = state.JobSummaryByID(ws, job.ID) if summary.Summary["web"].Starting != 1 || summary.Summary["web"].Running != 1 || summary.Summary["web"].Failed != 1 || summary.Summary["web"].Complete != 1 { t.Fatalf("bad job summary: %v", summary) } @@ -3436,12 +3789,26 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) { a := mock.VaultAccessor() a2 := mock.VaultAccessor() + ws := memdb.NewWatchSet() + if _, err := state.VaultAccessor(ws, a.Accessor); err != nil { + t.Fatalf("err: %v", err) + } + + if _, err := state.VaultAccessor(ws, a2.Accessor); err != nil { + t.Fatalf("err: %v", err) + } + err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{a, a2}) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.VaultAccessor(a.Accessor) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.VaultAccessor(ws, a.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -3450,7 +3817,7 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) { t.Fatalf("bad: %#v %#v", a, out) } - out, err = state.VaultAccessor(a2.Accessor) + out, err = state.VaultAccessor(ws, a2.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -3459,7 +3826,7 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) { t.Fatalf("bad: %#v %#v", a2, out) } - iter, err := state.VaultAccessors() + iter, err := state.VaultAccessors(ws) if err != nil { t.Fatalf("err: %v", err) } @@ -3490,6 +3857,10 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) { if index != 1000 { t.Fatalf("bad: %d", index) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_DeleteVaultAccessors(t *testing.T) { @@ -3503,19 +3874,29 @@ func TestStateStore_DeleteVaultAccessors(t *testing.T) { t.Fatalf("err: %v", err) } + ws := memdb.NewWatchSet() + if _, err := state.VaultAccessor(ws, a1.Accessor); err != nil { + t.Fatalf("err: %v", err) + } + err = state.DeleteVaultAccessors(1001, accessors) if err != nil { t.Fatalf("err: %v", err) } - out, err := state.VaultAccessor(a1.Accessor) + if !watchFired(ws) { + t.Fatalf("bad") + } + + ws = memdb.NewWatchSet() + out, err := state.VaultAccessor(ws, a1.Accessor) if err != nil { t.Fatalf("err: %v", err) } if out != nil { t.Fatalf("bad: %#v %#v", a1, out) } - out, err = state.VaultAccessor(a2.Accessor) + out, err = state.VaultAccessor(ws, a2.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -3530,6 +3911,10 @@ func TestStateStore_DeleteVaultAccessors(t *testing.T) { if index != 1001 { t.Fatalf("bad: %d", index) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_VaultAccessorsByAlloc(t *testing.T) { @@ -3555,7 +3940,8 @@ func TestStateStore_VaultAccessorsByAlloc(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.VaultAccessorsByAlloc(alloc.ID) + ws := memdb.NewWatchSet() + out, err := state.VaultAccessorsByAlloc(ws, alloc.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -3571,6 +3957,10 @@ func TestStateStore_VaultAccessorsByAlloc(t *testing.T) { if index != 1000 { t.Fatalf("bad: %d", index) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_VaultAccessorsByNode(t *testing.T) { @@ -3596,7 +3986,8 @@ func TestStateStore_VaultAccessorsByNode(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := state.VaultAccessorsByNode(node.ID) + ws := memdb.NewWatchSet() + out, err := state.VaultAccessorsByNode(ws, node.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -3612,6 +4003,10 @@ func TestStateStore_VaultAccessorsByNode(t *testing.T) { if index != 1000 { t.Fatalf("bad: %d", index) } + + if watchFired(ws) { + t.Fatalf("bad") + } } func TestStateStore_RestoreVaultAccessor(t *testing.T) { @@ -3629,7 +4024,8 @@ func TestStateStore_RestoreVaultAccessor(t *testing.T) { } restore.Commit() - out, err := state.VaultAccessor(a.Accessor) + ws := memdb.NewWatchSet() + out, err := state.VaultAccessor(ws, a.Accessor) if err != nil { t.Fatalf("err: %v", err) } @@ -3637,38 +4033,33 @@ func TestStateStore_RestoreVaultAccessor(t *testing.T) { if !reflect.DeepEqual(out, a) { t.Fatalf("Bad: %#v %#v", out, a) } -} -// setupNotifyTest takes a state store and a set of watch items, then creates -// and subscribes a notification channel for each item. -func setupNotifyTest(state *StateStore, items ...watch.Item) notifyTest { - var n notifyTest - for _, item := range items { - ch := make(chan struct{}, 1) - state.Watch(watch.NewItems(item), ch) - n = append(n, ¬ifyTestCase{item, ch}) + if watchFired(ws) { + t.Fatalf("bad") } - return n } -// notifyTestCase is used to set up and verify watch triggers. -type notifyTestCase struct { - item watch.Item - ch chan struct{} -} - -// notifyTest is a suite of notifyTestCases. -type notifyTest []*notifyTestCase - -// verify ensures that each channel received a notification. -func (n notifyTest) verify(t *testing.T) { - for _, tcase := range n { - if len(tcase.ch) != 1 { - t.Fatalf("should notify %#v", tcase.item) - } +func TestStateStore_Abandon(t *testing.T) { + s := testStateStore(t) + abandonCh := s.AbandonCh() + s.Abandon() + select { + case <-abandonCh: + default: + t.Fatalf("bad") } } +// watchFired is a helper for unit tests that returns if the given watch set +// fired (it doesn't care which watch actually fired). This uses a fixed +// timeout since we already expect the event happened before calling this and +// just need to distinguish a fire from a timeout. We do need a little time to +// allow the watch to set up any goroutines, though. +func watchFired(ws memdb.WatchSet) bool { + timedOut := ws.Watch(time.After(50 * time.Millisecond)) + return !timedOut +} + // NodeIDSort is used to sort nodes by ID type NodeIDSort []*structs.Node diff --git a/nomad/system_endpoint_test.go b/nomad/system_endpoint_test.go index 91e5d51c9..14dada680 100644 --- a/nomad/system_endpoint_test.go +++ b/nomad/system_endpoint_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -38,7 +39,8 @@ func TestSystemEndpoint_GarbageCollect(t *testing.T) { testutil.WaitForResult(func() (bool, error) { // Check if the job has been GC'd - exist, err := state.JobByID(job.ID) + ws := memdb.NewWatchSet() + exist, err := state.JobByID(ws, job.ID) if err != nil { return false, err } @@ -81,7 +83,8 @@ func TestSystemEndpoint_ReconcileSummaries(t *testing.T) { testutil.WaitForResult(func() (bool, error) { // Check if Nomad has reconciled the summary for the job - summary, err := state.JobSummaryByID(job.ID) + ws := memdb.NewWatchSet() + summary, err := state.JobSummaryByID(ws, job.ID) if err != nil { return false, err } diff --git a/nomad/watch/watch.go b/nomad/watch/watch.go deleted file mode 100644 index 8578df33f..000000000 --- a/nomad/watch/watch.go +++ /dev/null @@ -1,40 +0,0 @@ -package watch - -// The watch package provides a means of describing a watch for a blocking -// query. It is exported so it may be shared between Nomad's RPC layer and -// the underlying state store. - -// Item describes the scope of a watch. It is used to provide a uniform -// input for subscribe/unsubscribe and notification firing. Specifying -// multiple fields does not place a watch on multiple items. Each Item -// describes exactly one scoped watch. -type Item struct { - Alloc string - AllocEval string - AllocJob string - AllocNode string - Eval string - EvalJob string - Job string - JobSummary string - Node string - Table string -} - -// Items is a helper used to construct a set of watchItems. It deduplicates -// the items as they are added using map keys. -type Items map[Item]struct{} - -// NewItems creates a new Items set and adds the given items. -func NewItems(items ...Item) Items { - wi := make(Items) - for _, item := range items { - wi.Add(item) - } - return wi -} - -// Add adds an item to the watch set. -func (wi Items) Add(i Item) { - wi[i] = struct{}{} -} diff --git a/nomad/watch/watch_test.go b/nomad/watch/watch_test.go deleted file mode 100644 index 9a8901aa8..000000000 --- a/nomad/watch/watch_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package watch - -import ( - "testing" -) - -func TestWatchItems(t *testing.T) { - // Creates an empty set of items - wi := NewItems() - if len(wi) != 0 { - t.Fatalf("expect 0 items, got: %#v", wi) - } - - // Creates a new set of supplied items - wi = NewItems(Item{Table: "foo"}) - if len(wi) != 1 { - t.Fatalf("expected 1 item, got: %#v", wi) - } - - // Adding items works - wi.Add(Item{Node: "bar"}) - if len(wi) != 2 { - t.Fatalf("expected 2 items, got: %#v", wi) - } - - // Adding duplicates auto-dedupes - wi.Add(Item{Table: "foo"}) - if len(wi) != 2 { - t.Fatalf("expected 2 items, got: %#v", wi) - } -} diff --git a/nomad/worker.go b/nomad/worker.go index f4ef1fb5c..6a274bf12 100644 --- a/nomad/worker.go +++ b/nomad/worker.go @@ -8,6 +8,7 @@ import ( "time" "github.com/armon/go-metrics" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" ) @@ -446,7 +447,8 @@ func (w *Worker) ReblockEval(eval *structs.Evaluation) error { // Update the evaluation if the queued jobs is not same as what is // recorded in the job summary - summary, err := w.srv.fsm.state.JobSummaryByID(eval.JobID) + ws := memdb.NewWatchSet() + summary, err := w.srv.fsm.state.JobSummaryByID(ws, eval.JobID) if err != nil { return fmt.Errorf("couldn't retreive job summary: %v", err) } diff --git a/nomad/worker_test.go b/nomad/worker_test.go index fea703ceb..4de1ddcf3 100644 --- a/nomad/worker_test.go +++ b/nomad/worker_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" @@ -397,7 +398,8 @@ func TestWorker_UpdateEval(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := s1.fsm.State().EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + out, err := s1.fsm.State().EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -443,7 +445,8 @@ func TestWorker_CreateEval(t *testing.T) { t.Fatalf("err: %v", err) } - out, err := s1.fsm.State().EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + out, err := s1.fsm.State().EvalByID(ws, eval2.ID) if err != nil { t.Fatalf("err: %v", err) } @@ -512,7 +515,8 @@ func TestWorker_ReblockEval(t *testing.T) { } // Check that the eval was updated - eval, err := s1.fsm.State().EvalByID(eval2.ID) + ws := memdb.NewWatchSet() + eval, err := s1.fsm.State().EvalByID(ws, eval2.ID) if err != nil { t.Fatal(err) } diff --git a/scheduler/context.go b/scheduler/context.go index 5f3366f46..0e9d483c8 100644 --- a/scheduler/context.go +++ b/scheduler/context.go @@ -4,6 +4,7 @@ import ( "log" "regexp" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-version" "github.com/hashicorp/nomad/nomad/structs" ) @@ -107,7 +108,8 @@ func (e *EvalContext) Reset() { func (e *EvalContext) ProposedAllocs(nodeID string) ([]*structs.Allocation, error) { // Get the existing allocations that are non-terminal - existingAlloc, err := e.state.AllocsByNodeTerminal(nodeID, false) + ws := memdb.NewWatchSet() + existingAlloc, err := e.state.AllocsByNodeTerminal(ws, nodeID, false) if err != nil { return nil, err } diff --git a/scheduler/generic_sched.go b/scheduler/generic_sched.go index 7c2cc24bb..5653d9537 100644 --- a/scheduler/generic_sched.go +++ b/scheduler/generic_sched.go @@ -4,6 +4,7 @@ import ( "fmt" "log" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/structs" ) @@ -183,7 +184,8 @@ func (s *GenericScheduler) createBlockedEval(planFailure bool) error { func (s *GenericScheduler) process() (bool, error) { // Lookup the Job by ID var err error - s.job, err = s.state.JobByID(s.eval.JobID) + ws := memdb.NewWatchSet() + s.job, err = s.state.JobByID(ws, s.eval.JobID) if err != nil { return false, fmt.Errorf("failed to get job '%s': %v", s.eval.JobID, err) @@ -354,7 +356,8 @@ func (s *GenericScheduler) computeJobAllocs() error { } // Lookup the allocations by JobID - allocs, err := s.state.AllocsByJob(s.eval.JobID, true) + ws := memdb.NewWatchSet() + allocs, err := s.state.AllocsByJob(ws, s.eval.JobID, true) if err != nil { return fmt.Errorf("failed to get allocs for job '%s': %v", s.eval.JobID, err) @@ -513,7 +516,8 @@ func (s *GenericScheduler) findPreferredNode(allocTuple *allocTuple) (node *stru } if taskGroup.EphemeralDisk.Sticky == true { var preferredNode *structs.Node - preferredNode, err = s.state.NodeByID(allocTuple.Alloc.NodeID) + ws := memdb.NewWatchSet() + preferredNode, err = s.state.NodeByID(ws, allocTuple.Alloc.NodeID) if preferredNode.Ready() { node = preferredNode } diff --git a/scheduler/generic_sched_test.go b/scheduler/generic_sched_test.go index e0fdbaaf4..36ec50b0f 100644 --- a/scheduler/generic_sched_test.go +++ b/scheduler/generic_sched_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" ) @@ -67,7 +68,8 @@ func TestServiceSched_JobRegister(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -215,7 +217,8 @@ func TestServiceSched_JobRegister_DiskConstraints(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure only one allocation was placed @@ -270,7 +273,8 @@ func TestServiceSched_JobRegister_Annotate(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -335,7 +339,8 @@ func TestServiceSched_JobRegister_CountZero(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no allocations placed @@ -561,7 +566,8 @@ func TestServiceSched_JobRegister_FeasibleAndInfeasibleTG(t *testing.T) { } // Ensure two allocations placed - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) if len(out) != 2 { t.Fatalf("bad: %#v", out) @@ -680,7 +686,8 @@ func TestServiceSched_Plan_Partial_Progress(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure only one allocations placed @@ -800,7 +807,8 @@ func TestServiceSched_EvaluateBlockedEval_Finished(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -908,7 +916,8 @@ func TestServiceSched_JobModify(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -999,7 +1008,8 @@ func TestServiceSched_JobModify_IncrCount_NodeLimit(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -1095,7 +1105,8 @@ func TestServiceSched_JobModify_CountZero(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -1283,7 +1294,8 @@ func TestServiceSched_JobModify_InPlace(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -1350,7 +1362,8 @@ func TestServiceSched_JobDeregister(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure that the job field on the allocation is still populated @@ -1401,8 +1414,9 @@ func TestServiceSched_NodeDown(t *testing.T) { noErr(t, h.State.UpsertAllocs(h.NextIndex(), allocs)) // Mark some allocs as running + ws := memdb.NewWatchSet() for i := 0; i < 4; i++ { - out, _ := h.State.AllocByID(allocs[i].ID) + out, _ := h.State.AllocByID(ws, allocs[i].ID) out.ClientStatus = structs.AllocClientStatusRunning noErr(t, h.State.UpdateAllocsFromClient(h.NextIndex(), []*structs.Allocation{out})) } @@ -1468,8 +1482,9 @@ func TestServiceSched_NodeUpdate(t *testing.T) { noErr(t, h.State.UpsertAllocs(h.NextIndex(), allocs)) // Mark some allocs as running + ws := memdb.NewWatchSet() for i := 0; i < 4; i++ { - out, _ := h.State.AllocByID(allocs[i].ID) + out, _ := h.State.AllocByID(ws, allocs[i].ID) out.ClientStatus = structs.AllocClientStatusRunning noErr(t, h.State.UpdateAllocsFromClient(h.NextIndex(), []*structs.Allocation{out})) } @@ -1560,7 +1575,8 @@ func TestServiceSched_NodeDrain(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -1829,7 +1845,8 @@ func TestServiceSched_RetryLimit(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no allocations placed @@ -1882,7 +1899,8 @@ func TestBatchSched_Run_CompleteAlloc(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no allocations placed @@ -1935,7 +1953,8 @@ func TestBatchSched_Run_DrainedAlloc(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure a replacement alloc was placed. @@ -1987,7 +2006,8 @@ func TestBatchSched_Run_FailedAlloc(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure a replacement alloc was placed. @@ -2105,7 +2125,8 @@ func TestBatchSched_ReRun_SuccessfullyFinishedAlloc(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no replacement alloc was placed. diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index c69a5984e..ddbf855c4 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -63,22 +63,22 @@ type Scheduler interface { type State interface { // Nodes returns an iterator over all the nodes. // The type of each result is *structs.Node - Nodes() (memdb.ResultIterator, error) + Nodes(ws memdb.WatchSet) (memdb.ResultIterator, error) // AllocsByJob returns the allocations by JobID - AllocsByJob(jobID string, all bool) ([]*structs.Allocation, error) + AllocsByJob(ws memdb.WatchSet, jobID string, all bool) ([]*structs.Allocation, error) // AllocsByNode returns all the allocations by node - AllocsByNode(node string) ([]*structs.Allocation, error) + AllocsByNode(ws memdb.WatchSet, node string) ([]*structs.Allocation, error) // AllocsByNodeTerminal returns all the allocations by node filtering by terminal status - AllocsByNodeTerminal(node string, terminal bool) ([]*structs.Allocation, error) + AllocsByNodeTerminal(ws memdb.WatchSet, node string, terminal bool) ([]*structs.Allocation, error) // GetNodeByID is used to lookup a node by ID - NodeByID(nodeID string) (*structs.Node, error) + NodeByID(ws memdb.WatchSet, nodeID string) (*structs.Node, error) // GetJobByID is used to lookup a job by ID - JobByID(id string) (*structs.Job, error) + JobByID(ws memdb.WatchSet, id string) (*structs.Job, error) } // Planner interface is used to submit a task allocation plan. diff --git a/scheduler/system_sched.go b/scheduler/system_sched.go index f68de6b8f..755153d9c 100644 --- a/scheduler/system_sched.go +++ b/scheduler/system_sched.go @@ -4,6 +4,7 @@ import ( "fmt" "log" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" ) @@ -87,7 +88,8 @@ func (s *SystemScheduler) Process(eval *structs.Evaluation) error { func (s *SystemScheduler) process() (bool, error) { // Lookup the Job by ID var err error - s.job, err = s.state.JobByID(s.eval.JobID) + ws := memdb.NewWatchSet() + s.job, err = s.state.JobByID(ws, s.eval.JobID) if err != nil { return false, fmt.Errorf("failed to get job '%s': %v", s.eval.JobID, err) @@ -178,7 +180,8 @@ func (s *SystemScheduler) process() (bool, error) { // existing allocations and node status to update the allocations. func (s *SystemScheduler) computeJobAllocs() error { // Lookup the allocations by JobID - allocs, err := s.state.AllocsByJob(s.eval.JobID, true) + ws := memdb.NewWatchSet() + allocs, err := s.state.AllocsByJob(ws, s.eval.JobID, true) if err != nil { return fmt.Errorf("failed to get allocs for job '%s': %v", s.eval.JobID, err) diff --git a/scheduler/system_sched_test.go b/scheduler/system_sched_test.go index 438014e2d..313d573d3 100644 --- a/scheduler/system_sched_test.go +++ b/scheduler/system_sched_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" ) @@ -58,7 +59,8 @@ func TestSystemSched_JobRegister(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -182,7 +184,8 @@ func TestSystemSched_JobRegister_EphemeralDiskConstraint(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -205,7 +208,7 @@ func TestSystemSched_JobRegister_EphemeralDiskConstraint(t *testing.T) { t.Fatalf("err: %v", err) } - out, err = h1.State.AllocsByJob(job1.ID, false) + out, err = h1.State.AllocsByJob(ws, job1.ID, false) noErr(t, err) if len(out) != 0 { t.Fatalf("bad: %#v", out) @@ -319,7 +322,8 @@ func TestSystemSched_JobRegister_Annotate(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -430,7 +434,8 @@ func TestSystemSched_JobRegister_AddNode(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -558,7 +563,8 @@ func TestSystemSched_JobModify(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -746,7 +752,8 @@ func TestSystemSched_JobModify_InPlace(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure all allocations placed @@ -822,7 +829,8 @@ func TestSystemSched_JobDeregister(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no remaining allocations @@ -1094,7 +1102,8 @@ func TestSystemSched_RetryLimit(t *testing.T) { } // Lookup the allocations by JobID - out, err := h.State.AllocsByJob(job.ID, false) + ws := memdb.NewWatchSet() + out, err := h.State.AllocsByJob(ws, job.ID, false) noErr(t, err) // Ensure no allocations placed diff --git a/scheduler/testing.go b/scheduler/testing.go index 08254eae3..74c01c486 100644 --- a/scheduler/testing.go +++ b/scheduler/testing.go @@ -7,6 +7,7 @@ import ( "sync" "testing" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" ) @@ -159,7 +160,8 @@ func (h *Harness) ReblockEval(eval *structs.Evaluation) error { defer h.planLock.Unlock() // Check that the evaluation was already blocked. - old, err := h.State.EvalByID(eval.ID) + ws := memdb.NewWatchSet() + old, err := h.State.EvalByID(ws, eval.ID) if err != nil { return err } diff --git a/scheduler/util.go b/scheduler/util.go index 1ed306b76..f305134cd 100644 --- a/scheduler/util.go +++ b/scheduler/util.go @@ -6,6 +6,7 @@ import ( "math/rand" "reflect" + memdb "github.com/hashicorp/go-memdb" "github.com/hashicorp/nomad/nomad/structs" ) @@ -228,8 +229,9 @@ func readyNodesInDCs(state State, dcs []string) ([]*structs.Node, map[string]int } // Scan the nodes + ws := memdb.NewWatchSet() var out []*structs.Node - iter, err := state.Nodes() + iter, err := state.Nodes(ws) if err != nil { return nil, nil, err } @@ -301,7 +303,8 @@ func taintedNodes(state State, allocs []*structs.Allocation) (map[string]*struct continue } - node, err := state.NodeByID(alloc.NodeID) + ws := memdb.NewWatchSet() + node, err := state.NodeByID(ws, alloc.NodeID) if err != nil { return nil, err } @@ -452,6 +455,7 @@ func setStatus(logger *log.Logger, planner Planner, func inplaceUpdate(ctx Context, eval *structs.Evaluation, job *structs.Job, stack Stack, updates []allocTuple) (destructive, inplace []allocTuple) { + ws := memdb.NewWatchSet() n := len(updates) inplaceCount := 0 for i := 0; i < n; i++ { @@ -471,7 +475,7 @@ func inplaceUpdate(ctx Context, eval *structs.Evaluation, job *structs.Job, } // Get the existing node - node, err := ctx.State().NodeByID(update.Alloc.NodeID) + node, err := ctx.State().NodeByID(ws, update.Alloc.NodeID) if err != nil { ctx.Logger().Printf("[ERR] sched: %#v failed to get node '%s': %v", eval, update.Alloc.NodeID, err) diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go index 8d26fc95f..1f63f769e 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go @@ -2,6 +2,7 @@ package iradix import ( "bytes" + "strings" "github.com/hashicorp/golang-lru/simplelru" ) @@ -11,7 +12,9 @@ const ( // cache used per transaction. This is used to cache the updates // to the nodes near the root, while the leaves do not need to be // cached. This is important for very large transactions to prevent - // the modified cache from growing to be enormous. + // the modified cache from growing to be enormous. This is also used + // to set the max size of the mutation notify maps since those should + // also be bounded in a similar way. defaultModifiedCache = 8192 ) @@ -27,7 +30,11 @@ type Tree struct { // New returns an empty Tree func New() *Tree { - t := &Tree{root: &Node{}} + t := &Tree{ + root: &Node{ + mutateCh: make(chan struct{}), + }, + } return t } @@ -40,75 +47,148 @@ func (t *Tree) Len() int { // atomically and returns a new tree when committed. A transaction // is not thread safe, and should only be used by a single goroutine. type Txn struct { - root *Node - size int - modified *simplelru.LRU + // root is the modified root for the transaction. + root *Node + + // snap is a snapshot of the root node for use if we have to run the + // slow notify algorithm. + snap *Node + + // size tracks the size of the tree as it is modified during the + // transaction. + size int + + // writable is a cache of writable nodes that have been created during + // the course of the transaction. This allows us to re-use the same + // nodes for further writes and avoid unnecessary copies of nodes that + // have never been exposed outside the transaction. This will only hold + // up to defaultModifiedCache number of entries. + writable *simplelru.LRU + + // trackChannels is used to hold channels that need to be notified to + // signal mutation of the tree. This will only hold up to + // defaultModifiedCache number of entries, after which we will set the + // trackOverflow flag, which will cause us to use a more expensive + // algorithm to perform the notifications. Mutation tracking is only + // performed if trackMutate is true. + trackChannels map[*chan struct{}]struct{} + trackOverflow bool + trackMutate bool } // Txn starts a new transaction that can be used to mutate the tree func (t *Tree) Txn() *Txn { txn := &Txn{ root: t.root, + snap: t.root, size: t.size, } return txn } -// writeNode returns a node to be modified, if the current -// node as already been modified during the course of -// the transaction, it is used in-place. -func (t *Txn) writeNode(n *Node) *Node { - // Ensure the modified set exists - if t.modified == nil { +// TrackMutate can be used to toggle if mutations are tracked. If this is enabled +// then notifications will be issued for affected internal nodes and leaves when +// the transaction is committed. +func (t *Txn) TrackMutate(track bool) { + t.trackMutate = track +} + +// trackChannel safely attempts to track the given mutation channel, setting the +// overflow flag if we can no longer track any more. This limits the amount of +// state that will accumulate during a transaction and we have a slower algorithm +// to switch to if we overflow. +func (t *Txn) trackChannel(ch *chan struct{}) { + // In overflow, make sure we don't store any more objects. + if t.trackOverflow { + return + } + + // Create the map on the fly when we need it. + if t.trackChannels == nil { + t.trackChannels = make(map[*chan struct{}]struct{}) + } + + // If this would overflow the state we reject it and set the flag (since + // we aren't tracking everything that's required any longer). + if len(t.trackChannels) >= defaultModifiedCache { + t.trackOverflow = true + return + } + + // Otherwise we are good to track it. + t.trackChannels[ch] = struct{}{} +} + +// writeNode returns a node to be modified, if the current node has already been +// modified during the course of the transaction, it is used in-place. Set +// forLeafUpdate to true if you are getting a write node to update the leaf, +// which will set leaf mutation tracking appropriately as well. +func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node { + // Ensure the writable set exists. + if t.writable == nil { lru, err := simplelru.NewLRU(defaultModifiedCache, nil) if err != nil { panic(err) } - t.modified = lru + t.writable = lru } - // If this node has already been modified, we can - // continue to use it during this transaction. - if _, ok := t.modified.Get(n); ok { + // If this node has already been modified, we can continue to use it + // during this transaction. If a node gets kicked out of cache then we + // *may* notify for its mutation if we end up copying the node again, + // but we don't make any guarantees about notifying for intermediate + // mutations that were never exposed outside of a transaction. + if _, ok := t.writable.Get(n); ok { return n } - // Copy the existing node - nc := new(Node) + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(&(n.mutateCh)) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(&(n.leaf.mutateCh)) + } + + // Copy the existing node. + nc := &Node{ + mutateCh: make(chan struct{}), + leaf: n.leaf, + } if n.prefix != nil { nc.prefix = make([]byte, len(n.prefix)) copy(nc.prefix, n.prefix) } - if n.leaf != nil { - nc.leaf = new(leafNode) - *nc.leaf = *n.leaf - } if len(n.edges) != 0 { nc.edges = make([]edge, len(n.edges)) copy(nc.edges, n.edges) } - // Mark this node as modified - t.modified.Add(nc, nil) + // Mark this node as writable. + t.writable.Add(nc, nil) return nc } // insert does a recursive insertion func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) { - // Handle key exhaution + // Handle key exhaustion if len(search) == 0 { - nc := t.writeNode(n) + var oldVal interface{} + didUpdate := false if n.isLeaf() { - old := nc.leaf.val - nc.leaf.val = v - return nc, old, true - } else { - nc.leaf = &leafNode{ - key: k, - val: v, - } - return nc, nil, false + oldVal = n.leaf.val + didUpdate = true } + + nc := t.writeNode(n, true) + nc.leaf = &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + } + return nc, oldVal, didUpdate } // Look for the edge @@ -119,14 +199,16 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface e := edge{ label: search[0], node: &Node{ + mutateCh: make(chan struct{}), leaf: &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, }, prefix: search, }, } - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.addEdge(e) return nc, nil, false } @@ -137,7 +219,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface search = search[commonPrefix:] newChild, oldVal, didUpdate := t.insert(child, k, search, v) if newChild != nil { - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.edges[idx].node = newChild return nc, oldVal, didUpdate } @@ -145,9 +227,10 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface } // Split the node - nc := t.writeNode(n) + nc := t.writeNode(n, false) splitNode := &Node{ - prefix: search[:commonPrefix], + mutateCh: make(chan struct{}), + prefix: search[:commonPrefix], } nc.replaceEdge(edge{ label: search[0], @@ -155,7 +238,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface }) // Restore the existing child node - modChild := t.writeNode(child) + modChild := t.writeNode(child, false) splitNode.addEdge(edge{ label: modChild.prefix[commonPrefix], node: modChild, @@ -164,8 +247,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // Create a new leaf node leaf := &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, } // If the new key is a subset, add to to this node @@ -179,8 +263,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface splitNode.addEdge(edge{ label: search[0], node: &Node{ - leaf: leaf, - prefix: search, + mutateCh: make(chan struct{}), + leaf: leaf, + prefix: search, }, }) return nc, nil, false @@ -188,14 +273,14 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // delete does a recursive deletion func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if !n.isLeaf() { return nil, nil } // Remove the leaf node - nc := t.writeNode(n) + nc := t.writeNode(n, true) nc.leaf = nil // Check if this node should be merged @@ -219,8 +304,11 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { return nil, nil } - // Copy this node - nc := t.writeNode(n) + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChilde() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + nc := t.writeNode(n, false) // Delete the edge if the node has no edges if newChild.leaf == nil && len(newChild.edges) == 0 { @@ -274,10 +362,109 @@ func (t *Txn) Get(k []byte) (interface{}, bool) { return t.root.Get(k) } -// Commit is used to finalize the transaction and return a new tree +// GetWatch is used to lookup a specific key, returning +// the watch channel, value and if it was found +func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { + return t.root.GetWatch(k) +} + +// Commit is used to finalize the transaction and return a new tree. If mutation +// tracking is turned on then notifications will also be issued. func (t *Txn) Commit() *Tree { - t.modified = nil - return &Tree{t.root, t.size} + nt := t.commit() + if t.trackMutate { + t.notify() + } + return nt +} + +// commit is an internal helper for Commit(), useful for unit tests. +func (t *Txn) commit() *Tree { + nt := &Tree{t.root, t.size} + t.writable = nil + return nt +} + +// slowNotify does a complete comparison of the before and after trees in order +// to trigger notifications. This doesn't require any additional state but it +// is very expensive to compute. +func (t *Txn) slowNotify() { + snapIter := t.snap.rawIterator() + rootIter := t.root.rawIterator() + for snapIter.Front() != nil || rootIter.Front() != nil { + // If we've exhausted the nodes in the old snapshot, we know + // there's nothing remaining to notify. + if snapIter.Front() == nil { + return + } + snapElem := snapIter.Front() + + // If we've exhausted the nodes in the new root, we know we need + // to invalidate everything that remains in the old snapshot. We + // know from the loop condition there's something in the old + // snapshot. + if rootIter.Front() == nil { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // Do one string compare so we can check the various conditions + // below without repeating the compare. + cmp := strings.Compare(snapIter.Path(), rootIter.Path()) + + // If the snapshot is behind the root, then we must have deleted + // this node during the transaction. + if cmp < 0 { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // If the snapshot is ahead of the root, then we must have added + // this node during the transaction. + if cmp > 0 { + rootIter.Next() + continue + } + + // If we have the same path, then we need to see if we mutated a + // node and possibly the leaf. + rootElem := rootIter.Front() + if snapElem != rootElem { + close(snapElem.mutateCh) + if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) { + close(snapElem.leaf.mutateCh) + } + } + snapIter.Next() + rootIter.Next() + } +} + +// notify is used along with TrackMutate to trigger notifications. This should +// only be done once a transaction is committed. +func (t *Txn) notify() { + // If we've overflowed the tracking state we can't use it in any way and + // need to do a full tree compare. + if t.trackOverflow { + t.slowNotify() + } else { + for ch := range t.trackChannels { + close(*ch) + } + } + + // Clean up the tracking state so that a re-notify is safe (will trigger + // the else clause above which will be a no-op). + t.trackChannels = nil + t.trackOverflow = false } // Insert is used to add or update a given key. The return provides diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iter.go b/vendor/github.com/hashicorp/go-immutable-radix/iter.go index 75cbaa110..9815e0253 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iter.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iter.go @@ -9,11 +9,13 @@ type Iterator struct { stack []edges } -// SeekPrefix is used to seek the iterator to a given prefix -func (i *Iterator) SeekPrefix(prefix []byte) { +// SeekPrefixWatch is used to seek the iterator to a given prefix +// and returns the watch channel of the finest granularity +func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) { // Wipe the stack i.stack = nil n := i.node + watch = n.mutateCh search := prefix for { // Check for key exhaution @@ -29,6 +31,9 @@ func (i *Iterator) SeekPrefix(prefix []byte) { return } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -43,6 +48,11 @@ func (i *Iterator) SeekPrefix(prefix []byte) { } } +// SeekPrefix is used to seek the iterator to a given prefix +func (i *Iterator) SeekPrefix(prefix []byte) { + i.SeekPrefixWatch(prefix) +} + // Next returns the next node in order func (i *Iterator) Next() ([]byte, interface{}, bool) { // Initialize our stack if needed diff --git a/vendor/github.com/hashicorp/go-immutable-radix/node.go b/vendor/github.com/hashicorp/go-immutable-radix/node.go index fea6f6343..cf7137f93 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/node.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/node.go @@ -12,8 +12,9 @@ type WalkFn func(k []byte, v interface{}) bool // leafNode is used to represent a value type leafNode struct { - key []byte - val interface{} + mutateCh chan struct{} + key []byte + val interface{} } // edge is used to represent an edge node @@ -24,6 +25,9 @@ type edge struct { // Node is an immutable node in the radix tree type Node struct { + // mutateCh is closed if this node is modified + mutateCh chan struct{} + // leaf is used to store possible leaf leaf *leafNode @@ -105,13 +109,14 @@ func (n *Node) mergeChild() { } } -func (n *Node) Get(k []byte) (interface{}, bool) { +func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { search := k + watch := n.mutateCh for { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if n.isLeaf() { - return n.leaf.val, true + return n.leaf.mutateCh, n.leaf.val, true } break } @@ -122,6 +127,9 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -129,7 +137,12 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } } - return nil, false + return watch, nil, false +} + +func (n *Node) Get(k []byte) (interface{}, bool) { + _, val, ok := n.GetWatch(k) + return val, ok } // LongestPrefix is like Get, but instead of an @@ -204,6 +217,14 @@ func (n *Node) Iterator() *Iterator { return &Iterator{node: n} } +// rawIterator is used to return a raw iterator at the given node to walk the +// tree. +func (n *Node) rawIterator() *rawIterator { + iter := &rawIterator{node: n} + iter.Next() + return iter +} + // Walk is used to walk the tree func (n *Node) Walk(fn WalkFn) { recursiveWalk(n, fn) diff --git a/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go new file mode 100644 index 000000000..04814c132 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go @@ -0,0 +1,78 @@ +package iradix + +// rawIterator visits each of the nodes in the tree, even the ones that are not +// leaves. It keeps track of the effective path (what a leaf at a given node +// would be called), which is useful for comparing trees. +type rawIterator struct { + // node is the starting node in the tree for the iterator. + node *Node + + // stack keeps track of edges in the frontier. + stack []rawStackEntry + + // pos is the current position of the iterator. + pos *Node + + // path is the effective path of the current iterator position, + // regardless of whether the current node is a leaf. + path string +} + +// rawStackEntry is used to keep track of the cumulative common path as well as +// its associated edges in the frontier. +type rawStackEntry struct { + path string + edges edges +} + +// Front returns the current node that has been iterated to. +func (i *rawIterator) Front() *Node { + return i.pos +} + +// Path returns the effective path of the current node, even if it's not actually +// a leaf. +func (i *rawIterator) Path() string { + return i.path +} + +// Next advances the iterator to the next node. +func (i *rawIterator) Next() { + // Initialize our stack if needed. + if i.stack == nil && i.node != nil { + i.stack = []rawStackEntry{ + rawStackEntry{ + edges: edges{ + edge{node: i.node}, + }, + }, + } + } + + for len(i.stack) > 0 { + // Inspect the last element of the stack. + n := len(i.stack) + last := i.stack[n-1] + elem := last.edges[0].node + + // Update the stack. + if len(last.edges) > 1 { + i.stack[n-1].edges = last.edges[1:] + } else { + i.stack = i.stack[:n-1] + } + + // Push the edges onto the frontier. + if len(elem.edges) > 0 { + path := last.path + string(elem.prefix) + i.stack = append(i.stack, rawStackEntry{path, elem.edges}) + } + + i.pos = elem + i.path = last.path + string(elem.prefix) + return + } + + i.pos = nil + i.path = "" +} diff --git a/vendor/github.com/hashicorp/go-memdb/README.md b/vendor/github.com/hashicorp/go-memdb/README.md index 203a0af14..675044beb 100644 --- a/vendor/github.com/hashicorp/go-memdb/README.md +++ b/vendor/github.com/hashicorp/go-memdb/README.md @@ -19,7 +19,7 @@ The database provides the following: * Rich Indexing - Tables can support any number of indexes, which can be simple like a single field index, or more advanced compound field indexes. Certain types like - UUID can be efficiently compressed from strings into byte indexes for reduces + UUID can be efficiently compressed from strings into byte indexes for reduced storage requirements. For the underlying immutable radix trees, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix). diff --git a/vendor/github.com/hashicorp/go-memdb/index.go b/vendor/github.com/hashicorp/go-memdb/index.go index 7237f33e2..17aa02699 100644 --- a/vendor/github.com/hashicorp/go-memdb/index.go +++ b/vendor/github.com/hashicorp/go-memdb/index.go @@ -9,15 +9,27 @@ import ( // Indexer is an interface used for defining indexes type Indexer interface { - // FromObject is used to extract an index value from an - // object or to indicate that the index value is missing. - FromObject(raw interface{}) (bool, []byte, error) - // ExactFromArgs is used to build an exact index lookup // based on arguments FromArgs(args ...interface{}) ([]byte, error) } +// SingleIndexer is an interface used for defining indexes +// generating a single entry per object +type SingleIndexer interface { + // FromObject is used to extract an index value from an + // object or to indicate that the index value is missing. + FromObject(raw interface{}) (bool, []byte, error) +} + +// MultiIndexer is an interface used for defining indexes +// generating multiple entries per object +type MultiIndexer interface { + // FromObject is used to extract index values from an + // object or to indicate that the index value is missing. + FromObject(raw interface{}) (bool, [][]byte, error) +} + // PrefixIndexer can optionally be implemented for any // indexes that support prefix based iteration. This may // not apply to all indexes. @@ -88,6 +100,155 @@ func (s *StringFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { return val, nil } +// StringSliceFieldIndex is used to extract a field from an object +// using reflection and builds an index on that field. +type StringSliceFieldIndex struct { + Field string + Lowercase bool +} + +func (s *StringSliceFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + if fv.Kind() != reflect.Slice || fv.Type().Elem().Kind() != reflect.String { + return false, nil, fmt.Errorf("field '%s' is not a string slice", s.Field) + } + + length := fv.Len() + vals := make([][]byte, 0, length) + for i := 0; i < fv.Len(); i++ { + val := fv.Index(i).String() + if val == "" { + continue + } + + if s.Lowercase { + val = strings.ToLower(val) + } + + // Add the null character as a terminator + val += "\x00" + vals = append(vals, []byte(val)) + } + if len(vals) == 0 { + return false, nil, nil + } + return true, vals, nil +} + +func (s *StringSliceFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + arg, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + arg = strings.ToLower(arg) + } + // Add the null character as a terminator + arg += "\x00" + return []byte(arg), nil +} + +func (s *StringSliceFieldIndex) PrefixFromArgs(args ...interface{}) ([]byte, error) { + val, err := s.FromArgs(args...) + if err != nil { + return nil, err + } + + // Strip the null terminator, the rest is a prefix + n := len(val) + if n > 0 { + return val[:n-1], nil + } + return val, nil +} + +// StringMapFieldIndex is used to extract a field of type map[string]string +// from an object using reflection and builds an index on that field. +type StringMapFieldIndex struct { + Field string + Lowercase bool +} + +var MapType = reflect.MapOf(reflect.TypeOf(""), reflect.TypeOf("")).Kind() + +func (s *StringMapFieldIndex) FromObject(obj interface{}) (bool, [][]byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(s.Field) + if !fv.IsValid() { + return false, nil, fmt.Errorf("field '%s' for %#v is invalid", s.Field, obj) + } + + if fv.Kind() != MapType { + return false, nil, fmt.Errorf("field '%s' is not a map[string]string", s.Field) + } + + length := fv.Len() + vals := make([][]byte, 0, length) + for _, key := range fv.MapKeys() { + k := key.String() + if k == "" { + continue + } + val := fv.MapIndex(key).String() + + if s.Lowercase { + k = strings.ToLower(k) + val = strings.ToLower(val) + } + + // Add the null character as a terminator + k += "\x00" + val + "\x00" + + vals = append(vals, []byte(k)) + } + if len(vals) == 0 { + return false, nil, nil + } + return true, vals, nil +} + +func (s *StringMapFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) > 2 || len(args) == 0 { + return nil, fmt.Errorf("must provide one or two arguments") + } + key, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[0]) + } + if s.Lowercase { + key = strings.ToLower(key) + } + // Add the null character as a terminator + key += "\x00" + + if len(args) == 2 { + val, ok := args[1].(string) + if !ok { + return nil, fmt.Errorf("argument must be a string: %#v", args[1]) + } + if s.Lowercase { + val = strings.ToLower(val) + } + // Add the null character as a terminator + key += val + "\x00" + } + + return []byte(key), nil +} + // UUIDFieldIndex is used to extract a field from an object // using reflection and builds an index on that field by treating // it as a UUID. This is an optimization to using a StringFieldIndex @@ -270,7 +431,11 @@ type CompoundIndex struct { func (c *CompoundIndex) FromObject(raw interface{}) (bool, []byte, error) { var out []byte - for i, idx := range c.Indexes { + for i, idxRaw := range c.Indexes { + idx, ok := idxRaw.(SingleIndexer) + if !ok { + return false, nil, fmt.Errorf("sub-index %d error: %s", i, "sub-index must be a SingleIndexer") + } ok, val, err := idx.FromObject(raw) if err != nil { return false, nil, fmt.Errorf("sub-index %d error: %v", i, err) diff --git a/vendor/github.com/hashicorp/go-memdb/memdb.go b/vendor/github.com/hashicorp/go-memdb/memdb.go index 1d708517d..13817547b 100644 --- a/vendor/github.com/hashicorp/go-memdb/memdb.go +++ b/vendor/github.com/hashicorp/go-memdb/memdb.go @@ -15,6 +15,7 @@ import ( type MemDB struct { schema *DBSchema root unsafe.Pointer // *iradix.Tree underneath + primary bool // There can only be a single writter at once writer sync.Mutex @@ -31,6 +32,7 @@ func NewMemDB(schema *DBSchema) (*MemDB, error) { db := &MemDB{ schema: schema, root: unsafe.Pointer(iradix.New()), + primary: true, } if err := db.initialize(); err != nil { return nil, err @@ -65,6 +67,7 @@ func (db *MemDB) Snapshot() *MemDB { clone := &MemDB{ schema: db.schema, root: unsafe.Pointer(db.getRoot()), + primary: false, } return clone } diff --git a/vendor/github.com/hashicorp/go-memdb/schema.go b/vendor/github.com/hashicorp/go-memdb/schema.go index 2b8ffb476..d7210f91c 100644 --- a/vendor/github.com/hashicorp/go-memdb/schema.go +++ b/vendor/github.com/hashicorp/go-memdb/schema.go @@ -38,7 +38,7 @@ func (s *TableSchema) Validate() error { return fmt.Errorf("missing table name") } if len(s.Indexes) == 0 { - return fmt.Errorf("missing table schemas for '%s'", s.Name) + return fmt.Errorf("missing table indexes for '%s'", s.Name) } if _, ok := s.Indexes["id"]; !ok { return fmt.Errorf("must have id index") @@ -46,6 +46,9 @@ func (s *TableSchema) Validate() error { if !s.Indexes["id"].Unique { return fmt.Errorf("id index must be unique") } + if _, ok := s.Indexes["id"].Indexer.(SingleIndexer); !ok { + return fmt.Errorf("id index must be a SingleIndexer") + } for name, index := range s.Indexes { if name != index.Name { return fmt.Errorf("index name mis-match for '%s'", name) @@ -72,5 +75,11 @@ func (s *IndexSchema) Validate() error { if s.Indexer == nil { return fmt.Errorf("missing index function for '%s'", s.Name) } + switch s.Indexer.(type) { + case SingleIndexer: + case MultiIndexer: + default: + return fmt.Errorf("indexer for '%s' must be a SingleIndexer or MultiIndexer", s.Name) + } return nil } diff --git a/vendor/github.com/hashicorp/go-memdb/txn.go b/vendor/github.com/hashicorp/go-memdb/txn.go index 6228677da..a069a9fd9 100644 --- a/vendor/github.com/hashicorp/go-memdb/txn.go +++ b/vendor/github.com/hashicorp/go-memdb/txn.go @@ -70,6 +70,11 @@ func (txn *Txn) writableIndex(table, index string) *iradix.Txn { raw, _ := txn.rootTxn.Get(path) indexTxn := raw.(*iradix.Tree).Txn() + // If we are the primary DB, enable mutation tracking. Snapshots should + // not notify, otherwise we will trigger watches on the primary DB when + // the writes will not be visible. + indexTxn.TrackMutate(txn.db.primary) + // Keep this open for the duration of the txn txn.modified[key] = indexTxn return indexTxn @@ -148,7 +153,8 @@ func (txn *Txn) Insert(table string, obj interface{}) error { // Get the primary ID of the object idSchema := tableSchema.Indexes[id] - ok, idVal, err := idSchema.Indexer.FromObject(obj) + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(obj) if err != nil { return fmt.Errorf("failed to build primary index: %v", err) } @@ -167,7 +173,19 @@ func (txn *Txn) Insert(table string, obj interface{}) error { indexTxn := txn.writableIndex(table, name) // Determine the new index value - ok, val, err := indexSchema.Indexer.FromObject(obj) + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(obj) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(obj) + } if err != nil { return fmt.Errorf("failed to build index '%s': %v", name, err) } @@ -176,28 +194,44 @@ func (txn *Txn) Insert(table string, obj interface{}) error { // This is done by appending the primary key which must // be unique anyways. if ok && !indexSchema.Unique { - val = append(val, idVal...) + for i := range vals { + vals[i] = append(vals[i], idVal...) + } } // Handle the update by deleting from the index first if update { - okExist, valExist, err := indexSchema.Indexer.FromObject(existing) + var ( + okExist bool + valsExist [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var valExist []byte + okExist, valExist, err = indexer.FromObject(existing) + valsExist = [][]byte{valExist} + case MultiIndexer: + okExist, valsExist, err = indexer.FromObject(existing) + } if err != nil { return fmt.Errorf("failed to build index '%s': %v", name, err) } if okExist { - // Handle non-unique index by computing a unique index. - // This is done by appending the primary key which must - // be unique anyways. - if !indexSchema.Unique { - valExist = append(valExist, idVal...) - } + for i, valExist := range valsExist { + // Handle non-unique index by computing a unique index. + // This is done by appending the primary key which must + // be unique anyways. + if !indexSchema.Unique { + valExist = append(valExist, idVal...) + } - // If we are writing to the same index with the same value, - // we can avoid the delete as the insert will overwrite the - // value anyways. - if !bytes.Equal(valExist, val) { - indexTxn.Delete(valExist) + // If we are writing to the same index with the same value, + // we can avoid the delete as the insert will overwrite the + // value anyways. + if i >= len(vals) || !bytes.Equal(valExist, vals[i]) { + indexTxn.Delete(valExist) + } } } } @@ -213,7 +247,9 @@ func (txn *Txn) Insert(table string, obj interface{}) error { } // Update the value of the index - indexTxn.Insert(val, obj) + for _, val := range vals { + indexTxn.Insert(val, obj) + } } return nil } @@ -233,7 +269,8 @@ func (txn *Txn) Delete(table string, obj interface{}) error { // Get the primary ID of the object idSchema := tableSchema.Indexes[id] - ok, idVal, err := idSchema.Indexer.FromObject(obj) + idIndexer := idSchema.Indexer.(SingleIndexer) + ok, idVal, err := idIndexer.FromObject(obj) if err != nil { return fmt.Errorf("failed to build primary index: %v", err) } @@ -253,7 +290,19 @@ func (txn *Txn) Delete(table string, obj interface{}) error { indexTxn := txn.writableIndex(table, name) // Handle the update by deleting from the index first - ok, val, err := indexSchema.Indexer.FromObject(existing) + var ( + ok bool + vals [][]byte + err error + ) + switch indexer := indexSchema.Indexer.(type) { + case SingleIndexer: + var val []byte + ok, val, err = indexer.FromObject(existing) + vals = [][]byte{val} + case MultiIndexer: + ok, vals, err = indexer.FromObject(existing) + } if err != nil { return fmt.Errorf("failed to build index '%s': %v", name, err) } @@ -261,10 +310,12 @@ func (txn *Txn) Delete(table string, obj interface{}) error { // Handle non-unique index by computing a unique index. // This is done by appending the primary key which must // be unique anyways. - if !indexSchema.Unique { - val = append(val, idVal...) + for _, val := range vals { + if !indexSchema.Unique { + val = append(val, idVal...) + } + indexTxn.Delete(val) } - indexTxn.Delete(val) } } return nil @@ -306,13 +357,13 @@ func (txn *Txn) DeleteAll(table, index string, args ...interface{}) (int, error) return num, nil } -// First is used to return the first matching object for -// the given constraints on the index -func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { +// FirstWatch is used to return the first matching object for +// the given constraints on the index along with the watch channel +func (txn *Txn) FirstWatch(table, index string, args ...interface{}) (<-chan struct{}, interface{}, error) { // Get the index value indexSchema, val, err := txn.getIndexValue(table, index, args...) if err != nil { - return nil, err + return nil, nil, err } // Get the index itself @@ -320,18 +371,25 @@ func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, er // Do an exact lookup if indexSchema.Unique && val != nil && indexSchema.Name == index { - obj, ok := indexTxn.Get(val) + watch, obj, ok := indexTxn.GetWatch(val) if !ok { - return nil, nil + return watch, nil, nil } - return obj, nil + return watch, obj, nil } // Handle non-unique index by using an iterator and getting the first value iter := indexTxn.Root().Iterator() - iter.SeekPrefix(val) + watch := iter.SeekPrefixWatch(val) _, value, _ := iter.Next() - return value, nil + return watch, value, nil +} + +// First is used to return the first matching object for +// the given constraints on the index +func (txn *Txn) First(table, index string, args ...interface{}) (interface{}, error) { + _, val, err := txn.FirstWatch(table, index, args...) + return val, err } // LongestPrefix is used to fetch the longest prefix match for the given @@ -422,6 +480,7 @@ func (txn *Txn) getIndexValue(table, index string, args ...interface{}) (*IndexS // ResultIterator is used to iterate over a list of results // from a Get query on a table. type ResultIterator interface { + WatchCh() <-chan struct{} Next() interface{} } @@ -442,11 +501,12 @@ func (txn *Txn) Get(table, index string, args ...interface{}) (ResultIterator, e indexIter := indexRoot.Iterator() // Seek the iterator to the appropriate sub-set - indexIter.SeekPrefix(val) + watchCh := indexIter.SeekPrefixWatch(val) // Create an iterator iter := &radixIterator{ - iter: indexIter, + iter: indexIter, + watchCh: watchCh, } return iter, nil } @@ -460,10 +520,15 @@ func (txn *Txn) Defer(fn func()) { } // radixIterator is used to wrap an underlying iradix iterator. -// This is much mroe efficient than a sliceIterator as we are not +// This is much more efficient than a sliceIterator as we are not // materializing the entire view. type radixIterator struct { - iter *iradix.Iterator + iter *iradix.Iterator + watchCh <-chan struct{} +} + +func (r *radixIterator) WatchCh() <-chan struct{} { + return r.watchCh } func (r *radixIterator) Next() interface{} { diff --git a/vendor/github.com/hashicorp/go-memdb/watch.go b/vendor/github.com/hashicorp/go-memdb/watch.go new file mode 100644 index 000000000..7c4a3ba6e --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch.go @@ -0,0 +1,108 @@ +package memdb + +import "time" + +// WatchSet is a collection of watch channels. +type WatchSet map[<-chan struct{}]struct{} + +// NewWatchSet constructs a new watch set. +func NewWatchSet() WatchSet { + return make(map[<-chan struct{}]struct{}) +} + +// Add appends a watchCh to the WatchSet if non-nil. +func (w WatchSet) Add(watchCh <-chan struct{}) { + if w == nil { + return + } + + if _, ok := w[watchCh]; !ok { + w[watchCh] = struct{}{} + } +} + +// AddWithLimit appends a watchCh to the WatchSet if non-nil, and if the given +// softLimit hasn't been exceeded. Otherwise, it will watch the given alternate +// channel. It's expected that the altCh will be the same on many calls to this +// function, so you will exceed the soft limit a little bit if you hit this, but +// not by much. +// +// This is useful if you want to track individual items up to some limit, after +// which you watch a higher-level channel (usually a channel from start start of +// an iterator higher up in the radix tree) that will watch a superset of items. +func (w WatchSet) AddWithLimit(softLimit int, watchCh <-chan struct{}, altCh <-chan struct{}) { + // This is safe for a nil WatchSet so we don't need to check that here. + if len(w) < softLimit { + w.Add(watchCh) + } else { + w.Add(altCh) + } +} + +// Watch is used to wait for either the watch set to trigger or a timeout. +// Returns true on timeout. +func (w WatchSet) Watch(timeoutCh <-chan time.Time) bool { + if w == nil { + return false + } + + if n := len(w); n <= aFew { + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + chunk[idx] = watchCh + idx++ + } + return watchFew(chunk, timeoutCh) + } else { + return w.watchMany(timeoutCh) + } +} + +// watchMany is used if there are many watchers. +func (w WatchSet) watchMany(timeoutCh <-chan time.Time) bool { + // Make a fake timeout channel we can feed into watchFew to cancel all + // the blocking goroutines. + doneCh := make(chan time.Time) + defer close(doneCh) + + // Set up a goroutine for each watcher. + triggerCh := make(chan struct{}, 1) + watcher := func(chunk []<-chan struct{}) { + if timeout := watchFew(chunk, doneCh); !timeout { + select { + case triggerCh <- struct{}{}: + default: + } + } + } + + // Apportion the watch channels into chunks we can feed into the + // watchFew helper. + idx := 0 + chunk := make([]<-chan struct{}, aFew) + for watchCh := range w { + subIdx := idx % aFew + chunk[subIdx] = watchCh + idx++ + + // Fire off this chunk and start a fresh one. + if idx%aFew == 0 { + go watcher(chunk) + chunk = make([]<-chan struct{}, aFew) + } + } + + // Make sure to watch any residual channels in the last chunk. + if idx%aFew != 0 { + go watcher(chunk) + } + + // Wait for a channel to trigger or timeout. + select { + case <-triggerCh: + return false + case <-timeoutCh: + return true + } +} diff --git a/vendor/github.com/hashicorp/go-memdb/watch_few.go b/vendor/github.com/hashicorp/go-memdb/watch_few.go new file mode 100644 index 000000000..f2bb19db1 --- /dev/null +++ b/vendor/github.com/hashicorp/go-memdb/watch_few.go @@ -0,0 +1,116 @@ +//go:generate sh -c "go run watch-gen/main.go >watch_few.go" +package memdb + +import( + "time" +) + +// aFew gives how many watchers this function is wired to support. You must +// always pass a full slice of this length, but unused channels can be nil. +const aFew = 32 + +// watchFew is used if there are only a few watchers as a performance +// optimization. +func watchFew(ch []<-chan struct{}, timeoutCh <-chan time.Time) bool { + select { + + case <-ch[0]: + return false + + case <-ch[1]: + return false + + case <-ch[2]: + return false + + case <-ch[3]: + return false + + case <-ch[4]: + return false + + case <-ch[5]: + return false + + case <-ch[6]: + return false + + case <-ch[7]: + return false + + case <-ch[8]: + return false + + case <-ch[9]: + return false + + case <-ch[10]: + return false + + case <-ch[11]: + return false + + case <-ch[12]: + return false + + case <-ch[13]: + return false + + case <-ch[14]: + return false + + case <-ch[15]: + return false + + case <-ch[16]: + return false + + case <-ch[17]: + return false + + case <-ch[18]: + return false + + case <-ch[19]: + return false + + case <-ch[20]: + return false + + case <-ch[21]: + return false + + case <-ch[22]: + return false + + case <-ch[23]: + return false + + case <-ch[24]: + return false + + case <-ch[25]: + return false + + case <-ch[26]: + return false + + case <-ch[27]: + return false + + case <-ch[28]: + return false + + case <-ch[29]: + return false + + case <-ch[30]: + return false + + case <-ch[31]: + return false + + case <-timeoutCh: + return true + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json index fbca2b391..bdc8d1cce 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -672,16 +672,16 @@ "revision": "3142ddc1d627a166970ddd301bc09cb510c74edc" }, { - "checksumSHA1": "qmE9mO0WW6ALLpUU81rXDyspP5M=", + "checksumSHA1": "jPxyofQxI1PRPq6LPc6VlcRn5fI=", "path": "github.com/hashicorp/go-immutable-radix", - "revision": "afc5a0dbb18abdf82c277a7bc01533e81fa1d6b8", - "revisionTime": "2016-06-09T02:05:29Z" + "revision": "76b5f4e390910df355bfb9b16b41899538594a05", + "revisionTime": "2017-01-13T02:29:29Z" }, { - "checksumSHA1": "/V57CyN7x2NUlHoOzVL5GgGXX84=", + "checksumSHA1": "K8Fsgt1llTXP0EwqdBzvSGdKOKc=", "path": "github.com/hashicorp/go-memdb", - "revision": "98f52f52d7a476958fa9da671354d270c50661a7", - "revisionTime": "2016-03-01T23:01:42Z" + "revision": "c01f56b44823e8ba697e23c18d12dca984b85aca", + "revisionTime": "2017-01-23T15:32:28Z" }, { "path": "github.com/hashicorp/go-msgpack/codec",