From 7ee2a1515badbdb56be9c8164534154a1d57e076 Mon Sep 17 00:00:00 2001 From: Diptanu Choudhury Date: Tue, 1 Nov 2016 11:55:29 -0700 Subject: [PATCH] Making Nomad TLS configs region aware --- client/client.go | 2 +- helper/tlsutil/config.go | 40 ++++++++++++++++++++++++++++++++++++---- nomad/pool.go | 6 +++--- nomad/server.go | 7 ++++--- 4 files changed, 44 insertions(+), 11 deletions(-) diff --git a/client/client.go b/client/client.go index ba945f17f..5f34ec699 100644 --- a/client/client.go +++ b/client/client.go @@ -166,7 +166,7 @@ var ( // NewClient is used to create a new client from the given configuration func NewClient(cfg *config.Config, consulSyncer *consul.Syncer, logger *log.Logger) (*Client, error) { // Create the tls wrapper - var tlsWrap tlsutil.Wrapper + var tlsWrap tlsutil.RegionWrapper if cfg.TLSConfig.EnableRPC { tw, err := cfg.TLSConfiguration().OutgoingTLSWrapper() if err != nil { diff --git a/helper/tlsutil/config.go b/helper/tlsutil/config.go index 56c81cbe0..5a2068ea1 100644 --- a/helper/tlsutil/config.go +++ b/helper/tlsutil/config.go @@ -9,6 +9,22 @@ import ( "time" ) +// RegionSpecificWrapper is used to invoke a static Region and turns a +// RegionWrapper into a Wrapper type. +func RegionSpecificWrapper(region string, tlsWrap RegionWrapper) Wrapper { + if tlsWrap == nil { + return nil + } + return func(conn net.Conn) (net.Conn, error) { + return tlsWrap(region, conn) + } +} + +// RegionWrapper is a function that is used to wrap a non-TLS connection and +// returns an appropriate TLS connection or error. This takes a Region as an +// argument. +type RegionWrapper func(region string, conn net.Conn) (net.Conn, error) + // Wrapper wraps a connection and enables TLS on it. type Wrapper func(conn net.Conn) (net.Conn, error) @@ -102,6 +118,11 @@ func (c *Config) OutgoingTLSConfig() (*tls.Config, error) { tlsConfig.ServerName = c.ServerName tlsConfig.InsecureSkipVerify = false } + if c.VerifyServerHostname { + // ServerName is filled in dynamically based on the target DC + tlsConfig.ServerName = "VerifyServerHostname" + tlsConfig.InsecureSkipVerify = false + } // Ensure we have a CA if VerifyOutgoing is set if c.VerifyOutgoing && c.CAFile == "" { @@ -128,7 +149,7 @@ func (c *Config) OutgoingTLSConfig() (*tls.Config, error) { // OutgoingTLSWrapper returns a a Wrapper based on the OutgoingTLS // configuration. If hostname verification is on, the wrapper // will properly generate the dynamic server name for verification. -func (c *Config) OutgoingTLSWrapper() (Wrapper, error) { +func (c *Config) OutgoingTLSWrapper() (RegionWrapper, error) { // Get the TLS config tlsConfig, err := c.OutgoingTLSConfig() if err != nil { @@ -140,10 +161,21 @@ func (c *Config) OutgoingTLSWrapper() (Wrapper, error) { return nil, nil } - wrapper := func(conn net.Conn) (net.Conn, error) { - return WrapTLSClient(conn, tlsConfig) + // Generate the wrapper based on hostname verification + if c.VerifyServerHostname { + wrapper := func(region string, conn net.Conn) (net.Conn, error) { + conf := *tlsConfig + conf.ServerName = "server." + region + ".nomad" + return WrapTLSClient(conn, &conf) + } + return wrapper, nil + } else { + wrapper := func(dc string, c net.Conn) (net.Conn, error) { + return WrapTLSClient(c, tlsConfig) + } + return wrapper, nil } - return wrapper, nil + } // Wrap a net.Conn into a client tls connection, performing any diff --git a/nomad/pool.go b/nomad/pool.go index 7b65dbc03..0e7e29427 100644 --- a/nomad/pool.go +++ b/nomad/pool.go @@ -129,7 +129,7 @@ type ConnPool struct { limiter map[string]chan struct{} // TLS wrapper - tlsWrap tlsutil.Wrapper + tlsWrap tlsutil.RegionWrapper // Used to indicate the pool is shutdown shutdown bool @@ -141,7 +141,7 @@ type ConnPool struct { // Set maxTime to 0 to disable reaping. maxStreams is used to control // the number of idle streams allowed. // If TLS settings are provided outgoing connections use TLS. -func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.Wrapper) *ConnPool { +func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool { pool := &ConnPool{ logOutput: logOutput, maxTime: maxTime, @@ -261,7 +261,7 @@ func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, } // Wrap the connection in a TLS client - tlsConn, err := p.tlsWrap(conn) + tlsConn, err := p.tlsWrap(region, conn) if err != nil { conn.Close() return nil, err diff --git a/nomad/server.go b/nomad/server.go index 39bceeff0..97a5f1e2a 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -188,7 +188,7 @@ func NewServer(config *Config, consulSyncer *consul.Syncer, logger *log.Logger) } // Configure TLS - var tlsWrap tlsutil.Wrapper + var tlsWrap tlsutil.RegionWrapper var incomingTLS *tls.Config if config.TLSConfig.EnableRPC { tlsConf := config.tlsConfig() @@ -594,7 +594,7 @@ func (s *Server) setupVaultClient() error { } // setupRPC is used to setup the RPC listener -func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error { +func (s *Server) setupRPC(tlsWrap tlsutil.RegionWrapper) error { // Create endpoints s.endpoints.Status = &Status{s} s.endpoints.Node = &Node{srv: s} @@ -640,7 +640,8 @@ func (s *Server) setupRPC(tlsWrap tlsutil.Wrapper) error { return fmt.Errorf("RPC advertise address is not advertisable: %v", addr) } - s.raftLayer = NewRaftLayer(s.rpcAdvertise, tlsWrap) + wrapper := tlsutil.RegionSpecificWrapper(s.config.Region, tlsWrap) + s.raftLayer = NewRaftLayer(s.rpcAdvertise, wrapper) return nil }