From 9bab96ebd327dfe79d86359d45b00f56db4fa186 Mon Sep 17 00:00:00 2001 From: Michael Schurter Date: Mon, 6 Feb 2023 11:31:22 -0800 Subject: [PATCH] Task API via Unix Domain Socket (#15864) This change introduces the Task API: a portable way for tasks to access Nomad's HTTP API. This particular implementation uses a Unix Domain Socket and, unlike the agent's HTTP API, always requires authentication even if ACLs are disabled. This PR contains the core feature and tests but followup work is required for the following TODO items: - Docs - might do in a followup since dynamic node metadata / task api / workload id all need to interlink - Unit tests for auth middleware - Caching for auth middleware - Rate limiting on negative lookups for auth middleware --------- Co-authored-by: Seth Hoenig --- .changelog/15864.txt | 3 + .../allocrunner/interfaces/task_lifecycle.go | 5 +- client/allocrunner/taskrunner/api_hook.go | 119 ++++++++++++ .../allocrunner/taskrunner/api_hook_test.go | 169 ++++++++++++++++++ .../taskrunner/task_runner_hooks.go | 5 +- client/config/config.go | 19 ++ client/config/testing.go | 11 ++ command/agent/agent.go | 18 +- command/agent/agent_endpoint.go | 8 +- command/agent/alloc_endpoint.go | 2 +- command/agent/consul/int_test.go | 3 +- command/agent/http.go | 132 +++++++++++++- command/agent/job_endpoint.go | 2 +- command/agent/metrics_endpoint.go | 4 +- command/agent/variable_endpoint.go | 3 +- e2e/workload_id/input/api-auth.nomad.hcl | 99 ++++++++++ e2e/workload_id/input/api-win.nomad.hcl | 36 ++++ e2e/workload_id/taskapi_test.go | 111 ++++++++++++ helper/users/lookup.go | 83 ++++++++- helper/users/lookup_linux_test.go | 39 +++- helper/users/lookup_windows_test.go | 18 +- nomad/acl_endpoint.go | 13 +- nomad/structs/structs.go | 12 +- 23 files changed, 876 insertions(+), 38 deletions(-) create mode 100644 .changelog/15864.txt create mode 100644 client/allocrunner/taskrunner/api_hook.go create mode 100644 client/allocrunner/taskrunner/api_hook_test.go create mode 100644 e2e/workload_id/input/api-auth.nomad.hcl create mode 100644 e2e/workload_id/input/api-win.nomad.hcl create mode 100644 e2e/workload_id/taskapi_test.go diff --git a/.changelog/15864.txt b/.changelog/15864.txt new file mode 100644 index 000000000..b91ffba97 --- /dev/null +++ b/.changelog/15864.txt @@ -0,0 +1,3 @@ +```release-note:improvement +client: added http api access for tasks via unix socket +``` diff --git a/client/allocrunner/interfaces/task_lifecycle.go b/client/allocrunner/interfaces/task_lifecycle.go index 1bf61bd5a..3ea51c4a9 100644 --- a/client/allocrunner/interfaces/task_lifecycle.go +++ b/client/allocrunner/interfaces/task_lifecycle.go @@ -33,8 +33,6 @@ import ( +-----------+ *Kill (forces terminal) - -Link: http://stable.ascii-flow.appspot.com/#Draw4489375405966393064/1824429135 */ // TaskHook is a lifecycle hook into the life cycle of a task runner. @@ -186,6 +184,9 @@ type TaskStopRequest struct { // ExistingState is previously set hook data and should only be // read. Stop hooks cannot alter state. ExistingState map[string]string + + // TaskDir contains the task's directory tree on the host + TaskDir *allocdir.TaskDir } type TaskStopResponse struct{} diff --git a/client/allocrunner/taskrunner/api_hook.go b/client/allocrunner/taskrunner/api_hook.go new file mode 100644 index 000000000..003d5fecd --- /dev/null +++ b/client/allocrunner/taskrunner/api_hook.go @@ -0,0 +1,119 @@ +package taskrunner + +import ( + "context" + "errors" + "net" + "net/http" + "os" + "path/filepath" + "sync" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/config" + "github.com/hashicorp/nomad/helper/users" +) + +// apiHook exposes the Task API. The Task API allows task's to access the Nomad +// HTTP API without having to discover and connect to an agent's address. +// Instead a unix socket is provided in a standard location. To prevent access +// by untrusted workloads the Task API always requires authentication even when +// ACLs are disabled. +// +// The Task API hook largely soft-fails as there are a number of ways creating +// the unix socket could fail (the most common one being path length +// restrictions), and it is assumed most tasks won't require access to the Task +// API anyway. Tasks that do require access are expected to crash and get +// rescheduled should they land on a client who Task API hook soft-fails. +type apiHook struct { + shutdownCtx context.Context + srv config.APIListenerRegistrar + logger hclog.Logger + + // Lock listener as it is updated from multiple hooks. + lock sync.Mutex + + // Listener is the unix domain socket of the task api for this taks. + ln net.Listener +} + +func newAPIHook(shutdownCtx context.Context, srv config.APIListenerRegistrar, logger hclog.Logger) *apiHook { + h := &apiHook{ + shutdownCtx: shutdownCtx, + srv: srv, + } + h.logger = logger.Named(h.Name()) + return h +} + +func (*apiHook) Name() string { + return "api" +} + +func (h *apiHook) Prestart(_ context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + h.lock.Lock() + defer h.lock.Unlock() + + if h.ln != nil { + // Listener already set. Task is probably restarting. + return nil + } + + udsPath := apiSocketPath(req.TaskDir) + udsln, err := users.SocketFileFor(h.logger, udsPath, req.Task.User) + if err != nil { + // Soft-fail and let the task fail if it requires the task api. + h.logger.Warn("error creating task api socket", "path", udsPath, "error", err) + return nil + } + + go func() { + // Cannot use Prestart's context as it is closed after all prestart hooks + // have been closed, but we do want to try to cleanup on shutdown. + if err := h.srv.Serve(h.shutdownCtx, udsln); err != nil { + if errors.Is(err, http.ErrServerClosed) { + return + } + if errors.Is(err, net.ErrClosed) { + return + } + h.logger.Error("error serving task api", "error", err) + } + }() + + h.ln = udsln + return nil +} + +func (h *apiHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { + h.lock.Lock() + defer h.lock.Unlock() + + if h.ln != nil { + if err := h.ln.Close(); err != nil { + if !errors.Is(err, net.ErrClosed) { + h.logger.Debug("error closing task listener: %v", err) + } + } + h.ln = nil + } + + // Best-effort at cleaining things up. Alloc dir cleanup will remove it if + // this fails for any reason. + _ = os.RemoveAll(apiSocketPath(req.TaskDir)) + + return nil +} + +// apiSocketPath returns the path to the Task API socket. +// +// The path needs to be as short as possible because of the low limits on the +// sun_path char array imposed by the syscall used to create unix sockets. +// +// See https://github.com/hashicorp/nomad/pull/13971 for an example of the +// sadness this causes. +func apiSocketPath(taskDir *allocdir.TaskDir) string { + return filepath.Join(taskDir.SecretsDir, "api.sock") +} diff --git a/client/allocrunner/taskrunner/api_hook_test.go b/client/allocrunner/taskrunner/api_hook_test.go new file mode 100644 index 000000000..164d4433d --- /dev/null +++ b/client/allocrunner/taskrunner/api_hook_test.go @@ -0,0 +1,169 @@ +package taskrunner + +import ( + "context" + "io/fs" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "syscall" + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/users" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" +) + +type testAPIListenerRegistrar struct { + cb func(net.Listener) error +} + +func (n testAPIListenerRegistrar) Serve(_ context.Context, ln net.Listener) error { + if n.cb != nil { + return n.cb(ln) + } + return nil +} + +// TestAPIHook_SoftFail asserts that the Task API Hook soft fails and does not +// return errors. +func TestAPIHook_SoftFail(t *testing.T) { + ci.Parallel(t) + + // Use a SecretsDir that will always exceed Unix socket path length + // limits (sun_path) + dst := filepath.Join(t.TempDir(), strings.Repeat("_NOMAD_TEST_", 100)) + + ctx := context.Background() + srv := testAPIListenerRegistrar{} + logger := testlog.HCLogger(t) + h := newAPIHook(ctx, srv, logger) + + req := &interfaces.TaskPrestartRequest{ + Task: &structs.Task{}, // needs to be non-nil for Task.User lookup + TaskDir: &allocdir.TaskDir{ + SecretsDir: dst, + }, + } + resp := &interfaces.TaskPrestartResponse{} + + err := h.Prestart(ctx, req, resp) + must.NoError(t, err) + + // listener should not have been set + must.Nil(t, h.ln) + + // File should not have been created + _, err = os.Stat(dst) + must.Error(t, err) + + // Assert stop also soft-fails + stopReq := &interfaces.TaskStopRequest{ + TaskDir: req.TaskDir, + } + stopResp := &interfaces.TaskStopResponse{} + err = h.Stop(ctx, stopReq, stopResp) + must.NoError(t, err) + + // File should not have been created + _, err = os.Stat(dst) + must.Error(t, err) +} + +// TestAPIHook_Ok asserts that the Task API Hook creates and cleans up a +// socket. +func TestAPIHook_Ok(t *testing.T) { + ci.Parallel(t) + + // If this test fails it may be because TempDir() + /api.sock is longer than + // the unix socket path length limit (sun_path) in which case the test should + // use a different temporary directory on that platform. + dst := t.TempDir() + + // Write "ok" and close the connection and listener + srv := testAPIListenerRegistrar{ + cb: func(ln net.Listener) error { + conn, err := ln.Accept() + if err != nil { + return err + } + if _, err = conn.Write([]byte("ok")); err != nil { + return err + } + conn.Close() + return nil + }, + } + + ctx := context.Background() + logger := testlog.HCLogger(t) + h := newAPIHook(ctx, srv, logger) + + req := &interfaces.TaskPrestartRequest{ + Task: &structs.Task{ + User: "nobody", + }, + TaskDir: &allocdir.TaskDir{ + SecretsDir: dst, + }, + } + resp := &interfaces.TaskPrestartResponse{} + + err := h.Prestart(ctx, req, resp) + must.NoError(t, err) + + // File should have been created + sockDst := apiSocketPath(req.TaskDir) + + // Stat and chown fail on Windows, so skip these checks + if runtime.GOOS != "windows" { + stat, err := os.Stat(sockDst) + must.NoError(t, err) + must.True(t, stat.Mode()&fs.ModeSocket != 0, + must.Sprintf("expected %q to be a unix socket but got %s", sockDst, stat.Mode())) + + nobody, _ := users.Lookup("nobody") + if syscall.Getuid() == 0 && nobody != nil { + t.Logf("root and nobody exists: testing file perms") + + // We're root and nobody exists! Check perms + must.Eq(t, fs.FileMode(0o600), stat.Mode().Perm()) + + sysStat, ok := stat.Sys().(*syscall.Stat_t) + must.True(t, ok, must.Sprintf("expected stat.Sys() to be a *syscall.Stat_t on %s but found %T", + runtime.GOOS, stat.Sys())) + + nobodyUID, err := strconv.Atoi(nobody.Uid) + must.NoError(t, err) + must.Eq(t, nobodyUID, int(sysStat.Uid)) + } + } + + // Assert the listener is working + conn, err := net.Dial("unix", sockDst) + must.NoError(t, err) + buf := make([]byte, 2) + _, err = conn.Read(buf) + must.NoError(t, err) + must.Eq(t, []byte("ok"), buf) + conn.Close() + + // Assert stop cleans up + stopReq := &interfaces.TaskStopRequest{ + TaskDir: req.TaskDir, + } + stopResp := &interfaces.TaskStopResponse{} + err = h.Stop(ctx, stopReq, stopResp) + must.NoError(t, err) + + // File should be gone + _, err = net.Dial("unix", sockDst) + must.Error(t, err) +} diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 7c1b73dbd..7ea501982 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -68,6 +68,7 @@ func (tr *TaskRunner) initHooks() { newArtifactHook(tr, tr.getter, hookLogger), newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), newDeviceHook(tr.devicemanager, hookLogger), + newAPIHook(tr.shutdownCtx, tr.clientConfig.APIListenerRegistrar, hookLogger), } // If the task has a CSI block, add the hook. @@ -431,7 +432,9 @@ func (tr *TaskRunner) stop() error { tr.logger.Trace("running stop hook", "name", name, "start", start) } - req := interfaces.TaskStopRequest{} + req := interfaces.TaskStopRequest{ + TaskDir: tr.taskDir, + } origHookState := tr.hookState(name) if origHookState != nil { diff --git a/client/config/config.go b/client/config/config.go index 2945f6daa..466230ddc 100644 --- a/client/config/config.go +++ b/client/config/config.go @@ -1,8 +1,10 @@ package config import ( + "context" "errors" "fmt" + "net" "reflect" "strconv" "strings" @@ -301,10 +303,27 @@ type Config struct { // used for template functions which require access to the Nomad API. TemplateDialer *bufconndialer.BufConnWrapper + // APIListenerRegistrar allows the client to register listeners created at + // runtime (eg the Task API) with the agent's HTTP server. Since the agent + // creates the HTTP *after* the client starts, we have to use this shim to + // pass listeners back to the agent. + // This is the same design as the bufconndialer but for the + // http.Serve(listener) API instead of the net.Dial API. + APIListenerRegistrar APIListenerRegistrar + // Artifact configuration from the agent's config file. Artifact *ArtifactConfig } +type APIListenerRegistrar interface { + // Serve the HTTP API on the provided listener. + // + // The context is because Serve may be called before the HTTP server has been + // initialized. If the context is canceled before the HTTP server is + // initialized, the context's error will be returned. + Serve(context.Context, net.Listener) error +} + // ClientTemplateConfig is configuration on the client specific to template // rendering type ClientTemplateConfig struct { diff --git a/client/config/testing.go b/client/config/testing.go index 02f87984f..adb703de2 100644 --- a/client/config/testing.go +++ b/client/config/testing.go @@ -1,7 +1,9 @@ package config import ( + "context" "io/ioutil" + "net" "os" "path/filepath" "time" @@ -74,5 +76,14 @@ func TestClientConfig(t testing.T) (*Config, func()) { // Same as default; necessary for task Event messages conf.MaxKillTimeout = 30 * time.Second + // Provide a stub APIListenerRegistrar implementation + conf.APIListenerRegistrar = NoopAPIListenerRegistrar{} + return conf, cleanup } + +type NoopAPIListenerRegistrar struct{} + +func (NoopAPIListenerRegistrar) Serve(_ context.Context, _ net.Listener) error { + return nil +} diff --git a/command/agent/agent.go b/command/agent/agent.go index ae3d3e661..d5703f721 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -115,7 +115,11 @@ type Agent struct { builtinListener net.Listener builtinDialer *bufconndialer.BufConnWrapper - InmemSink *metrics.InmemSink + // builtinServer is an HTTP server for attaching per-task listeners. Always + // requires auth. + builtinServer *builtinAPI + + inmemSink *metrics.InmemSink } // NewAgent is used to create a new agent with the given configuration @@ -124,7 +128,7 @@ func NewAgent(config *Config, logger log.InterceptLogger, logOutput io.Writer, i config: config, logOutput: logOutput, shutdownCh: make(chan struct{}), - InmemSink: inmem, + inmemSink: inmem, } // Create the loggers @@ -1020,6 +1024,11 @@ func (a *Agent) setupClient() error { a.builtinListener, a.builtinDialer = bufconndialer.New() conf.TemplateDialer = a.builtinDialer + // Initialize builtin API server here for use in the client, but it won't + // accept connections until the HTTP servers are created. + a.builtinServer = newBuiltinAPI() + conf.APIListenerRegistrar = a.builtinServer + nomadClient, err := client.NewClient( conf, a.consulCatalog, a.consulProxies, a.consulService, nil) if err != nil { @@ -1300,6 +1309,11 @@ func (a *Agent) GetConfig() *Config { return a.config } +// GetMetricsSink returns the metrics sink. +func (a *Agent) GetMetricsSink() *metrics.InmemSink { + return a.inmemSink +} + // setupConsul creates the Consul client and starts its main Run loop. func (a *Agent) setupConsul(consulConfig *config.ConsulConfig) error { apiConf, err := consulConfig.ApiConfig() diff --git a/command/agent/agent_endpoint.go b/command/agent/agent_endpoint.go index 898001ca8..21df7c3b3 100644 --- a/command/agent/agent_endpoint.go +++ b/command/agent/agent_endpoint.go @@ -440,7 +440,7 @@ func (s *HTTPServer) listServers(resp http.ResponseWriter, req *http.Request) (i return nil, structs.ErrPermissionDenied } - peers := s.agent.client.GetServers() + peers := client.GetServers() sort.Strings(peers) return peers, nil } @@ -468,9 +468,9 @@ func (s *HTTPServer) updateServers(resp http.ResponseWriter, req *http.Request) } // Set the servers list into the client - s.agent.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT") + s.logger.Trace("adding servers to the client's primary server list", "servers", servers, "path", "/v1/agent/servers", "method", "PUT") if _, err := client.SetServers(servers); err != nil { - s.agent.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT") + s.logger.Error("failed adding servers to client's server list", "servers", servers, "error", err, "path", "/v1/agent/servers", "method", "PUT") //TODO is this the right error to return? return nil, CodedError(400, err.Error()) } @@ -708,7 +708,7 @@ func (s *HTTPServer) AgentHostRequest(resp http.ResponseWriter, req *http.Reques // The RPC endpoint actually forwards the request to the correct // agent, but we need to use the correct RPC interface. localClient, remoteClient, localServer := s.rpcHandlerForNode(lookupNodeID) - s.agent.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer) + s.logger.Debug("s.rpcHandlerForNode()", "lookupNodeID", lookupNodeID, "serverID", serverID, "nodeID", nodeID, "localClient", localClient, "remoteClient", remoteClient, "localServer", localServer) // Make the RPC call if localClient { diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index 18b4dfdc3..ee071e7db 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -222,7 +222,7 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ case "exec": return s.allocExec(allocID, resp, req) case "snapshot": - if s.agent.client == nil { + if s.agent.Client() == nil { return nil, clientNotRunning } return s.allocSnapshot(allocID, resp, req) diff --git a/command/agent/consul/int_test.go b/command/agent/consul/int_test.go index 1fd6a6fcd..02124f108 100644 --- a/command/agent/consul/int_test.go +++ b/command/agent/consul/int_test.go @@ -46,7 +46,7 @@ func TestConsul_Integration(t *testing.T) { // Create an embedded Consul server testconsul, err := testutil.NewTestServerConfigT(t, func(c *testutil.TestServerConfig) { - c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering + c.Peering = nil // fix for older versions of Consul (<1.13.0) that don't support peering // If -v wasn't specified squelch consul logging if !testing.Verbose() { c.Stdout = ioutil.Discard @@ -61,6 +61,7 @@ func TestConsul_Integration(t *testing.T) { conf := config.DefaultConfig() conf.Node = mock.Node() conf.ConsulConfig.Addr = testconsul.HTTPAddr + conf.APIListenerRegistrar = config.NoopAPIListenerRegistrar{} consulConfig, err := conf.ConsulConfig.ApiConfig() if err != nil { t.Fatalf("error generating consul config: %v", err) diff --git a/command/agent/http.go b/command/agent/http.go index da3580e77..b51d6696a 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -2,6 +2,7 @@ package agent import ( "bytes" + "context" "crypto/tls" "encoding/json" "errors" @@ -26,8 +27,10 @@ import ( "golang.org/x/time/rate" "github.com/hashicorp/nomad/acl" + "github.com/hashicorp/nomad/client" "github.com/hashicorp/nomad/helper/noxssrw" "github.com/hashicorp/nomad/helper/tlsutil" + "github.com/hashicorp/nomad/nomad" "github.com/hashicorp/nomad/nomad/structs" ) @@ -74,9 +77,18 @@ var ( type handlerFn func(resp http.ResponseWriter, req *http.Request) (interface{}, error) type handlerByteFn func(resp http.ResponseWriter, req *http.Request) ([]byte, error) +type RPCer interface { + RPC(string, any, any) error + Server() *nomad.Server + Client() *client.Client + Stats() map[string]map[string]string + GetConfig() *Config + GetMetricsSink() *metrics.InmemSink +} + // HTTPServer is used to wrap an Agent and expose it over an HTTP interface type HTTPServer struct { - agent *Agent + agent RPCer mux *http.ServeMux listener net.Listener listenerCh chan struct{} @@ -170,7 +182,7 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) { srvs = append(srvs, srv) } - // This HTTP server is only create when running in client mode, otherwise + // This HTTP server is only created when running in client mode, otherwise // the builtinDialer and builtinListener will be nil. if agent.builtinDialer != nil && agent.builtinListener != nil { srv := &HTTPServer{ @@ -185,12 +197,15 @@ func NewHTTPServers(agent *Agent, config *Config) ([]*HTTPServer, error) { srv.registerHandlers(config.EnableDebug) + // builtinServer adds a wrapper to always authenticate requests httpServer := http.Server{ Addr: srv.Addr, - Handler: srv.mux, + Handler: newAuthMiddleware(srv, srv.mux), ErrorLog: newHTTPServerLogger(srv.logger), } + agent.builtinServer.SetServer(&httpServer) + go func() { defer close(srv.listenerCh) httpServer.Serve(agent.builtinListener) @@ -465,7 +480,8 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/v1/vars", wrapCORS(s.wrap(s.VariablesListRequest))) s.mux.Handle("/v1/var/", wrapCORSWithAllowedMethods(s.wrap(s.VariableSpecificRequest), "HEAD", "GET", "PUT", "DELETE")) - uiConfigEnabled := s.agent.config.UI != nil && s.agent.config.UI.Enabled + agentConfig := s.agent.GetConfig() + uiConfigEnabled := agentConfig.UI != nil && agentConfig.UI.Enabled if uiEnabled && uiConfigEnabled { s.mux.Handle("/ui/", http.StripPrefix("/ui/", s.handleUI(http.FileServer(&UIAssetWrapper{FileSystem: assetFS()})))) @@ -484,7 +500,7 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.Handle("/", s.handleRootFallthrough()) if enableDebug { - if !s.agent.config.DevMode { + if !agentConfig.DevMode { s.logger.Warn("enable_debug is set to true. This is insecure and should not be enabled in production") } s.mux.HandleFunc("/debug/pprof/", pprof.Index) @@ -498,6 +514,54 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.registerEnterpriseHandlers() } +// builtinAPI is a wrapper around serving the HTTP API to arbitrary listeners +// such as the Task API. It is necessary because the HTTP servers are created +// *after* the client has been initialized, so this wrapper blocks Serve +// requests from task api hooks until the HTTP server is setup and ready to +// accept from new listeners. +// +// bufconndialer provides similar functionality to consul-template except it +// satisfies the Dialer API as opposed to the Serve(Listener) API. +type builtinAPI struct { + srv *http.Server + srvReadyCh chan struct{} +} + +func newBuiltinAPI() *builtinAPI { + return &builtinAPI{ + srvReadyCh: make(chan struct{}), + } +} + +// SetServer sets the API HTTP server for Serve to add listeners to. +// +// It must be called exactly once and will panic if called more than once. +func (b *builtinAPI) SetServer(srv *http.Server) { + select { + case <-b.srvReadyCh: + panic(fmt.Sprintf("SetServer called twice. first=%p second=%p", b.srv, srv)) + default: + } + b.srv = srv + close(b.srvReadyCh) +} + +// Serve the HTTP API on the listener unless the context is canceled before the +// HTTP API is ready to serve listeners. A non-nil error will always be +// returned, but http.ErrServerClosed and net.ErrClosed can likely be ignored +// as they indicate the server or listener is being shutdown. +func (b *builtinAPI) Serve(ctx context.Context, l net.Listener) error { + select { + case <-ctx.Done(): + // Caller canceled context before server was ready. + return ctx.Err() + case <-b.srvReadyCh: + // Server ready for listeners! Continue on... + } + + return b.srv.Serve(l) +} + // HTTPCodedError is used to provide the HTTP error code type HTTPCodedError interface { error @@ -591,7 +655,7 @@ func errCodeFromHandler(err error) (int, string) { // wrap is used to wrap functions to make them more convenient func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Request) (interface{}, error)) func(resp http.ResponseWriter, req *http.Request) { f := func(resp http.ResponseWriter, req *http.Request) { - setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders) + setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders) // Invoke the handler reqURL := req.URL.String() start := time.Now() @@ -673,7 +737,7 @@ func (s *HTTPServer) wrap(handler func(resp http.ResponseWriter, req *http.Reque // Handler functions are responsible for setting Content-Type Header func (s *HTTPServer) wrapNonJSON(handler func(resp http.ResponseWriter, req *http.Request) ([]byte, error)) func(resp http.ResponseWriter, req *http.Request) { f := func(resp http.ResponseWriter, req *http.Request) { - setHeaders(resp, s.agent.config.HTTPAPIResponseHeaders) + setHeaders(resp, s.agent.GetConfig().HTTPAPIResponseHeaders) // Invoke the handler reqURL := req.URL.String() start := time.Now() @@ -817,7 +881,7 @@ func (s *HTTPServer) parseRegion(req *http.Request, r *string) { if other := req.URL.Query().Get("region"); other != "" { *r = other } else if *r == "" { - *r = s.agent.config.Region + *r = s.agent.GetConfig().Region } } @@ -976,3 +1040,55 @@ func wrapCORS(f func(http.ResponseWriter, *http.Request)) http.Handler { func wrapCORSWithAllowedMethods(f func(http.ResponseWriter, *http.Request), methods ...string) http.Handler { return allowCORSWithMethods(methods...).Handler(http.HandlerFunc(f)) } + +// authMiddleware implements the http.Handler interface to enforce +// authentication for *all* requests. Even with ACLs enabled there are +// endpoints which are accessible without authenticating. This middleware is +// used for the Task API to enfoce authentication for all API access. +type authMiddleware struct { + srv *HTTPServer + wrapped http.Handler +} + +func newAuthMiddleware(srv *HTTPServer, h http.Handler) http.Handler { + return &authMiddleware{ + srv: srv, + wrapped: h, + } +} + +func (a *authMiddleware) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + args := structs.GenericRequest{} + reply := structs.ACLWhoAmIResponse{} + if a.srv.parse(resp, req, &args.Region, &args.QueryOptions) { + // Error parsing request, 400 + resp.WriteHeader(http.StatusBadRequest) + resp.Write([]byte(http.StatusText(http.StatusBadRequest))) + return + } + + if args.AuthToken == "" { + // 401 instead of 403 since no token was present. + resp.WriteHeader(http.StatusUnauthorized) + resp.Write([]byte(http.StatusText(http.StatusUnauthorized))) + return + } + + if err := a.srv.agent.RPC("ACL.WhoAmI", &args, &reply); err != nil { + a.srv.logger.Error("error authenticating built API request", "error", err, "url", req.URL, "method", req.Method) + resp.WriteHeader(500) + resp.Write([]byte("Server error authenticating request\n")) + return + } + + // Require an acl token or workload identity + if reply.Identity == nil || (reply.Identity.ACLToken == nil && reply.Identity.Claims == nil) { + a.srv.logger.Debug("Failed to authenticated Task API request", "method", req.Method, "url", req.URL) + resp.WriteHeader(http.StatusForbidden) + resp.Write([]byte(http.StatusText(http.StatusForbidden))) + return + } + + a.srv.logger.Trace("Authenticated request", "id", reply.Identity, "method", req.Method, "url", req.URL) + a.wrapped.ServeHTTP(resp, req) +} diff --git a/command/agent/job_endpoint.go b/command/agent/job_endpoint.go index 47f5fce9c..407e782b6 100644 --- a/command/agent/job_endpoint.go +++ b/command/agent/job_endpoint.go @@ -819,7 +819,7 @@ func (s *HTTPServer) apiJobAndRequestToStructs(job *api.Job, req *http.Request, queryRegion := req.URL.Query().Get("region") requestRegion, jobRegion := regionForJob( - job, queryRegion, writeReq.Region, s.agent.config.Region, + job, queryRegion, writeReq.Region, s.agent.GetConfig().Region, ) sJob := ApiJobToStructJob(job) diff --git a/command/agent/metrics_endpoint.go b/command/agent/metrics_endpoint.go index 7233492ae..90686cb3e 100644 --- a/command/agent/metrics_endpoint.go +++ b/command/agent/metrics_endpoint.go @@ -25,14 +25,14 @@ func (s *HTTPServer) MetricsRequest(resp http.ResponseWriter, req *http.Request) // Only return Prometheus formatted metrics if the user has enabled // this functionality. - if !s.agent.config.Telemetry.PrometheusMetrics { + if !s.agent.GetConfig().Telemetry.PrometheusMetrics { return nil, CodedError(http.StatusUnsupportedMediaType, "Prometheus is not enabled") } s.prometheusHandler().ServeHTTP(resp, req) return nil, nil } - return s.agent.InmemSink.DisplayMetrics(resp, req) + return s.agent.GetMetricsSink().DisplayMetrics(resp, req) } func (s *HTTPServer) prometheusHandler() http.Handler { diff --git a/command/agent/variable_endpoint.go b/command/agent/variable_endpoint.go index 17f55ac9c..bbf8b03bc 100644 --- a/command/agent/variable_endpoint.go +++ b/command/agent/variable_endpoint.go @@ -16,7 +16,8 @@ func (s *HTTPServer) VariablesListRequest(resp http.ResponseWriter, req *http.Re args := structs.VariablesListRequest{} if s.parse(resp, req, &args.Region, &args.QueryOptions) { - return nil, nil + //TODO(schmichael) shouldn't we return something here?! + return nil, CodedError(http.StatusBadRequest, "failed to parse parameters") } var out structs.VariablesListResponse diff --git a/e2e/workload_id/input/api-auth.nomad.hcl b/e2e/workload_id/input/api-auth.nomad.hcl new file mode 100644 index 000000000..dae134697 --- /dev/null +++ b/e2e/workload_id/input/api-auth.nomad.hcl @@ -0,0 +1,99 @@ +job "api-auth" { + datacenters = ["dc1"] + type = "batch" + + constraint { + attribute = "${attr.kernel.name}" + value = "linux" + } + + group "api-auth" { + + # none task should get a 401 response + task "none" { + driver = "docker" + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-v", + "localhost/v1/agent/health", + ] + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # bad task should get a 403 response + task "bad" { + driver = "docker" + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-H", "X-Nomad-Token: 37297754-3b87-41da-9ac7-d98fd934deed", + "-v", + "localhost/v1/agent/health", + ] + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # docker-wid task should succeed due to using workload identity + task "docker-wid" { + driver = "docker" + + config { + image = "curlimages/curl:7.87.0" + args = [ + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-H", "Authorization: Bearer ${NOMAD_TOKEN}", + "-v", + "localhost/v1/agent/health", + ] + } + + identity { + env = true + } + + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + + # exec-wid task should succeed due to using workload identity + task "exec-wid" { + driver = "exec" + + config { + command = "curl" + args = [ + "-H", "Authorization: Bearer ${NOMAD_TOKEN}", + "--unix-socket", "${NOMAD_SECRETS_DIR}/api.sock", + "-v", + "localhost/v1/agent/health", + ] + } + + identity { + env = true + } + + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + } +} diff --git a/e2e/workload_id/input/api-win.nomad.hcl b/e2e/workload_id/input/api-win.nomad.hcl new file mode 100644 index 000000000..7552500f3 --- /dev/null +++ b/e2e/workload_id/input/api-win.nomad.hcl @@ -0,0 +1,36 @@ +job "api-win" { + datacenters = ["dc1"] + type = "batch" + + constraint { + attribute = "${attr.kernel.name}" + value = "windows" + } + + constraint { + attribute = "${attr.cpu.arch}" + value = "amd64" + } + + group "api-win" { + + task "win" { + driver = "raw_exec" + config { + command = "powershell" + args = ["local/curl-7.87.0_4-win64-mingw/bin/curl.exe -H \"Authorization: Bearer $env:NOMAD_TOKEN\" --unix-socket $env:NOMAD_SECRETS_DIR/api.sock -v localhost:4646/v1/agent/health"] + } + artifact { + source = "https://curl.se/windows/dl-7.87.0_4/curl-7.87.0_4-win64-mingw.zip" + } + identity { + env = true + } + resources { + cpu = 16 + memory = 32 + disk = 64 + } + } + } +} diff --git a/e2e/workload_id/taskapi_test.go b/e2e/workload_id/taskapi_test.go new file mode 100644 index 000000000..3c636d367 --- /dev/null +++ b/e2e/workload_id/taskapi_test.go @@ -0,0 +1,111 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "testing" + + "github.com/hashicorp/nomad/e2e/e2eutil" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/shoenig/test" + "github.com/shoenig/test/must" +) + +// TestTaskAPI runs subtets exercising the Task API related functionality. +// Bundled with Workload Identity as that's a prereq for the Task API to work. +func TestTaskAPI(t *testing.T) { + nomad := e2eutil.NomadClient(t) + + e2eutil.WaitForLeader(t, nomad) + e2eutil.WaitForNodesReady(t, nomad, 1) + + t.Run("testTaskAPI_Auth", testTaskAPIAuth) + t.Run("testTaskAPI_Windows", testTaskAPIWindows) +} + +func testTaskAPIAuth(t *testing.T) { + nomad := e2eutil.NomadClient(t) + jobID := "api-auth-" + uuid.Short() + jobIDs := []string{jobID} + t.Cleanup(e2eutil.CleanupJobsAndGC(t, &jobIDs)) + + // start job + allocs := e2eutil.RegisterAndWaitForAllocs(t, nomad, "./input/api-auth.nomad.hcl", jobID, "") + must.Len(t, 1, allocs) + allocID := allocs[0].ID + + // wait for batch alloc to complete + alloc := e2eutil.WaitForAllocStopped(t, nomad, allocID) + must.Eq(t, alloc.ClientStatus, "complete") + + assertions := []struct { + task string + suffix string + }{ + { + task: "none", + suffix: http.StatusText(http.StatusUnauthorized), + }, + { + task: "bad", + suffix: http.StatusText(http.StatusForbidden), + }, + { + task: "docker-wid", + suffix: `"ok":true}}`, + }, + { + task: "exec-wid", + suffix: `"ok":true}}`, + }, + } + + // Ensure the assertions and input file match + must.Len(t, len(assertions), alloc.Job.TaskGroups[0].Tasks, + must.Sprintf("test and jobspec mismatch")) + + for _, tc := range assertions { + logFile := fmt.Sprintf("alloc/logs/%s.stdout.0", tc.task) + fd, err := nomad.AllocFS().Cat(alloc, logFile, nil) + must.NoError(t, err) + logBytes, err := io.ReadAll(fd) + must.NoError(t, err) + logs := string(logBytes) + + ps := must.Sprintf("Task: %s Logs: <