diff --git a/plugins/csi/client.go b/plugins/csi/client.go index dd650d96b..8d45a7f71 100644 --- a/plugins/csi/client.go +++ b/plugins/csi/client.go @@ -60,11 +60,18 @@ type CSIControllerClient interface { ValidateVolumeCapabilities(ctx context.Context, in *csipbv1.ValidateVolumeCapabilitiesRequest, opts ...grpc.CallOption) (*csipbv1.ValidateVolumeCapabilitiesResponse, error) } +// CSINodeClient defines the minimal CSI Node Plugin interface used +// by nomad to simplify the interface required for testing. +type CSINodeClient interface { + NodeGetCapabilities(ctx context.Context, in *csipbv1.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipbv1.NodeGetCapabilitiesResponse, error) + NodeGetInfo(ctx context.Context, in *csipbv1.NodeGetInfoRequest, opts ...grpc.CallOption) (*csipbv1.NodeGetInfoResponse, error) +} + type client struct { conn *grpc.ClientConn identityClient csipbv1.IdentityClient controllerClient CSIControllerClient - nodeClient csipbv1.NodeClient + nodeClient CSINodeClient } func (c *client) Close() error { @@ -243,6 +250,22 @@ func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPub // Node Endpoints // +func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, error) { + if c == nil { + return nil, fmt.Errorf("Client not initialized") + } + if c.nodeClient == nil { + return nil, fmt.Errorf("Client not initialized") + } + + resp, err := c.nodeClient.NodeGetCapabilities(ctx, &csipbv1.NodeGetCapabilitiesRequest{}) + if err != nil { + return nil, err + } + + return NewNodeCapabilitySet(resp), nil +} + func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) { if c == nil { return nil, fmt.Errorf("Client not initialized") diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go index 84973b308..1d372ad34 100644 --- a/plugins/csi/client_test.go +++ b/plugins/csi/client_test.go @@ -11,15 +11,17 @@ import ( "github.com/stretchr/testify/require" ) -func newTestClient() (*fake.IdentityClient, *fake.ControllerClient, CSIPlugin) { - ic := &fake.IdentityClient{} - cc := &fake.ControllerClient{} +func newTestClient() (*fake.IdentityClient, *fake.ControllerClient, *fake.NodeClient, CSIPlugin) { + ic := fake.NewIdentityClient() + cc := fake.NewControllerClient() + nc := fake.NewNodeClient() client := &client{ identityClient: ic, controllerClient: cc, + nodeClient: nc, } - return ic, cc, client + return ic, cc, nc, client } func TestClient_RPC_PluginProbe(t *testing.T) { @@ -63,7 +65,7 @@ func TestClient_RPC_PluginProbe(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - ic, _, client := newTestClient() + ic, _, _, client := newTestClient() defer client.Close() ic.NextErr = c.ResponseErr @@ -111,7 +113,7 @@ func TestClient_RPC_PluginInfo(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - ic, _, client := newTestClient() + ic, _, _, client := newTestClient() defer client.Close() ic.NextErr = c.ResponseErr @@ -175,7 +177,7 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - ic, _, client := newTestClient() + ic, _, _, client := newTestClient() defer client.Close() ic.NextErr = c.ResponseErr @@ -273,7 +275,7 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - _, cc, client := newTestClient() + _, cc, _, client := newTestClient() defer client.Close() cc.NextErr = tc.ResponseErr @@ -289,6 +291,71 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) { } } +func TestClient_RPC_NodeGetCapabilities(t *testing.T) { + cases := []struct { + Name string + ResponseErr error + Response *csipbv1.NodeGetCapabilitiesResponse + ExpectedResponse *NodeCapabilitySet + ExpectedErr error + }{ + { + Name: "handles underlying grpc errors", + ResponseErr: fmt.Errorf("some grpc error"), + ExpectedErr: fmt.Errorf("some grpc error"), + }, + { + Name: "ignores unknown capabilities", + Response: &csipbv1.NodeGetCapabilitiesResponse{ + Capabilities: []*csipbv1.NodeServiceCapability{ + { + Type: &csipbv1.NodeServiceCapability_Rpc{ + Rpc: &csipbv1.NodeServiceCapability_RPC{ + Type: csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME, + }, + }, + }, + }, + }, + ExpectedResponse: &NodeCapabilitySet{}, + }, + { + Name: "detects stage volumes capability", + Response: &csipbv1.NodeGetCapabilitiesResponse{ + Capabilities: []*csipbv1.NodeServiceCapability{ + { + Type: &csipbv1.NodeServiceCapability_Rpc{ + Rpc: &csipbv1.NodeServiceCapability_RPC{ + Type: csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME, + }, + }, + }, + }, + }, + ExpectedResponse: &NodeCapabilitySet{ + HasStageUnstageVolume: true, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + _, _, nc, client := newTestClient() + defer client.Close() + + nc.NextErr = tc.ResponseErr + nc.NextCapabilitiesResponse = tc.Response + + resp, err := client.NodeGetCapabilities(context.TODO()) + if tc.ExpectedErr != nil { + require.Error(t, tc.ExpectedErr, err) + } + + require.Equal(t, tc.ExpectedResponse, resp) + }) + } +} + func TestClient_RPC_ControllerPublishVolume(t *testing.T) { cases := []struct { Name string @@ -326,7 +393,7 @@ func TestClient_RPC_ControllerPublishVolume(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - _, cc, client := newTestClient() + _, cc, _, client := newTestClient() defer client.Close() cc.NextErr = c.ResponseErr diff --git a/plugins/csi/fake/client.go b/plugins/csi/fake/client.go index b2372906d..eb3d2c79b 100644 --- a/plugins/csi/fake/client.go +++ b/plugins/csi/fake/client.go @@ -43,6 +43,10 @@ type Client struct { NextControllerPublishVolumeErr error ControllerPublishVolumeCallCount int64 + NextNodeGetCapabilitiesResponse *csi.NodeCapabilitySet + NextNodeGetCapabilitiesErr error + NodeGetCapabilitiesCallCount int64 + NextNodeGetInfoResponse *csi.NodeGetInfoResponse NextNodeGetInfoErr error NodeGetInfoCallCount int64 @@ -122,6 +126,15 @@ func (c *Client) ControllerPublishVolume(ctx context.Context, req *csi.Controlle return c.NextControllerPublishVolumeResponse, c.NextControllerPublishVolumeErr } +func (c *Client) NodeGetCapabilities(ctx context.Context) (*csi.NodeCapabilitySet, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.NodeGetCapabilitiesCallCount++ + + return c.NextNodeGetCapabilitiesResponse, c.NextNodeGetCapabilitiesErr +} + // NodeGetInfo is used to return semantic data about the current node in // respect to the SP. func (c *Client) NodeGetInfo(ctx context.Context) (*csi.NodeGetInfoResponse, error) { diff --git a/plugins/csi/plugin.go b/plugins/csi/plugin.go index 56a813a28..7273c65f1 100644 --- a/plugins/csi/plugin.go +++ b/plugins/csi/plugin.go @@ -33,6 +33,10 @@ type CSIPlugin interface { // ControllerPublishVolume is used to attach a remote volume to a cluster node. ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest) (*ControllerPublishVolumeResponse, error) + // NodeGetCapabilities is used to return the available capabilities from the + // Node Service. + NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, error) + // NodeGetInfo is used to return semantic data about the current node in // respect to the SP. NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) @@ -133,3 +137,24 @@ type ControllerPublishVolumeRequest struct { type ControllerPublishVolumeResponse struct { PublishContext map[string]string } + +type NodeCapabilitySet struct { + HasStageUnstageVolume bool +} + +func NewNodeCapabilitySet(resp *csipbv1.NodeGetCapabilitiesResponse) *NodeCapabilitySet { + cs := &NodeCapabilitySet{} + pluginCapabilities := resp.GetCapabilities() + for _, pcap := range pluginCapabilities { + if c := pcap.GetRpc(); c != nil { + switch c.Type { + case csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME: + cs.HasStageUnstageVolume = true + default: + continue + } + } + } + + return cs +} diff --git a/plugins/csi/testing/client.go b/plugins/csi/testing/client.go index 95739de46..75f20929d 100644 --- a/plugins/csi/testing/client.go +++ b/plugins/csi/testing/client.go @@ -77,3 +77,29 @@ func (c *ControllerClient) ControllerUnpublishVolume(ctx context.Context, in *cs func (c *ControllerClient) ValidateVolumeCapabilities(ctx context.Context, in *csipbv1.ValidateVolumeCapabilitiesRequest, opts ...grpc.CallOption) (*csipbv1.ValidateVolumeCapabilitiesResponse, error) { panic("not implemented") // TODO: Implement } + +// NodeClient is a CSI Node client used for testing +type NodeClient struct { + NextErr error + NextCapabilitiesResponse *csipbv1.NodeGetCapabilitiesResponse + NextGetInfoResponse *csipbv1.NodeGetInfoResponse +} + +// NewNodeClient returns a new ControllerClient +func NewNodeClient() *NodeClient { + return &NodeClient{} +} + +func (f *NodeClient) Reset() { + f.NextErr = nil + f.NextCapabilitiesResponse = nil + f.NextGetInfoResponse = nil +} + +func (c *NodeClient) NodeGetCapabilities(ctx context.Context, in *csipbv1.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipbv1.NodeGetCapabilitiesResponse, error) { + return c.NextCapabilitiesResponse, c.NextErr +} + +func (c *NodeClient) NodeGetInfo(ctx context.Context, in *csipbv1.NodeGetInfoRequest, opts ...grpc.CallOption) (*csipbv1.NodeGetInfoResponse, error) { + return c.NextGetInfoResponse, c.NextErr +}