diff --git a/.changelog/26832.txt b/.changelog/26832.txt new file mode 100644 index 000000000..71460e27a --- /dev/null +++ b/.changelog/26832.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug where multiple node plugin RPCs could be in-flight for a single volume +``` diff --git a/client/pluginmanager/csimanager/volume.go b/client/pluginmanager/csimanager/volume.go index f59f3c939..11eec6907 100644 --- a/client/pluginmanager/csimanager/volume.go +++ b/client/pluginmanager/csimanager/volume.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" @@ -58,6 +59,9 @@ type volumeManager struct { // externalNodeID is the identity of a given nomad client as observed by the // storage provider (ex. a hostname, VM instance ID, etc.) externalNodeID string + + inFlight map[structs.NamespacedID]context.Context + inFlightLock sync.Mutex } func newVolumeManager(logger hclog.Logger, eventer TriggerNodeEvent, plugin csi.CSIPlugin, rootDir, containerRootDir string, requiresStaging bool, externalID string) *volumeManager { @@ -71,91 +75,72 @@ func newVolumeManager(logger hclog.Logger, eventer TriggerNodeEvent, plugin csi. requiresStaging: requiresStaging, usageTracker: newVolumeUsageTracker(), externalNodeID: externalID, + inFlight: make(map[structs.NamespacedID]context.Context), } } -func (v *volumeManager) stagingDirForVolume(root string, volNS, volID string, usage *UsageOptions) string { - return filepath.Join(root, StagingDirName, volNS, volID, usage.ToFS()) +func (v *volumeManager) ExternalID() string { + return v.externalNodeID } -func (v *volumeManager) allocDirForVolume(root string, volID, allocID string) string { - return filepath.Join(root, AllocSpecificDirName, allocID, volID) -} - -func (v *volumeManager) targetForVolume(root string, volID, allocID string, usage *UsageOptions) string { - return filepath.Join(root, AllocSpecificDirName, allocID, volID, usage.ToFS()) -} - -// ensureStagingDir attempts to create a directory for use when staging a volume -// and then validates that the path is not already a mount point for e.g an -// existing volume stage. -// -// Returns whether the directory is a pre-existing mountpoint, the staging path, -// and any errors that occurred. -func (v *volumeManager) ensureStagingDir(vol *structs.CSIVolume, usage *UsageOptions) (string, bool, error) { - hostStagingPath := v.stagingDirForVolume(v.mountRoot, vol.Namespace, vol.ID, usage) - - // Make the staging path, owned by the Nomad User - if err := os.MkdirAll(hostStagingPath, 0700); err != nil && !os.IsExist(err) { - return "", false, fmt.Errorf("failed to create staging directory for volume (%s): %v", vol.ID, err) - - } - - // Validate that it is not already a mount point +func (v *volumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { m := mount.New() - isNotMount, err := m.IsNotAMountPoint(hostStagingPath) - if err != nil { - return "", false, fmt.Errorf("mount point detection failed for volume (%s): %v", vol.ID, err) - } - - return hostStagingPath, !isNotMount, nil + isNotMount, err := m.IsNotAMountPoint(mountInfo.Source) + return !isNotMount, err } -// ensureAllocDir attempts to create a directory for use when publishing a volume -// and then validates that the path is not already a mount point (e.g when reattaching -// to existing allocs). -// -// Returns whether the directory is a pre-existing mountpoint, the publish path, -// and any errors that occurred. -func (v *volumeManager) ensureAllocDir(vol *structs.CSIVolume, alloc *structs.Allocation, usage *UsageOptions) (string, bool, error) { - allocPath := v.allocDirForVolume(v.mountRoot, vol.ID, alloc.ID) +// MountVolume performs the steps required for using a given volume +// configuration for the provided allocation. It is passed the publishContext +// from remote attachment, and specific usage modes from the CSI Hook. It then +// uses this state to stage and publish the volume as required for use by the +// given allocation. +func (v *volumeManager) MountVolume(ctx context.Context, + vol *structs.CSIVolume, alloc *structs.Allocation, + usage *UsageOptions, publishContext map[string]string, +) (*MountInfo, error) { - // Make the alloc path, owned by the Nomad User - if err := os.MkdirAll(allocPath, 0700); err != nil && !os.IsExist(err) { - return "", false, fmt.Errorf("failed to create allocation directory for volume (%s): %v", vol.ID, err) - } + var mountInfo *MountInfo + err := v.serializedOp(ctx, vol.Namespace, vol.ID, func() error { + var err error + mountInfo, err = v.mountVolumeImpl(ctx, vol, alloc, usage, publishContext) + return err + }) - // Validate that the target is not already a mount point - targetPath := v.targetForVolume(v.mountRoot, vol.ID, alloc.ID, usage) - - m := mount.New() - isNotMount, err := m.IsNotAMountPoint(targetPath) - - switch { - case errors.Is(err, os.ErrNotExist): - // ignore; path does not exist and as such is not a mount - case err != nil: - return "", false, fmt.Errorf("mount point detection failed for volume (%s): %v", vol.ID, err) - } - - return targetPath, !isNotMount, nil + return mountInfo, err } -func volumeCapability(vol *structs.CSIVolume, usage *UsageOptions) (*csi.VolumeCapability, error) { - var opts *structs.CSIMountOptions - if vol.MountOptions == nil { - opts = usage.MountOptions +func (v *volumeManager) mountVolumeImpl(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usage *UsageOptions, publishContext map[string]string) (mountInfo *MountInfo, err error) { + + logger := v.logger.With("volume_id", vol.ID, "alloc_id", alloc.ID) + ctx = hclog.WithContext(ctx, logger) + + // Claim before we stage/publish to prevent interleaved Unmount for another + // alloc from unstaging between stage/publish steps below + v.usageTracker.Claim(alloc.ID, vol.ID, vol.Namespace, usage) + + if v.requiresStaging { + err = v.stageVolume(ctx, vol, usage, publishContext) + } + + if err == nil { + mountInfo, err = v.publishVolume(ctx, vol, alloc, usage, publishContext) + } + + event := structs.NewNodeEvent(). + SetSubsystem(structs.NodeEventSubsystemStorage). + SetMessage("Mount volume"). + AddDetail("volume_id", vol.ID) + if err == nil { + event.AddDetail("success", "true") } else { - opts = vol.MountOptions.Copy() - opts.Merge(usage.MountOptions) + event.AddDetail("success", "false") + event.AddDetail("error", err.Error()) + v.usageTracker.Free(alloc.ID, vol.ID, vol.Namespace, usage) } - capability, err := csi.VolumeCapabilityFromStructs(usage.AttachmentMode, usage.AccessMode, opts) - if err != nil { - return nil, err - } + v.eventer(event) - return capability, nil + return mountInfo, err } // stageVolume prepares a volume for use by allocations. When a plugin exposes @@ -243,92 +228,51 @@ func (v *volumeManager) publishVolume(ctx context.Context, vol *structs.CSIVolum return &MountInfo{Source: hostTargetPath}, err } -// MountVolume performs the steps required for using a given volume -// configuration for the provided allocation. -// It is passed the publishContext from remote attachment, and specific usage -// modes from the CSI Hook. -// It then uses this state to stage and publish the volume as required for use -// by the given allocation. -func (v *volumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usage *UsageOptions, publishContext map[string]string) (mountInfo *MountInfo, err error) { - logger := v.logger.With("volume_id", vol.ID, "alloc_id", alloc.ID) +// UnmountVolume unpublishes the volume for a specific allocation, and unstages +// the volume if there are no more allocations claiming it on the node +func (v *volumeManager) UnmountVolume(ctx context.Context, + volNS, volID, remoteID, allocID string, usage *UsageOptions, +) error { + return v.serializedOp(ctx, volNS, volID, func() error { + return v.unmountVolumeImpl(ctx, volNS, volID, remoteID, allocID, usage) + }) +} + +func (v *volumeManager) unmountVolumeImpl(ctx context.Context, volNS, volID, remoteID, allocID string, usage *UsageOptions) error { + + logger := v.logger.With("volume_id", volID, "ns", volNS, "alloc_id", allocID) ctx = hclog.WithContext(ctx, logger) + logger.Trace("unmounting volume") - // Claim before we stage/publish to prevent interleaved Unmount for another - // alloc from unstaging between stage/publish steps below - v.usageTracker.Claim(alloc.ID, vol.ID, vol.Namespace, usage) + err := v.unpublishVolume(ctx, volID, remoteID, allocID, usage) - if v.requiresStaging { - err = v.stageVolume(ctx, vol, usage, publishContext) + if err == nil || errors.Is(err, structs.ErrCSIClientRPCIgnorable) { + canRelease := v.usageTracker.Free(allocID, volID, volNS, usage) + if v.requiresStaging && canRelease { + err = v.unstageVolume(ctx, volNS, volID, remoteID, usage) + } } - if err == nil { - mountInfo, err = v.publishVolume(ctx, vol, alloc, usage, publishContext) + if errors.Is(err, structs.ErrCSIClientRPCIgnorable) { + logger.Trace("unmounting volume failed with ignorable error", "error", err) + err = nil } event := structs.NewNodeEvent(). SetSubsystem(structs.NodeEventSubsystemStorage). - SetMessage("Mount volume"). - AddDetail("volume_id", vol.ID) + SetMessage("Unmount volume"). + AddDetail("volume_id", volID) if err == nil { event.AddDetail("success", "true") } else { event.AddDetail("success", "false") event.AddDetail("error", err.Error()) - v.usageTracker.Free(alloc.ID, vol.ID, vol.Namespace, usage) } v.eventer(event) - return mountInfo, err -} + return err -// unstageVolume is the inverse operation of `stageVolume` and must be called -// once for each staging path that a volume has been staged under. -// It is safe to call multiple times and a plugin is required to return OK if -// the volume has been unstaged or was never staged on the node. -func (v *volumeManager) unstageVolume(ctx context.Context, volNS, volID, remoteID string, usage *UsageOptions) error { - logger := hclog.FromContext(ctx) - - // This is the staging path inside the container, which we pass to the - // plugin to perform unstaging - stagingPath := v.stagingDirForVolume(v.containerMountPoint, volNS, volID, usage) - - // This is the path from the host, which we need to use to verify whether - // the path is the right one to pass to the plugin container - hostStagingPath := v.stagingDirForVolume(v.mountRoot, volNS, volID, usage) - _, err := os.Stat(hostStagingPath) - if err != nil && errors.Is(err, fs.ErrNotExist) { - // COMPAT: it's possible to get an unmount request that includes the - // namespace even for volumes that were mounted before the path included - // the namespace, so if the staging path doesn't exist, try the older - // path - stagingPath = v.stagingDirForVolume(v.containerMountPoint, "", volID, usage) - } - - logger.Trace("unstaging volume", "staging_path", stagingPath) - - // CSI NodeUnstageVolume errors for timeout, codes.Unavailable and - // codes.ResourceExhausted are retried; all other errors are fatal. - return v.plugin.NodeUnstageVolume(ctx, - remoteID, - stagingPath, - grpc_retry.WithPerRetryTimeout(DefaultMountActionTimeout), - grpc_retry.WithMax(3), - grpc_retry.WithBackoff(grpc_retry.BackoffExponential(100*time.Millisecond)), - ) -} - -func combineErrors(maybeErrs ...error) error { - var result *multierror.Error - for _, err := range maybeErrs { - if err == nil { - continue - } - - result = multierror.Append(result, err) - } - - return result.ErrorOrNil() } func (v *volumeManager) unpublishVolume(ctx context.Context, volID, remoteID, allocID string, usage *UsageOptions) error { @@ -372,43 +316,61 @@ func (v *volumeManager) unpublishVolume(ctx context.Context, volID, remoteID, al return fmt.Errorf("%w: %v", structs.ErrCSIClientRPCIgnorable, rpcErr) } -func (v *volumeManager) UnmountVolume(ctx context.Context, volNS, volID, remoteID, allocID string, usage *UsageOptions) (err error) { - logger := v.logger.With("volume_id", volID, "ns", volNS, "alloc_id", allocID) - ctx = hclog.WithContext(ctx, logger) - logger.Trace("unmounting volume") +// unstageVolume is the inverse operation of `stageVolume` and must be called +// once for each staging path that a volume has been staged under. +// It is safe to call multiple times and a plugin is required to return OK if +// the volume has been unstaged or was never staged on the node. +func (v *volumeManager) unstageVolume(ctx context.Context, volNS, volID, remoteID string, usage *UsageOptions) error { + logger := hclog.FromContext(ctx) - err = v.unpublishVolume(ctx, volID, remoteID, allocID, usage) + // This is the staging path inside the container, which we pass to the + // plugin to perform unstaging + stagingPath := v.stagingDirForVolume(v.containerMountPoint, volNS, volID, usage) - if err == nil || errors.Is(err, structs.ErrCSIClientRPCIgnorable) { - canRelease := v.usageTracker.Free(allocID, volID, volNS, usage) - if v.requiresStaging && canRelease { - err = v.unstageVolume(ctx, volNS, volID, remoteID, usage) - } + // This is the path from the host, which we need to use to verify whether + // the path is the right one to pass to the plugin container + hostStagingPath := v.stagingDirForVolume(v.mountRoot, volNS, volID, usage) + _, err := os.Stat(hostStagingPath) + if err != nil && errors.Is(err, fs.ErrNotExist) { + // COMPAT: it's possible to get an unmount request that includes the + // namespace even for volumes that were mounted before the path included + // the namespace, so if the staging path doesn't exist, try the older + // path + stagingPath = v.stagingDirForVolume(v.containerMountPoint, "", volID, usage) } - if errors.Is(err, structs.ErrCSIClientRPCIgnorable) { - logger.Trace("unmounting volume failed with ignorable error", "error", err) - err = nil - } + logger.Trace("unstaging volume", "staging_path", stagingPath) - event := structs.NewNodeEvent(). - SetSubsystem(structs.NodeEventSubsystemStorage). - SetMessage("Unmount volume"). - AddDetail("volume_id", volID) - if err == nil { - event.AddDetail("success", "true") - } else { - event.AddDetail("success", "false") - event.AddDetail("error", err.Error()) - } - - v.eventer(event) - - return err + // CSI NodeUnstageVolume errors for timeout, codes.Unavailable and + // codes.ResourceExhausted are retried; all other errors are fatal. + return v.plugin.NodeUnstageVolume(ctx, + remoteID, + stagingPath, + grpc_retry.WithPerRetryTimeout(DefaultMountActionTimeout), + grpc_retry.WithMax(3), + grpc_retry.WithBackoff(grpc_retry.BackoffExponential(100*time.Millisecond)), + ) } // ExpandVolume sends a NodeExpandVolume request to the node plugin -func (v *volumeManager) ExpandVolume(ctx context.Context, volNS, volID, remoteID, allocID string, usage *UsageOptions, capacity *csi.CapacityRange) (newCapacity int64, err error) { +func (v *volumeManager) ExpandVolume(ctx context.Context, + volNS, volID, remoteID, allocID string, + usage *UsageOptions, capacity *csi.CapacityRange, +) (int64, error) { + + var newCapacity int64 + err := v.serializedOp(ctx, volNS, volID, func() error { + var err error + newCapacity, err = v.expandVolumeImpl( + ctx, volNS, volID, remoteID, allocID, usage, capacity) + return err + }) + + return newCapacity, err +} + +func (v *volumeManager) expandVolumeImpl(ctx context.Context, volNS, volID, remoteID, allocID string, usage *UsageOptions, capacity *csi.CapacityRange) (newCapacity int64, err error) { + capability, err := csi.VolumeCapabilityFromStructs(usage.AttachmentMode, usage.AccessMode, usage.MountOptions) if err != nil { // nil may be acceptable, so let the node plugin decide. @@ -453,12 +415,139 @@ func (v *volumeManager) ExpandVolume(ctx context.Context, volNS, volID, remoteID return resp.CapacityBytes, nil } -func (v *volumeManager) ExternalID() string { - return v.externalNodeID +func (v *volumeManager) stagingDirForVolume(root string, volNS, volID string, usage *UsageOptions) string { + return filepath.Join(root, StagingDirName, volNS, volID, usage.ToFS()) } -func (v *volumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { - m := mount.New() - isNotMount, err := m.IsNotAMountPoint(mountInfo.Source) - return !isNotMount, err +func (v *volumeManager) allocDirForVolume(root string, volID, allocID string) string { + return filepath.Join(root, AllocSpecificDirName, allocID, volID) +} + +func (v *volumeManager) targetForVolume(root string, volID, allocID string, usage *UsageOptions) string { + return filepath.Join(root, AllocSpecificDirName, allocID, volID, usage.ToFS()) +} + +// ensureStagingDir attempts to create a directory for use when staging a volume +// and then validates that the path is not already a mount point for e.g an +// existing volume stage. +// +// Returns whether the directory is a pre-existing mountpoint, the staging path, +// and any errors that occurred. +func (v *volumeManager) ensureStagingDir(vol *structs.CSIVolume, usage *UsageOptions) (string, bool, error) { + hostStagingPath := v.stagingDirForVolume(v.mountRoot, vol.Namespace, vol.ID, usage) + + // Make the staging path, owned by the Nomad User + if err := os.MkdirAll(hostStagingPath, 0700); err != nil && !os.IsExist(err) { + return "", false, fmt.Errorf("failed to create staging directory for volume (%s): %v", vol.ID, err) + + } + + // Validate that it is not already a mount point + m := mount.New() + isNotMount, err := m.IsNotAMountPoint(hostStagingPath) + if err != nil { + return "", false, fmt.Errorf("mount point detection failed for volume (%s): %v", vol.ID, err) + } + + return hostStagingPath, !isNotMount, nil +} + +// ensureAllocDir attempts to create a directory for use when publishing a volume +// and then validates that the path is not already a mount point (e.g when reattaching +// to existing allocs). +// +// Returns whether the directory is a pre-existing mountpoint, the publish path, +// and any errors that occurred. +func (v *volumeManager) ensureAllocDir(vol *structs.CSIVolume, alloc *structs.Allocation, usage *UsageOptions) (string, bool, error) { + allocPath := v.allocDirForVolume(v.mountRoot, vol.ID, alloc.ID) + + // Make the alloc path, owned by the Nomad User + if err := os.MkdirAll(allocPath, 0700); err != nil && !os.IsExist(err) { + return "", false, fmt.Errorf("failed to create allocation directory for volume (%s): %v", vol.ID, err) + } + + // Validate that the target is not already a mount point + targetPath := v.targetForVolume(v.mountRoot, vol.ID, alloc.ID, usage) + + m := mount.New() + isNotMount, err := m.IsNotAMountPoint(targetPath) + + switch { + case errors.Is(err, os.ErrNotExist): + // ignore; path does not exist and as such is not a mount + case err != nil: + return "", false, fmt.Errorf("mount point detection failed for volume (%s): %v", vol.ID, err) + } + + return targetPath, !isNotMount, nil +} + +func volumeCapability(vol *structs.CSIVolume, usage *UsageOptions) (*csi.VolumeCapability, error) { + var opts *structs.CSIMountOptions + if vol.MountOptions == nil { + opts = usage.MountOptions + } else { + opts = vol.MountOptions.Copy() + opts.Merge(usage.MountOptions) + } + + capability, err := csi.VolumeCapabilityFromStructs(usage.AttachmentMode, usage.AccessMode, opts) + if err != nil { + return nil, err + } + + return capability, nil +} + +func combineErrors(maybeErrs ...error) error { + var result *multierror.Error + for _, err := range maybeErrs { + if err == nil { + continue + } + + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() +} + +// serializedOp ensures that we only have one in-flight request per volume, and +// keeps multi-step operations (ex. stage + publish) together in a single batch +// rather than potentially interleaving +func (v *volumeManager) serializedOp(ctx context.Context, volumeNS, volumeID string, fn func() error) error { + + id := structs.NewNamespacedID(volumeID, volumeNS) + + for { + v.inFlightLock.Lock() + future := v.inFlight[id] + + if future == nil { + future, resolveFuture := context.WithCancel(ctx) + v.inFlight[id] = future + v.inFlightLock.Unlock() + + err := fn() + + // close the future while holding the lock and not in a defer so + // that we can ensure we've cleared it from the map before allowing + // anyone else to take the lock and write a new one + v.inFlightLock.Lock() + resolveFuture() + delete(v.inFlight, id) + v.inFlightLock.Unlock() + + return err + } else { + v.inFlightLock.Unlock() + + select { + case <-future.Done(): + continue + case <-ctx.Done(): + return nil // agent shutdown + } + } + } } diff --git a/client/pluginmanager/csimanager/volume_test.go b/client/pluginmanager/csimanager/volume_test.go index 1138d0e1f..6e35ffdb4 100644 --- a/client/pluginmanager/csimanager/volume_test.go +++ b/client/pluginmanager/csimanager/volume_test.go @@ -8,6 +8,7 @@ import ( "errors" "os" "runtime" + "sync" "testing" "time" @@ -597,3 +598,93 @@ func TestVolumeManager_InterleavedStaging(t *testing.T) { must.Eq(t, 1, csiFake.NodeUnpublishVolumeCallCount, must.Sprint("expected 1 unpublish call")) must.Eq(t, 0, csiFake.NodeUnstageVolumeCallCount, must.Sprint("expected no unstage call")) } + +func TestVolumeManager_Serialization(t *testing.T) { + ci.Parallel(t) + + tmpPath := t.TempDir() + csiFake := &csifake.Client{} + + logger := testlog.HCLogger(t) + + ctx := hclog.WithContext(t.Context(), logger) + + manager := newVolumeManager(logger, + func(e *structs.NodeEvent) {}, csiFake, + tmpPath, tmpPath, true, "i-example") + + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + + // test that an operation on a volume can block another operation on the + // same volume + // + // we can't guarantee the goroutines will try to contend, so we'll force the + // op in the goroutine to wait until we've entered the serialized function, + // and then have the serialized function sleep. the wait + the op in the + // goroutine should take at least as long as that sleep to complete and + // return + var wg sync.WaitGroup + wg.Add(1) + + elapsedCh := make(chan time.Duration) + + go func() { + now := time.Now() + wg.Wait() + manager.serializedOp(ctx, "ns", "vol0", func() error { + return errors.New("two") + }) + elapsedCh <- time.Since(now) + }() + + manager.serializedOp(ctx, "ns", "vol0", func() error { + wg.Done() + time.Sleep(100 * time.Millisecond) + return errors.New("one") + }) + + must.GreaterEq(t, 100*time.Millisecond, <-elapsedCh) + + // test that serialized ops for different volumes don't block each other + + var wg1 sync.WaitGroup + var wg2 sync.WaitGroup + wg1.Add(1) + wg2.Add(1) + errs := make(chan error, 2) + go func() { + errs <- manager.serializedOp(ctx, "ns", "vol0", func() error { + // at this point we've entered the serialized op for vol0 and are + // waiting to enter the serialized op for vol1. if serialization + // blocks vol1's op, we'll never unblock here and will hit the + // timeout + wg1.Wait() + wg2.Done() + return errors.New("four") + }) + }() + + errs <- manager.serializedOp(ctx, "ns", "vol1", func() error { + wg1.Done() // unblock the first op + wg2.Wait() // wait for the first op to make sure we're running concurrently + return errors.New("five") + }) + + ctx2, cancel2 := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel2) + + found := 0 + for { + if found >= 2 { + break + } + select { + case <-errs: + found++ + case <-ctx2.Done(): + t.Fatal("timed out waiting for error") + } + } + +}