CSI: ensure only client-terminal allocs are treated as past claims (#26831)

The volume watcher checks whether any allocations that have claims are terminal
so that it knows if it's safe to unpublish the volume. This check was
considering a claim as unpublishable if the allocation was terminal on either
the server or client, rather than the client alone. In many circumstances this
is safe.

But if an allocation takes a while to stop (ex. it has a `shutdown_delay`), it's
possible for garbage collection to run in the window between when the alloc is
marked server-terminal and when the task is actually stopped. The server
unpublishes the volume which sends a node plugin RPC. The plugin unmounts the
volume while it's in use, and then unmounts it again when the allocation stops
and the CSI postrun hook runs. If the task writes to the volume during the
unmounting process, some providers end up in a broken state and the volume is
not usable unless it's detached and reattached.

Fix this by considering a claim a "past claim" only when the allocation is
client terminal. This way if garbage collection runs while we're waiting for
allocation shutdown, the alloc will only be server-terminal and we won't send
the extra node RPCs.

Fixes: https://github.com/hashicorp/nomad/issues/24130
Fixes: https://github.com/hashicorp/nomad/issues/25819
Ref: https://hashicorp.atlassian.net/browse/NMD-1001
This commit is contained in:
Tim Gross
2025-09-25 09:24:53 -04:00
committed by GitHub
parent c80c60965f
commit 40241b261b
5 changed files with 63 additions and 17 deletions

3
.changelog/26831.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
csi: Fixed a bug where volumes could be unmounted while in use by a task that was shutting down
```

View File

@@ -2918,6 +2918,7 @@ func (s *StateStore) volSafeToForce(txn Txn, v *structs.CSIVolume) bool {
}
for _, alloc := range vol.ReadAllocs {
// note we check that both server and client agree on terminal status
if alloc != nil && !alloc.TerminalStatus() {
return false
}
@@ -3029,7 +3030,7 @@ func (s *StateStore) csiVolumeDenormalizeTxn(txn Txn, ws memdb.WatchSet, vol *st
}
currentAllocs[id] = a
if (a == nil || a.TerminalStatus()) && pastClaim == nil {
if (a == nil || a.ClientTerminalStatus()) && pastClaim == nil {
// the alloc is garbage collected but nothing has written a PastClaim,
// so create one now
pastClaim = &structs.CSIVolumeClaim{

View File

@@ -3680,6 +3680,8 @@ func TestStateStore_CSIVolume(t *testing.T) {
must.NoError(t, err)
vs = slurp(iter)
must.False(t, vs[0].HasFreeWriteClaims())
must.MapLen(t, 1, vs[0].ReadClaims)
must.MapLen(t, 0, vs[0].PastClaims)
claim2 := new(structs.CSIVolumeClaim)
*claim2 = *claim0
@@ -3692,7 +3694,20 @@ func TestStateStore_CSIVolume(t *testing.T) {
vs = slurp(iter)
must.True(t, vs[0].ReadSchedulable())
// deregistration is an error when the volume is in use
// alloc finishes, so we should see a past claim
a0 = a0.Copy()
a0.ClientStatus = structs.AllocClientStatusComplete
index++
err = state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{a0})
must.NoError(t, err)
v0, err = state.CSIVolumeByID(nil, ns, vol0)
must.NoError(t, err)
must.MapLen(t, 1, v0.ReadClaims)
must.MapLen(t, 1, v0.PastClaims)
// but until this claim is freed the volume is in use, so deregistration is
// still an error
index++
err = state.CSIVolumeDeregister(index, ns, []string{vol0}, false)
must.Error(t, err, must.Sprint("volume deregistered while in use"))

View File

@@ -184,10 +184,6 @@ func (vw *volumeWatcher) volumeReap(vol *structs.CSIVolume) {
}
}
func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool {
return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0
}
// volumeReapImpl unpublished all the volume's PastClaims. PastClaims
// will be populated from nil or terminal allocs when we call
// CSIVolumeDenormalize(), so this assumes we've done so in the caller

View File

@@ -18,33 +18,61 @@ import (
func TestVolumeWatch_Reap(t *testing.T) {
ci.Parallel(t)
// note: this test doesn't put the volume in the state store so that we
// don't have to have the mock write updates back to it
store := state.TestStateStore(t)
srv := &MockRPCServer{
state: state.TestStateStore(t),
state: store,
}
plugin := mock.CSIPlugin()
node := testNode(plugin, srv.State())
node := testNode(plugin, store)
alloc := mock.Alloc()
alloc.NodeID = node.ID
alloc.ClientStatus = structs.AllocClientStatusComplete
alloc.ClientStatus = structs.AllocClientStatusRunning
index, _ := store.LatestIndex()
index++
must.NoError(t, store.UpsertAllocs(
structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}))
vol := testVolume(plugin, alloc, node.ID)
vol.PastClaims = vol.ReadClaims
ctx, exitFn := context.WithCancel(context.Background())
w := &volumeWatcher{
v: vol,
rpc: srv,
state: srv.State(),
state: store,
ctx: ctx,
exitFn: exitFn,
logger: testlog.HCLogger(t),
}
vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy())
err := w.volumeReapImpl(vol)
must.NoError(t, err)
// past claim from a previous pass
// verify no change has been made
must.MapLen(t, 1, vol.ReadClaims)
must.MapLen(t, 0, vol.PastClaims)
must.Eq(t, 0, srv.countCSIUnpublish)
alloc = alloc.Copy()
alloc.ClientStatus = structs.AllocClientStatusComplete
index++
must.NoError(t, store.UpdateAllocsFromClient(
structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}))
vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy())
must.MapLen(t, 1, vol.ReadClaims)
must.MapLen(t, 1, vol.PastClaims)
err = w.volumeReapImpl(vol)
must.NoError(t, err)
must.Eq(t, 1, srv.countCSIUnpublish)
// simulate updated past claim from a previous pass
vol.PastClaims = map[string]*structs.CSIVolumeClaim{
alloc.ID: {
NodeID: node.ID,
@@ -52,10 +80,11 @@ func TestVolumeWatch_Reap(t *testing.T) {
State: structs.CSIVolumeClaimStateNodeDetached,
},
}
vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy())
err = w.volumeReapImpl(vol)
must.NoError(t, err)
must.MapLen(t, 1, vol.PastClaims)
must.Eq(t, 2, srv.countCSIUnpublish)
// claim emitted by a GC event
vol.PastClaims = map[string]*structs.CSIVolumeClaim{
@@ -64,10 +93,11 @@ func TestVolumeWatch_Reap(t *testing.T) {
Mode: structs.CSIVolumeClaimGC,
},
}
vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy())
err = w.volumeReapImpl(vol)
must.NoError(t, err)
must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim
must.Eq(t, 4, srv.countCSIUnpublish)
// release claims of a previously GC'd allocation
vol.ReadAllocs[alloc.ID] = nil
@@ -77,10 +107,11 @@ func TestVolumeWatch_Reap(t *testing.T) {
Mode: structs.CSIVolumeClaimRead,
},
}
vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
vol, _ = store.CSIVolumeDenormalize(nil, vol.Copy())
err = w.volumeReapImpl(vol)
must.NoError(t, err)
must.MapLen(t, 2, vol.PastClaims) // alloc claim + GC claim
must.Eq(t, 6, srv.countCSIUnpublish)
}
func TestVolumeReapBadState(t *testing.T) {
@@ -102,7 +133,7 @@ func TestVolumeReapBadState(t *testing.T) {
w := &volumeWatcher{
v: vol,
rpc: srv,
state: srv.State(),
state: store,
ctx: ctx,
exitFn: exitFn,
logger: testlog.HCLogger(t),