diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index 4784c0c20..6c01ac91c 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -1,7 +1,6 @@ package nomad import ( - "errors" "fmt" "math/rand" "time" @@ -10,7 +9,6 @@ import ( log "github.com/hashicorp/go-hclog" memdb "github.com/hashicorp/go-memdb" cstructs "github.com/hashicorp/nomad/client/structs" - "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" ) @@ -23,22 +21,12 @@ type ClientCSI struct { func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now()) - - // Verify the arguments. - if args.ControllerNodeID == "" { - return errors.New("missing ControllerNodeID") - } - - // Make sure Node is valid and new enough to support RPC - snap, err := a.srv.State().Snapshot() - if err != nil { - return err - } - - _, err = getNodeForRpc(snap, args.ControllerNodeID) + // Get a Nomad client node for the controller + nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) if err != nil { return err } + args.ControllerNodeID = nodeID // Get the connection to the client state, ok := a.srv.getNodeConn(args.ControllerNodeID) @@ -57,21 +45,12 @@ func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAtt func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now()) - // Verify the arguments. - if args.ControllerNodeID == "" { - return errors.New("missing ControllerNodeID") - } - - // Make sure Node is valid and new enough to support RPC - snap, err := a.srv.State().Snapshot() - if err != nil { - return err - } - - _, err = getNodeForRpc(snap, args.ControllerNodeID) + // Get a Nomad client node for the controller + nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) if err != nil { return err } + args.ControllerNodeID = nodeID // Get the connection to the client state, ok := a.srv.getNodeConn(args.ControllerNodeID) @@ -90,21 +69,12 @@ func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerV func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now()) - // Verify the arguments. - if args.ControllerNodeID == "" { - return errors.New("missing ControllerNodeID") - } - - // Make sure Node is valid and new enough to support RPC - snap, err := a.srv.State().Snapshot() - if err != nil { - return err - } - - _, err = getNodeForRpc(snap, args.ControllerNodeID) + // Get a Nomad client node for the controller + nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) if err != nil { return err } + args.ControllerNodeID = nodeID // Get the connection to the client state, ok := a.srv.getNodeConn(args.ControllerNodeID) @@ -178,17 +148,43 @@ func (srv *Server) volAndPluginLookup(namespace, volID string) (*structs.CSIPlug return plug, vol, nil } -// nodeForControllerPlugin returns the node ID for a random controller -// to load-balance long-blocking RPCs across client nodes. -func nodeForControllerPlugin(state *state.StateStore, plugin *structs.CSIPlugin) (string, error) { +// nodeForController validates that the Nomad client node ID for +// a plugin exists and is new enough to support client RPC. If no node +// ID is passed, select a random node ID for the controller to load-balance +// long blocking RPCs across client nodes. +func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) { + + snap, err := a.srv.State().Snapshot() + if err != nil { + return "", err + } + + if nodeID != "" { + _, err = getNodeForRpc(snap, nodeID) + if err == nil { + return nodeID, nil + } + } + + if pluginID == "" { + return "", fmt.Errorf("missing plugin ID") + } + ws := memdb.NewWatchSet() + + // note: plugin IDs are not scoped to region/DC but volumes are. + // so any node we get for a controller is already in the same + // region/DC for the volume. + plugin, err := snap.CSIPluginByID(ws, pluginID) + if err != nil { + return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err) + } + if plugin == nil { + return "", fmt.Errorf("plugin missing: %s %v", pluginID, err) + } count := len(plugin.Controllers) if count == 0 { return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID) } - snap, err := state.Snapshot() - if err != nil { - return "", err - } // iterating maps is "random" but unspecified and isn't particularly // random with small maps, so not well-suited for load balancing. diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index b9f84d6ad..7aecc47b0 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -3,10 +3,13 @@ package nomad import ( "testing" + memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/client" "github.com/hashicorp/nomad/client/config" cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" @@ -167,3 +170,45 @@ func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) { // Should recieve an error from the client endpoint require.Contains(err.Error(), "must specify plugin name to dispense") } + +func TestClientCSI_NodeForControllerPlugin(t *testing.T) { + t.Parallel() + srv, shutdown := TestServer(t, func(c *Config) {}) + testutil.WaitForLeader(t, srv.RPC) + defer shutdown() + + plugins := map[string]*structs.CSIInfo{ + "minnie": {PluginID: "minnie", + Healthy: true, + ControllerInfo: &structs.CSIControllerInfo{}, + NodeInfo: &structs.CSINodeInfo{}, + RequiresControllerPlugin: true, + }, + } + state := srv.fsm.State() + + node1 := mock.Node() + node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions + node1.CSIControllerPlugins = plugins + node2 := mock.Node() + node2.CSIControllerPlugins = plugins + node2.ID = uuid.Generate() + node3 := mock.Node() + node3.ID = uuid.Generate() + + err := state.UpsertNode(1002, node1) + require.NoError(t, err) + err = state.UpsertNode(1003, node2) + require.NoError(t, err) + err = state.UpsertNode(1004, node3) + require.NoError(t, err) + + ws := memdb.NewWatchSet() + + plugin, err := state.CSIPluginByID(ws, "minnie") + require.NoError(t, err) + nodeID, err := srv.staticEndpoints.ClientCSI.nodeForController(plugin.ID, "") + + // only node1 has both the controller and a recent Nomad version + require.Equal(t, nodeID, node1.ID) +} diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 934bab4c6..708c8226f 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -867,16 +867,11 @@ func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]i return args.nodeClaims, fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) } - controllerNodeID, err := nodeForControllerPlugin(srv.State(), args.plug) - if err != nil || controllerNodeID == "" { - return args.nodeClaims, err - } cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{ VolumeID: vol.RemoteID(), ClientCSINodeID: targetCSIInfo.NodeInfo.ID, } cReq.PluginID = args.plug.ID - cReq.ControllerNodeID = controllerNodeID err = srv.RPC("ClientCSI.ControllerDetachVolume", cReq, &cstructs.ClientCSIControllerDetachVolumeResponse{}) if err != nil { diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index f33125ed1..ca6d67feb 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -207,8 +207,8 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol return v.srv.blockingRPC(&opts) } -func (srv *Server) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) { - state := srv.fsm.State() +func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) { + state := v.srv.fsm.State() ws := memdb.NewWatchSet() plugin, err := state.CSIPluginByID(ws, vol.PluginID) @@ -224,7 +224,7 @@ func (srv *Server) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, v return plugin, nil } -func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume, plugin *structs.CSIPlugin) error { +func (v *CSIVolume) controllerValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume, plugin *structs.CSIPlugin) error { if !plugin.ControllerRequired { // The plugin does not require a controller, so for now we won't do any @@ -232,18 +232,6 @@ func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterReques return nil } - // The plugin requires a controller. Now we do some validation of the Volume - // to ensure that the registered capabilities are valid and that the volume - // exists. - - // plugin IDs are not scoped to region/DC but volumes are. - // so any node we get for a controller is already in the same region/DC - // for the volume. - nodeID, err := nodeForControllerPlugin(srv.fsm.State(), plugin) - if err != nil || nodeID == "" { - return err - } - method := "ClientCSI.ControllerValidateVolume" cReq := &cstructs.ClientCSIControllerValidateVolumeRequest{ VolumeID: vol.RemoteID(), @@ -251,10 +239,9 @@ func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterReques AccessMode: vol.AccessMode, } cReq.PluginID = plugin.ID - cReq.ControllerNodeID = nodeID cResp := &cstructs.ClientCSIControllerValidateVolumeResponse{} - return srv.RPC(method, cReq, cResp) + return v.srv.RPC(method, cReq, cResp) } // Register registers a new volume @@ -285,11 +272,11 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru return err } - plugin, err := v.srv.pluginValidateVolume(args, vol) + plugin, err := v.pluginValidateVolume(args, vol) if err != nil { return err } - if err := v.srv.controllerValidateVolume(args, vol, plugin); err != nil { + if err := v.controllerValidateVolume(args, vol, plugin); err != nil { return err } } @@ -364,7 +351,7 @@ func (v *CSIVolume) Claim(args *structs.CSIVolumeClaimRequest, reply *structs.CS // if this is a new claim, add a Volume and PublishContext from the // controller (if any) to the reply if args.Claim != structs.CSIVolumeClaimRelease { - err = v.srv.controllerPublishVolume(args, reply) + err = v.controllerPublishVolume(args, reply) if err != nil { return fmt.Errorf("controller publish: %v", err) } @@ -384,6 +371,66 @@ func (v *CSIVolume) Claim(args *structs.CSIVolumeClaimRequest, reply *structs.CS return nil } +// controllerPublishVolume sends publish request to the CSI controller +// plugin associated with a volume, if any. +func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, resp *structs.CSIVolumeClaimResponse) error { + plug, vol, err := v.srv.volAndPluginLookup(req.RequestNamespace(), req.VolumeID) + if err != nil { + return err + } + + // Set the Response volume from the lookup + resp.Volume = vol + + // Validate the existence of the allocation, regardless of whether we need it + // now. + state := v.srv.fsm.State() + ws := memdb.NewWatchSet() + alloc, err := state.AllocByID(ws, req.AllocationID) + if err != nil { + return err + } + if alloc == nil { + return fmt.Errorf("%s: %s", structs.ErrUnknownAllocationPrefix, req.AllocationID) + } + + // if no plugin was returned then controller validation is not required. + // Here we can return nil. + if plug == nil { + return nil + } + + targetNode, err := state.NodeByID(ws, alloc.NodeID) + if err != nil { + return err + } + if targetNode == nil { + return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID) + } + targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID] + if !ok { + return fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) + } + + method := "ClientCSI.ControllerAttachVolume" + cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{ + VolumeID: vol.RemoteID(), + ClientCSINodeID: targetCSIInfo.NodeInfo.ID, + AttachmentMode: vol.AttachmentMode, + AccessMode: vol.AccessMode, + ReadOnly: req.Claim == structs.CSIVolumeClaimRead, + } + cReq.PluginID = plug.ID + cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{} + + err = v.srv.RPC(method, cReq, cResp) + if err != nil { + return fmt.Errorf("attach volume: %v", err) + } + resp.PublishContext = cResp.PublishContext + return nil +} + // allowCSIMount is called on Job register to check mount permission func allowCSIMount(aclObj *acl.ACL, namespace string) bool { return aclObj.AllowPluginRead() && @@ -498,72 +545,3 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu }} return v.srv.blockingRPC(&opts) } - -// controllerPublishVolume sends publish request to the CSI controller -// plugin associated with a volume, if any. -func (srv *Server) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, resp *structs.CSIVolumeClaimResponse) error { - plug, vol, err := srv.volAndPluginLookup(req.RequestNamespace(), req.VolumeID) - if err != nil { - return err - } - - // Set the Response volume from the lookup - resp.Volume = vol - - // Validate the existence of the allocation, regardless of whether we need it - // now. - state := srv.fsm.State() - ws := memdb.NewWatchSet() - alloc, err := state.AllocByID(ws, req.AllocationID) - if err != nil { - return err - } - if alloc == nil { - return fmt.Errorf("%s: %s", structs.ErrUnknownAllocationPrefix, req.AllocationID) - } - - // if no plugin was returned then controller validation is not required. - // Here we can return nil. - if plug == nil { - return nil - } - - // plugin IDs are not scoped to region/DC but volumes are. - // so any node we get for a controller is already in the same region/DC - // for the volume. - nodeID, err := nodeForControllerPlugin(state, plug) - if err != nil || nodeID == "" { - return err - } - - targetNode, err := state.NodeByID(ws, alloc.NodeID) - if err != nil { - return err - } - if targetNode == nil { - return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID) - } - targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID] - if !ok { - return fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) - } - - method := "ClientCSI.ControllerAttachVolume" - cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{ - VolumeID: vol.RemoteID(), - ClientCSINodeID: targetCSIInfo.NodeInfo.ID, - AttachmentMode: vol.AttachmentMode, - AccessMode: vol.AccessMode, - ReadOnly: req.Claim == structs.CSIVolumeClaimRead, - } - cReq.PluginID = plug.ID - cReq.ControllerNodeID = nodeID - cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{} - - err = srv.RPC(method, cReq, cResp) - if err != nil { - return fmt.Errorf("attach volume: %v", err) - } - resp.PublishContext = cResp.PublishContext - return nil -} diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 12f56b02b..2ff04a6ec 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -4,7 +4,6 @@ import ( "fmt" "testing" - memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/helper/uuid" @@ -601,45 +600,3 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) { require.Nil(t, vol) require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2)) } - -func TestCSI_NodeForControllerPlugin(t *testing.T) { - t.Parallel() - srv, shutdown := TestServer(t, func(c *Config) {}) - testutil.WaitForLeader(t, srv.RPC) - defer shutdown() - - plugins := map[string]*structs.CSIInfo{ - "minnie": {PluginID: "minnie", - Healthy: true, - ControllerInfo: &structs.CSIControllerInfo{}, - NodeInfo: &structs.CSINodeInfo{}, - RequiresControllerPlugin: true, - }, - } - state := srv.fsm.State() - - node1 := mock.Node() - node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions - node1.CSIControllerPlugins = plugins - node2 := mock.Node() - node2.CSIControllerPlugins = plugins - node2.ID = uuid.Generate() - node3 := mock.Node() - node3.ID = uuid.Generate() - - err := state.UpsertNode(1002, node1) - require.NoError(t, err) - err = state.UpsertNode(1003, node2) - require.NoError(t, err) - err = state.UpsertNode(1004, node3) - require.NoError(t, err) - - ws := memdb.NewWatchSet() - - plugin, err := state.CSIPluginByID(ws, "minnie") - require.NoError(t, err) - nodeID, err := nodeForControllerPlugin(state, plugin) - - // only node1 has both the controller and a recent Nomad version - require.Equal(t, nodeID, node1.ID) -}