mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
Merge pull request #5275 from hashicorp/f-api-config-httpclient
api: allow configuring http client
This commit is contained in:
158
api/api.go
158
api/api.go
@@ -121,8 +121,11 @@ type Config struct {
|
||||
// Namespace to use. If not provided the default namespace is used.
|
||||
Namespace string
|
||||
|
||||
// httpClient is the client to use. Default will be used if not provided.
|
||||
httpClient *http.Client
|
||||
// HttpClient is the client to use. Default will be used if not provided.
|
||||
//
|
||||
// If set, it expected to be configured for tls already, and TLSConfig is ignored.
|
||||
// You may use ConfigureTLS() function to aid with initialization.
|
||||
HttpClient *http.Client
|
||||
|
||||
// HttpAuth is the auth info to use for http access.
|
||||
HttpAuth *HttpBasicAuth
|
||||
@@ -132,7 +135,9 @@ type Config struct {
|
||||
WaitTime time.Duration
|
||||
|
||||
// TLSConfig provides the various TLS related configurations for the http
|
||||
// client
|
||||
// client.
|
||||
//
|
||||
// TLSConfig is ignored if HttpClient is set.
|
||||
TLSConfig *TLSConfig
|
||||
}
|
||||
|
||||
@@ -143,12 +148,11 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
|
||||
if tlsEnabled {
|
||||
scheme = "https"
|
||||
}
|
||||
defaultConfig := DefaultConfig()
|
||||
config := &Config{
|
||||
Address: fmt.Sprintf("%s://%s", scheme, address),
|
||||
Region: region,
|
||||
Namespace: c.Namespace,
|
||||
httpClient: defaultConfig.httpClient,
|
||||
HttpClient: c.HttpClient,
|
||||
SecretID: c.SecretID,
|
||||
HttpAuth: c.HttpAuth,
|
||||
WaitTime: c.WaitTime,
|
||||
@@ -198,19 +202,23 @@ func (t *TLSConfig) Copy() *TLSConfig {
|
||||
return nt
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for the client
|
||||
func DefaultConfig() *Config {
|
||||
config := &Config{
|
||||
Address: "http://127.0.0.1:4646",
|
||||
httpClient: cleanhttp.DefaultClient(),
|
||||
TLSConfig: &TLSConfig{},
|
||||
}
|
||||
transport := config.httpClient.Transport.(*http.Transport)
|
||||
func defaultHttpClient() *http.Client {
|
||||
httpClient := cleanhttp.DefaultClient()
|
||||
transport := httpClient.Transport.(*http.Transport)
|
||||
transport.TLSHandshakeTimeout = 10 * time.Second
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for the client
|
||||
func DefaultConfig() *Config {
|
||||
config := &Config{
|
||||
Address: "http://127.0.0.1:4646",
|
||||
TLSConfig: &TLSConfig{},
|
||||
}
|
||||
if addr := os.Getenv("NOMAD_ADDR"); addr != "" {
|
||||
config.Address = addr
|
||||
}
|
||||
@@ -260,49 +268,72 @@ func DefaultConfig() *Config {
|
||||
return config
|
||||
}
|
||||
|
||||
// SetTimeout is used to place a timeout for connecting to Nomad. A negative
|
||||
// duration is ignored, a duration of zero means no timeout, and any other value
|
||||
// will add a timeout.
|
||||
func (c *Config) SetTimeout(t time.Duration) error {
|
||||
if c == nil {
|
||||
return fmt.Errorf("nil config")
|
||||
} else if c.httpClient == nil {
|
||||
return fmt.Errorf("nil HTTP client")
|
||||
} else if c.httpClient.Transport == nil {
|
||||
return fmt.Errorf("nil HTTP client transport")
|
||||
// cloneWithTimeout returns a cloned httpClient with set timeout if positive;
|
||||
// otherwise, returns the same client
|
||||
func cloneWithTimeout(httpClient *http.Client, t time.Duration) (*http.Client, error) {
|
||||
if httpClient == nil {
|
||||
return nil, fmt.Errorf("nil HTTP client")
|
||||
} else if httpClient.Transport == nil {
|
||||
return nil, fmt.Errorf("nil HTTP client transport")
|
||||
}
|
||||
|
||||
// Apply a timeout.
|
||||
if t.Nanoseconds() >= 0 {
|
||||
transport, ok := c.httpClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected HTTP transport: %T", c.httpClient.Transport)
|
||||
}
|
||||
|
||||
transport.DialContext = (&net.Dialer{
|
||||
Timeout: t,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext
|
||||
if t.Nanoseconds() < 0 {
|
||||
return httpClient, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
tr, ok := httpClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected HTTP transport: %T", httpClient.Transport)
|
||||
}
|
||||
|
||||
// copy all public fields, to avoid copying transient state and locks
|
||||
ntr := &http.Transport{
|
||||
Proxy: tr.Proxy,
|
||||
DialContext: tr.DialContext,
|
||||
Dial: tr.Dial,
|
||||
DialTLS: tr.DialTLS,
|
||||
TLSClientConfig: tr.TLSClientConfig,
|
||||
TLSHandshakeTimeout: tr.TLSHandshakeTimeout,
|
||||
DisableKeepAlives: tr.DisableKeepAlives,
|
||||
DisableCompression: tr.DisableCompression,
|
||||
MaxIdleConns: tr.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: tr.MaxIdleConnsPerHost,
|
||||
MaxConnsPerHost: tr.MaxConnsPerHost,
|
||||
IdleConnTimeout: tr.IdleConnTimeout,
|
||||
ResponseHeaderTimeout: tr.ResponseHeaderTimeout,
|
||||
ExpectContinueTimeout: tr.ExpectContinueTimeout,
|
||||
TLSNextProto: tr.TLSNextProto,
|
||||
ProxyConnectHeader: tr.ProxyConnectHeader,
|
||||
MaxResponseHeaderBytes: tr.MaxResponseHeaderBytes,
|
||||
}
|
||||
|
||||
// apply timeout
|
||||
ntr.DialContext = (&net.Dialer{
|
||||
Timeout: t,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext
|
||||
|
||||
// clone http client with new transport
|
||||
nc := *httpClient
|
||||
nc.Transport = ntr
|
||||
return &nc, nil
|
||||
}
|
||||
|
||||
// ConfigureTLS applies a set of TLS configurations to the the HTTP client.
|
||||
func (c *Config) ConfigureTLS() error {
|
||||
if c.TLSConfig == nil {
|
||||
func ConfigureTLS(httpClient *http.Client, tlsConfig *TLSConfig) error {
|
||||
if tlsConfig == nil {
|
||||
return nil
|
||||
}
|
||||
if c.httpClient == nil {
|
||||
if httpClient == nil {
|
||||
return fmt.Errorf("config HTTP Client must be set")
|
||||
}
|
||||
|
||||
var clientCert tls.Certificate
|
||||
foundClientCert := false
|
||||
if c.TLSConfig.ClientCert != "" || c.TLSConfig.ClientKey != "" {
|
||||
if c.TLSConfig.ClientCert != "" && c.TLSConfig.ClientKey != "" {
|
||||
if tlsConfig.ClientCert != "" || tlsConfig.ClientKey != "" {
|
||||
if tlsConfig.ClientCert != "" && tlsConfig.ClientKey != "" {
|
||||
var err error
|
||||
clientCert, err = tls.LoadX509KeyPair(c.TLSConfig.ClientCert, c.TLSConfig.ClientKey)
|
||||
clientCert, err = tls.LoadX509KeyPair(tlsConfig.ClientCert, tlsConfig.ClientKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -312,22 +343,22 @@ func (c *Config) ConfigureTLS() error {
|
||||
}
|
||||
}
|
||||
|
||||
clientTLSConfig := c.httpClient.Transport.(*http.Transport).TLSClientConfig
|
||||
clientTLSConfig := httpClient.Transport.(*http.Transport).TLSClientConfig
|
||||
rootConfig := &rootcerts.Config{
|
||||
CAFile: c.TLSConfig.CACert,
|
||||
CAPath: c.TLSConfig.CAPath,
|
||||
CAFile: tlsConfig.CACert,
|
||||
CAPath: tlsConfig.CAPath,
|
||||
}
|
||||
if err := rootcerts.ConfigureTLS(clientTLSConfig, rootConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientTLSConfig.InsecureSkipVerify = c.TLSConfig.Insecure
|
||||
clientTLSConfig.InsecureSkipVerify = tlsConfig.Insecure
|
||||
|
||||
if foundClientCert {
|
||||
clientTLSConfig.Certificates = []tls.Certificate{clientCert}
|
||||
}
|
||||
if c.TLSConfig.TLSServerName != "" {
|
||||
clientTLSConfig.ServerName = c.TLSConfig.TLSServerName
|
||||
if tlsConfig.TLSServerName != "" {
|
||||
clientTLSConfig.ServerName = tlsConfig.TLSServerName
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -335,7 +366,8 @@ func (c *Config) ConfigureTLS() error {
|
||||
|
||||
// Client provides a client to the Nomad API
|
||||
type Client struct {
|
||||
config Config
|
||||
httpClient *http.Client
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewClient returns a new client
|
||||
@@ -349,17 +381,17 @@ func NewClient(config *Config) (*Client, error) {
|
||||
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
|
||||
}
|
||||
|
||||
if config.httpClient == nil {
|
||||
config.httpClient = defConfig.httpClient
|
||||
}
|
||||
|
||||
// Configure the TLS configurations
|
||||
if err := config.ConfigureTLS(); err != nil {
|
||||
return nil, err
|
||||
httpClient := config.HttpClient
|
||||
if httpClient == nil {
|
||||
httpClient = defaultHttpClient()
|
||||
if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
config: *config,
|
||||
config: *config,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
@@ -428,8 +460,12 @@ func (c *Client) getNodeClientImpl(nodeID string, timeout time.Duration, q *Quer
|
||||
// Get an API client for the node
|
||||
conf := c.config.ClientConfig(region, node.HTTPAddr, node.TLSEnabled)
|
||||
|
||||
// Set the timeout
|
||||
conf.SetTimeout(timeout)
|
||||
// set timeout - preserve old behavior where errors are ignored and use untimed one
|
||||
httpClient, err := cloneWithTimeout(c.httpClient, timeout)
|
||||
if err == nil {
|
||||
httpClient = c.httpClient
|
||||
}
|
||||
conf.HttpClient = httpClient
|
||||
|
||||
return NewClient(conf)
|
||||
}
|
||||
@@ -612,7 +648,7 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
|
||||
return 0, nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := c.config.httpClient.Do(req)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
diff := time.Now().Sub(start)
|
||||
|
||||
// If the response is compressed, we swap the body's reader.
|
||||
@@ -659,14 +695,14 @@ func (c *Client) rawQuery(endpoint string, q *QueryOptions) (io.ReadCloser, erro
|
||||
// websocket makes a websocket request to the specific endpoint
|
||||
func (c *Client) websocket(endpoint string, q *QueryOptions) (*websocket.Conn, *http.Response, error) {
|
||||
|
||||
transport, ok := c.config.httpClient.Transport.(*http.Transport)
|
||||
transport, ok := c.httpClient.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("unsupported transport")
|
||||
}
|
||||
dialer := websocket.Dialer{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
HandshakeTimeout: c.config.httpClient.Timeout,
|
||||
HandshakeTimeout: c.httpClient.Timeout,
|
||||
|
||||
// values to inherit from http client configuration
|
||||
NetDial: transport.Dial,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/nomad/api/internal/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type configCallback func(c *Config)
|
||||
@@ -443,3 +445,40 @@ func TestClient_NodeClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneHttpClient(t *testing.T) {
|
||||
client := defaultHttpClient()
|
||||
originalTransport := client.Transport.(*http.Transport)
|
||||
originalTransport.Proxy = func(*http.Request) (*url.URL, error) {
|
||||
return nil, fmt.Errorf("stub function")
|
||||
}
|
||||
|
||||
t.Run("closing with negative timeout", func(t *testing.T) {
|
||||
clone, err := cloneWithTimeout(client, -1)
|
||||
require.True(t, originalTransport == client.Transport, "original transport changed")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client, clone)
|
||||
require.True(t, client == clone)
|
||||
})
|
||||
|
||||
t.Run("closing with positive timeout", func(t *testing.T) {
|
||||
clone, err := cloneWithTimeout(client, 1*time.Second)
|
||||
require.True(t, originalTransport == client.Transport, "original transport changed")
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, client, clone)
|
||||
require.True(t, client != clone)
|
||||
require.True(t, client.Transport != clone.Transport)
|
||||
|
||||
// test that proxy function is the same in clone
|
||||
clonedProxy := clone.Transport.(*http.Transport).Proxy
|
||||
require.NotNil(t, clonedProxy)
|
||||
_, err = clonedProxy(nil)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "stub function", err.Error())
|
||||
|
||||
// if we reset transport, the strutcs are equal
|
||||
clone.Transport = originalTransport
|
||||
require.Equal(t, client, clone)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user