Fix TLSServerName for Node API Client

This PR fixes the construction of the TLSServerName when connecting to a
node that has TLS enabled and adds tests for all possible permutations.

Fixes https://github.com/hashicorp/nomad/issues/3013
This commit is contained in:
Alex Dadgar
2017-08-29 11:11:19 -07:00
parent 6fb08b844b
commit 4d3b75d867
2 changed files with 147 additions and 2 deletions

View File

@@ -125,7 +125,9 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
WaitTime: c.WaitTime,
TLSConfig: c.TLSConfig.Copy(),
}
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", c.Region)
if tlsEnabled && config.TLSConfig != nil {
config.TLSConfig.TLSServerName = fmt.Sprintf("client.%s.nomad", region)
}
return config
}
@@ -221,6 +223,9 @@ func DefaultConfig() *Config {
// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
func (c *Config) ConfigureTLS() error {
if c.TLSConfig == nil {
return nil
}
if c.HttpClient == nil {
return fmt.Errorf("config HTTP Client must be set")
}
@@ -300,7 +305,17 @@ func (c *Client) SetRegion(region string) {
// GetNodeClient returns a new Client that will dial the specified node. If the
// QueryOptions is set, its region will be used.
func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error) {
node, _, err := c.Nodes().Info(nodeID, q)
return c.getNodeClientImpl(nodeID, q, c.Nodes().Info)
}
// nodeLookup is used to lookup a node
type nodeLookup func(nodeID string, q *QueryOptions) (*Node, *QueryMeta, error)
// getNodeClientImpl is the implementation of creating a API client for
// contacting a node. It is takes a function to lookup the node such that it can
// be mocked during tests.
func (c *Client) getNodeClientImpl(nodeID string, q *QueryOptions, lookup nodeLookup) (*Client, error) {
node, _, err := lookup(nodeID, q)
if err != nil {
return nil, err
}
@@ -316,6 +331,10 @@ func (c *Client) GetNodeClient(nodeID string, q *QueryOptions) (*Client, error)
region = q.Region
}
if region == "" {
region = "global"
}
// Get an API client for the node
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)
return NewClient(conf)

View File

@@ -2,6 +2,7 @@ package api
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
@@ -9,7 +10,9 @@ import (
"testing"
"time"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/assert"
)
type configCallback func(c *Config)
@@ -243,3 +246,126 @@ func TestQueryString(t *testing.T) {
t.Fatalf("bad uri: %q", uri)
}
}
func TestClient_NodeClient(t *testing.T) {
http := "testdomain:4646"
tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: true,
}, nil, nil
}
noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: structs.GenerateUUID(),
Status: "ready",
HTTPAddr: http,
TLSEnabled: false,
}, nil, nil
}
optionNoRegion := &QueryOptions{}
optionRegion := &QueryOptions{
Region: "foo",
}
clientNoRegion, err := NewClient(DefaultConfig())
assert.Nil(t, err)
regionConfig := DefaultConfig()
regionConfig.Region = "bar"
clientRegion, err := NewClient(regionConfig)
assert.Nil(t, err)
expectedTLSAddr := fmt.Sprintf("https://%s", http)
expectedNoTLSAddr := fmt.Sprintf("http://%s", http)
cases := []struct {
Node nodeLookup
QueryOptions *QueryOptions
Client *Client
ExpectedAddr string
ExpectedRegion string
ExpectedTLSServerName string
}{
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "client.global.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "client.bar.nomad",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "",
},
}
for _, c := range cases {
name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName)
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
nodeClient, err := c.Client.getNodeClientImpl("testID", c.QueryOptions, c.Node)
assert.Nil(err)
assert.Equal(c.ExpectedRegion, nodeClient.config.Region)
assert.Equal(c.ExpectedAddr, nodeClient.config.Address)
assert.NotNil(nodeClient.config.TLSConfig)
assert.Equal(c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName)
})
}
}