Files
nomad/api/api_test.go
Tim Gross b7419bc940 api: only set url field in config if previously unset (#23785)
In #16872 we added support for configuring the API client with a unix domain
socket. In order to set the host correctly, we parse the address before mutating
the Address field in the configuration. But this prevents the configuration from
being reused across multiple clients, as the next time we parse the address it
will no longer be pointing to the socket. This breaks consumers like the
autoscaler, which reuse the API config between plugins.

Update the `NewClient` constructor to only override the `url` field if it hasn't
already been parsed. Include a test demonstrating safe reuse with a unix domain
socket.

Ref: https://github.com/hashicorp/nomad-autoscaler/issues/944
Ref: https://github.com/hashicorp/nomad-autoscaler/pull/945
2024-08-09 13:28:04 -04:00

642 lines
15 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package api
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"strings"
"testing"
"time"
"github.com/hashicorp/nomad/api/internal/testutil"
"github.com/shoenig/test/must"
)
type configCallback func(c *Config)
func makeACLClient(t *testing.T, cb1 configCallback,
cb2 testutil.ServerConfigCallback) (*Client, *testutil.TestServer, *ACLToken) {
client, server := makeClient(t, cb1, func(c *testutil.TestServerConfig) {
c.ACL.Enabled = true
if cb2 != nil {
cb2(c)
}
})
// Get the root token
root, _, err := client.ACLTokens().Bootstrap(nil)
if err != nil {
t.Fatalf("failed to bootstrap ACLs: %v", err)
}
client.SetSecretID(root.SecretID)
return client, server, root
}
func makeClient(t *testing.T, cb1 configCallback,
cb2 testutil.ServerConfigCallback) (*Client, *testutil.TestServer) {
// Make client config
conf := DefaultConfig()
if cb1 != nil {
cb1(conf)
}
// Create server
server := testutil.NewTestServer(t, cb2)
conf.Address = "http://" + server.HTTPAddr
// Create client
client, err := NewClient(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
return client, server
}
func TestRequestTime(t *testing.T) {
testutil.Parallel(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
d, err := json.Marshal(struct{ Done bool }{true})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(d)
}))
defer srv.Close()
conf := DefaultConfig()
conf.Address = srv.URL
client, err := NewClient(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
var out interface{}
qm, err := client.query("/", &out, nil)
if err != nil {
t.Fatalf("query err: %v", err)
}
if qm.RequestTime == 0 {
t.Errorf("bad request time: %d", qm.RequestTime)
}
wm, err := client.put("/", struct{ S string }{"input"}, &out, nil)
if err != nil {
t.Fatalf("write err: %v", err)
}
if wm.RequestTime == 0 {
t.Errorf("bad request time: %d", wm.RequestTime)
}
wm, err = client.delete("/", nil, &out, nil)
if err != nil {
t.Fatalf("delete err: %v", err)
}
if wm.RequestTime == 0 {
t.Errorf("bad request time: %d", wm.RequestTime)
}
}
func TestDefaultConfig_env(t *testing.T) {
testURL := "http://1.2.3.4:5678"
auth := []string{"nomaduser", "12345"}
region := "test"
namespace := "dev"
token := "foobar"
t.Setenv("NOMAD_ADDR", testURL)
t.Setenv("NOMAD_REGION", region)
t.Setenv("NOMAD_NAMESPACE", namespace)
t.Setenv("NOMAD_HTTP_AUTH", strings.Join(auth, ":"))
t.Setenv("NOMAD_TOKEN", token)
config := DefaultConfig()
if config.Address != testURL {
t.Errorf("expected %q to be %q", config.Address, testURL)
}
if config.Region != region {
t.Errorf("expected %q to be %q", config.Region, region)
}
if config.Namespace != namespace {
t.Errorf("expected %q to be %q", config.Namespace, namespace)
}
if config.HttpAuth.Username != auth[0] {
t.Errorf("expected %q to be %q", config.HttpAuth.Username, auth[0])
}
if config.HttpAuth.Password != auth[1] {
t.Errorf("expected %q to be %q", config.HttpAuth.Password, auth[1])
}
if config.SecretID != token {
t.Errorf("Expected %q to be %q", config.SecretID, token)
}
}
func TestSetQueryOptions(t *testing.T) {
testutil.Parallel(t)
c, s := makeClient(t, nil, nil)
defer s.Stop()
r, _ := c.newRequest("GET", "/v1/jobs")
q := &QueryOptions{
Region: "foo",
Namespace: "bar",
AllowStale: true,
WaitIndex: 1000,
WaitTime: 100 * time.Second,
AuthToken: "foobar",
Reverse: true,
}
r.setQueryOptions(q)
try := func(key, exp string) {
result := r.params.Get(key)
must.Eq(t, exp, result)
}
// Check auth token is set
must.Eq(t, "foobar", r.token)
// Check query parameters are set
try("region", "foo")
try("namespace", "bar")
try("stale", "") // should not be present
try("index", "1000")
try("wait", "100000ms")
try("reverse", "true")
}
func TestQueryOptionsContext(t *testing.T) {
testutil.Parallel(t)
ctx, cancel := context.WithCancel(context.Background())
c, s := makeClient(t, nil, nil)
defer s.Stop()
q := (&QueryOptions{
WaitIndex: 10000,
}).WithContext(ctx)
if q.ctx != ctx {
t.Fatalf("expected context to be set")
}
go func() {
cancel()
}()
_, _, err := c.Jobs().List(q)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected job wait to fail with canceled, got %s", err)
}
}
func TestWriteOptionsContext(t *testing.T) {
// No blocking query to test a real cancel of a pending request so
// just test that if we pass a pre-canceled context, writes fail quickly
testutil.Parallel(t)
c, err := NewClient(DefaultConfig())
if err != nil {
t.Fatalf("failed to initialize client: %s", err)
}
ctx, cancel := context.WithCancel(context.Background())
w := (&WriteOptions{}).WithContext(ctx)
if w.ctx != ctx {
t.Fatalf("expected context to be set")
}
cancel()
_, _, err = c.Jobs().Deregister("jobid", true, w)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected job to fail with canceled, got %s", err)
}
}
func TestSetWriteOptions(t *testing.T) {
testutil.Parallel(t)
c, s := makeClient(t, nil, nil)
defer s.Stop()
r, _ := c.newRequest("GET", "/v1/jobs")
q := &WriteOptions{
Region: "foo",
Namespace: "bar",
AuthToken: "foobar",
IdempotencyToken: "idempotent",
}
r.setWriteOptions(q)
if r.params.Get("region") != "foo" {
t.Fatalf("bad: %v", r.params)
}
if r.params.Get("namespace") != "bar" {
t.Fatalf("bad: %v", r.params)
}
if r.params.Get("idempotency_token") != "idempotent" {
t.Fatalf("bad: %v", r.params)
}
if r.token != "foobar" {
t.Fatalf("bad: %v", r.token)
}
}
func TestRequestToHTTP(t *testing.T) {
testutil.Parallel(t)
c, s := makeClient(t, nil, nil)
defer s.Stop()
r, _ := c.newRequest("DELETE", "/v1/jobs/foo")
q := &QueryOptions{
Region: "foo",
Namespace: "bar",
AuthToken: "foobar",
}
r.setQueryOptions(q)
req, err := r.toHTTP()
if err != nil {
t.Fatalf("err: %v", err)
}
if req.Method != "DELETE" {
t.Fatalf("bad: %v", req)
}
if req.URL.RequestURI() != "/v1/jobs/foo?namespace=bar&region=foo" {
t.Fatalf("bad: %v", req)
}
if req.Header.Get("X-Nomad-Token") != "foobar" {
t.Fatalf("bad: %v", req)
}
}
func TestParseQueryMeta(t *testing.T) {
testutil.Parallel(t)
resp := &http.Response{
Header: make(map[string][]string),
}
resp.Header.Set("X-Nomad-Index", "12345")
resp.Header.Set("X-Nomad-LastContact", "80")
resp.Header.Set("X-Nomad-KnownLeader", "true")
qm := &QueryMeta{}
if err := parseQueryMeta(resp, qm); err != nil {
t.Fatalf("err: %v", err)
}
if qm.LastIndex != 12345 {
t.Fatalf("Bad: %v", qm)
}
if qm.LastContact != 80*time.Millisecond {
t.Fatalf("Bad: %v", qm)
}
if !qm.KnownLeader {
t.Fatalf("Bad: %v", qm)
}
}
func TestParseWriteMeta(t *testing.T) {
testutil.Parallel(t)
resp := &http.Response{
Header: make(map[string][]string),
}
resp.Header.Set("X-Nomad-Index", "12345")
wm := &WriteMeta{}
if err := parseWriteMeta(resp, wm); err != nil {
t.Fatalf("err: %v", err)
}
if wm.LastIndex != 12345 {
t.Fatalf("Bad: %v", wm)
}
}
func TestClientHeader(t *testing.T) {
testutil.Parallel(t)
c, s := makeClient(t, func(c *Config) {
c.Headers = http.Header{
"Hello": []string{"World"},
}
}, nil)
defer s.Stop()
r, _ := c.newRequest("GET", "/v1/jobs")
if r.header.Get("Hello") != "World" {
t.Fatalf("bad: %v", r.header)
}
}
func TestQueryString(t *testing.T) {
testutil.Parallel(t)
c, s := makeClient(t, nil, nil)
defer s.Stop()
r, _ := c.newRequest("PUT", "/v1/abc?foo=bar&baz=zip")
q := &WriteOptions{
Region: "foo",
Namespace: "bar",
}
r.setWriteOptions(q)
req, err := r.toHTTP()
if err != nil {
t.Fatalf("err: %s", err)
}
if uri := req.URL.RequestURI(); uri != "/v1/abc?baz=zip&foo=bar&namespace=bar&region=foo" {
t.Fatalf("bad uri: %q", uri)
}
}
func TestClient_NodeClient(t *testing.T) {
addr := "testdomain:4646"
tlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: generateUUID(),
Status: "ready",
HTTPAddr: addr,
TLSEnabled: true,
}, nil, nil
}
noTlsNode := func(string, *QueryOptions) (*Node, *QueryMeta, error) {
return &Node{
ID: generateUUID(),
Status: "ready",
HTTPAddr: addr,
TLSEnabled: false,
}, nil, nil
}
optionNoRegion := &QueryOptions{}
optionRegion := &QueryOptions{
Region: "foo",
}
clientNoRegion, err := NewClient(DefaultConfig())
must.NoError(t, err)
regionConfig := DefaultConfig()
regionConfig.Region = "bar"
clientRegion, err := NewClient(regionConfig)
must.NoError(t, err)
expectedTLSAddr := fmt.Sprintf("https://%s", addr)
expectedNoTLSAddr := fmt.Sprintf("http://%s", addr)
cases := []struct {
Node nodeLookup
QueryOptions *QueryOptions
Client *Client
ExpectedAddr string
ExpectedRegion string
ExpectedTLSServerName string
}{
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "client.global.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "client.foo.nomad",
},
{
Node: tlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "client.bar.nomad",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "global",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientNoRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "foo",
ExpectedTLSServerName: "",
},
{
Node: noTlsNode,
QueryOptions: optionNoRegion,
Client: clientRegion,
ExpectedAddr: expectedNoTLSAddr,
ExpectedRegion: "bar",
ExpectedTLSServerName: "",
},
}
for _, c := range cases {
name := fmt.Sprintf("%s__%s__%s", c.ExpectedAddr, c.ExpectedRegion, c.ExpectedTLSServerName)
t.Run(name, func(t *testing.T) {
nodeClient, getErr := c.Client.getNodeClientImpl("testID", -1, c.QueryOptions, c.Node)
must.NoError(t, getErr)
must.Eq(t, c.ExpectedRegion, nodeClient.config.Region)
must.Eq(t, c.ExpectedAddr, nodeClient.config.Address)
must.NotNil(t, nodeClient.config.TLSConfig)
must.Eq(t, c.ExpectedTLSServerName, nodeClient.config.TLSConfig.TLSServerName)
})
}
}
func TestCloneHttpClient(t *testing.T) {
client := defaultHttpClient()
originalTransport := client.Transport.(*http.Transport)
originalTransport.Proxy = func(*http.Request) (*url.URL, error) {
return nil, errors.New("stub function")
}
t.Run("closing with negative timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, -1)
must.True(t, originalTransport == client.Transport, must.Sprint("original transport changed"))
must.NoError(t, err)
must.True(t, client == clone)
})
t.Run("closing with positive timeout", func(t *testing.T) {
clone, err := cloneWithTimeout(client, 1*time.Second)
must.True(t, originalTransport == client.Transport, must.Sprint("original transport changed"))
must.NoError(t, err)
must.True(t, client != clone)
must.True(t, client.Transport != clone.Transport)
// test that proxy function is the same in clone
clonedProxy := clone.Transport.(*http.Transport).Proxy
must.NotNil(t, clonedProxy)
_, err = clonedProxy(nil)
must.Error(t, err)
must.EqError(t, err, "stub function")
// if we reset transport, the strutcs are equal
clone.Transport = originalTransport
must.Eq(t, client, clone)
})
}
func TestClient_HeaderRaceCondition(t *testing.T) {
conf := DefaultConfig()
conf.Headers = map[string][]string{
"test-header": {"a"},
}
client, err := NewClient(conf)
must.NoError(t, err)
c := make(chan int)
go func() {
req, _ := client.newRequest("GET", "/any/path/will/do")
r, _ := req.toHTTP()
c <- len(r.Header)
}()
req, _ := client.newRequest("GET", "/any/path/will/do")
r, _ := req.toHTTP()
must.MapLen(t, 2, r.Header, must.Sprint("local request should have two headers"))
must.Eq(t, 2, <-c, must.Sprint("goroutine request should have two headers"))
must.MapLen(t, 1, conf.Headers, must.Sprint("config headers should not mutate"))
}
func TestClient_autoUnzip(t *testing.T) {
var client *Client = nil
try := func(resp *http.Response, exp error) {
err := client.autoUnzip(resp)
must.Eq(t, exp, err)
}
// response object is nil
try(nil, nil)
// response.Body is nil
try(new(http.Response), nil)
// content-encoding is not gzip
try(&http.Response{
Header: http.Header{"Content-Encoding": []string{"text"}},
}, nil)
// content-encoding is gzip but body is empty
try(&http.Response{
Header: http.Header{"Content-Encoding": []string{"gzip"}},
Body: io.NopCloser(bytes.NewBuffer([]byte{})),
}, nil)
// content-encoding is gzip but body is invalid gzip
try(&http.Response{
Header: http.Header{"Content-Encoding": []string{"gzip"}},
Body: io.NopCloser(bytes.NewBuffer([]byte("not a zip"))),
}, errors.New("unexpected EOF"))
// sample gzip payload
var b bytes.Buffer
w := gzip.NewWriter(&b)
_, err := w.Write([]byte("hello world"))
must.NoError(t, err)
err = w.Close()
must.NoError(t, err)
// content-encoding is gzip and body is gzip data
try(&http.Response{
Header: http.Header{"Content-Encoding": []string{"gzip"}},
Body: io.NopCloser(&b),
}, nil)
}
func TestUnixSocketConfig(t *testing.T) {
td := os.TempDir() // testing.TempDir() on macOS makes paths that are too long
socketPath := path.Join(td, t.Name()+".sock")
os.Remove(socketPath) // git rid of stale ones now.
t.Cleanup(func() { os.Remove(socketPath) })
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`"10.1.1.1"`))
}))
// Create a Unix domain socket and listen for incoming connections.
socket, err := net.Listen("unix", socketPath)
must.NoError(t, err)
t.Cleanup(func() { socket.Close() })
ts.Listener = socket
ts.Start()
defer ts.Close()
cfg := &Config{Address: "unix://" + socketPath}
client1, err := NewClient(cfg)
must.NoError(t, err)
t.Cleanup(client1.Close)
resp, err := client1.Status().Leader()
must.NoError(t, err)
must.Eq(t, "10.1.1.1", resp)
client2, err := NewClient(cfg)
must.NoError(t, err)
t.Cleanup(client2.Close)
_, err = client2.Status().Leader()
must.NoError(t, err)
must.Eq(t, "10.1.1.1", resp)
}