diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index ff4c050ba..6a0aa995b 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -201,6 +201,39 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol return v.srv.blockingRPC(&opts) } +func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) error { + state := srv.fsm.State() + ws := memdb.NewWatchSet() + + plugin, err := state.CSIPluginByID(ws, vol.PluginID) + if err != nil { + return err + } + if plugin == nil { + return fmt.Errorf("no CSI plugin named: %s could be found", vol.PluginID) + } + + if !plugin.ControllerRequired { + // The plugin does not require a controller, so for now we won't do any + // further validation of the volume. + return nil + } + + // The plugin requires a controller. Now we do some validation of the Volume + // to ensure that the registered capabilities are valid and that the volume + // exists. + method := "ClientCSI.CSIControllerValidateVolume" + cReq := &cstructs.ClientCSIControllerValidateVolumeRequest{ + PluginID: plugin.ID, + VolumeID: vol.ID, + AttachmentMode: vol.AttachmentMode, + AccessMode: vol.AccessMode, + } + cResp := &cstructs.ClientCSIControllerValidateVolumeResponse{} + + return srv.csiControllerRPC(plugin, method, cReq, cResp) +} + // Register registers a new volume func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *structs.CSIVolumeRegisterResponse) error { if done, err := v.srv.forward("CSIVolume.Register", args, args, reply); done { @@ -220,12 +253,18 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru return structs.ErrPermissionDenied } - // This is the only namespace we ACL checked, force all the volumes to use it + // This is the only namespace we ACL checked, force all the volumes to use it. + // We also validate that the plugin exists for each plugin, and validate the + // capabilities when the plugin has a controller. for _, vol := range args.Volumes { vol.Namespace = args.RequestNamespace() if err = vol.Validate(); err != nil { return err } + + if err := v.srv.controllerValidateVolume(args, vol); err != nil { + return err + } } resp, index, err := v.srv.raftApply(structs.CSIVolumeRegisterRequestType, args) diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 1931d39c4..6d784ad77 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -116,15 +116,22 @@ func TestCSIVolumeEndpoint_Register(t *testing.T) { ns := structs.DefaultNamespace state := srv.fsm.State() - state.BootstrapACLTokens(1, 0, mock.ACLManagementToken()) - srv.config.ACLEnabled = true - policy := mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSICreateVolume}) - validToken := mock.CreatePolicyAndToken(t, state, 1001, acl.NamespaceCapabilityCSICreateVolume, policy) - codec := rpcClient(t, srv) id0 := uuid.Generate() + // Create the node and plugin + node := mock.Node() + node.CSINodePlugins = map[string]*structs.CSIInfo{ + "minnie": {PluginID: "minnie", + Healthy: true, + // Registers as node plugin that does not require a controller to skip + // the client RPC during registration. + NodeInfo: &structs.CSINodeInfo{}, + }, + } + require.NoError(t, state.UpsertNode(1000, node)) + // Create the volume vols := []*structs.CSIVolume{{ ID: id0, @@ -132,9 +139,6 @@ func TestCSIVolumeEndpoint_Register(t *testing.T) { PluginID: "minnie", AccessMode: structs.CSIVolumeAccessModeMultiNodeReader, AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, - Topologies: []*structs.CSITopology{{ - Segments: map[string]string{"foo": "bar"}, - }}, }} // Create the register request @@ -143,7 +147,6 @@ func TestCSIVolumeEndpoint_Register(t *testing.T) { WriteRequest: structs.WriteRequest{ Region: "global", Namespace: ns, - AuthToken: validToken.SecretID, }, } resp1 := &structs.CSIVolumeRegisterResponse{} @@ -152,14 +155,10 @@ func TestCSIVolumeEndpoint_Register(t *testing.T) { require.NotEqual(t, uint64(0), resp1.Index) // Get the volume back out - policy = mock.NamespacePolicy(ns, "", []string{acl.NamespaceCapabilityCSIAccess}) - getToken := mock.CreatePolicyAndToken(t, state, 1001, "csi-access", policy) - req2 := &structs.CSIVolumeGetRequest{ ID: id0, QueryOptions: structs.QueryOptions{ - Region: "global", - AuthToken: getToken.SecretID, + Region: "global", }, } resp2 := &structs.CSIVolumeGetResponse{} @@ -179,7 +178,6 @@ func TestCSIVolumeEndpoint_Register(t *testing.T) { WriteRequest: structs.WriteRequest{ Region: "global", Namespace: ns, - AuthToken: validToken.SecretID, }, } resp3 := &structs.CSIVolumeDeregisterResponse{}