diff --git a/plugins/csi/client.go b/plugins/csi/client.go index 83a2c7c30..dd650d96b 100644 --- a/plugins/csi/client.go +++ b/plugins/csi/client.go @@ -198,6 +198,22 @@ func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySe // Controller Endpoints // +func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapabilitySet, error) { + if c == nil { + return nil, fmt.Errorf("Client not initialized") + } + if c.controllerClient == nil { + return nil, fmt.Errorf("controllerClient not initialized") + } + + resp, err := c.controllerClient.ControllerGetCapabilities(ctx, &csipbv1.ControllerGetCapabilitiesRequest{}) + if err != nil { + return nil, err + } + + return NewControllerCapabilitySet(resp), nil +} + func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest) (*ControllerPublishVolumeResponse, error) { if c == nil { return nil, fmt.Errorf("Client not initialized") diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go index 5832aeb5e..84973b308 100644 --- a/plugins/csi/client_test.go +++ b/plugins/csi/client_test.go @@ -191,6 +191,104 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) { } } +func TestClient_RPC_ControllerGetCapabilities(t *testing.T) { + cases := []struct { + Name string + ResponseErr error + Response *csipbv1.ControllerGetCapabilitiesResponse + ExpectedResponse *ControllerCapabilitySet + ExpectedErr error + }{ + { + Name: "handles underlying grpc errors", + ResponseErr: fmt.Errorf("some grpc error"), + ExpectedErr: fmt.Errorf("some grpc error"), + }, + { + Name: "ignores unknown capabilities", + Response: &csipbv1.ControllerGetCapabilitiesResponse{ + Capabilities: []*csipbv1.ControllerServiceCapability{ + { + Type: &csipbv1.ControllerServiceCapability_Rpc{ + Rpc: &csipbv1.ControllerServiceCapability_RPC{ + Type: csipbv1.ControllerServiceCapability_RPC_GET_CAPACITY, + }, + }, + }, + }, + }, + ExpectedResponse: &ControllerCapabilitySet{}, + }, + { + Name: "detects list volumes capabilities", + Response: &csipbv1.ControllerGetCapabilitiesResponse{ + Capabilities: []*csipbv1.ControllerServiceCapability{ + { + Type: &csipbv1.ControllerServiceCapability_Rpc{ + Rpc: &csipbv1.ControllerServiceCapability_RPC{ + Type: csipbv1.ControllerServiceCapability_RPC_LIST_VOLUMES, + }, + }, + }, + { + Type: &csipbv1.ControllerServiceCapability_Rpc{ + Rpc: &csipbv1.ControllerServiceCapability_RPC{ + Type: csipbv1.ControllerServiceCapability_RPC_LIST_VOLUMES_PUBLISHED_NODES, + }, + }, + }, + }, + }, + ExpectedResponse: &ControllerCapabilitySet{ + HasListVolumes: true, + HasListVolumesPublishedNodes: true, + }, + }, + { + Name: "detects publish capabilities", + Response: &csipbv1.ControllerGetCapabilitiesResponse{ + Capabilities: []*csipbv1.ControllerServiceCapability{ + { + Type: &csipbv1.ControllerServiceCapability_Rpc{ + Rpc: &csipbv1.ControllerServiceCapability_RPC{ + Type: csipbv1.ControllerServiceCapability_RPC_PUBLISH_READONLY, + }, + }, + }, + { + Type: &csipbv1.ControllerServiceCapability_Rpc{ + Rpc: &csipbv1.ControllerServiceCapability_RPC{ + Type: csipbv1.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }, + }, + }, + ExpectedResponse: &ControllerCapabilitySet{ + HasPublishUnpublishVolume: true, + HasPublishReadonly: true, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + _, cc, client := newTestClient() + defer client.Close() + + cc.NextErr = tc.ResponseErr + cc.NextCapabilitiesResponse = tc.Response + + resp, err := client.ControllerGetCapabilities(context.TODO()) + if tc.ExpectedErr != nil { + require.Error(t, tc.ExpectedErr, err) + } + + require.Equal(t, tc.ExpectedResponse, resp) + }) + } +} + func TestClient_RPC_ControllerPublishVolume(t *testing.T) { cases := []struct { Name string diff --git a/plugins/csi/fake/client.go b/plugins/csi/fake/client.go index f3d218ff8..b2372906d 100644 --- a/plugins/csi/fake/client.go +++ b/plugins/csi/fake/client.go @@ -35,6 +35,10 @@ type Client struct { NextPluginGetCapabilitiesErr error PluginGetCapabilitiesCallCount int64 + NextControllerGetCapabilitiesResponse *csi.ControllerCapabilitySet + NextControllerGetCapabilitiesErr error + ControllerGetCapabilitiesCallCount int64 + NextControllerPublishVolumeResponse *csi.ControllerPublishVolumeResponse NextControllerPublishVolumeErr error ControllerPublishVolumeCallCount int64 @@ -99,6 +103,15 @@ func (c *Client) PluginGetCapabilities(ctx context.Context) (*csi.PluginCapabili return c.NextPluginGetCapabilitiesResponse, c.NextPluginGetCapabilitiesErr } +func (c *Client) ControllerGetCapabilities(ctx context.Context) (*csi.ControllerCapabilitySet, error) { + c.Mu.Lock() + defer c.Mu.Unlock() + + c.ControllerGetCapabilitiesCallCount++ + + return c.NextControllerGetCapabilitiesResponse, c.NextControllerGetCapabilitiesErr +} + // ControllerPublishVolume is used to attach a remote volume to a node func (c *Client) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { c.Mu.Lock() diff --git a/plugins/csi/plugin.go b/plugins/csi/plugin.go index 012313670..56a813a28 100644 --- a/plugins/csi/plugin.go +++ b/plugins/csi/plugin.go @@ -26,6 +26,10 @@ type CSIPlugin interface { // Accessible Topology Support PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) + // GetControllerCapabilities is used to get controller-specific capabilities + // for a plugin. + ControllerGetCapabilities(ctx context.Context) (*ControllerCapabilitySet, error) + // ControllerPublishVolume is used to attach a remote volume to a cluster node. ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest) (*ControllerPublishVolumeResponse, error) @@ -87,6 +91,37 @@ func NewPluginCapabilitySet(capabilities *csipbv1.GetPluginCapabilitiesResponse) return cs } +type ControllerCapabilitySet struct { + HasPublishUnpublishVolume bool + HasPublishReadonly bool + HasListVolumes bool + HasListVolumesPublishedNodes bool +} + +func NewControllerCapabilitySet(resp *csipbv1.ControllerGetCapabilitiesResponse) *ControllerCapabilitySet { + cs := &ControllerCapabilitySet{} + + pluginCapabilities := resp.GetCapabilities() + for _, pcap := range pluginCapabilities { + if c := pcap.GetRpc(); c != nil { + switch c.Type { + case csipbv1.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME: + cs.HasPublishUnpublishVolume = true + case csipbv1.ControllerServiceCapability_RPC_PUBLISH_READONLY: + cs.HasPublishReadonly = true + case csipbv1.ControllerServiceCapability_RPC_LIST_VOLUMES: + cs.HasListVolumes = true + case csipbv1.ControllerServiceCapability_RPC_LIST_VOLUMES_PUBLISHED_NODES: + cs.HasListVolumesPublishedNodes = true + default: + continue + } + } + } + + return cs +} + type ControllerPublishVolumeRequest struct { VolumeID string NodeID string