diff --git a/client/allocrunner/taskrunner/volume_hook.go b/client/allocrunner/taskrunner/volume_hook.go index 3bd3e6e96..d6a8ffbc5 100644 --- a/client/allocrunner/taskrunner/volume_hook.go +++ b/client/allocrunner/taskrunner/volume_hook.go @@ -122,8 +122,50 @@ func (h *volumeHook) prepareHostVolumes(volumes map[string]*structs.VolumeReques return hostVolumeMounts, nil } -func (h *volumeHook) prepareCSIVolumes(req *interfaces.TaskPrestartRequest) ([]*drivers.MountConfig, error) { - return nil, nil +// partitionMountsByVolume takes a list of volume mounts and returns them in the +// form of volume-alias:[]volume-mount because one volume may be mounted multiple +// times. +func partitionMountsByVolume(xs []*structs.VolumeMount) map[string][]*structs.VolumeMount { + result := make(map[string][]*structs.VolumeMount) + for _, mount := range xs { + result[mount.Volume] = append(result[mount.Volume], mount) + } + + return result +} + +func (h *volumeHook) prepareCSIVolumes(req *interfaces.TaskPrestartRequest, volumes map[string]*structs.VolumeRequest) ([]*drivers.MountConfig, error) { + if len(volumes) == 0 { + return nil, nil + } + + var mounts []*drivers.MountConfig + + mountRequests := partitionMountsByVolume(req.Task.VolumeMounts) + csiMountPoints := h.runner.allocHookResources.GetCSIMounts() + for alias, request := range volumes { + mountsForAlias, ok := mountRequests[alias] + if !ok { + // This task doesn't use the volume + continue + } + + csiMountPoint, ok := csiMountPoints[alias] + if !ok { + return nil, fmt.Errorf("No CSI Mount Point found for volume: %s", alias) + } + + for _, m := range mountsForAlias { + mcfg := &drivers.MountConfig{ + HostPath: csiMountPoint.Source, + TaskPath: m.Destination, + Readonly: request.ReadOnly || m.ReadOnly, + } + mounts = append(mounts, mcfg) + } + } + + return mounts, nil } func (h *volumeHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { @@ -134,7 +176,7 @@ func (h *volumeHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartR return err } - csiVolumeMounts, err := h.prepareCSIVolumes(req) + csiVolumeMounts, err := h.prepareCSIVolumes(req, volumes[structs.VolumeTypeCSI]) if err != nil { return err } diff --git a/client/allocrunner/taskrunner/volume_hook_test.go b/client/allocrunner/taskrunner/volume_hook_test.go new file mode 100644 index 000000000..8c0e924fb --- /dev/null +++ b/client/allocrunner/taskrunner/volume_hook_test.go @@ -0,0 +1,111 @@ +package taskrunner + +import ( + "testing" + + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" + "github.com/stretchr/testify/require" +) + +func TestVolumeHook_PartitionMountsByVolume_Works(t *testing.T) { + mounts := []*structs.VolumeMount{ + { + Volume: "foo", + Destination: "/tmp", + ReadOnly: false, + }, + { + Volume: "foo", + Destination: "/bar", + ReadOnly: false, + }, + { + Volume: "baz", + Destination: "/baz", + ReadOnly: false, + }, + } + + expected := map[string][]*structs.VolumeMount{ + "foo": { + { + Volume: "foo", + Destination: "/tmp", + ReadOnly: false, + }, + { + Volume: "foo", + Destination: "/bar", + ReadOnly: false, + }, + }, + "baz": { + { + Volume: "baz", + Destination: "/baz", + ReadOnly: false, + }, + }, + } + + // Test with a real collection + + partitioned := partitionMountsByVolume(mounts) + require.Equal(t, expected, partitioned) + + // Test with nil/emptylist + + partitioned = partitionMountsByVolume(nil) + require.Equal(t, map[string][]*structs.VolumeMount{}, partitioned) +} + +func TestVolumeHook_prepareCSIVolumes(t *testing.T) { + req := &interfaces.TaskPrestartRequest{ + Task: &structs.Task{ + VolumeMounts: []*structs.VolumeMount{ + { + Volume: "foo", + Destination: "/bar", + }, + }, + }, + } + + volumes := map[string]*structs.VolumeRequest{ + "foo": { + Type: "csi", + Source: "my-test-volume", + }, + } + + tr := &TaskRunner{ + allocHookResources: &cstructs.AllocHookResources{ + CSIMounts: map[string]*csimanager.MountInfo{ + "foo": &csimanager.MountInfo{ + Source: "/mnt/my-test-volume", + }, + }, + }, + } + + expected := []*drivers.MountConfig{ + { + HostPath: "/mnt/my-test-volume", + TaskPath: "/bar", + }, + } + + hook := &volumeHook{ + logger: testlog.HCLogger(t), + alloc: structs.MockAlloc(), + runner: tr, + } + mounts, err := hook.prepareCSIVolumes(req, volumes) + require.NoError(t, err) + require.Equal(t, expected, mounts) +}