refactor: make nodeForControllerPlugin private to ClientCSI (#7688)

The current design of `ClientCSI` RPC requires that callers in the
server know about the free-standing `nodeForControllerPlugin`
function. This makes it difficult to send `ClientCSI` RPC messages
from subpackages of `nomad` and adds a bunch of boilerplate to every
server-side caller of a controller RPC.

This changeset makes it so that the `ClientCSI` RPCs will populate and
validate the controller's client node ID if it hasn't been passed by
the caller, centralizing the logic of picking and validating
controller targets into the `nomad.ClientCSI` struct.
This commit is contained in:
Tim Gross
2020-04-10 16:47:21 -04:00
committed by GitHub
parent 47dfa762b3
commit 09abe0c702
5 changed files with 154 additions and 183 deletions

View File

@@ -1,7 +1,6 @@
package nomad
import (
"errors"
"fmt"
"math/rand"
"time"
@@ -10,7 +9,6 @@ import (
log "github.com/hashicorp/go-hclog"
memdb "github.com/hashicorp/go-memdb"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
)
@@ -23,22 +21,12 @@ type ClientCSI struct {
func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now())
// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}
// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}
_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID
// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
@@ -57,21 +45,12 @@ func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAtt
func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now())
// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}
// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}
_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID
// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
@@ -90,21 +69,12 @@ func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerV
func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now())
// Verify the arguments.
if args.ControllerNodeID == "" {
return errors.New("missing ControllerNodeID")
}
// Make sure Node is valid and new enough to support RPC
snap, err := a.srv.State().Snapshot()
if err != nil {
return err
}
_, err = getNodeForRpc(snap, args.ControllerNodeID)
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
if err != nil {
return err
}
args.ControllerNodeID = nodeID
// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
@@ -178,17 +148,43 @@ func (srv *Server) volAndPluginLookup(namespace, volID string) (*structs.CSIPlug
return plug, vol, nil
}
// nodeForControllerPlugin returns the node ID for a random controller
// to load-balance long-blocking RPCs across client nodes.
func nodeForControllerPlugin(state *state.StateStore, plugin *structs.CSIPlugin) (string, error) {
// nodeForController validates that the Nomad client node ID for
// a plugin exists and is new enough to support client RPC. If no node
// ID is passed, select a random node ID for the controller to load-balance
// long blocking RPCs across client nodes.
func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {
snap, err := a.srv.State().Snapshot()
if err != nil {
return "", err
}
if nodeID != "" {
_, err = getNodeForRpc(snap, nodeID)
if err == nil {
return nodeID, nil
}
}
if pluginID == "" {
return "", fmt.Errorf("missing plugin ID")
}
ws := memdb.NewWatchSet()
// note: plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same
// region/DC for the volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
}
if plugin == nil {
return "", fmt.Errorf("plugin missing: %s %v", pluginID, err)
}
count := len(plugin.Controllers)
if count == 0 {
return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID)
}
snap, err := state.Snapshot()
if err != nil {
return "", err
}
// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.

View File

@@ -3,10 +3,13 @@ package nomad
import (
"testing"
memdb "github.com/hashicorp/go-memdb"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/client"
"github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
@@ -167,3 +170,45 @@ func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) {
// Should recieve an error from the client endpoint
require.Contains(err.Error(), "must specify plugin name to dispense")
}
func TestClientCSI_NodeForControllerPlugin(t *testing.T) {
t.Parallel()
srv, shutdown := TestServer(t, func(c *Config) {})
testutil.WaitForLeader(t, srv.RPC)
defer shutdown()
plugins := map[string]*structs.CSIInfo{
"minnie": {PluginID: "minnie",
Healthy: true,
ControllerInfo: &structs.CSIControllerInfo{},
NodeInfo: &structs.CSINodeInfo{},
RequiresControllerPlugin: true,
},
}
state := srv.fsm.State()
node1 := mock.Node()
node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions
node1.CSIControllerPlugins = plugins
node2 := mock.Node()
node2.CSIControllerPlugins = plugins
node2.ID = uuid.Generate()
node3 := mock.Node()
node3.ID = uuid.Generate()
err := state.UpsertNode(1002, node1)
require.NoError(t, err)
err = state.UpsertNode(1003, node2)
require.NoError(t, err)
err = state.UpsertNode(1004, node3)
require.NoError(t, err)
ws := memdb.NewWatchSet()
plugin, err := state.CSIPluginByID(ws, "minnie")
require.NoError(t, err)
nodeID, err := srv.staticEndpoints.ClientCSI.nodeForController(plugin.ID, "")
// only node1 has both the controller and a recent Nomad version
require.Equal(t, nodeID, node1.ID)
}

View File

@@ -867,16 +867,11 @@ func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]i
return args.nodeClaims, fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID)
}
controllerNodeID, err := nodeForControllerPlugin(srv.State(), args.plug)
if err != nil || controllerNodeID == "" {
return args.nodeClaims, err
}
cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{
VolumeID: vol.RemoteID(),
ClientCSINodeID: targetCSIInfo.NodeInfo.ID,
}
cReq.PluginID = args.plug.ID
cReq.ControllerNodeID = controllerNodeID
err = srv.RPC("ClientCSI.ControllerDetachVolume", cReq,
&cstructs.ClientCSIControllerDetachVolumeResponse{})
if err != nil {

View File

@@ -207,8 +207,8 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol
return v.srv.blockingRPC(&opts)
}
func (srv *Server) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) {
state := srv.fsm.State()
func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) {
state := v.srv.fsm.State()
ws := memdb.NewWatchSet()
plugin, err := state.CSIPluginByID(ws, vol.PluginID)
@@ -224,7 +224,7 @@ func (srv *Server) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, v
return plugin, nil
}
func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume, plugin *structs.CSIPlugin) error {
func (v *CSIVolume) controllerValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume, plugin *structs.CSIPlugin) error {
if !plugin.ControllerRequired {
// The plugin does not require a controller, so for now we won't do any
@@ -232,18 +232,6 @@ func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterReques
return nil
}
// The plugin requires a controller. Now we do some validation of the Volume
// to ensure that the registered capabilities are valid and that the volume
// exists.
// plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same region/DC
// for the volume.
nodeID, err := nodeForControllerPlugin(srv.fsm.State(), plugin)
if err != nil || nodeID == "" {
return err
}
method := "ClientCSI.ControllerValidateVolume"
cReq := &cstructs.ClientCSIControllerValidateVolumeRequest{
VolumeID: vol.RemoteID(),
@@ -251,10 +239,9 @@ func (srv *Server) controllerValidateVolume(req *structs.CSIVolumeRegisterReques
AccessMode: vol.AccessMode,
}
cReq.PluginID = plugin.ID
cReq.ControllerNodeID = nodeID
cResp := &cstructs.ClientCSIControllerValidateVolumeResponse{}
return srv.RPC(method, cReq, cResp)
return v.srv.RPC(method, cReq, cResp)
}
// Register registers a new volume
@@ -285,11 +272,11 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru
return err
}
plugin, err := v.srv.pluginValidateVolume(args, vol)
plugin, err := v.pluginValidateVolume(args, vol)
if err != nil {
return err
}
if err := v.srv.controllerValidateVolume(args, vol, plugin); err != nil {
if err := v.controllerValidateVolume(args, vol, plugin); err != nil {
return err
}
}
@@ -364,7 +351,7 @@ func (v *CSIVolume) Claim(args *structs.CSIVolumeClaimRequest, reply *structs.CS
// if this is a new claim, add a Volume and PublishContext from the
// controller (if any) to the reply
if args.Claim != structs.CSIVolumeClaimRelease {
err = v.srv.controllerPublishVolume(args, reply)
err = v.controllerPublishVolume(args, reply)
if err != nil {
return fmt.Errorf("controller publish: %v", err)
}
@@ -384,6 +371,66 @@ func (v *CSIVolume) Claim(args *structs.CSIVolumeClaimRequest, reply *structs.CS
return nil
}
// controllerPublishVolume sends publish request to the CSI controller
// plugin associated with a volume, if any.
func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, resp *structs.CSIVolumeClaimResponse) error {
plug, vol, err := v.srv.volAndPluginLookup(req.RequestNamespace(), req.VolumeID)
if err != nil {
return err
}
// Set the Response volume from the lookup
resp.Volume = vol
// Validate the existence of the allocation, regardless of whether we need it
// now.
state := v.srv.fsm.State()
ws := memdb.NewWatchSet()
alloc, err := state.AllocByID(ws, req.AllocationID)
if err != nil {
return err
}
if alloc == nil {
return fmt.Errorf("%s: %s", structs.ErrUnknownAllocationPrefix, req.AllocationID)
}
// if no plugin was returned then controller validation is not required.
// Here we can return nil.
if plug == nil {
return nil
}
targetNode, err := state.NodeByID(ws, alloc.NodeID)
if err != nil {
return err
}
if targetNode == nil {
return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID)
}
targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID]
if !ok {
return fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID)
}
method := "ClientCSI.ControllerAttachVolume"
cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{
VolumeID: vol.RemoteID(),
ClientCSINodeID: targetCSIInfo.NodeInfo.ID,
AttachmentMode: vol.AttachmentMode,
AccessMode: vol.AccessMode,
ReadOnly: req.Claim == structs.CSIVolumeClaimRead,
}
cReq.PluginID = plug.ID
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}
err = v.srv.RPC(method, cReq, cResp)
if err != nil {
return fmt.Errorf("attach volume: %v", err)
}
resp.PublishContext = cResp.PublishContext
return nil
}
// allowCSIMount is called on Job register to check mount permission
func allowCSIMount(aclObj *acl.ACL, namespace string) bool {
return aclObj.AllowPluginRead() &&
@@ -498,72 +545,3 @@ func (v *CSIPlugin) Get(args *structs.CSIPluginGetRequest, reply *structs.CSIPlu
}}
return v.srv.blockingRPC(&opts)
}
// controllerPublishVolume sends publish request to the CSI controller
// plugin associated with a volume, if any.
func (srv *Server) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, resp *structs.CSIVolumeClaimResponse) error {
plug, vol, err := srv.volAndPluginLookup(req.RequestNamespace(), req.VolumeID)
if err != nil {
return err
}
// Set the Response volume from the lookup
resp.Volume = vol
// Validate the existence of the allocation, regardless of whether we need it
// now.
state := srv.fsm.State()
ws := memdb.NewWatchSet()
alloc, err := state.AllocByID(ws, req.AllocationID)
if err != nil {
return err
}
if alloc == nil {
return fmt.Errorf("%s: %s", structs.ErrUnknownAllocationPrefix, req.AllocationID)
}
// if no plugin was returned then controller validation is not required.
// Here we can return nil.
if plug == nil {
return nil
}
// plugin IDs are not scoped to region/DC but volumes are.
// so any node we get for a controller is already in the same region/DC
// for the volume.
nodeID, err := nodeForControllerPlugin(state, plug)
if err != nil || nodeID == "" {
return err
}
targetNode, err := state.NodeByID(ws, alloc.NodeID)
if err != nil {
return err
}
if targetNode == nil {
return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID)
}
targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID]
if !ok {
return fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID)
}
method := "ClientCSI.ControllerAttachVolume"
cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{
VolumeID: vol.RemoteID(),
ClientCSINodeID: targetCSIInfo.NodeInfo.ID,
AttachmentMode: vol.AttachmentMode,
AccessMode: vol.AccessMode,
ReadOnly: req.Claim == structs.CSIVolumeClaimRead,
}
cReq.PluginID = plug.ID
cReq.ControllerNodeID = nodeID
cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{}
err = srv.RPC(method, cReq, cResp)
if err != nil {
return fmt.Errorf("attach volume: %v", err)
}
resp.PublishContext = cResp.PublishContext
return nil
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"testing"
memdb "github.com/hashicorp/go-memdb"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/helper/uuid"
@@ -601,45 +600,3 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) {
require.Nil(t, vol)
require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2))
}
func TestCSI_NodeForControllerPlugin(t *testing.T) {
t.Parallel()
srv, shutdown := TestServer(t, func(c *Config) {})
testutil.WaitForLeader(t, srv.RPC)
defer shutdown()
plugins := map[string]*structs.CSIInfo{
"minnie": {PluginID: "minnie",
Healthy: true,
ControllerInfo: &structs.CSIControllerInfo{},
NodeInfo: &structs.CSINodeInfo{},
RequiresControllerPlugin: true,
},
}
state := srv.fsm.State()
node1 := mock.Node()
node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions
node1.CSIControllerPlugins = plugins
node2 := mock.Node()
node2.CSIControllerPlugins = plugins
node2.ID = uuid.Generate()
node3 := mock.Node()
node3.ID = uuid.Generate()
err := state.UpsertNode(1002, node1)
require.NoError(t, err)
err = state.UpsertNode(1003, node2)
require.NoError(t, err)
err = state.UpsertNode(1004, node3)
require.NoError(t, err)
ws := memdb.NewWatchSet()
plugin, err := state.CSIPluginByID(ws, "minnie")
require.NoError(t, err)
nodeID, err := nodeForControllerPlugin(state, plugin)
// only node1 has both the controller and a recent Nomad version
require.Equal(t, nodeID, node1.ID)
}