diff --git a/.changelog/26831.txt b/.changelog/26831.txt new file mode 100644 index 000000000..7b5e8b14e --- /dev/null +++ b/.changelog/26831.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where volumes could be unmounted while in use by a task that was shutting down +``` diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 49c44fee1..d59e63493 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2918,6 +2918,7 @@ func (s *StateStore) volSafeToForce(txn Txn, v *structs.CSIVolume) bool { } for _, alloc := range vol.ReadAllocs { + // note we check that both server and client agree on terminal status if alloc != nil && !alloc.TerminalStatus() { return false } @@ -3029,7 +3030,7 @@ func (s *StateStore) csiVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *st } currentAllocs[id] = a - if (a == nil || a.TerminalStatus()) && pastClaim == nil { + if (a == nil || a.ClientTerminalStatus()) && pastClaim == nil { // the alloc is garbage collected but nothing has written a PastClaim, // so create one now pastClaim = &structs.CSIVolumeClaim{ diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 5774c6396..492062077 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -3680,6 +3680,8 @@ func TestStateStore_CSIVolume(t *testing.T) { must.NoError(t, err) vs = slurp(iter) must.False(t, vs[0].HasFreeWriteClaims()) + must.MapLen(t, 1, vs[0].ReadClaims) + must.MapLen(t, 0, vs[0].PastClaims) claim2 := new(structs.CSIVolumeClaim) *claim2 = *claim0 @@ -3692,7 +3694,20 @@ func TestStateStore_CSIVolume(t *testing.T) { vs = slurp(iter) must.True(t, vs[0].ReadSchedulable()) - // deregistration is an error when the volume is in use + // alloc finishes, so we should see a past claim + a0 = a0.Copy() + a0.ClientStatus = structs.AllocClientStatusComplete + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{a0}) + must.NoError(t, err) + + v0, err = state.CSIVolumeByID(nil, ns, vol0) + must.NoError(t, err) + must.MapLen(t, 1, v0.ReadClaims) + must.MapLen(t, 1, v0.PastClaims) + + // but until this claim is freed the volume is in use, so deregistration is + // still an error index++ err = state.CSIVolumeDeregister(index, ns, []string{vol0}, false) must.Error(t, err, must.Sprint("volume deregistered while in use")) diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index 114aa7aa3..4f29ce809 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -184,10 +184,6 @@ func (vw *volumeWatcher) volumeReap(vol *structs.CSIVolume) { } } -func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool { - return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0 -} - // volumeReapImpl unpublished all the volume's PastClaims. PastClaims // will be populated from nil or terminal allocs when we call // CSIVolumeDenormalize(), so this assumes we've done so in the caller diff --git a/nomad/volumewatcher/volume_watcher_test.go b/nomad/volumewatcher/volume_watcher_test.go index dc3745b5b..2c41bfcc7 100644 --- a/nomad/volumewatcher/volume_watcher_test.go +++ b/nomad/volumewatcher/volume_watcher_test.go @@ -18,33 +18,61 @@ import ( func TestVolumeWatch_Reap(t *testing.T) { ci.Parallel(t) + // note: this test doesn't put the volume in the state store so that we + // don't have to have the mock write updates back to it + store := state.TestStateStore(t) srv := &MockRPCServer{ - state: state.TestStateStore(t), + state: store, } plugin := mock.CSIPlugin() - node := testNode(plugin, srv.State()) + node := testNode(plugin, store) alloc := mock.Alloc() alloc.NodeID = node.ID - alloc.ClientStatus = structs.AllocClientStatusComplete + alloc.ClientStatus = structs.AllocClientStatusRunning + + index, _ := store.LatestIndex() + index++ + must.NoError(t, store.UpsertAllocs( + structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc})) + vol := testVolume(plugin, alloc, node.ID) - vol.PastClaims = vol.ReadClaims ctx, exitFn := context.WithCancel(context.Background()) w := &volumeWatcher{ v: vol, rpc: srv, - state: srv.State(), + state: store, ctx: ctx, exitFn: exitFn, logger: testlog.HCLogger(t), } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err := w.volumeReapImpl(vol) must.NoError(t, err) - // past claim from a previous pass + // verify no change has been made + must.MapLen(t, 1, vol.ReadClaims) + must.MapLen(t, 0, vol.PastClaims) + must.Eq(t, 0, srv.countCSIUnpublish) + + alloc = alloc.Copy() + alloc.ClientStatus = structs.AllocClientStatusComplete + + index++ + must.NoError(t, store.UpdateAllocsFromClient( + structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc})) + + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) + must.MapLen(t, 1, vol.ReadClaims) + must.MapLen(t, 1, vol.PastClaims) + + err = w.volumeReapImpl(vol) + must.NoError(t, err) + must.Eq(t, 1, srv.countCSIUnpublish) + + // simulate updated past claim from a previous pass vol.PastClaims = map[string]*structs.CSIVolumeClaim{ alloc.ID: { NodeID: node.ID, @@ -52,10 +80,11 @@ func TestVolumeWatch_Reap(t *testing.T) { State: structs.CSIVolumeClaimStateNodeDetached, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 1, vol.PastClaims) + must.Eq(t, 2, srv.countCSIUnpublish) // claim emitted by a GC event vol.PastClaims = map[string]*structs.CSIVolumeClaim{ @@ -64,10 +93,11 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimGC, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim + must.Eq(t, 4, srv.countCSIUnpublish) // release claims of a previously GC'd allocation vol.ReadAllocs[alloc.ID] = nil @@ -77,10 +107,11 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimRead, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim + must.Eq(t, 6, srv.countCSIUnpublish) } func TestVolumeReapBadState(t *testing.T) { @@ -102,7 +133,7 @@ func TestVolumeReapBadState(t *testing.T) { w := &volumeWatcher{ v: vol, rpc: srv, - state: srv.State(), + state: store, ctx: ctx, exitFn: exitFn, logger: testlog.HCLogger(t),