From 4895d708b438b42e52fd54a128f9ec4cb6d72277 Mon Sep 17 00:00:00 2001 From: Daniel Bennett Date: Mon, 18 Sep 2023 10:30:15 -0500 Subject: [PATCH] csi: implement NodeExpandVolume (#18522) following ControllerExpandVolume in c6dbba7cde911bb08f1f8da445a44a0125cd2047, which expands the disk at e.g. a cloud vendor, the controller plugin may say that we also need to issue NodeExpandVolume for the node plugin to make the new disk space available to task(s) that have claims on the volume by e.g. expanding the filesystem on the node. csi spec: https://github.com/container-storage-interface/spec/blob/c918b7f/spec.md#nodeexpandvolume --- client/allocrunner/csi_hook_test.go | 5 + client/csi_endpoint.go | 39 +++++ client/csi_endpoint_test.go | 99 ++++++++++++ client/pluginmanager/csimanager/interface.go | 2 + client/pluginmanager/csimanager/testing.go | 73 +++++++++ client/pluginmanager/csimanager/volume.go | 30 ++++ client/structs/csi.go | 43 +++++ client/structs/csi_test.go | 45 ++++++ nomad/client_csi_endpoint.go | 36 ++++- nomad/client_csi_endpoint_test.go | 7 + nomad/csi_endpoint.go | 47 +++++- nomad/csi_endpoint_test.go | 107 ++++++++++++- plugins/csi/client.go | 34 +++- plugins/csi/client_test.go | 160 +++++++++++++++++++ plugins/csi/plugin.go | 41 +++-- plugins/csi/testing/client.go | 2 + 16 files changed, 744 insertions(+), 26 deletions(-) create mode 100644 client/pluginmanager/csimanager/testing.go create mode 100644 client/structs/csi_test.go diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go index c1836a732..a0fd5aea7 100644 --- a/client/allocrunner/csi_hook_test.go +++ b/client/allocrunner/csi_hook_test.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" "github.com/hashicorp/nomad/plugins/drivers" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" @@ -498,6 +499,10 @@ func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.Mo return mountInfo != nil && vm.hasMounts, nil } +func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) { + return 0, nil +} + func (vm mockVolumeManager) ExternalID() string { return "i-example" } diff --git a/client/csi_endpoint.go b/client/csi_endpoint.go index 2d3ab673d..dcbcd70e5 100644 --- a/client/csi_endpoint.go +++ b/client/csi_endpoint.go @@ -537,6 +537,45 @@ func (c *CSI) NodeDetachVolume(req *structs.ClientCSINodeDetachVolumeRequest, re return nil } +// NodeExpandVolume instructs the node plugin to complete a volume expansion +// for a particular claim held by an allocation. +func (c *CSI) NodeExpandVolume(req *structs.ClientCSINodeExpandVolumeRequest, resp *structs.ClientCSINodeExpandVolumeResponse) error { + defer metrics.MeasureSince([]string{"client", "csi_node", "expand_volume"}, time.Now()) + + if err := req.Validate(); err != nil { + return err + } + usageOpts := &csimanager.UsageOptions{ + // Claim will not be nil here, per req.Validate() above. + ReadOnly: req.Claim.Mode == nstructs.CSIVolumeClaimRead, + AttachmentMode: req.Claim.AttachmentMode, + AccessMode: req.Claim.AccessMode, + } + + ctx, cancel := c.requestContext() // note: this has a 2-minute timeout + defer cancel() + + err := c.c.csimanager.WaitForPlugin(ctx, dynamicplugins.PluginTypeCSINode, req.PluginID) + if err != nil { + return err + } + + manager, err := c.c.csimanager.ManagerForPlugin(ctx, req.PluginID) + if err != nil { + return err + } + + newCapacity, err := manager.ExpandVolume(ctx, + req.VolumeID, req.ExternalID, req.Claim.AllocationID, usageOpts, req.Capacity) + + if err != nil && !errors.Is(err, nstructs.ErrCSIClientRPCIgnorable) { + return err + } + resp.CapacityBytes = newCapacity + + return nil +} + func (c *CSI) findControllerPlugin(name string) (csi.CSIPlugin, error) { return c.findPlugin(dynamicplugins.PluginTypeCSIController, name) } diff --git a/client/csi_endpoint_test.go b/client/csi_endpoint_test.go index a83b975ba..072a427da 100644 --- a/client/csi_endpoint_test.go +++ b/client/csi_endpoint_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/dynamicplugins" + "github.com/hashicorp/nomad/client/pluginmanager/csimanager" "github.com/hashicorp/nomad/client/structs" nstructs "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" @@ -1069,3 +1070,101 @@ func TestCSINode_DetachVolume(t *testing.T) { }) } } + +func TestCSINode_ExpandVolume(t *testing.T) { + ci.Parallel(t) + + client, cleanup := TestClient(t, nil) + t.Cleanup(func() { test.NoError(t, cleanup()) }) + + cases := []struct { + Name string + ModRequest func(r *structs.ClientCSINodeExpandVolumeRequest) + ModManager func(m *csimanager.MockCSIManager) + ExpectErr error + }{ + { + Name: "success", + }, + { + Name: "invalid request", + ModRequest: func(r *structs.ClientCSINodeExpandVolumeRequest) { + r.Claim = nil + }, + ExpectErr: errors.New("Claim is required"), + }, + { + Name: "error waiting for plugin", + ModManager: func(m *csimanager.MockCSIManager) { + m.NextWaitForPluginErr = errors.New("sad plugin") + }, + ExpectErr: errors.New("sad plugin"), + }, + { + Name: "error from manager expand", + ModManager: func(m *csimanager.MockCSIManager) { + m.VM.NextExpandVolumeErr = errors.New("no expand, so sad") + }, + ExpectErr: errors.New("no expand, so sad"), + }, + { + Name: "ignorable error from manager expand", + ModManager: func(m *csimanager.MockCSIManager) { + m.VM.NextExpandVolumeErr = fmt.Errorf("%w: not found", nstructs.ErrCSIClientRPCIgnorable) + }, + ExpectErr: nil, // explicitly expecting no error + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + mockManager := &csimanager.MockCSIManager{ + VM: &csimanager.MockVolumeManager{}, + } + if tc.ModManager != nil { + tc.ModManager(mockManager) + } + client.csimanager = mockManager + + req := &structs.ClientCSINodeExpandVolumeRequest{ + PluginID: "fake-plug", + VolumeID: "fake-vol", + ExternalID: "fake-external", + Capacity: &csi.CapacityRange{ + RequiredBytes: 5, + }, + Claim: &nstructs.CSIVolumeClaim{ + // minimal claim to pass validation + AllocationID: "fake-alloc", + }, + } + if tc.ModRequest != nil { + tc.ModRequest(req) + } + + var resp structs.ClientCSINodeExpandVolumeResponse + err := client.ClientRPC("CSI.NodeExpandVolume", req, &resp) + + if tc.ExpectErr != nil { + test.EqError(t, tc.ExpectErr, err.Error()) + return + } + test.NoError(t, err) + + expect := csimanager.MockExpandVolumeCall{ + VolID: req.VolumeID, + RemoteID: req.ExternalID, + AllocID: req.Claim.AllocationID, + Capacity: req.Capacity, + UsageOpts: &csimanager.UsageOptions{ + ReadOnly: true, + }, + } + test.Eq(t, req.Capacity.RequiredBytes, resp.CapacityBytes) + test.NotNil(t, mockManager.VM.LastExpandVolumeCall) + test.Eq(t, &expect, mockManager.VM.LastExpandVolumeCall) + + }) + } +} diff --git a/client/pluginmanager/csimanager/interface.go b/client/pluginmanager/csimanager/interface.go index 5dde5a4c6..526df7515 100644 --- a/client/pluginmanager/csimanager/interface.go +++ b/client/pluginmanager/csimanager/interface.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/nomad/client/pluginmanager" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" ) type MountInfo struct { @@ -57,6 +58,7 @@ type VolumeManager interface { MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error HasMount(ctx context.Context, mountInfo *MountInfo) (bool, error) + ExpandVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) ExternalID() string } diff --git a/client/pluginmanager/csimanager/testing.go b/client/pluginmanager/csimanager/testing.go new file mode 100644 index 000000000..f27f74265 --- /dev/null +++ b/client/pluginmanager/csimanager/testing.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package csimanager + +import ( + "context" + + "github.com/hashicorp/nomad/client/pluginmanager" + nstructs "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/csi" +) + +var _ Manager = &MockCSIManager{} + +type MockCSIManager struct { + VM *MockVolumeManager + + NextWaitForPluginErr error + NextManagerForPluginErr error +} + +func (m *MockCSIManager) PluginManager() pluginmanager.PluginManager { + panic("implement me") +} + +func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID string) error { + return m.NextWaitForPluginErr +} + +func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) { + return m.VM, m.NextManagerForPluginErr +} + +func (m *MockCSIManager) Shutdown() { + panic("implement me") +} + +var _ VolumeManager = &MockVolumeManager{} + +type MockVolumeManager struct { + NextExpandVolumeErr error + LastExpandVolumeCall *MockExpandVolumeCall +} + +func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) { + panic("implement me") +} + +func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error { + panic("implement me") +} + +func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { + panic("implement me") +} + +func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) { + m.LastExpandVolumeCall = &MockExpandVolumeCall{ + volID, remoteID, allocID, usageOpts, capacity, + } + return capacity.RequiredBytes, m.NextExpandVolumeErr +} + +type MockExpandVolumeCall struct { + VolID, RemoteID, AllocID string + UsageOpts *UsageOptions + Capacity *csi.CapacityRange +} + +func (m *MockVolumeManager) ExternalID() string { + return "mock-volume-manager" +} diff --git a/client/pluginmanager/csimanager/volume.go b/client/pluginmanager/csimanager/volume.go index b0611b306..f243e226b 100644 --- a/client/pluginmanager/csimanager/volume.go +++ b/client/pluginmanager/csimanager/volume.go @@ -383,6 +383,36 @@ func (v *volumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allo return err } +// ExpandVolume sends a NodeExpandVolume request to the node plugin +func (v *volumeManager) ExpandVolume(ctx context.Context, volID, remoteID, allocID string, usage *UsageOptions, capacity *csi.CapacityRange) (newCapacity int64, err error) { + capability, err := csi.VolumeCapabilityFromStructs(usage.AttachmentMode, usage.AccessMode, usage.MountOptions) + if err != nil { + // nil may be acceptable, so let the node plugin decide. + v.logger.Warn("ExpandVolume: unable to detect volume capability", + "volume_id", volID, "alloc_id", allocID, "error", err) + } + + req := &csi.NodeExpandVolumeRequest{ + ExternalVolumeID: remoteID, + CapacityRange: capacity, + Capability: capability, + TargetPath: v.targetForVolume(v.containerMountPoint, volID, allocID, usage), + StagingPath: v.stagingDirForVolume(v.containerMountPoint, volID, usage), + } + resp, err := v.plugin.NodeExpandVolume(ctx, req, + grpc_retry.WithPerRetryTimeout(DefaultMountActionTimeout), + grpc_retry.WithMax(3), + grpc_retry.WithBackoff(grpc_retry.BackoffExponential(100*time.Millisecond)), + ) + if err != nil { + return 0, err + } + if resp == nil { + return 0, errors.New("nil response from plugin.NodeExpandVolume") + } + return resp.CapacityBytes, nil +} + func (v *volumeManager) ExternalID() string { return v.externalNodeID } diff --git a/client/structs/csi.go b/client/structs/csi.go index 131071196..86f2812cb 100644 --- a/client/structs/csi.go +++ b/client/structs/csi.go @@ -4,6 +4,8 @@ package structs import ( + "errors" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" ) @@ -452,3 +454,44 @@ type ClientCSINodeDetachVolumeRequest struct { } type ClientCSINodeDetachVolumeResponse struct{} + +// ClientCSINodeExpandVolumeRequest is the RPC made from the server to +// a Nomad client to tell a CSI node plugin on that client to perform +// NodeExpandVolume. +type ClientCSINodeExpandVolumeRequest struct { + PluginID string // ID of the plugin that manages the volume (required) + VolumeID string // ID of the volume to be expanded (required) + ExternalID string // External ID of the volume to be expanded (required) + + // Capacity range (required) to be sent to the node plugin + Capacity *csi.CapacityRange + + // Claim currently held for the allocation (required) + // used to determine capabilities and the mount point on the client + Claim *structs.CSIVolumeClaim +} + +func (req *ClientCSINodeExpandVolumeRequest) Validate() error { + var err error + // These should not occur during normal operations; they're here + // mainly to catch potential programmer error. + if req.PluginID == "" { + err = errors.Join(err, errors.New("PluginID is required")) + } + if req.VolumeID == "" { + err = errors.Join(err, errors.New("VolumeID is required")) + } + if req.ExternalID == "" { + err = errors.Join(err, errors.New("ExternalID is required")) + } + if req.Claim == nil { + err = errors.Join(err, errors.New("Claim is required")) + } else if req.Claim.AllocationID == "" { + err = errors.Join(err, errors.New("Claim.AllocationID is required")) + } + return err +} + +type ClientCSINodeExpandVolumeResponse struct { + CapacityBytes int64 +} diff --git a/client/structs/csi_test.go b/client/structs/csi_test.go new file mode 100644 index 000000000..117a230b5 --- /dev/null +++ b/client/structs/csi_test.go @@ -0,0 +1,45 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package structs + +import ( + "testing" + + "github.com/shoenig/test/must" + + "github.com/hashicorp/nomad/nomad/structs" +) + +func TestClientCSINodeExpandVolumeRequest_Validate(t *testing.T) { + req := &ClientCSINodeExpandVolumeRequest{ + PluginID: "plug-id", + VolumeID: "vol-id", + ExternalID: "ext-id", + Claim: &structs.CSIVolumeClaim{ + AllocationID: "alloc-id", + }, + } + err := req.Validate() + must.NoError(t, err) + + req.PluginID = "" + err = req.Validate() + must.ErrorContains(t, err, "PluginID is required") + + req.VolumeID = "" + err = req.Validate() + must.ErrorContains(t, err, "VolumeID is required") + + req.ExternalID = "" + err = req.Validate() + must.ErrorContains(t, err, "ExternalID is required") + + req.Claim.AllocationID = "" + err = req.Validate() + must.ErrorContains(t, err, "Claim.AllocationID is required") + + req.Claim = nil + err = req.Validate() + must.ErrorContains(t, err, "Claim is required") +} diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index f1c531ce0..7965d9b1f 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -224,12 +224,34 @@ func (a *ClientCSI) isRetryable(err error) bool { func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_node", "detach_volume"}, time.Now()) + return a.sendCSINodeRPC( + args.NodeID, + "CSI.NodeDetachVolume", + "ClientCSI.NodeDetachVolume", + structs.RateMetricWrite, + args, + reply, + ) +} +func (a *ClientCSI) NodeExpandVolume(args *cstructs.ClientCSINodeExpandVolumeRequest, reply *cstructs.ClientCSINodeExpandVolumeResponse) error { + defer metrics.MeasureSince([]string{"nomad", "client_csi_node", "expand_volume"}, time.Now()) + return a.sendCSINodeRPC( + args.Claim.NodeID, + "CSI.NodeExpandVolume", + "ClientCSI.NodeExpandVolume", + structs.RateMetricWrite, + args, + reply, + ) +} + +func (a *ClientCSI) sendCSINodeRPC(nodeID, method, fwdMethod, op string, args any, reply any) error { // client requests aren't RequestWithIdentity, so we use a placeholder here // to populate the identity data for metrics identityReq := &structs.GenericRequest{} authErr := a.srv.Authenticate(a.ctx, identityReq) - a.srv.MeasureRPCRate("client_csi", structs.RateMetricWrite, identityReq) + a.srv.MeasureRPCRate("client_csi", op, identityReq) // only servers can send these client RPCs err := validateTLSCertificateLevel(a.srv, a.ctx, tlsCertificateLevelServer) @@ -243,24 +265,22 @@ func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeReq return err } - _, err = getNodeForRpc(snap, args.NodeID) + _, err = getNodeForRpc(snap, nodeID) if err != nil { return err } // Get the connection to the client - state, ok := a.srv.getNodeConn(args.NodeID) + state, ok := a.srv.getNodeConn(nodeID) if !ok { - return findNodeConnAndForward(a.srv, args.NodeID, "ClientCSI.NodeDetachVolume", args, reply) + return findNodeConnAndForward(a.srv, nodeID, fwdMethod, args, reply) } // Make the RPC - err = NodeRpc(state.Session, "CSI.NodeDetachVolume", args, reply) - if err != nil { - return fmt.Errorf("node detach volume: %v", err) + if err := NodeRpc(state.Session, method, args, reply); err != nil { + return fmt.Errorf("%s error: %w", method, err) } return nil - } // clientIDsForController returns a shuffled list of client IDs where the diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index 74c4adb05..037e75358 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -45,6 +45,8 @@ type MockClientCSI struct { NextControllerExpandVolumeError error NextControllerExpandVolumeResponse *cstructs.ClientCSIControllerExpandVolumeResponse NextNodeDetachError error + NextNodeExpandError error + LastNodeExpandRequest *cstructs.ClientCSINodeExpandVolumeRequest } func newMockClientCSI() *MockClientCSI { @@ -108,6 +110,11 @@ func (c *MockClientCSI) NodeDetachVolume(req *cstructs.ClientCSINodeDetachVolume return c.NextNodeDetachError } +func (c *MockClientCSI) NodeExpandVolume(req *cstructs.ClientCSINodeExpandVolumeRequest, resp *cstructs.ClientCSINodeExpandVolumeResponse) error { + c.LastNodeExpandRequest = req + return c.NextNodeExpandError +} + func TestClientCSIController_AttachVolume_Local(t *testing.T) { ci.Parallel(t) require := require.New(t) diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 705982064..3ffb41e0e 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -1272,12 +1272,57 @@ func (v *CSIVolume) expandVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug logger.Info("controller done expanding volume") if cResp.NodeExpansionRequired { - v.logger.Warn("TODO: also do node volume expansion if needed") // TODO + return v.nodeExpandVolume(vol, plugin, capacity) } return nil } +// nodeExpandVolume sends NodeExpandVolume requests to the appropriate client +// for each allocation that has a claim on the volume. The client will then +// send a gRPC call to the CSI node plugin colocated with the allocation. +func (v *CSIVolume) nodeExpandVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, capacity *csi.CapacityRange) error { + var mErr multierror.Error + logger := v.logger.Named("nodeExpandVolume"). + With("volume", vol.ID, "plugin", plugin.ID) + + expand := func(claim *structs.CSIVolumeClaim) { + if claim == nil { + return + } + + logger.Debug("starting volume expansion on node", + "node_id", claim.NodeID, "alloc_id", claim.AllocationID) + + resp := &cstructs.ClientCSINodeExpandVolumeResponse{} + req := &cstructs.ClientCSINodeExpandVolumeRequest{ + PluginID: plugin.ID, + VolumeID: vol.ID, + ExternalID: vol.ExternalID, + Capacity: capacity, + Claim: claim, + } + if err := v.srv.RPC("ClientCSI.NodeExpandVolume", req, resp); err != nil { + mErr.Errors = append(mErr.Errors, err) + } + + if resp.CapacityBytes != vol.Capacity { + // not necessarily an error, but maybe notable + logger.Warn("unexpected capacity from NodeExpandVolume", + "expected", vol.Capacity, "resp", resp.CapacityBytes) + } + } + + for _, claim := range vol.ReadClaims { + expand(claim) + } + for _, claim := range vol.WriteClaims { + expand(claim) + } + + return mErr.ErrorOrNil() +} + func (v *CSIVolume) Delete(args *structs.CSIVolumeDeleteRequest, reply *structs.CSIVolumeDeleteResponse) error { authErr := v.srv.Authenticate(v.ctx, args) diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index dce281428..ebd97724a 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -4,6 +4,7 @@ package nomad import ( + "errors" "fmt" "strings" "sync" @@ -1889,7 +1890,8 @@ func TestCSIVolume_expandVolume(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { fake.NextControllerExpandVolumeResponse = &cstructs.ClientCSIControllerExpandVolumeResponse{ - CapacityBytes: tc.ControllerResp, + CapacityBytes: tc.ControllerResp, + // this also exercises some node expand code, incidentally NodeExpansionRequired: true, } @@ -1914,6 +1916,81 @@ func TestCSIVolume_expandVolume(t *testing.T) { }) } + // a nodeExpandVolume error should fail expandVolume too + t.Run("node error", func(t *testing.T) { + expect := "sad node expand" + fake.NextNodeExpandError = errors.New(expect) + fake.NextControllerExpandVolumeResponse = &cstructs.ClientCSIControllerExpandVolumeResponse{ + CapacityBytes: 2000, + NodeExpansionRequired: true, + } + err = endpoint.expandVolume(vol, plug, &csi.CapacityRange{ + RequiredBytes: 2000, + }) + test.ErrorContains(t, err, expect) + }) + +} + +func TestCSIVolume_nodeExpandVolume(t *testing.T) { + ci.Parallel(t) + + srv, cleanupSrv := TestServer(t, nil) + t.Cleanup(cleanupSrv) + testutil.WaitForLeader(t, srv.RPC) + t.Log("server started 👍") + + c, fake, _, fakeVolID := testClientWithCSI(t, srv) + fakeClaim := fakeCSIClaim(c.NodeID()) + + endpoint := NewCSIVolumeEndpoint(srv, nil) + plug, vol, err := endpoint.volAndPluginLookup(structs.DefaultNamespace, fakeVolID) + must.NoError(t, err) + + // there's not a lot of logic here -- validation has been done prior, + // in (controller) expandVolume and what preceeds it. + cases := []struct { + Name string + Error error + }{ + { + Name: "ok", + }, + { + Name: "not ok", + Error: errors.New("test node expand fail"), + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + fake.NextNodeExpandError = tc.Error + capacity := &csi.CapacityRange{ + RequiredBytes: 10, + LimitBytes: 10, + } + + err = endpoint.nodeExpandVolume(vol, plug, capacity) + + if tc.Error == nil { + test.NoError(t, err) + } else { + must.Error(t, err) + must.ErrorContains(t, err, + fmt.Sprintf("CSI.NodeExpandVolume error: %s", tc.Error)) + } + + req := fake.LastNodeExpandRequest + must.NotNil(t, req, must.Sprint("request should have happened")) + test.Eq(t, fakeVolID, req.VolumeID) + test.Eq(t, capacity, req.Capacity) + test.Eq(t, "fake-csi-plugin", req.PluginID) + test.Eq(t, "fake-csi-external-id", req.ExternalID) + test.Eq(t, fakeClaim, req.Claim) + + }) + } } func TestCSIPluginEndpoint_RegisterViaFingerprint(t *testing.T) { @@ -2266,8 +2343,8 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie t.Helper() m = newMockClientCSI() - plugID = "fake-plugin" - volID = "fake-volume" + plugID = "fake-csi-plugin" + volID = "fake-csi-volume" c, cleanup := client.TestClientWithRPCs(t, func(c *cconfig.Config) { @@ -2316,15 +2393,19 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie // Register a minimum-viable fake volume req := &structs.CSIVolumeRegisterRequest{ Volumes: []*structs.CSIVolume{{ - PluginID: plugID, - ID: volID, - Namespace: structs.DefaultNamespace, + PluginID: plugID, + ID: volID, + ExternalID: "fake-csi-external-id", + Namespace: structs.DefaultNamespace, RequestedCapabilities: []*structs.CSIVolumeCapability{ { - AccessMode: structs.CSIVolumeAccessModeMultiNodeMultiWriter, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, }, }, + WriteClaims: map[string]*structs.CSIVolumeClaim{ + "fake-csi-claim": fakeCSIClaim(c.NodeID()), + }, }}, WriteRequest: structs.WriteRequest{Region: srv.Region()}, } @@ -2333,3 +2414,15 @@ func testClientWithCSI(t *testing.T, srv *Server) (c *client.Client, m *MockClie return c, m, plugID, volID } + +func fakeCSIClaim(nodeID string) *structs.CSIVolumeClaim { + return &structs.CSIVolumeClaim{ + NodeID: nodeID, + AllocationID: "fake-csi-alloc", + ExternalNodeID: "fake-csi-external-node", + Mode: structs.CSIVolumeClaimWrite, + AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, + AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, + State: structs.CSIVolumeClaimStateTaken, + } +} diff --git a/plugins/csi/client.go b/plugins/csi/client.go index 0f101f4c5..736289e9c 100644 --- a/plugins/csi/client.go +++ b/plugins/csi/client.go @@ -926,5 +926,37 @@ func (c *client) NodeUnpublishVolume(ctx context.Context, volumeID, targetPath s } func (c *client) NodeExpandVolume(ctx context.Context, req *NodeExpandVolumeRequest, opts ...grpc.CallOption) (*NodeExpandVolumeResponse, error) { - return nil, nil + if err := req.Validate(); err != nil { + return nil, err + } + if err := c.ensureConnected(ctx); err != nil { + return nil, err + } + + exReq := req.ToCSIRepresentation() + resp, err := c.nodeClient.NodeExpandVolume(ctx, exReq, opts...) + if err != nil { + code := status.Code(err) + switch code { + case codes.InvalidArgument: + return nil, fmt.Errorf( + "requested capabilities not compatible with volume %q: %v", + req.ExternalVolumeID, err) + case codes.NotFound: + return nil, fmt.Errorf("%w: volume %q could not be found: %v", + structs.ErrCSIClientRPCIgnorable, req.ExternalVolumeID, err) + case codes.FailedPrecondition: + return nil, fmt.Errorf("volume %q cannot be expanded while in use: %v", req.ExternalVolumeID, err) + case codes.OutOfRange: + return nil, fmt.Errorf( + "unsupported capacity_range for volume %q: %v", req.ExternalVolumeID, err) + case codes.Internal: + return nil, fmt.Errorf( + "node plugin returned an internal error, check the plugin allocation logs for more information: %v", err) + default: + return nil, fmt.Errorf("node plugin returned an error: %v", err) + } + } + + return &NodeExpandVolumeResponse{resp.GetCapacityBytes()}, nil } diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go index 6947a953e..45bfb4783 100644 --- a/plugins/csi/client_test.go +++ b/plugins/csi/client_test.go @@ -1518,3 +1518,163 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) { }) } } + +func TestClient_RPC_NodeExpandVolume(t *testing.T) { + // minimum valid request + minRequest := &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + } + + cases := []struct { + Name string + Request *NodeExpandVolumeRequest + ExpectCall *csipbv1.NodeExpandVolumeRequest + ResponseErr error + ExpectedErr error + }{ + { + Name: "success min", + Request: minRequest, + ExpectCall: &csipbv1.NodeExpandVolumeRequest{ + VolumeId: "test-vol", + VolumePath: "/test-path", + }, + }, + { + Name: "success full", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + StagingPath: "/test-staging-path", + CapacityRange: &CapacityRange{ + RequiredBytes: 5, + LimitBytes: 10, + }, + Capability: &VolumeCapability{ + AccessType: VolumeAccessTypeMount, + AccessMode: VolumeAccessModeMultiNodeSingleWriter, + MountVolume: &structs.CSIMountOptions{ + FSType: "test-fstype", + MountFlags: []string{"test-flags"}, + }, + }, + }, + ExpectCall: &csipbv1.NodeExpandVolumeRequest{ + VolumeId: "test-vol", + VolumePath: "/test-path", + StagingTargetPath: "/test-staging-path", + CapacityRange: &csipbv1.CapacityRange{ + RequiredBytes: 5, + LimitBytes: 10, + }, + VolumeCapability: &csipbv1.VolumeCapability{ + AccessType: &csipbv1.VolumeCapability_Mount{ + Mount: &csipbv1.VolumeCapability_MountVolume{ + FsType: "test-fstype", + MountFlags: []string{"test-flags"}, + VolumeMountGroup: "", + }}, + AccessMode: &csipbv1.VolumeCapability_AccessMode{ + Mode: csipbv1.VolumeCapability_AccessMode_MULTI_NODE_SINGLE_WRITER}, + }, + }, + }, + + { + Name: "validate missing volume id", + Request: &NodeExpandVolumeRequest{ + TargetPath: "/test-path", + }, + ExpectedErr: errors.New("ExternalVolumeID is required"), + }, + { + Name: "validate missing target path", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-volume", + }, + ExpectedErr: errors.New("TargetPath is required"), + }, + { + Name: "validate min greater than max", + Request: &NodeExpandVolumeRequest{ + ExternalVolumeID: "test-vol", + TargetPath: "/test-path", + CapacityRange: &CapacityRange{ + RequiredBytes: 4, + LimitBytes: 2, + }, + }, + ExpectedErr: errors.New("LimitBytes cannot be less than RequiredBytes"), + }, + + { + Name: "grpc error default case", + Request: minRequest, + ResponseErr: status.Errorf(codes.DataLoss, "misc unspecified error"), + ExpectedErr: errors.New("node plugin returned an error: rpc error: code = DataLoss desc = misc unspecified error"), + }, + { + Name: "grpc error invalid argument", + Request: minRequest, + ResponseErr: status.Errorf(codes.InvalidArgument, "sad args"), + ExpectedErr: errors.New("requested capabilities not compatible with volume \"test-vol\": rpc error: code = InvalidArgument desc = sad args"), + }, + { + Name: "grpc error NotFound", + Request: minRequest, + ResponseErr: status.Errorf(codes.NotFound, "does not exist"), + ExpectedErr: errors.New("CSI client error (ignorable): volume \"test-vol\" could not be found: rpc error: code = NotFound desc = does not exist"), + }, + { + Name: "grpc error FailedPrecondition", + Request: minRequest, + ResponseErr: status.Errorf(codes.FailedPrecondition, "unsupported"), + ExpectedErr: errors.New("volume \"test-vol\" cannot be expanded while in use: rpc error: code = FailedPrecondition desc = unsupported"), + }, + { + Name: "grpc error OutOfRange", + Request: minRequest, + ResponseErr: status.Errorf(codes.OutOfRange, "too small"), + ExpectedErr: errors.New("unsupported capacity_range for volume \"test-vol\": rpc error: code = OutOfRange desc = too small"), + }, + { + Name: "grpc error Internal", + Request: minRequest, + ResponseErr: status.Errorf(codes.Internal, "some grpc error"), + ExpectedErr: errors.New("node plugin returned an internal error, check the plugin allocation logs for more information: rpc error: code = Internal desc = some grpc error"), + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + _, _, nc, client := newTestClient(t) + + nc.NextErr = tc.ResponseErr + // the fake client should take ~no time, but set a timeout just in case + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + resp, err := client.NodeExpandVolume(ctx, tc.Request) + if tc.ExpectedErr != nil { + must.EqError(t, err, tc.ExpectedErr.Error()) + return + } + must.NoError(t, err) + must.NotNil(t, resp) + must.Eq(t, tc.ExpectCall, nc.LastExpandVolumeRequest) + + }) + } + + t.Run("connection error", func(t *testing.T) { + c := &client{} // induce c.ensureConnected() error + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + resp, err := c.NodeExpandVolume(ctx, &NodeExpandVolumeRequest{ + ExternalVolumeID: "valid-id", + TargetPath: "/some-path", + }) + must.Nil(t, resp) + must.EqError(t, err, "address is empty") + }) +} diff --git a/plugins/csi/plugin.go b/plugins/csi/plugin.go index df3c9c4fd..e4f5155df 100644 --- a/plugins/csi/plugin.go +++ b/plugins/csi/plugin.go @@ -1020,6 +1020,19 @@ type CapacityRange struct { LimitBytes int64 } +func (c *CapacityRange) Validate() error { + if c == nil { + return nil + } + if c.RequiredBytes == 0 && c.LimitBytes == 0 { + return errors.New("either RequiredBytes or LimitBytes must be set") + } + if c.LimitBytes > 0 && c.LimitBytes < c.RequiredBytes { + return errors.New("LimitBytes cannot be less than RequiredBytes") + } + return nil +} + func (c *CapacityRange) ToCSIRepresentation() *csipbv1.CapacityRange { if c == nil { return nil @@ -1032,11 +1045,24 @@ func (c *CapacityRange) ToCSIRepresentation() *csipbv1.CapacityRange { type NodeExpandVolumeRequest struct { ExternalVolumeID string - RequiredBytes int64 - LimitBytes int64 + CapacityRange *CapacityRange + Capability *VolumeCapability TargetPath string StagingPath string - Capability *VolumeCapability +} + +func (r *NodeExpandVolumeRequest) Validate() error { + var err error + if r.ExternalVolumeID == "" { + err = errors.Join(err, errors.New("ExternalVolumeID is required")) + } + if r.TargetPath == "" { + err = errors.Join(err, errors.New("TargetPath is required")) + } + if e := r.CapacityRange.Validate(); e != nil { + err = errors.Join(err, e) + } + return err } func (r *NodeExpandVolumeRequest) ToCSIRepresentation() *csipbv1.NodeExpandVolumeRequest { @@ -1044,13 +1070,10 @@ func (r *NodeExpandVolumeRequest) ToCSIRepresentation() *csipbv1.NodeExpandVolum return nil } return &csipbv1.NodeExpandVolumeRequest{ - VolumeId: r.ExternalVolumeID, - VolumePath: r.TargetPath, - CapacityRange: &csipbv1.CapacityRange{ - RequiredBytes: r.RequiredBytes, - LimitBytes: r.LimitBytes, - }, + VolumeId: r.ExternalVolumeID, + VolumePath: r.TargetPath, StagingTargetPath: r.StagingPath, + CapacityRange: r.CapacityRange.ToCSIRepresentation(), VolumeCapability: r.Capability.ToCSIRepresentation(), } } diff --git a/plugins/csi/testing/client.go b/plugins/csi/testing/client.go index 7595d79f2..64e9cf8b9 100644 --- a/plugins/csi/testing/client.go +++ b/plugins/csi/testing/client.go @@ -150,6 +150,7 @@ type NodeClient struct { NextPublishVolumeResponse *csipbv1.NodePublishVolumeResponse NextUnpublishVolumeResponse *csipbv1.NodeUnpublishVolumeResponse NextExpandVolumeResponse *csipbv1.NodeExpandVolumeResponse + LastExpandVolumeRequest *csipbv1.NodeExpandVolumeRequest } // NewNodeClient returns a new stub NodeClient @@ -193,5 +194,6 @@ func (c *NodeClient) NodeUnpublishVolume(ctx context.Context, in *csipbv1.NodeUn } func (c *NodeClient) NodeExpandVolume(ctx context.Context, in *csipbv1.NodeExpandVolumeRequest, opts ...grpc.CallOption) (*csipbv1.NodeExpandVolumeResponse, error) { + c.LastExpandVolumeRequest = in return c.NextExpandVolumeResponse, c.NextErr }