diff --git a/nomad/rpc.go b/nomad/rpc.go index 1b7e8bcca..e2ce814ce 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -55,6 +55,21 @@ const ( enqueueLimit = 30 * time.Second ) +// RPCContext provides metadata about the RPC connection. +type RPCContext struct { + // Session exposes the multiplexed connection session. + Session *yamux.Session + + // TLS marks whether the RPC is over a TLS based connection + TLS bool + + // TLSRole is the certificate role making the TLS connection. + TLSRole string + + // TLSRegion is the region on the certificate making theTLS connection + TLSRegion string +} + // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to // the Nomad Server. func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { @@ -94,14 +109,14 @@ func (s *Server) listen(ctx context.Context) { continue } - go s.handleConn(ctx, conn, false) + go s.handleConn(ctx, conn, &RPCContext{}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } } // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler -func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { +func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -113,7 +128,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { } // Enforce TLS if EnableRPC is set - if s.config.TLSConfig.EnableRPC && !isTLS && RPCType(buf[0]) != rpcTLS { + if s.config.TLSConfig.EnableRPC && !rpcCtx.TLS && RPCType(buf[0]) != rpcTLS { if !s.config.TLSConfig.RPCUpgradeMode { s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String()) conn.Close() @@ -124,14 +139,17 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { // Switch on the byte switch RPCType(buf[0]) { case rpcNomad: - s.handleNomadConn(ctx, conn) + // Create an RPC Server and handle the request + server := rpc.NewServer() + s.setupRpcServer(server, rpcCtx) + s.handleNomadConn(ctx, conn, server) case rpcRaft: metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) s.raftLayer.Handoff(ctx, conn) case rpcMultiplex: - s.handleMultiplex(ctx, conn) + s.handleMultiplex(ctx, conn, rpcCtx) case rpcTLS: if s.rpcTLS == nil { @@ -140,7 +158,13 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { return } conn = tls.Server(conn, s.rpcTLS) - s.handleConn(ctx, conn, true) + + // Update the connection context with the fact that the connection is + // using TLS + // TODO pull out more TLS information into the context + rpcCtx.TLS = true + + s.handleConn(ctx, conn, rpcCtx) default: s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) @@ -151,11 +175,19 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) { // handleMultiplex is used to multiplex a single incoming connection // using the Yamux multiplexer -func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) { +func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { defer conn.Close() conf := yamux.DefaultConfig() conf.LogOutput = s.config.LogOutput server, _ := yamux.Server(conn, conf) + + // Update the context to store the yamux session + rpcCtx.Session = server + + // Create the RPC server for this connection + rpcServer := rpc.NewServer() + s.setupRpcServer(rpcServer, rpcCtx) + for { sub, err := server.Accept() if err != nil { @@ -164,12 +196,12 @@ func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) { } return } - go s.handleNomadConn(ctx, sub) + go s.handleNomadConn(ctx, sub, rpcServer) } } // handleNomadConn is used to service a single Nomad RPC connection -func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { +func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn, server *rpc.Server) { defer conn.Close() rpcCodec := NewServerCodec(conn) for { @@ -182,7 +214,7 @@ func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) { default: } - if err := s.rpcServer.ServeRequest(rpcCodec); err != nil { + if err := server.ServeRequest(rpcCodec); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") { s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn) metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) diff --git a/nomad/server.go b/nomad/server.go index 87f07f8c9..8b2fc350a 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -860,7 +860,7 @@ func (s *Server) setupVaultClient() error { // setupRPC is used to setup the RPC listener func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { // Populate the static RPC server - s.setupRpcServer(s.rpcServer) + s.setupRpcServer(s.rpcServer, nil) listener, err := s.createRPCListener() if err != nil { @@ -891,7 +891,7 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { } // setupRpcServer is used to populate an RPC server with endpoints -func (s *Server) setupRpcServer(server *rpc.Server) { +func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) { // Add the static endpoints to the RPC server. if s.staticEndpoints.Status == nil { // Initialize the list just once