Dynamic RPC servers with context

This commit is contained in:
Alex Dadgar
2018-01-03 16:00:55 -08:00
parent aaf883b424
commit 96587f2413
2 changed files with 44 additions and 12 deletions

View File

@@ -55,6 +55,21 @@ const (
enqueueLimit = 30 * time.Second
)
// RPCContext provides metadata about the RPC connection.
type RPCContext struct {
// Session exposes the multiplexed connection session.
Session *yamux.Session
// TLS marks whether the RPC is over a TLS based connection
TLS bool
// TLSRole is the certificate role making the TLS connection.
TLSRole string
// TLSRegion is the region on the certificate making theTLS connection
TLSRegion string
}
// NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls to
// the Nomad Server.
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
@@ -94,14 +109,14 @@ func (s *Server) listen(ctx context.Context) {
continue
}
go s.handleConn(ctx, conn, false)
go s.handleConn(ctx, conn, &RPCContext{})
metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1)
}
}
// handleConn is used to determine if this is a Raft or
// Nomad type RPC connection and invoke the correct handler
func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) {
func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) {
// Read a single byte
buf := make([]byte, 1)
if _, err := conn.Read(buf); err != nil {
@@ -113,7 +128,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) {
}
// Enforce TLS if EnableRPC is set
if s.config.TLSConfig.EnableRPC && !isTLS && RPCType(buf[0]) != rpcTLS {
if s.config.TLSConfig.EnableRPC && !rpcCtx.TLS && RPCType(buf[0]) != rpcTLS {
if !s.config.TLSConfig.RPCUpgradeMode {
s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String())
conn.Close()
@@ -124,14 +139,17 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) {
// Switch on the byte
switch RPCType(buf[0]) {
case rpcNomad:
s.handleNomadConn(ctx, conn)
// Create an RPC Server and handle the request
server := rpc.NewServer()
s.setupRpcServer(server, rpcCtx)
s.handleNomadConn(ctx, conn, server)
case rpcRaft:
metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1)
s.raftLayer.Handoff(ctx, conn)
case rpcMultiplex:
s.handleMultiplex(ctx, conn)
s.handleMultiplex(ctx, conn, rpcCtx)
case rpcTLS:
if s.rpcTLS == nil {
@@ -140,7 +158,13 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) {
return
}
conn = tls.Server(conn, s.rpcTLS)
s.handleConn(ctx, conn, true)
// Update the connection context with the fact that the connection is
// using TLS
// TODO pull out more TLS information into the context
rpcCtx.TLS = true
s.handleConn(ctx, conn, rpcCtx)
default:
s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0])
@@ -151,11 +175,19 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, isTLS bool) {
// handleMultiplex is used to multiplex a single incoming connection
// using the Yamux multiplexer
func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) {
func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) {
defer conn.Close()
conf := yamux.DefaultConfig()
conf.LogOutput = s.config.LogOutput
server, _ := yamux.Server(conn, conf)
// Update the context to store the yamux session
rpcCtx.Session = server
// Create the RPC server for this connection
rpcServer := rpc.NewServer()
s.setupRpcServer(rpcServer, rpcCtx)
for {
sub, err := server.Accept()
if err != nil {
@@ -164,12 +196,12 @@ func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn) {
}
return
}
go s.handleNomadConn(ctx, sub)
go s.handleNomadConn(ctx, sub, rpcServer)
}
}
// handleNomadConn is used to service a single Nomad RPC connection
func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) {
func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn, server *rpc.Server) {
defer conn.Close()
rpcCodec := NewServerCodec(conn)
for {
@@ -182,7 +214,7 @@ func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn) {
default:
}
if err := s.rpcServer.ServeRequest(rpcCodec); err != nil {
if err := server.ServeRequest(rpcCodec); err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "closed") {
s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn)
metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1)

View File

@@ -860,7 +860,7 @@ func (s *Server) setupVaultClient() error {
// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error {
// Populate the static RPC server
s.setupRpcServer(s.rpcServer)
s.setupRpcServer(s.rpcServer, nil)
listener, err := s.createRPCListener()
if err != nil {
@@ -891,7 +891,7 @@ func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error {
}
// setupRpcServer is used to populate an RPC server with endpoints
func (s *Server) setupRpcServer(server *rpc.Server) {
func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
// Add the static endpoints to the RPC server.
if s.staticEndpoints.Status == nil {
// Initialize the list just once