From f062c93e956e2f7d17edeeb1640249b0376925f2 Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Thu, 15 Feb 2018 15:22:57 -0800 Subject: [PATCH] Client tls --- client/client.go | 11 +++++ client/driver/mock_driver.go | 7 --- client/driver/mock_driver_testing.go | 8 ++++ client/rpc.go | 33 ++++++++------ client/rpc_test.go | 68 +++++++++++++++++++++++++++- helper/testlog/testlog.go | 2 +- 6 files changed, 105 insertions(+), 24 deletions(-) create mode 100644 client/driver/mock_driver_testing.go diff --git a/client/client.go b/client/client.go index a4a85c031..f39ad3116 100644 --- a/client/client.go +++ b/client/client.go @@ -113,6 +113,11 @@ type Client struct { connPool *pool.ConnPool + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + // servers is the list of nomad servers servers *servers.Manager @@ -197,6 +202,7 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic consulService: consulService, start: time.Now(), connPool: pool.NewPool(cfg.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap), + tlsWrap: tlsWrap, streamingRpcs: structs.NewStreamingRpcRegistery(), logger: logger, allocs: make(map[string]*AllocRunner), @@ -389,6 +395,11 @@ func (c *Client) reloadTLSConnections(newConfig *nconfig.TLSConfig) error { tlsWrap = tw } + // Store the new tls wrapper. + c.tlsWrapLock.Lock() + c.tlsWrap = tlsWrap + c.tlsWrapLock.Unlock() + // Keep the client configuration up to date as we use configuration values to // decide on what type of connections to accept c.configLock.Lock() diff --git a/client/driver/mock_driver.go b/client/driver/mock_driver.go index 07c262852..29d6a4a9d 100644 --- a/client/driver/mock_driver.go +++ b/client/driver/mock_driver.go @@ -1,5 +1,3 @@ -//+build nomad_test - package driver import ( @@ -34,11 +32,6 @@ const ( ShutdownPeriodicDuration = "test.shutdown_periodic_duration" ) -// Add the mock driver to the list of builtin drivers -func init() { - BuiltinDrivers["mock_driver"] = NewMockDriver -} - // MockDriverConfig is the driver configuration for the MockDriver type MockDriverConfig struct { diff --git a/client/driver/mock_driver_testing.go b/client/driver/mock_driver_testing.go new file mode 100644 index 000000000..1b1e861a8 --- /dev/null +++ b/client/driver/mock_driver_testing.go @@ -0,0 +1,8 @@ +//+build nomad_test + +package driver + +// Add the mock driver to the list of builtin drivers +func init() { + BuiltinDrivers["mock_driver"] = NewMockDriver +} diff --git a/client/rpc.go b/client/rpc.go index 1fe52288b..90a1eec47 100644 --- a/client/rpc.go +++ b/client/rpc.go @@ -151,23 +151,26 @@ func (c *Client) streamingRpcConn(server *servers.Server, method string) (net.Co tcp.SetNoDelay(true) } - // TODO TLS // Check if TLS is enabled - //if p.tlsWrap != nil { - //// Switch the connection into TLS mode - //if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { - //conn.Close() - //return nil, err - //} + c.tlsWrapLock.RLock() + tlsWrap := c.tlsWrap + c.tlsWrapLock.RUnlock() - //// Wrap the connection in a TLS client - //tlsConn, err := p.tlsWrap(region, conn) - //if err != nil { - //conn.Close() - //return nil, err - //} - //conn = tlsConn - //} + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(c.Region(), conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + } // Write the multiplex byte to set the mode if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { diff --git a/client/rpc_test.go b/client/rpc_test.go index 09984a3b6..c25033923 100644 --- a/client/rpc_test.go +++ b/client/rpc_test.go @@ -7,6 +7,7 @@ import ( "github.com/hashicorp/nomad/client/config" "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" + sconfig "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" ) @@ -45,5 +46,70 @@ func TestRpc_streamingRpcConn_badEndpoint(t *testing.T) { conn, err := c.streamingRpcConn(server, "Bogus") require.Nil(conn) require.NotNil(err) - require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") +} + +func TestRpc_streamingRpcConn_badEndpoint_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + s1 := nomad.TestServer(t, func(c *nomad.Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 1 + c.DevDisableBootstrap = true + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + testutil.WaitForLeader(t, s1.RPC) + + c := TestClient(t, func(c *config.Config) { + c.Region = "regionFoo" + c.Servers = []string{s1.GetConfig().RPCAddr.String()} + c.TLSConfig = &sconfig.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer c.Shutdown() + + // Wait for the client to connect + testutil.WaitForResult(func() (bool, error) { + node, err := s1.State().NodeByID(nil, c.NodeID()) + if err != nil { + return false, err + } + if node == nil { + return false, errors.New("no node") + } + + return node.Status == structs.NodeStatusReady, errors.New("wrong status") + }, func(err error) { + t.Fatalf("should have a clients") + }) + + // Get the server + server := c.servers.FindServer() + require.NotNil(server) + + conn, err := c.streamingRpcConn(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") } diff --git a/helper/testlog/testlog.go b/helper/testlog/testlog.go index 7f6c6cb04..b72fcfb28 100644 --- a/helper/testlog/testlog.go +++ b/helper/testlog/testlog.go @@ -42,5 +42,5 @@ func WithPrefix(t LogPrinter, prefix string) *log.Logger { // NewLog logger with "TEST" prefix and the Lmicroseconds flag. func Logger(t LogPrinter) *log.Logger { - return WithPrefix(t, "TEST ") + return WithPrefix(t, "") }