From 40241b261bfb264cb5295bcc4347657a8390f537 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Thu, 25 Sep 2025 09:24:53 -0400 Subject: [PATCH] CSI: ensure only client-terminal allocs are treated as past claims (#26831) The volume watcher checks whether any allocations that have claims are terminal so that it knows if it's safe to unpublish the volume. This check was considering a claim as unpublishable if the allocation was terminal on either the server or client, rather than the client alone. In many circumstances this is safe. But if an allocation takes a while to stop (ex. it has a `shutdown_delay`), it's possible for garbage collection to run in the window between when the alloc is marked server-terminal and when the task is actually stopped. The server unpublishes the volume which sends a node plugin RPC. The plugin unmounts the volume while it's in use, and then unmounts it again when the allocation stops and the CSI postrun hook runs. If the task writes to the volume during the unmounting process, some providers end up in a broken state and the volume is not usable unless it's detached and reattached. Fix this by considering a claim a "past claim" only when the allocation is client terminal. This way if garbage collection runs while we're waiting for allocation shutdown, the alloc will only be server-terminal and we won't send the extra node RPCs. Fixes: https://github.com/hashicorp/nomad/issues/24130 Fixes: https://github.com/hashicorp/nomad/issues/25819 Ref: https://hashicorp.atlassian.net/browse/NMD-1001 --- .changelog/26831.txt | 3 ++ nomad/state/state_store.go | 3 +- nomad/state/state_store_test.go | 17 ++++++- nomad/volumewatcher/volume_watcher.go | 4 -- nomad/volumewatcher/volume_watcher_test.go | 53 +++++++++++++++++----- 5 files changed, 63 insertions(+), 17 deletions(-) create mode 100644 .changelog/26831.txt diff --git a/.changelog/26831.txt b/.changelog/26831.txt new file mode 100644 index 000000000..7b5e8b14e --- /dev/null +++ b/.changelog/26831.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where volumes could be unmounted while in use by a task that was shutting down +``` diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 49c44fee1..d59e63493 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2918,6 +2918,7 @@ func (s *StateStore) volSafeToForce(txn Txn, v *structs.CSIVolume) bool { } for _, alloc := range vol.ReadAllocs { + // note we check that both server and client agree on terminal status if alloc != nil && !alloc.TerminalStatus() { return false } @@ -3029,7 +3030,7 @@ func (s *StateStore) csiVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *st } currentAllocs[id] = a - if (a == nil || a.TerminalStatus()) && pastClaim == nil { + if (a == nil || a.ClientTerminalStatus()) && pastClaim == nil { // the alloc is garbage collected but nothing has written a PastClaim, // so create one now pastClaim = &structs.CSIVolumeClaim{ diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 5774c6396..492062077 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -3680,6 +3680,8 @@ func TestStateStore_CSIVolume(t *testing.T) { must.NoError(t, err) vs = slurp(iter) must.False(t, vs[0].HasFreeWriteClaims()) + must.MapLen(t, 1, vs[0].ReadClaims) + must.MapLen(t, 0, vs[0].PastClaims) claim2 := new(structs.CSIVolumeClaim) *claim2 = *claim0 @@ -3692,7 +3694,20 @@ func TestStateStore_CSIVolume(t *testing.T) { vs = slurp(iter) must.True(t, vs[0].ReadSchedulable()) - // deregistration is an error when the volume is in use + // alloc finishes, so we should see a past claim + a0 = a0.Copy() + a0.ClientStatus = structs.AllocClientStatusComplete + index++ + err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{a0}) + must.NoError(t, err) + + v0, err = state.CSIVolumeByID(nil, ns, vol0) + must.NoError(t, err) + must.MapLen(t, 1, v0.ReadClaims) + must.MapLen(t, 1, v0.PastClaims) + + // but until this claim is freed the volume is in use, so deregistration is + // still an error index++ err = state.CSIVolumeDeregister(index, ns, []string{vol0}, false) must.Error(t, err, must.Sprint("volume deregistered while in use")) diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index 114aa7aa3..4f29ce809 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -184,10 +184,6 @@ func (vw *volumeWatcher) volumeReap(vol *structs.CSIVolume) { } } -func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool { - return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0 -} - // volumeReapImpl unpublished all the volume's PastClaims. PastClaims // will be populated from nil or terminal allocs when we call // CSIVolumeDenormalize(), so this assumes we've done so in the caller diff --git a/nomad/volumewatcher/volume_watcher_test.go b/nomad/volumewatcher/volume_watcher_test.go index dc3745b5b..2c41bfcc7 100644 --- a/nomad/volumewatcher/volume_watcher_test.go +++ b/nomad/volumewatcher/volume_watcher_test.go @@ -18,33 +18,61 @@ import ( func TestVolumeWatch_Reap(t *testing.T) { ci.Parallel(t) + // note: this test doesn't put the volume in the state store so that we + // don't have to have the mock write updates back to it + store := state.TestStateStore(t) srv := &MockRPCServer{ - state: state.TestStateStore(t), + state: store, } plugin := mock.CSIPlugin() - node := testNode(plugin, srv.State()) + node := testNode(plugin, store) alloc := mock.Alloc() alloc.NodeID = node.ID - alloc.ClientStatus = structs.AllocClientStatusComplete + alloc.ClientStatus = structs.AllocClientStatusRunning + + index, _ := store.LatestIndex() + index++ + must.NoError(t, store.UpsertAllocs( + structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc})) + vol := testVolume(plugin, alloc, node.ID) - vol.PastClaims = vol.ReadClaims ctx, exitFn := context.WithCancel(context.Background()) w := &volumeWatcher{ v: vol, rpc: srv, - state: srv.State(), + state: store, ctx: ctx, exitFn: exitFn, logger: testlog.HCLogger(t), } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err := w.volumeReapImpl(vol) must.NoError(t, err) - // past claim from a previous pass + // verify no change has been made + must.MapLen(t, 1, vol.ReadClaims) + must.MapLen(t, 0, vol.PastClaims) + must.Eq(t, 0, srv.countCSIUnpublish) + + alloc = alloc.Copy() + alloc.ClientStatus = structs.AllocClientStatusComplete + + index++ + must.NoError(t, store.UpdateAllocsFromClient( + structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc})) + + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) + must.MapLen(t, 1, vol.ReadClaims) + must.MapLen(t, 1, vol.PastClaims) + + err = w.volumeReapImpl(vol) + must.NoError(t, err) + must.Eq(t, 1, srv.countCSIUnpublish) + + // simulate updated past claim from a previous pass vol.PastClaims = map[string]*structs.CSIVolumeClaim{ alloc.ID: { NodeID: node.ID, @@ -52,10 +80,11 @@ func TestVolumeWatch_Reap(t *testing.T) { State: structs.CSIVolumeClaimStateNodeDetached, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 1, vol.PastClaims) + must.Eq(t, 2, srv.countCSIUnpublish) // claim emitted by a GC event vol.PastClaims = map[string]*structs.CSIVolumeClaim{ @@ -64,10 +93,11 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimGC, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim + must.Eq(t, 4, srv.countCSIUnpublish) // release claims of a previously GC'd allocation vol.ReadAllocs[alloc.ID] = nil @@ -77,10 +107,11 @@ func TestVolumeWatch_Reap(t *testing.T) { Mode: structs.CSIVolumeClaimRead, }, } - vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy()) + vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy()) err = w.volumeReapImpl(vol) must.NoError(t, err) must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim + must.Eq(t, 6, srv.countCSIUnpublish) } func TestVolumeReapBadState(t *testing.T) { @@ -102,7 +133,7 @@ func TestVolumeReapBadState(t *testing.T) { w := &volumeWatcher{ v: vol, rpc: srv, - state: srv.State(), + state: store, ctx: ctx, exitFn: exitFn, logger: testlog.HCLogger(t),