diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index a88d24c0a..dd610a6d2 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -116,16 +116,19 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { + snap, err := state.Snapshot() + if err != nil { + return err + } // Query all volumes - var err error var iter memdb.ResultIterator if args.NodeID != "" { - iter, err = state.CSIVolumesByNodeID(ws, args.NodeID) + iter, err = snap.CSIVolumesByNodeID(ws, args.NodeID) } else if args.PluginID != "" { - iter, err = state.CSIVolumesByPluginID(ws, ns, args.PluginID) + iter, err = snap.CSIVolumesByPluginID(ws, ns, args.PluginID) } else { - iter, err = state.CSIVolumesByNamespace(ws, ns) + iter, err = snap.CSIVolumesByNamespace(ws, ns) } if err != nil { @@ -140,23 +143,25 @@ func (v *CSIVolume) List(args *structs.CSIVolumeListRequest, reply *structs.CSIV if raw == nil { break } - vol := raw.(*structs.CSIVolume) - vol, err := state.CSIVolumeDenormalizePlugins(ws, vol.Copy()) - if err != nil { - return err - } - // Remove (possibly again) by PluginID to handle passing both NodeID and PluginID + // Remove (possibly again) by PluginID to handle passing both + // NodeID and PluginID if args.PluginID != "" && args.PluginID != vol.PluginID { continue } - // Remove by Namespace, since CSIVolumesByNodeID hasn't used the Namespace yet + // Remove by Namespace, since CSIVolumesByNodeID hasn't used + // the Namespace yet if vol.Namespace != ns { continue } + vol, err := snap.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + if err != nil { + return err + } + vs = append(vs, vol.Stub()) } reply.Volumes = vs @@ -195,12 +200,17 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { - vol, err := state.CSIVolumeByID(ws, ns, args.ID) + snap, err := state.Snapshot() + if err != nil { + return err + } + + vol, err := snap.CSIVolumeByID(ws, ns, args.ID) if err != nil { return err } if vol != nil { - vol, err = state.CSIVolumeDenormalize(ws, vol) + vol, err = snap.CSIVolumeDenormalize(ws, vol) } if err != nil { return err @@ -214,9 +224,8 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) { state := v.srv.fsm.State() - ws := memdb.NewWatchSet() - plugin, err := state.CSIPluginByID(ws, vol.PluginID) + plugin, err := state.CSIPluginByID(nil, vol.PluginID) if err != nil { return nil, err } @@ -481,9 +490,7 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlugin, *structs.CSIVolume, error) { state := v.srv.fsm.State() - ws := memdb.NewWatchSet() - - vol, err := state.CSIVolumeByID(ws, namespace, volID) + vol, err := state.CSIVolumeByID(nil, namespace, volID) if err != nil { return nil, nil, err } @@ -497,7 +504,7 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu // note: we do this same lookup in CSIVolumeByID but then throw // away the pointer to the plugin rather than attaching it to // the volume so we have to do it again here. - plug, err := state.CSIPluginByID(ws, vol.PluginID) + plug, err := state.CSIPluginByID(nil, vol.PluginID) if err != nil { return nil, nil, err } @@ -870,7 +877,12 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu queryOpts: &args.QueryOptions, queryMeta: &reply.QueryMeta, run: func(ws memdb.WatchSet, state *state.StateStore) error { - plug, err := state.CSIPluginByID(ws, args.ID) + snap, err := state.Snapshot() + if err != nil { + return err + } + + plug, err := snap.CSIPluginByID(ws, args.ID) if err != nil { return err } @@ -880,7 +892,7 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu } if withAllocs { - plug, err = state.CSIPluginDenormalize(ws, plug.Copy()) + plug, err = snap.CSIPluginDenormalize(ws, plug.Copy()) if err != nil { return err } diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 977eec5bf..9e5059407 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -72,7 +72,7 @@ type StateStore struct { // abandoned (usually during a restore). This is only ever closed. abandonCh chan struct{} - // TODO: refactor abondonCh to use a context so that both can use the same + // TODO: refactor abandonCh to use a context so that both can use the same // cancel mechanism. stopEventBroker func() } @@ -1272,7 +1272,7 @@ func deleteNodeCSIPlugins(txn *txn, node *structs.Node, index uint64) error { } // updateOrGCPlugin updates a plugin but will delete it if the plugin is empty -func updateOrGCPlugin(index uint64, txn *txn, plug *structs.CSIPlugin) error { +func updateOrGCPlugin(index uint64, txn Txn, plug *structs.CSIPlugin) error { plug.ModifyIndex = index if plug.IsEmpty() { @@ -1291,7 +1291,7 @@ func updateOrGCPlugin(index uint64, txn *txn, plug *structs.CSIPlugin) error { // deleteJobFromPlugins removes the allocations of this job from any plugins the job is // running, possibly deleting the plugin if it's no longer in use. It's called in DeleteJobTxn -func (s *StateStore) deleteJobFromPlugins(index uint64, txn *txn, job *structs.Job) error { +func (s *StateStore) deleteJobFromPlugins(index uint64, txn Txn, job *structs.Job) error { ws := memdb.NewWatchSet() summary, err := s.JobSummaryByID(ws, job.Namespace, job.ID) if err != nil { @@ -1348,7 +1348,7 @@ func (s *StateStore) deleteJobFromPlugins(index uint64, txn *txn, job *structs.J plug, ok := plugins[x.pluginID] if !ok { - plug, err = s.CSIPluginByID(ws, x.pluginID) + plug, err = s.CSIPluginByIDTxn(txn, nil, x.pluginID) if err != nil { return fmt.Errorf("error getting plugin: %s, %v", x.pluginID, err) } @@ -1826,22 +1826,20 @@ func (s *StateStore) JobsByIDPrefix(ws memdb.WatchSet, namespace, id string) (me func (s *StateStore) JobVersionsByID(ws memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { txn := s.db.ReadTxn() - return s.jobVersionByID(txn, &ws, namespace, id) + return s.jobVersionByID(txn, ws, namespace, id) } // jobVersionByID is the underlying implementation for retrieving all tracked // versions of a job and is called under an existing transaction. A watch set // can optionally be passed in to add the job histories to the watch set. -func (s *StateStore) jobVersionByID(txn *txn, ws *memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { +func (s *StateStore) jobVersionByID(txn *txn, ws memdb.WatchSet, namespace, id string) ([]*structs.Job, error) { // Get all the historic jobs for this ID iter, err := txn.Get("job_version", "id_prefix", namespace, id) if err != nil { return nil, err } - if ws != nil { - ws.Add(iter.WatchCh()) - } + ws.Add(iter.WatchCh()) var all []*structs.Job for { @@ -1884,9 +1882,7 @@ func (s *StateStore) jobByIDAndVersionImpl(ws memdb.WatchSet, namespace, id stri return nil, err } - if ws != nil { - ws.Add(watchCh) - } + ws.Add(watchCh) if existing != nil { job := existing.(*structs.Job) @@ -2096,7 +2092,8 @@ func (s *StateStore) CSIVolumeRegister(index uint64, volumes []*structs.CSIVolum return txn.Commit() } -// CSIVolumes returns the unfiltered list of all volumes +// CSIVolumes returns the unfiltered list of all volumes. Caller should +// snapshot if it wants to also denormalize the plugins. func (s *StateStore) CSIVolumes(ws memdb.WatchSet) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() defer txn.Abort() @@ -2111,8 +2108,9 @@ func (s *StateStore) CSIVolumes(ws memdb.WatchSet) (memdb.ResultIterator, error) return iter, nil } -// CSIVolumeByID is used to lookup a single volume. Returns a copy of the volume -// because its plugins are denormalized to provide accurate Health. +// CSIVolumeByID is used to lookup a single volume. Returns a copy of the +// volume because its plugins and allocations are denormalized to provide +// accurate Health. func (s *StateStore) CSIVolumeByID(ws memdb.WatchSet, namespace, id string) (*structs.CSIVolume, error) { txn := s.db.ReadTxn() @@ -2120,17 +2118,21 @@ func (s *StateStore) CSIVolumeByID(ws memdb.WatchSet, namespace, id string) (*st if err != nil { return nil, fmt.Errorf("volume lookup failed: %s %v", id, err) } + ws.Add(watchCh) if obj == nil { return nil, nil } + // we return the volume with the plugins denormalized by default, + // because the scheduler needs them for feasibility checking vol := obj.(*structs.CSIVolume) - return s.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + return s.CSIVolumeDenormalizePluginsTxn(txn, vol.Copy()) } -// CSIVolumes looks up csi_volumes by pluginID +// CSIVolumes looks up csi_volumes by pluginID. Caller should snapshot if it +// wants to also denormalize the plugins. func (s *StateStore) CSIVolumesByPluginID(ws memdb.WatchSet, namespace, pluginID string) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() @@ -2152,7 +2154,8 @@ func (s *StateStore) CSIVolumesByPluginID(ws memdb.WatchSet, namespace, pluginID return wrap, nil } -// CSIVolumesByIDPrefix supports search +// CSIVolumesByIDPrefix supports search. Caller should snapshot if it wants to +// also denormalize the plugins. func (s *StateStore) CSIVolumesByIDPrefix(ws memdb.WatchSet, namespace, volumeID string) (memdb.ResultIterator, error) { txn := s.db.ReadTxn() @@ -2162,10 +2165,12 @@ func (s *StateStore) CSIVolumesByIDPrefix(ws memdb.WatchSet, namespace, volumeID } ws.Add(iter.WatchCh()) + return iter, nil } -// CSIVolumesByNodeID looks up CSIVolumes in use on a node +// CSIVolumesByNodeID looks up CSIVolumes in use on a node. Caller should +// snapshot if it wants to also denormalize the plugins. func (s *StateStore) CSIVolumesByNodeID(ws memdb.WatchSet, nodeID string) (memdb.ResultIterator, error) { allocs, err := s.AllocsByNode(ws, nodeID) if err != nil { @@ -2202,6 +2207,8 @@ func (s *StateStore) CSIVolumesByNodeID(ws memdb.WatchSet, nodeID string) (memdb iter.Add(raw) } + ws.Add(iter.WatchCh()) + return iter, nil } @@ -2213,6 +2220,7 @@ func (s *StateStore) CSIVolumesByNamespace(ws memdb.WatchSet, namespace string) if err != nil { return nil, fmt.Errorf("volume lookup failed: %v", err) } + ws.Add(iter.WatchCh()) return iter, nil @@ -2222,7 +2230,6 @@ func (s *StateStore) CSIVolumesByNamespace(ws memdb.WatchSet, namespace string) func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *structs.CSIVolumeClaim) error { txn := s.db.WriteTxn(index) defer txn.Abort() - ws := memdb.NewWatchSet() row, err := txn.First("csi_volumes", "id", namespace, id) if err != nil { @@ -2239,7 +2246,7 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s var alloc *structs.Allocation if claim.State == structs.CSIVolumeClaimStateTaken { - alloc, err = s.AllocByID(ws, claim.AllocationID) + alloc, err = s.allocByIDImpl(txn, nil, claim.AllocationID) if err != nil { s.logger.Error("AllocByID failed", "error", err) return fmt.Errorf(structs.ErrUnknownAllocationPrefix) @@ -2252,12 +2259,11 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s } } - volume, err := s.CSIVolumeDenormalizePlugins(ws, orig.Copy()) + volume, err := s.CSIVolumeDenormalizePluginsTxn(txn, orig.Copy()) if err != nil { return err } - - volume, err = s.CSIVolumeDenormalize(ws, volume) + volume, err = s.CSIVolumeDenormalizeTxn(txn, nil, volume) if err != nil { return err } @@ -2321,7 +2327,7 @@ func (s *StateStore) CSIVolumeDeregister(index uint64, namespace string, ids []s // allocations have been stopped but claims can't be freed because // ex. the plugins have all been removed. if vol.InUse() { - if !force || !s.volSafeToForce(vol) { + if !force || !s.volSafeToForce(txn, vol) { return fmt.Errorf("volume in use: %s", id) } } @@ -2340,9 +2346,8 @@ func (s *StateStore) CSIVolumeDeregister(index uint64, namespace string, ids []s // volSafeToForce checks if the any of the remaining allocations // are in a non-terminal state. -func (s *StateStore) volSafeToForce(v *structs.CSIVolume) bool { - ws := memdb.NewWatchSet() - vol, err := s.CSIVolumeDenormalize(ws, v) +func (s *StateStore) volSafeToForce(txn Txn, v *structs.CSIVolume) bool { + vol, err := s.CSIVolumeDenormalizeTxn(txn, nil, v) if err != nil { return false } @@ -2360,19 +2365,30 @@ func (s *StateStore) volSafeToForce(v *structs.CSIVolume) bool { return true } -// CSIVolumeDenormalizePlugins returns a CSIVolume with current health and plugins, but -// without allocations -// Use this for current volume metadata, handling lists of volumes -// Use CSIVolumeDenormalize for volumes containing both health and current allocations +// CSIVolumeDenormalizePlugins returns a CSIVolume with current health and +// plugins, but without allocations. +// Use this for current volume metadata, handling lists of volumes. +// Use CSIVolumeDenormalize for volumes containing both health and current +// allocations. func (s *StateStore) CSIVolumeDenormalizePlugins(ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { if vol == nil { return nil, nil } - // Lookup CSIPlugin, the health records, and calculate volume health txn := s.db.ReadTxn() defer txn.Abort() + return s.CSIVolumeDenormalizePluginsTxn(txn, vol) +} - plug, err := s.CSIPluginByID(ws, vol.PluginID) +// CSIVolumeDenormalizePluginsTxn returns a CSIVolume with current health and +// plugins, but without allocations. +// Use this for current volume metadata, handling lists of volumes. +// Use CSIVolumeDenormalize for volumes containing both health and current +// allocations. +func (s *StateStore) CSIVolumeDenormalizePluginsTxn(txn Txn, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + if vol == nil { + return nil, nil + } + plug, err := s.CSIPluginByIDTxn(txn, nil, vol.PluginID) if err != nil { return nil, fmt.Errorf("plugin lookup error: %s %v", vol.PluginID, err) } @@ -2403,8 +2419,17 @@ func (s *StateStore) CSIVolumeDenormalizePlugins(ws memdb.WatchSet, vol *structs // CSIVolumeDenormalize returns a CSIVolume with allocations func (s *StateStore) CSIVolumeDenormalize(ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + txn := s.db.ReadTxn() + return s.CSIVolumeDenormalizeTxn(txn, ws, vol) +} + +// CSIVolumeDenormalizeTxn populates a CSIVolume with allocations +func (s *StateStore) CSIVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *structs.CSIVolume) (*structs.CSIVolume, error) { + if vol == nil { + return nil, nil + } for id := range vol.ReadAllocs { - a, err := s.AllocByID(ws, id) + a, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2425,7 +2450,7 @@ func (s *StateStore) CSIVolumeDenormalize(ws memdb.WatchSet, vol *structs.CSIVol } for id := range vol.WriteAllocs { - a, err := s.AllocByID(ws, id) + a, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2474,27 +2499,40 @@ func (s *StateStore) CSIPluginsByIDPrefix(ws memdb.WatchSet, pluginID string) (m return iter, nil } -// CSIPluginByID returns the one named CSIPlugin +// CSIPluginByID returns a named CSIPlugin. This method creates a new +// transaction so you should not call it from within another transaction. func (s *StateStore) CSIPluginByID(ws memdb.WatchSet, id string) (*structs.CSIPlugin, error) { txn := s.db.ReadTxn() - defer txn.Abort() + plugin, err := s.CSIPluginByIDTxn(txn, ws, id) + if err != nil { + return nil, err + } + return plugin, nil +} - raw, err := txn.First("csi_plugins", "id_prefix", id) +// CSIPluginByIDTxn returns a named CSIPlugin +func (s *StateStore) CSIPluginByIDTxn(txn Txn, ws memdb.WatchSet, id string) (*structs.CSIPlugin, error) { + + watchCh, obj, err := txn.FirstWatch("csi_plugins", "id_prefix", id) if err != nil { return nil, fmt.Errorf("csi_plugin lookup failed: %s %v", id, err) } - if raw == nil { - return nil, nil + ws.Add(watchCh) + + if obj != nil { + return obj.(*structs.CSIPlugin), nil } - - plug := raw.(*structs.CSIPlugin) - - return plug, nil + return nil, nil } // CSIPluginDenormalize returns a CSIPlugin with allocation details. Always called on a copy of the plugin. func (s *StateStore) CSIPluginDenormalize(ws memdb.WatchSet, plug *structs.CSIPlugin) (*structs.CSIPlugin, error) { + txn := s.db.ReadTxn() + return s.CSIPluginDenormalizeTxn(txn, ws, plug) +} + +func (s *StateStore) CSIPluginDenormalizeTxn(txn Txn, ws memdb.WatchSet, plug *structs.CSIPlugin) (*structs.CSIPlugin, error) { if plug == nil { return nil, nil } @@ -2509,7 +2547,7 @@ func (s *StateStore) CSIPluginDenormalize(ws memdb.WatchSet, plug *structs.CSIPl } for id := range ids { - alloc, err := s.AllocByID(ws, id) + alloc, err := s.allocByIDImpl(txn, ws, id) if err != nil { return nil, err } @@ -2553,9 +2591,8 @@ func (s *StateStore) UpsertCSIPlugin(index uint64, plug *structs.CSIPlugin) erro func (s *StateStore) DeleteCSIPlugin(index uint64, id string) error { txn := s.db.WriteTxn(index) defer txn.Abort() - ws := memdb.NewWatchSet() - plug, err := s.CSIPluginByID(ws, id) + plug, err := s.CSIPluginByIDTxn(txn, nil, id) if err != nil { return err } @@ -2564,7 +2601,7 @@ func (s *StateStore) DeleteCSIPlugin(index uint64, id string) error { return nil } - plug, err = s.CSIPluginDenormalize(ws, plug.Copy()) + plug, err = s.CSIPluginDenormalizeTxn(txn, nil, plug.Copy()) if err != nil { return err } @@ -3307,18 +3344,25 @@ func (s *StateStore) nestedUpdateAllocDesiredTransition( // AllocByID is used to lookup an allocation by its ID func (s *StateStore) AllocByID(ws memdb.WatchSet, id string) (*structs.Allocation, error) { txn := s.db.ReadTxn() + return s.allocByIDImpl(txn, ws, id) +} - watchCh, existing, err := txn.FirstWatch("allocs", "id", id) +// allocByIDImpl retrives an allocation and is called under and existing +// transaction. An optional watch set can be passed to add allocations to the +// watch set +func (s *StateStore) allocByIDImpl(txn Txn, ws memdb.WatchSet, id string) (*structs.Allocation, error) { + watchCh, raw, err := txn.FirstWatch("allocs", "id", id) if err != nil { return nil, fmt.Errorf("alloc lookup failed: %v", err) } ws.Add(watchCh) - if existing != nil { - return existing.(*structs.Allocation), nil + if raw == nil { + return nil, nil } - return nil, nil + alloc := raw.(*structs.Allocation) + return alloc, nil } // AllocsByIDPrefix is used to lookup allocs by prefix @@ -4613,7 +4657,6 @@ func (s *StateStore) updateJobScalingPolicies(index uint64, job *structs.Job, tx // updateJobCSIPlugins runs on job update, and indexes the job in the plugin func (s *StateStore) updateJobCSIPlugins(index uint64, job, prev *structs.Job, txn *txn) error { - ws := memdb.NewWatchSet() plugIns := make(map[string]*structs.CSIPlugin) loop := func(job *structs.Job, delete bool) error { @@ -4625,7 +4668,7 @@ func (s *StateStore) updateJobCSIPlugins(index uint64, job, prev *structs.Job, t plugIn, ok := plugIns[t.CSIPluginConfig.ID] if !ok { - p, err := s.CSIPluginByID(ws, t.CSIPluginConfig.ID) + p, err := s.CSIPluginByIDTxn(txn, nil, t.CSIPluginConfig.ID) if err != nil { return err } @@ -4909,12 +4952,11 @@ func (s *StateStore) updatePluginWithAlloc(index uint64, alloc *structs.Allocati return nil } - ws := memdb.NewWatchSet() tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) for _, t := range tg.Tasks { if t.CSIPluginConfig != nil { pluginID := t.CSIPluginConfig.ID - plug, err := s.CSIPluginByID(ws, pluginID) + plug, err := s.CSIPluginByIDTxn(txn, nil, pluginID) if err != nil { return err } @@ -4943,7 +4985,6 @@ func (s *StateStore) updatePluginWithAlloc(index uint64, alloc *structs.Allocati func (s *StateStore) updatePluginWithJobSummary(index uint64, summary *structs.JobSummary, alloc *structs.Allocation, txn *txn) error { - ws := memdb.NewWatchSet() tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) if tg == nil { return nil @@ -4952,7 +4993,7 @@ func (s *StateStore) updatePluginWithJobSummary(index uint64, summary *structs.J for _, t := range tg.Tasks { if t.CSIPluginConfig != nil { pluginID := t.CSIPluginConfig.ID - plug, err := s.CSIPluginByID(ws, pluginID) + plug, err := s.CSIPluginByIDTxn(txn, nil, pluginID) if err != nil { return err } diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 982080b21..76bd0ebb3 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -3520,6 +3520,84 @@ func TestStateStore_CSIPluginMultiNodeUpdates(t *testing.T) { } +// TestStateStore_CSIPlugin_ConcurrentStop tests that concurrent allocation +// updates don't cause the count to drift unexpectedly or cause allocation +// update errors. +func TestStateStore_CSIPlugin_ConcurrentStop(t *testing.T) { + t.Parallel() + index := uint64(999) + state := testStateStore(t) + ws := memdb.NewWatchSet() + + var err error + + // Create Nomad client Nodes + ns := []*structs.Node{mock.Node(), mock.Node(), mock.Node()} + for _, n := range ns { + index++ + err = state.UpsertNode(structs.MsgTypeTestSetup, index, n) + require.NoError(t, err) + } + + plugID := "foo" + plugCfg := &structs.TaskCSIPluginConfig{ID: plugID} + + allocs := []*structs.Allocation{} + + // Fingerprint 3 running node plugins and their allocs + for _, n := range ns[:] { + alloc := mock.Alloc() + n, _ := state.NodeByID(ws, n.ID) + n.CSINodePlugins = map[string]*structs.CSIInfo{ + plugID: { + PluginID: plugID, + AllocID: alloc.ID, + Healthy: true, + UpdateTime: time.Now(), + RequiresControllerPlugin: true, + RequiresTopologies: false, + NodeInfo: &structs.CSINodeInfo{}, + }, + } + index++ + err = state.UpsertNode(structs.MsgTypeTestSetup, index, n) + require.NoError(t, err) + + alloc.NodeID = n.ID + alloc.DesiredStatus = "run" + alloc.ClientStatus = "running" + alloc.Job.TaskGroups[0].Tasks[0].CSIPluginConfig = plugCfg + + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}) + require.NoError(t, err) + + allocs = append(allocs, alloc) + } + + plug, err := state.CSIPluginByID(ws, plugID) + require.NoError(t, err) + require.Equal(t, 3, plug.NodesHealthy, "nodes healthy") + require.Equal(t, 3, len(plug.Nodes), "nodes expected") + + // stop all the allocs + for _, alloc := range allocs { + alloc.DesiredStatus = "stop" + alloc.ClientStatus = "complete" + } + + // this is somewhat artificial b/c we get alloc updates from multiple + // nodes concurrently but not in a single RPC call. But this guarantees + // we'll trigger any nested transaction setup bugs + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, allocs) + require.NoError(t, err) + + plug, err = state.CSIPluginByID(ws, plugID) + require.NoError(t, err) + require.Nil(t, plug) +} + func TestStateStore_CSIPluginJobs(t *testing.T) { s := testStateStore(t) index := uint64(1001)