diff --git a/nomad/rpc.go b/nomad/rpc.go index 159c1f272..537e73d9d 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -172,7 +172,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCConte s.handleStreamingConn(conn) case pool.RpcMultiplexV2: - s.handleMultiplexV2(conn, ctx) + s.handleMultiplexV2(ctx, conn, rpcCtx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -286,11 +286,11 @@ func (s *Server) handleStreamingConn(conn net.Conn) { // handleMultiplexV2 is used to multiplex a single incoming connection // using the Yamux multiplexer. Version 2 handling allows a single connection to // switch streams between regulars RPCs and Streaming RPCs. -func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { +func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { defer func() { // Remove any potential mapping between a NodeID to this connection and // close the underlying connection. - s.removeNodeConn(ctx) + s.removeNodeConn(rpcCtx) conn.Close() }() @@ -303,11 +303,11 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { } // Update the context to store the yamux session - ctx.Session = server + rpcCtx.Session = server // Create the RPC server for this connection rpcServer := rpc.NewServer() - s.setupRpcServer(rpcServer, ctx) + s.setupRpcServer(rpcServer, rpcCtx) for { // Accept a new stream @@ -331,7 +331,7 @@ func (s *Server) handleMultiplexV2(conn net.Conn, ctx *RPCContext) { // Determine which handler to use switch pool.RPCType(buf[0]) { case pool.RpcNomad: - go s.handleNomadConn(sub, rpcServer) + go s.handleNomadConn(ctx, sub, rpcServer) case pool.RpcStreaming: go s.handleStreamingConn(sub) @@ -476,7 +476,7 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err tcp.SetNoDelay(true) } - if err := s.streamingRpcImpl(conn, method); err != nil { + if err := s.streamingRpcImpl(conn, server.Region, method); err != nil { return nil, err } @@ -487,24 +487,27 @@ func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, err // the handshake to establish a streaming RPC for the given method. If an error // is returned, the underlying connection has been closed. Otherwise it is // assumed that the connection has been hijacked by the RPC method. -func (s *Server) streamingRpcImpl(conn net.Conn, method string) error { - // TODO TLS +func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error { // Check if TLS is enabled - //if p.tlsWrap != nil { - //// Switch the connection into TLS mode - //if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil { - //conn.Close() - //return nil, err - //} + s.tlsWrapLock.RLock() + tlsWrap := s.tlsWrap + s.tlsWrapLock.RUnlock() - //// Wrap the connection in a TLS client - //tlsConn, err := p.tlsWrap(region, conn) - //if err != nil { - //conn.Close() - //return nil, err - //} - //conn = tlsConn - //} + if tlsWrap != nil { + // Switch the connection into TLS mode + if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { + conn.Close() + return err + } + + // Wrap the connection in a TLS client + tlsConn, err := tlsWrap(region, conn) + if err != nil { + conn.Close() + return err + } + conn = tlsConn + } // Write the multiplex byte to set the mode if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index eb85af57e..43b88386b 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -1,6 +1,7 @@ package nomad import ( + "context" "net" "net/rpc" "os" @@ -201,7 +202,67 @@ func TestRPC_streamingRpcConn_badMethod(t *testing.T) { conn, err := s1.streamingRpc(server, "Bogus") require.Nil(conn) require.NotNil(err) - require.Contains(err.Error(), "unknown rpc method: \"Bogus\"") + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") +} + +func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) { + t.Parallel() + require := require.New(t) + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + dir := tmpDir(t) + defer os.RemoveAll(dir) + s1 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node1") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s1.Shutdown() + + s2 := TestServer(t, func(c *Config) { + c.Region = "regionFoo" + c.BootstrapExpect = 2 + c.DevMode = false + c.DevDisableBootstrap = true + c.DataDir = path.Join(dir, "node2") + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + EnableRPC: true, + VerifyServerHostname: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer s2.Shutdown() + + TestJoin(t, s1, s2) + testutil.WaitForLeader(t, s1.RPC) + + s1.peerLock.RLock() + ok, parts := isNomadServer(s2.LocalMember()) + require.True(ok) + server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] + require.NotNil(server) + s1.peerLock.RUnlock() + + conn, err := s1.streamingRpc(server, "Bogus") + require.Nil(conn) + require.NotNil(err) + require.Contains(err.Error(), "Unknown rpc method: \"Bogus\"") } // COMPAT: Remove in 0.10 @@ -224,7 +285,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) { // Start the handler doneCh := make(chan struct{}) go func() { - s.handleConn(p2, &RPCContext{Conn: p2}) + s.handleConn(context.Background(), p2, &RPCContext{Conn: p2}) close(doneCh) }() @@ -257,7 +318,7 @@ func TestRPC_handleMultiplexV2(t *testing.T) { require.NotEmpty(l) // Make a streaming RPC - err = s.streamingRpcImpl(s2, "Bogus") + err = s.streamingRpcImpl(s2, s.Region(), "Bogus") require.NotNil(err) require.Contains(err.Error(), "unknown rpc") diff --git a/nomad/server.go b/nomad/server.go index 5b68a85ac..a70c2f890 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -112,6 +112,11 @@ type Server struct { rpcListener net.Listener listenerCh chan struct{} + // tlsWrap is used to wrap outbound connections using TLS. It should be + // accessed using the lock. + tlsWrap tlsutil.RegionWrapper + tlsWrapLock sync.RWMutex + // rpcServer is the static RPC server that is used by the local agent. rpcServer *rpc.Server @@ -276,6 +281,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, logger *log.Logg consulCatalog: consulCatalog, connPool: pool.NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap), logger: logger, + tlsWrap: tlsWrap, rpcServer: rpc.NewServer(), streamingRpcs: structs.NewStreamingRpcRegistery(), nodeConns: make(map[string]*nodeConnState), @@ -435,6 +441,11 @@ func (s *Server) reloadTLSConnections(newTLSConfig *config.TLSConfig) error { return err } + // Store the new tls wrapper. + s.tlsWrapLock.Lock() + s.tlsWrap = tlsWrap + s.tlsWrapLock.Unlock() + if s.rpcCancel == nil { err = fmt.Errorf("No existing RPC server to reset.") s.logger.Printf("[ERR] nomad: %s", err)