diff --git a/client/client.go b/client/client.go index 6252e9500..5dad2ee97 100644 --- a/client/client.go +++ b/client/client.go @@ -308,8 +308,12 @@ var ( noServersErr = errors.New("no servers") ) -// NewClient is used to create a new client from the given configuration -func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxies consulApi.SupportedProxiesAPI, consulService consulApi.ConsulServiceAPI) (*Client, error) { +// NewClient is used to create a new client from the given configuration. +// `rpcs` is a map of RPC names to RPC structs that, if non-nil, will be +// registered via https://golang.org/pkg/net/rpc/#Server.RegisterName in place +// of the client's normal RPC handlers. This allows server tests to override +// the behavior of the client. +func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxies consulApi.SupportedProxiesAPI, consulService consulApi.ConsulServiceAPI, rpcs map[string]interface{}) (*Client, error) { // Create the tls wrapper var tlsWrap tlsutil.RegionWrapper if cfg.TLSConfig.EnableRPC { @@ -384,7 +388,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie }) // Setup the clients RPC server - c.setupClientRpc() + c.setupClientRpc(rpcs) // Initialize the ACL state if err := c.clientACLResolver.init(); err != nil { diff --git a/client/client_test.go b/client/client_test.go index c00c18a4f..ee5215e49 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -622,7 +622,7 @@ func TestClient_SaveRestoreState(t *testing.T) { c1.config.PluginLoader = catalog.TestPluginLoaderWithOptions(t, "", c1.config.Options, nil) c1.config.PluginSingletonLoader = singleton.NewSingletonLoader(logger, c1.config.PluginLoader) - c2, err := NewClient(c1.config, consulCatalog, nil, mockService) + c2, err := NewClient(c1.config, consulCatalog, nil, mockService, nil) if err != nil { t.Fatalf("err: %v", err) } diff --git a/client/rpc.go b/client/rpc.go index 11ea5bf91..c106f5d4f 100644 --- a/client/rpc.go +++ b/client/rpc.go @@ -245,19 +245,24 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co } // setupClientRpc is used to setup the Client's RPC endpoints -func (c *Client) setupClientRpc() { - // Initialize the RPC handlers - c.endpoints.ClientStats = &ClientStats{c} - c.endpoints.CSI = &CSI{c} - c.endpoints.FileSystem = NewFileSystemEndpoint(c) - c.endpoints.Allocations = NewAllocationsEndpoint(c) - c.endpoints.Agent = NewAgentEndpoint(c) - +func (c *Client) setupClientRpc(rpcs map[string]interface{}) { // Create the RPC Server c.rpcServer = rpc.NewServer() - // Register the endpoints with the RPC server - c.setupClientRpcServer(c.rpcServer) + // Initialize the RPC handlers + if rpcs != nil { + // override RPCs + for name, rpc := range rpcs { + c.rpcServer.RegisterName(name, rpc) + } + } else { + c.endpoints.ClientStats = &ClientStats{c} + c.endpoints.CSI = &CSI{c} + c.endpoints.FileSystem = NewFileSystemEndpoint(c) + c.endpoints.Allocations = NewAllocationsEndpoint(c) + c.endpoints.Agent = NewAgentEndpoint(c) + c.setupClientRpcServer(c.rpcServer) + } go c.rpcConnListener() } diff --git a/client/testing.go b/client/testing.go index 6ce3ddd29..94681f76e 100644 --- a/client/testing.go +++ b/client/testing.go @@ -2,14 +2,18 @@ package client import ( "fmt" + "net" + "net/rpc" "time" "github.com/hashicorp/nomad/client/config" consulapi "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/fingerprint" + "github.com/hashicorp/nomad/client/servers" agentconsul "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/helper/pluginutils/catalog" "github.com/hashicorp/nomad/helper/pluginutils/singleton" + "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/helper/testlog" testing "github.com/mitchellh/go-testing-interface" ) @@ -21,6 +25,10 @@ import ( // and removed in the returned cleanup function. If they are overridden in the // callback then the caller still must run the returned cleanup func. func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error) { + return TestClientWithRPCs(t, cb, nil) +} + +func TestClientWithRPCs(t testing.T, cb func(c *config.Config), rpcs map[string]interface{}) (*Client, func() error) { conf, cleanup := config.TestClientConfig(t) // Tighten the fingerprinter timeouts (must be done in client package @@ -46,7 +54,7 @@ func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error) } mockCatalog := agentconsul.NewMockCatalog(logger) mockService := consulapi.NewMockConsulServiceClient(t, logger) - client, err := NewClient(conf, mockCatalog, nil, mockService) + client, err := NewClient(conf, mockCatalog, nil, mockService, rpcs) if err != nil { cleanup() t.Fatalf("err: %v", err) @@ -75,3 +83,51 @@ func TestClient(t testing.T, cb func(c *config.Config)) (*Client, func() error) } } } + +// TestRPCOnlyClient is a client that only pings to establish a connection +// with the server and then returns mock RPC responses for those interfaces +// passed in the `rpcs` parameter. Useful for testing client RPCs from the +// server. Returns the Client, a shutdown function, and any error. +func TestRPCOnlyClient(t testing.T, srvAddr net.Addr, rpcs map[string]interface{}) (*Client, func() error, error) { + var err error + conf, cleanup := config.TestClientConfig(t) + + client := &Client{config: conf, logger: testlog.HCLogger(t)} + client.servers = servers.New(client.logger, client.shutdownCh, client) + client.configCopy = client.config.Copy() + + client.rpcServer = rpc.NewServer() + for name, rpc := range rpcs { + client.rpcServer.RegisterName(name, rpc) + } + + client.connPool = pool.NewPool(testlog.HCLogger(t), 10*time.Second, 10, nil) + + cancelFunc := func() error { + ch := make(chan error) + + go func() { + defer close(ch) + client.connPool.Shutdown() + client.shutdownGroup.Wait() + cleanup() + }() + + select { + case <-ch: + return nil + case <-time.After(1 * time.Minute): + return fmt.Errorf("timed out while shutting down client") + } + } + + go client.rpcConnListener() + + _, err = client.SetServers([]string{srvAddr.String()}) + if err != nil { + return nil, cancelFunc, fmt.Errorf("could not set servers: %v", err) + } + client.shutdownGroup.Go(client.registerAndHeartbeat) + + return client, cancelFunc, nil +} diff --git a/command/agent/agent.go b/command/agent/agent.go index b98ce1ce7..32c134bb1 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -861,7 +861,8 @@ func (a *Agent) setupClient() error { conf.StateDBFactory = state.GetStateDBFactory(conf.DevMode) } - nomadClient, err := client.NewClient(conf, a.consulCatalog, a.consulProxies, a.consulService) + nomadClient, err := client.NewClient( + conf, a.consulCatalog, a.consulProxies, a.consulService, nil) if err != nil { return fmt.Errorf("client setup failed: %v", err) } diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index 0852bc2f1..3cc6c7853 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -18,6 +18,42 @@ import ( "github.com/stretchr/testify/require" ) +// MockClientCSI is a mock for the nomad.ClientCSI RPC server (see +// nomad/client_csi_endpoint.go). This can be used with a TestRPCOnlyClient to +// return specific plugin responses back to server RPCs for testing. Note that +// responses that have no bodies have no "Next*Response" field and will always +// return an empty response body. +type MockClientCSI struct { + NextValidateError error + NextAttachError error + NextAttachResponse *cstructs.ClientCSIControllerAttachVolumeResponse + NextDetachError error + NextNodeDetachError error +} + +func newMockClientCSI() *MockClientCSI { + return &MockClientCSI{ + NextAttachResponse: &cstructs.ClientCSIControllerAttachVolumeResponse{}, + } +} + +func (c *MockClientCSI) ControllerValidateVolume(req *cstructs.ClientCSIControllerValidateVolumeRequest, resp *cstructs.ClientCSIControllerValidateVolumeResponse) error { + return c.NextValidateError +} + +func (c *MockClientCSI) ControllerAttachVolume(req *cstructs.ClientCSIControllerAttachVolumeRequest, resp *cstructs.ClientCSIControllerAttachVolumeResponse) error { + *resp = *c.NextAttachResponse + return c.NextAttachError +} + +func (c *MockClientCSI) ControllerDetachVolume(req *cstructs.ClientCSIControllerDetachVolumeRequest, resp *cstructs.ClientCSIControllerDetachVolumeResponse) error { + return c.NextDetachError +} + +func (c *MockClientCSI) NodeDetachVolume(req *cstructs.ClientCSINodeDetachVolumeRequest, resp *cstructs.ClientCSINodeDetachVolumeResponse) error { + return c.NextNodeDetachError +} + func TestClientCSIController_AttachVolume_Local(t *testing.T) { t.Parallel() require := require.New(t) @@ -30,7 +66,7 @@ func TestClientCSIController_AttachVolume_Local(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -46,7 +82,7 @@ func TestClientCSIController_AttachVolume_Forwarded(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -62,7 +98,7 @@ func TestClientCSIController_DetachVolume_Local(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -78,7 +114,7 @@ func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -95,7 +131,7 @@ func TestClientCSIController_ValidateVolume_Local(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -112,7 +148,7 @@ func TestClientCSIController_ValidateVolume_Forwarded(t *testing.T) { var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) - require.NotNil(err) + require.Error(err) require.Contains(err.Error(), "no plugins registered for type") } @@ -163,9 +199,12 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) { // returns a RPC client to the leader and a cleanup function. func setupForward(t *testing.T) (rpc.ClientCodec, func()) { - s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) + s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) + s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) + TestJoin(t, s1, s2) testutil.WaitForLeader(t, s1.RPC) + testutil.WaitForLeader(t, s2.RPC) codec := rpcClient(t, s1) c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { @@ -176,24 +215,22 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) { select { case <-c1.Ready(): case <-time.After(10 * time.Second): - cleanupS1() cleanupC1() + cleanupS1() + cleanupS2() t.Fatal("client timedout on initialize") } - waitForNodes(t, s1, 1, 1) - - s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) - TestJoin(t, s1, s2) - c2, cleanupC2 := client.TestClient(t, func(c *config.Config) { c.Servers = []string{s2.config.RPCAddr.String()} }) select { case <-c2.Ready(): case <-time.After(10 * time.Second): - cleanupS1() cleanupC1() + cleanupC2() + cleanupS1() + cleanupS2() t.Fatal("client timedout on initialize") } @@ -224,10 +261,10 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) { s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1) cleanup := func() { - cleanupS1() cleanupC1() - cleanupS2() cleanupC2() + cleanupS2() + cleanupS1() } return codec, cleanup @@ -235,23 +272,43 @@ func setupForward(t *testing.T) (rpc.ClientCodec, func()) { // sets up a single server with a client, and registers a plugin to the client. func setupLocal(t *testing.T) (rpc.ClientCodec, func()) { - + var err error s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) testutil.WaitForLeader(t, s1.RPC) codec := rpcClient(t, s1) - c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s1.config.RPCAddr.String()} - }) + mockCSI := newMockClientCSI() + mockCSI.NextValidateError = fmt.Errorf("no plugins registered for type") + mockCSI.NextAttachError = fmt.Errorf("no plugins registered for type") + mockCSI.NextDetachError = fmt.Errorf("no plugins registered for type") - // Wait for client initialization - select { - case <-c1.Ready(): - case <-time.After(10 * time.Second): - cleanupS1() + c1, cleanupC1 := client.TestClientWithRPCs(t, + func(c *config.Config) { + c.Servers = []string{s1.config.RPCAddr.String()} + }, + map[string]interface{}{"CSI": mockCSI}, + ) + + if err != nil { cleanupC1() - t.Fatal("client timedout on initialize") + cleanupS1() + require.NoError(t, err, "could not setup test client") + } + + node1 := c1.Node() + node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions + + req := &structs.NodeRegisterRequest{ + Node: node1, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + var resp structs.NodeUpdateResponse + err = c1.RPC("Node.Register", req, &resp) + if err != nil { + cleanupC1() + cleanupS1() + require.NoError(t, err, "could not register client node") } waitForNodes(t, s1, 1, 1) @@ -266,15 +323,12 @@ func setupLocal(t *testing.T) (rpc.ClientCodec, func()) { } // update w/ plugin - node1 := c1.Node() - node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions node1.CSIControllerPlugins = plugins - s1.fsm.state.UpsertNode(structs.MsgTypeTestSetup, 1000, node1) cleanup := func() { - cleanupS1() cleanupC1() + cleanupS1() } return codec, cleanup