Add unix domain socket support to API (#16872)

- Expose internal HTTP client's Do() via Raw
- Use URL parser to identify scheme
- Align more with curl output
- Add changelog
- Fix test failure; add tests for socket envvars
- Apply review feedback for tests
- Consolidate address parsing
- Address feedback from code reviews

Co-authored-by: Tim Gross <tgross@hashicorp.com>
This commit is contained in:
Charlie Voiselle
2023-10-11 11:04:12 -04:00
committed by GitHub
parent a92461cdc9
commit 7266d267b0
5 changed files with 189 additions and 35 deletions

3
.changelog/16872.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
api: Added support for Unix domain sockets
```

View File

@@ -205,6 +205,16 @@ type Config struct {
// retryOptions holds the configuration necessary to perform retries
// on put calls.
retryOptions *retryOptions
// url is populated with the initial parsed address and is not modified in the
// case of a unix:// URL, as opposed to Address.
url *url.URL
}
// URL returns a copy of the initial parsed address and is not modified in the
// case of a `unix://` URL, as opposed to Address.
func (c *Config) URL() *url.URL {
return c.url
}
// ClientConfig copies the configuration with a new client address, region, and
@@ -214,6 +224,7 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
if tlsEnabled {
scheme = "https"
}
config := &Config{
Address: fmt.Sprintf("%s://%s", scheme, address),
Region: region,
@@ -223,6 +234,7 @@ func (c *Config) ClientConfig(region, address string, tlsEnabled bool) *Config {
HttpAuth: c.HttpAuth,
WaitTime: c.WaitTime,
TLSConfig: c.TLSConfig.Copy(),
url: copyURL(c.url),
}
// Update the tls server name for connecting to a client
@@ -278,9 +290,30 @@ func (t *TLSConfig) Copy() *TLSConfig {
return nt
}
// defaultUDSClient creates a unix domain socket client. Errors return a nil
// http.Client, which is tested for in ConfigureTLS. This function expects that
// the Address has already been parsed into the config.url value.
func defaultUDSClient(config *Config) *http.Client {
config.Address = "http://127.0.0.1"
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", config.url.EscapedPath())
},
},
}
return defaultClient(httpClient)
}
func defaultHttpClient() *http.Client {
httpClient := cleanhttp.DefaultPooledClient()
transport := httpClient.Transport.(*http.Transport)
return defaultClient(httpClient)
}
func defaultClient(c *http.Client) *http.Client {
transport := c.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
@@ -290,7 +323,7 @@ func defaultHttpClient() *http.Client {
// well yet: https://github.com/gorilla/websocket/issues/417
transport.ForceAttemptHTTP2 = false
return httpClient
return c
}
// DefaultConfig returns a default configuration for the client
@@ -467,18 +500,29 @@ type Client struct {
// NewClient returns a new client
func NewClient(config *Config) (*Client, error) {
var err error
// bootstrap the config
defConfig := DefaultConfig()
if config.Address == "" {
config.Address = defConfig.Address
} else if _, err := url.Parse(config.Address); err != nil {
}
// we have to test the address that comes from DefaultConfig, because it
// could be the value of NOMAD_ADDR which is applied without testing
if config.url, err = url.Parse(config.Address); err != nil {
return nil, fmt.Errorf("invalid address '%s': %v", config.Address, err)
}
httpClient := config.HttpClient
if httpClient == nil {
httpClient = defaultHttpClient()
switch {
case config.url.Scheme == "unix":
httpClient = defaultUDSClient(config) // mutates config
default:
httpClient = defaultHttpClient()
}
if err := ConfigureTLS(httpClient, config.TLSConfig); err != nil {
return nil, err
}
@@ -760,24 +804,32 @@ func (r *request) toHTTP() (*http.Request, error) {
// newRequest is used to create a new request
func (c *Client) newRequest(method, path string) (*request, error) {
base, _ := url.Parse(c.config.Address)
u, err := url.Parse(path)
if err != nil {
return nil, err
}
r := &request{
config: &c.config,
method: method,
url: &url.URL{
Scheme: base.Scheme,
User: base.User,
Host: base.Host,
Scheme: c.config.url.Scheme,
User: c.config.url.User,
Host: c.config.url.Host,
Path: u.Path,
RawPath: u.RawPath,
},
header: make(http.Header),
params: make(map[string][]string),
}
// fixup socket paths
if r.url.Scheme == "unix" {
r.url.Scheme = "http"
r.url.Host = "127.0.0.1"
}
if c.config.Region != "" {
r.params.Set("region", c.config.Region)
}
@@ -1210,3 +1262,16 @@ func (o *WriteOptions) WithContext(ctx context.Context) *WriteOptions {
o2.ctx = ctx
return o2
}
// copyURL makes a deep copy of a net/url.URL
func copyURL(u1 *url.URL) *url.URL {
if u1 == nil {
return nil
}
o := *u1
if o.User != nil {
ou := *u1.User
o.User = &ou
}
return &o
}

View File

@@ -3,7 +3,10 @@
package api
import "io"
import (
"io"
"net/http"
)
// Raw can be used to do raw queries against custom endpoints
type Raw struct {
@@ -39,3 +42,8 @@ func (raw *Raw) Write(endpoint string, in, out interface{}, q *WriteOptions) (*W
func (raw *Raw) Delete(endpoint string, out interface{}, q *WriteOptions) (*WriteMeta, error) {
return raw.c.delete(endpoint, nil, out, q)
}
// Do uses the raw client's internal httpClient to process the request
func (raw *Raw) Do(req *http.Request) (*http.Response, error) {
return raw.c.httpClient.Do(req)
}

View File

@@ -5,7 +5,6 @@ package command
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
@@ -13,9 +12,7 @@ import (
"net/url"
"os"
"strings"
"time"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/nomad/api"
"github.com/posener/complete"
)
@@ -138,11 +135,18 @@ func (c *OperatorAPICommand) Run(args []string) int {
// By default verbose func is a noop
verbose := func(string, ...interface{}) {}
verboseSocket := func(*api.Config, string, ...interface{}) {}
if c.verboseFlag {
verbose = func(format string, a ...interface{}) {
// Use Warn instead of Info because Info goes to stdout
c.Ui.Warn(fmt.Sprintf(format, a...))
}
verboseSocket = func(cfg *api.Config, format string, a ...interface{}) {
if cfg.URL() != nil && cfg.URL().Scheme == "unix" {
c.Ui.Warn(fmt.Sprintf(format, a...))
}
}
}
// Opportunistically read from stdin and POST unless method has been
@@ -166,11 +170,13 @@ func (c *OperatorAPICommand) Run(args []string) int {
c.method = "GET"
}
config := c.clientConfig()
// NewClient mutates or validates Config.Address, so call it to match
// the behavior of other commands.
_, err := api.NewClient(config)
// the behavior of other commands. Typically these are called as a combination
// using c.Client(); however, we need access to the client configuration
// to build the corresponding curl output.
config := c.clientConfig()
apiC, err := api.NewClient(config)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error initializing client: %v", err))
return 1
@@ -198,23 +204,10 @@ func (c *OperatorAPICommand) Run(args []string) int {
c.Ui.Output(out)
return 0
}
// Re-implement a big chunk of api/api.go since we don't export it.
client := cleanhttp.DefaultClient()
transport := client.Transport.(*http.Transport)
transport.TLSHandshakeTimeout = 10 * time.Second
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
if err := api.ConfigureTLS(client, config.TLSConfig); err != nil {
c.Ui.Error(fmt.Sprintf("Error configuring TLS: %v", err))
return 1
}
apiR := apiC.Raw()
setQueryParams(config, path)
verbose("> %s %s", c.method, path)
verboseSocket(config, fmt.Sprintf("* Trying %s...", config.URL().EscapedPath()))
req, err := http.NewRequest(c.method, path.String(), c.body)
if err != nil {
@@ -222,6 +215,10 @@ func (c *OperatorAPICommand) Run(args []string) int {
return 1
}
h := req.URL.Hostname()
verboseSocket(config, fmt.Sprintf("* Connected to %s (%s)", h, config.URL().EscapedPath()))
verbose("> %s %s %s", c.method, req.URL.Path, req.Proto)
// Set headers from command line
req.Header = headerFlags.headers
@@ -244,11 +241,11 @@ func (c *OperatorAPICommand) Run(args []string) int {
verbose("> %s: %s", k, v)
}
}
verbose(">")
verbose("* Sending request and receiving response...")
// Do the request!
resp, err := client.Do(req)
resp, err := apiR.Do(req)
if err != nil {
c.Ui.Error(fmt.Sprintf("Error performing request: %v", err))
return 1
@@ -310,7 +307,8 @@ func (c *OperatorAPICommand) apiToCurl(config *api.Config, headers http.Header,
parts = append(parts, "--verbose")
}
if c.method != "" {
// add method flags. Note: curl output complains about `-X GET`
if c.method != "" && c.method != http.MethodGet {
parts = append(parts, "-X "+c.method)
}
@@ -318,6 +316,10 @@ func (c *OperatorAPICommand) apiToCurl(config *api.Config, headers http.Header,
parts = append(parts, "--data-binary @-")
}
if config.URL().EscapedPath() != "" {
parts = append(parts, fmt.Sprintf("--unix-socket %q", config.URL().EscapedPath()))
}
if config.TLSConfig != nil {
parts = tlsToCurl(parts, config.TLSConfig)
@@ -412,7 +414,9 @@ func pathToURL(config *api.Config, path string) (*url.URL, error) {
// If the scheme is missing from the path, it likely means the path is just
// the HTTP handler path. Attempt to infer this.
if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") {
if !strings.HasPrefix(path, "http://") &&
!strings.HasPrefix(path, "https://") &&
!strings.HasPrefix(path, "unix://") {
scheme := "http"
// If the user has set any TLS configuration value, this is a good sign

View File

@@ -5,10 +5,13 @@ package command
import (
"bytes"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"testing"
"time"
@@ -220,3 +223,74 @@ func TestOperatorAPICommand_ContentLength(t *testing.T) {
t.Fatalf("timed out waiting for request")
}
}
func makeSocketListener(t *testing.T) (net.Listener, string) {
td := os.TempDir() // testing.TempDir() on macOS makes paths that are too long
sPath := path.Join(td, t.Name()+".sock")
os.Remove(sPath) // git rid of stale ones now.
t.Cleanup(func() { os.Remove(sPath) })
// Create a Unix domain socket and listen for incoming connections.
socket, err := net.Listen("unix", sPath)
must.NoError(t, err)
return socket, sPath
}
// TestOperatorAPICommand_Socket tests that requests can be routed over a unix
// domain socket
//
// Can not be run in parallel as it modifies the environment.
func TestOperatorAPICommand_Socket(t *testing.T) {
ping := make(chan struct{}, 1)
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ping <- struct{}{}
}))
sock, sockPath := makeSocketListener(t)
ts.Listener = sock
ts.Start()
defer ts.Close()
// Setup command.
ui := cli.NewMockUi()
cmd := &OperatorAPICommand{Meta: Meta{Ui: ui}}
tcs := []struct {
name string
env map[string]string
args []string
exitCode int
}{
{
name: "nomad_addr",
env: map[string]string{"NOMAD_ADDR": "unix://" + sockPath},
args: []string{"/v1/jobs"},
exitCode: 0,
},
{
name: "nomad_addr opaques host",
env: map[string]string{"NOMAD_ADDR": "unix://" + sockPath},
args: []string{"http://example.com/v1/jobs"},
exitCode: 0,
},
}
for i, tc := range tcs {
t.Run(fmt.Sprintf("%v_%s", i+1, t.Name()), func(t *testing.T) {
tc := tc
for k, v := range tc.env {
t.Setenv(k, v)
}
exitCode := cmd.Run(tc.args)
must.Eq(t, tc.exitCode, exitCode, must.Sprint(ui.ErrorWriter.String()))
select {
case l := <-ping:
must.Eq(t, struct{}{}, l)
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for request")
}
})
}
}