diff --git a/client/allocrunner/alloc_runner.go b/client/allocrunner/alloc_runner.go index eee412cc7..344b5204c 100644 --- a/client/allocrunner/alloc_runner.go +++ b/client/allocrunner/alloc_runner.go @@ -62,6 +62,10 @@ type allocRunner struct { // registering services and checks consulClient consul.ConsulServiceAPI + // sidsClient is the client used by the service identity hook for + // managing SI tokens + sidsClient consul.ServiceIdentityAPI + // vaultClient is the used to manage Vault tokens vaultClient vaultclient.VaultClient @@ -157,6 +161,7 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { alloc: alloc, clientConfig: config.ClientConfig, consulClient: config.Consul, + sidsClient: config.ConsulSI, vaultClient: config.Vault, tasks: make(map[string]*taskrunner.TaskRunner, len(tg.Tasks)), waitCh: make(chan struct{}), @@ -202,14 +207,16 @@ func NewAllocRunner(config *Config) (*allocRunner, error) { func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error { for _, task := range tasks { config := &taskrunner.Config{ - Alloc: ar.alloc, - ClientConfig: ar.clientConfig, - Task: task, - TaskDir: ar.allocDir.NewTaskDir(task.Name), - Logger: ar.logger, - StateDB: ar.stateDB, - StateUpdater: ar, - Consul: ar.consulClient, + Alloc: ar.alloc, + ClientConfig: ar.clientConfig, + Task: task, + TaskDir: ar.allocDir.NewTaskDir(task.Name), + Logger: ar.logger, + StateDB: ar.stateDB, + StateUpdater: ar, + Consul: ar.consulClient, + ConsulSI: ar.sidsClient, + Vault: ar.vaultClient, DeviceStatsReporter: ar.deviceStatsReporter, DeviceManager: ar.devicemanager, diff --git a/client/allocrunner/config.go b/client/allocrunner/config.go index 42cea978e..a9240b3a3 100644 --- a/client/allocrunner/config.go +++ b/client/allocrunner/config.go @@ -30,6 +30,9 @@ type Config struct { // Consul is the Consul client used to register task services and checks Consul consul.ConsulServiceAPI + // ConsulSI is the Consul client used to manage service identity tokens. + ConsulSI consul.ServiceIdentityAPI + // Vault is the Vault client to use to retrieve Vault tokens Vault vaultclient.VaultClient diff --git a/client/allocrunner/taskrunner/sids_hook.go b/client/allocrunner/taskrunner/sids_hook.go new file mode 100644 index 000000000..caa548cd9 --- /dev/null +++ b/client/allocrunner/taskrunner/sids_hook.go @@ -0,0 +1,193 @@ +package taskrunner + +import ( + "context" + "io/ioutil" + "os" + "path/filepath" + "sync" + "time" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/pkg/errors" +) + +const ( + // the name of this hook, used in logs + sidsHookName = "consul_sids" + + // sidsBackoffBaseline is the baseline time for exponential backoff when + // attempting to retrieve a Consul SI token + sidsBackoffBaseline = 5 * time.Second + + // sidsBackoffLimit is the limit of the exponential backoff when attempting + // to retrieve a Consul SI token + sidsBackoffLimit = 3 * time.Minute + + // sidsTokenFile is the name of the file holding the Consul SI token inside + // the task's secret directory + sidsTokenFile = "sids_token" + + // sidsTokenFilePerms is the level of file permissions granted on the file + // in the secrets directory for the task + sidsTokenFilePerms = 0440 +) + +type sidsHookConfig struct { + alloc *structs.Allocation + task *structs.Task + sidsClient consul.ServiceIdentityAPI + logger hclog.Logger +} + +// Service Identities hook for managing SI tokens of connect enabled tasks. +type sidsHook struct { + alloc *structs.Allocation + taskName string + sidsClient consul.ServiceIdentityAPI + logger hclog.Logger + + lock sync.Mutex + firstRun bool +} + +func newSIDSHook(c sidsHookConfig) *sidsHook { + return &sidsHook{ + alloc: c.alloc, + taskName: c.task.Name, + sidsClient: c.sidsClient, + logger: c.logger.Named(sidsHookName), + firstRun: true, + } +} + +func (h *sidsHook) Name() string { + return sidsHookName +} + +func (h *sidsHook) Prestart( + ctx context.Context, + req *interfaces.TaskPrestartRequest, + _ *interfaces.TaskPrestartResponse) error { + + h.lock.Lock() + defer h.lock.Unlock() + + // do nothing if we have already done things + if h.earlyExit() { + return nil + } + + // optimistically try to recover token from disk + token, err := h.recoverToken(req.TaskDir.SecretsDir) + if err != nil { + return err + } + + // need to ask for a new SI token & persist it to disk + if token == "" { + if token, err = h.deriveSIToken(ctx); err != nil { + return err + } + if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil { + return err + } + } + + return nil +} + +// earlyExit returns true if the Prestart hook has already been executed during +// the instantiation of this task runner. +// +// assumes h is locked +func (h *sidsHook) earlyExit() bool { + if h.firstRun { + h.firstRun = false + return false + } + return true +} + +// writeToken writes token into the secrets directory for the task. +func (h *sidsHook) writeToken(dir string, token string) error { + tokenPath := filepath.Join(dir, sidsTokenFile) + if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil { + return errors.Wrap(err, "failed to write SI token") + } + return nil +} + +// recoverToken returns the token saved to disk in the secrets directory for the +// task if it exists, or the empty string if the file does not exist. an error +// is returned only for some other (e.g. disk IO) error. +func (h *sidsHook) recoverToken(dir string) (string, error) { + tokenPath := filepath.Join(dir, sidsTokenFile) + token, err := ioutil.ReadFile(tokenPath) + if err != nil { + if !os.IsNotExist(err) { + return "", errors.Wrap(err, "failed to recover SI token") + } + h.logger.Trace("no pre-existing SI token to recover", "task", h.taskName) + return "", nil // token file does not exist yet + } + h.logger.Trace("recovered pre-existing SI token", "task", h.taskName) + return string(token), nil +} + +// deriveSIToken spawns and waits on a goroutine which will make attempts to +// derive an SI token until a token is successfully created, or ctx is signaled +// done. +func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) { + tokenCh := make(chan string) + + // keep trying to get the token in the background + go h.tryDerive(ctx, tokenCh) + + // wait until we get a token, or we get a signal to quit + for { + select { + case token := <-tokenCh: + return token, nil + case <-ctx.Done(): + return "", ctx.Err() + } + } +} + +// tryDerive loops forever until a token is created, or ctx is done. +func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- string) { + for i := 0; backoff(ctx, i); i++ { + tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.taskName}) + if err != nil { + h.logger.Warn("failed to derive SI token", "attempt", i, "error", err) + continue + } + ch <- tokens[h.taskName] + return + } +} + +func backoff(ctx context.Context, i int) bool { + next := computeBackoff(i) + select { + case <-ctx.Done(): + return false + case <-time.After(next): + return true + } +} + +func computeBackoff(attempt int) time.Duration { + switch { + case attempt <= 0: + return 0 + case attempt >= 4: + return sidsBackoffLimit + default: + return (1 << (2 * uint(attempt))) * sidsBackoffBaseline + } +} diff --git a/client/allocrunner/taskrunner/sids_hook_test.go b/client/allocrunner/taskrunner/sids_hook_test.go new file mode 100644 index 000000000..cbb056634 --- /dev/null +++ b/client/allocrunner/taskrunner/sids_hook_test.go @@ -0,0 +1,122 @@ +package taskrunner + +import ( + "context" + "io/ioutil" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/client/allocrunner/interfaces" + "github.com/hashicorp/nomad/client/consul" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +var _ interfaces.TaskPrestartHook = (*sidsHook)(nil) + +func tmpDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "sids-") + require.NoError(t, err) + return dir +} + +func cleanupDir(t *testing.T, dir string) { + err := os.RemoveAll(dir) + require.NoError(t, err) +} + +func TestSIDSHook_recoverToken(t *testing.T) { + t.Parallel() + + r := require.New(t) + secrets := tmpDir(t) + defer cleanupDir(t, secrets) + + h := newSIDSHook(sidsHookConfig{ + task: &structs.Task{Name: "task1"}, + logger: testlog.HCLogger(t), + }) + + expected := "12345678-1234-1234-1234-1234567890" + err := h.writeToken(secrets, expected) + r.NoError(err) + + token, err := h.recoverToken(secrets) + r.NoError(err) + r.Equal(expected, token) +} + +func TestSIDSHook_recoverToken_empty(t *testing.T) { + t.Parallel() + + r := require.New(t) + secrets := tmpDir(t) + defer cleanupDir(t, secrets) + + h := newSIDSHook(sidsHookConfig{ + task: &structs.Task{Name: "task1"}, + logger: testlog.HCLogger(t), + }) + + token, err := h.recoverToken(secrets) + r.NoError(err) + r.Empty(token) +} + +func TestSIDSHook_deriveSIToken(t *testing.T) { + t.Parallel() + + r := require.New(t) + secrets := tmpDir(t) + defer cleanupDir(t, secrets) + + h := newSIDSHook(sidsHookConfig{ + alloc: &structs.Allocation{ID: "a1"}, + task: &structs.Task{Name: "task1"}, + logger: testlog.HCLogger(t), + sidsClient: consul.NewMockServiceIdentitiesClient(), + }) + + ctx := context.Background() + token, err := h.deriveSIToken(ctx) + r.NoError(err) + r.True(helper.IsUUID(token)) +} + +func TestSIDSHook_computeBackoff(t *testing.T) { + t.Parallel() + + try := func(i int, exp time.Duration) { + result := computeBackoff(i) + require.Equal(t, exp, result) + } + + try(0, time.Duration(0)) + try(1, 20*time.Second) + try(2, 80*time.Second) + try(3, 320*time.Second) + try(4, sidsBackoffLimit) +} + +func TestSIDSHook_backoff(t *testing.T) { + t.Parallel() + r := require.New(t) + + ctx := context.Background() + stop := !backoff(ctx, 0) + r.False(stop) +} + +func TestSIDSHook_backoffKilled(t *testing.T) { + t.Parallel() + r := require.New(t) + + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + + stop := !backoff(ctx, 1000) + r.True(stop) +} diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index bb6c6a445..2e4e09445 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -50,7 +50,7 @@ const ( // giving up and potentially leaking resources. killFailureLimit = 5 - // triggerUpdatechCap is the capacity for the triggerUpdateCh used for + // triggerUpdateChCap is the capacity for the triggerUpdateCh used for // triggering updates. It should be exactly 1 as even if multiple // updates have come in since the last one was handled, we only need to // handle the last one. @@ -158,6 +158,10 @@ type TaskRunner struct { // registering services and checks consulClient consul.ConsulServiceAPI + // sidsClient is the client used by the service identity hook for managing + // service identity tokens + siClient consul.ServiceIdentityAPI + // vaultClient is the client to use to derive and renew Vault tokens vaultClient vaultclient.VaultClient @@ -210,11 +214,16 @@ type TaskRunner struct { type Config struct { Alloc *structs.Allocation ClientConfig *config.Config - Consul consul.ConsulServiceAPI Task *structs.Task TaskDir *allocdir.TaskDir Logger log.Logger + // Consul is the client to use for managing Consul service registrations + Consul consul.ConsulServiceAPI + + // ConsulSI is the client to use for managing Consul SI tokens + ConsulSI consul.ServiceIdentityAPI + // Vault is the client to use to derive and renew Vault tokens Vault vaultclient.VaultClient @@ -271,6 +280,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) { taskLeader: config.Task.Leader, envBuilder: envBuilder, consulClient: config.Consul, + siClient: config.ConsulSI, vaultClient: config.Vault, state: tstate, localState: state.NewLocalState(), diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index 2f5723197..46c5b5e89 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -45,7 +45,7 @@ func (h *hookResources) getMounts() []*drivers.MountConfig { return h.Mounts } -// initHooks intializes the tasks hooks. +// initHooks initializes the tasks hooks. func (tr *TaskRunner) initHooks() { hookLogger := tr.logger.Named("task_hook") task := tr.Task() @@ -96,7 +96,7 @@ func (tr *TaskRunner) initHooks() { })) } - // If there are any services, add the hook + // If there are any services, add the service hook if len(task.Services) != 0 { tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{ alloc: tr.Alloc(), @@ -107,6 +107,15 @@ func (tr *TaskRunner) initHooks() { })) } + if usesConnect(tr.alloc.Job.LookupTaskGroup(tr.alloc.TaskGroup)) { + tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{ + alloc: tr.Alloc(), + task: tr.Task(), + sidsClient: tr.siClient, + logger: hookLogger, + })) + } + // If there are any script checks, add the hook scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{ alloc: tr.Alloc(), @@ -117,6 +126,15 @@ func (tr *TaskRunner) initHooks() { tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook) } +func usesConnect(tg *structs.TaskGroup) bool { + for _, service := range tg.Services { + if service.Connect != nil { + return true + } + } + return false +} + func (tr *TaskRunner) emitHookError(err error, hookName string) { var taskEvent *structs.TaskEvent if herr, ok := err.(*hookError); ok { @@ -131,7 +149,7 @@ func (tr *TaskRunner) emitHookError(err error, hookName string) { // prestart is used to run the runners prestart hooks. func (tr *TaskRunner) prestart() error { - // Determine if the allocation is terminaland we should avoid running + // Determine if the allocation is terminal and we should avoid running // prestart hooks. alloc := tr.Alloc() if alloc.TerminalStatus() { diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 25124c3fd..d2dad4534 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -97,10 +97,11 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri conf := &Config{ Alloc: alloc, ClientConfig: clientConf, - Consul: consulapi.NewMockConsulServiceClient(t, logger), Task: thisTask, TaskDir: taskDir, Logger: clientConf.Logger, + Consul: consulapi.NewMockConsulServiceClient(t, logger), + ConsulSI: consulapi.NewMockServiceIdentitiesClient(), Vault: vaultclient.NewMockVaultClient(), StateDB: cstate.NoopDB{}, StateUpdater: NewMockTaskStateUpdater(), @@ -1085,6 +1086,76 @@ func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { require.True(t, state.Failed, pretty.Sprint(state)) } +func TestTaskRunner_BlockForSIDS(t *testing.T) { + t.Parallel() + r := require.New(t) + + // setup a connect enabled batch job that wants to exit immediately, which + // makes testing the prestart lifecycle easier + alloc := mock.BatchAlloc() + tg := alloc.Job.TaskGroups[0] + tg.Tasks[0].Config = map[string]interface{}{"run_for": "0s"} + tg.Services = []*structs.Service{{ + Name: "testconnect", + PortLabel: "9999", + Connect: &structs.ConsulConnect{ + SidecarService: &structs.ConsulSidecarService{}, + }}, + } + taskName := tg.Tasks[0].Name + + trConfig, cleanup := testTaskRunnerConfig(t, alloc, taskName) + defer cleanup() + + // control when we get a Consul SI token + token := "12345678-1234-1234-1234-1234567890" + waitCh := make(chan struct{}) + deriveFn := func(*structs.Allocation, []string) (map[string]string, error) { + <-waitCh + return map[string]string{taskName: token}, nil + } + siClient := trConfig.ConsulSI.(*consulapi.MockServiceIdentitiesClient) + siClient.DeriveTokenFn = deriveFn + + // start the task runner + tr, err := NewTaskRunner(trConfig) + r.NoError(err) + defer tr.Kill(context.Background(), structs.NewTaskEvent("cleanup")) + go tr.Run() + + // assert task runner blocks on SI token + select { + case <-tr.WaitCh(): + r.Fail("task_runner exited before si unblocked") + case <-time.After(100 * time.Millisecond): + } + + // assert task state is still pending + r.Equal(structs.TaskStatePending, tr.TaskState().State) + + // unblock service identity token + close(waitCh) + + // task runner should exit now that it has been unblocked and it is a batch + // job with a zero sleep time + select { + case <-tr.WaitCh(): + case <-time.After(15 * time.Second * time.Duration(testutil.TestMultiplier())): + r.Fail("timed out waiting for batch task to exist") + } + + // assert task exited successfully + finalState := tr.TaskState() + r.Equal(structs.TaskStateDead, finalState.State) + r.False(finalState.Failed) + + // assert the token is on disk + tokenPath := filepath.Join(trConfig.TaskDir.SecretsDir, sidsTokenFile) + data, err := ioutil.ReadFile(tokenPath) + r.NoError(err) + r.Equal(token, string(data)) +} + // TestTaskRunner_BlockForVault asserts tasks do not start until a vault token // is derived. func TestTaskRunner_BlockForVault(t *testing.T) { diff --git a/client/allocrunner/testing.go b/client/allocrunner/testing.go index 75806644b..02751687b 100644 --- a/client/allocrunner/testing.go +++ b/client/allocrunner/testing.go @@ -60,6 +60,7 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*Config, fu ClientConfig: clientConf, StateDB: state.NoopDB{}, Consul: consul.NewMockConsulServiceClient(t, clientConf.Logger), + ConsulSI: consul.NewMockServiceIdentitiesClient(), Vault: vaultclient.NewMockVaultClient(), StateUpdater: &MockStateUpdater{}, PrevAllocWatcher: allocwatcher.NoopPrevAlloc{}, diff --git a/client/client.go b/client/client.go index f23cc1efd..4cf875059 100644 --- a/client/client.go +++ b/client/client.go @@ -1,7 +1,6 @@ package client import ( - "errors" "fmt" "io/ioutil" "net" @@ -46,6 +45,7 @@ import ( "github.com/hashicorp/nomad/plugins/device" "github.com/hashicorp/nomad/plugins/drivers" vaultapi "github.com/hashicorp/vault/api" + "github.com/pkg/errors" "github.com/shirou/gopsutil/host" ) @@ -236,6 +236,10 @@ type Client struct { // Shutdown() blocks on Wait() after closing shutdownCh. shutdownGroup group + // tokensClient is Nomad Client's custom Consul client for requesting Consul + // Service Identity tokens through Nomad Server. + tokensClient consulApi.ServiceIdentityAPI + // vaultClient is used to interact with Vault for token and secret renewals vaultClient vaultclient.VaultClient @@ -445,6 +449,10 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulServic } } + if err := c.setupConsulTokenClient(); err != nil { + return nil, errors.Wrap(err, "failed to setup consul tokens client") + } + // Setup the vault client for token and secret renewals if err := c.setupVaultClient(); err != nil { return nil, fmt.Errorf("failed to setup vault client: %v", err) @@ -1042,6 +1050,7 @@ func (c *Client) restoreState() error { StateUpdater: c, DeviceStatsReporter: c, Consul: c.consulService, + ConsulSI: c.tokensClient, // todo(shoenig), keep plumbing! Vault: c.vaultClient, PrevAllocWatcher: prevAllocWatcher, PrevAllocMigrator: prevAllocMigrator, @@ -2295,6 +2304,7 @@ func (c *Client) addAlloc(alloc *structs.Allocation, migrateToken string) error ClientConfig: c.configCopy, StateDB: c.stateDB, Consul: c.consulService, + ConsulSI: c.tokensClient, // todo(shoenig), keep plumbing! Vault: c.vaultClient, StateUpdater: c, DeviceStatsReporter: c, @@ -2317,6 +2327,14 @@ func (c *Client) addAlloc(alloc *structs.Allocation, migrateToken string) error return nil } +// setupConsulTokenClient configures a tokenClient for managing consul service +// identity tokens. +func (c *Client) setupConsulTokenClient() error { + tc := consulApi.NewIdentitiesClient(c.logger, c.deriveSIToken) + c.tokensClient = tc + return nil +} + // setupVaultClient creates an object to periodically renew tokens and secrets // with vault. func (c *Client) setupVaultClient() error { @@ -2342,33 +2360,10 @@ func (c *Client) setupVaultClient() error { // client and returns a map of unwrapped tokens, indexed by the task name. func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vclient *vaultapi.Client) (map[string]string, error) { vlogger := c.logger.Named("vault") - if alloc == nil { - return nil, fmt.Errorf("nil allocation") - } - if taskNames == nil || len(taskNames) == 0 { - return nil, fmt.Errorf("missing task names") - } - - group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) - if group == nil { - return nil, fmt.Errorf("group name in allocation is not present in job") - } - - verifiedTasks := []string{} - // Check if the given task names actually exist in the allocation - for _, taskName := range taskNames { - found := false - for _, task := range group.Tasks { - if task.Name == taskName { - found = true - } - } - if !found { - vlogger.Error("task not found in the allocation", "task_name", taskName) - return nil, fmt.Errorf("task %q not found in the allocation", taskName) - } - verifiedTasks = append(verifiedTasks, taskName) + verifiedTasks, err := verifiedTasks(vlogger, alloc, taskNames) + if err != nil { + return nil, err } // DeriveVaultToken of nomad server can take in a set of tasks and @@ -2443,6 +2438,89 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli return unwrappedTokens, nil } +// deriveSIToken takes an allocation and a set of tasks and derives Consul +// Service Identity tokens for each of the tasks by requesting them from the +// Nomad Server. +func (c *Client) deriveSIToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + tasks, err := verifiedTasks(c.logger, alloc, taskNames) + if err != nil { + return nil, err + } + + req := &structs.DeriveSITokenRequest{ + NodeID: c.NodeID(), + AllocID: alloc.ID, + Tasks: tasks, + QueryOptions: structs.QueryOptions{Region: c.Region()}, + } + + // Nicely ask Nomad Server for the tokens. + var resp structs.DeriveSITokenResponse + if err := c.RPC("Node.DeriveSIToken", &req, &resp); err != nil { + c.logger.Error("error making derive token RPC", "error", err) + return nil, fmt.Errorf("DeriveSIToken RPC failed: %v", err) + } + if err := resp.Error; err != nil { + c.logger.Error("error deriving SI tokens", "error", err) + return nil, structs.NewWrappedServerError(err) + } + if len(resp.Tokens) == 0 { + c.logger.Error("error deriving SI tokens", "error", "invalid_response") + return nil, fmt.Errorf("failed to derive SI tokens: invalid response") + } + + // NOTE: Unlike with the Vault integration, Nomad Server replies with the + // actual Consul SI token (.SecretID), because otherwise each Nomad + // Client would need to be blessed with 'acl:write' permissions to read the + // secret value given the .AccessorID, which does not fit well in the Consul + // security model. + // + // https://www.consul.io/api/acl/tokens.html#read-a-token + // https://www.consul.io/docs/internals/security.html + + m := helper.CopyMapStringString(resp.Tokens) + return m, nil +} + +// verifiedTasks asserts each task in taskNames actually exists in the given alloc, +// otherwise an error is returned. +func verifiedTasks(logger hclog.Logger, alloc *structs.Allocation, taskNames []string) ([]string, error) { + if alloc == nil { + return nil, fmt.Errorf("nil allocation") + } + + if len(taskNames) == 0 { + return nil, fmt.Errorf("missing task names") + } + + group := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + if group == nil { + return nil, fmt.Errorf("group name in allocation is not present in job") + } + + verifiedTasks := make([]string, 0, len(taskNames)) + + // confirm the requested task names actually exist in the allocation + for _, taskName := range taskNames { + if !taskIsPresent(taskName, group.Tasks) { + logger.Error("task not found in the allocation", "task_name", taskName) + return nil, fmt.Errorf("task %q not found in allocation", taskName) + } + verifiedTasks = append(verifiedTasks, taskName) + } + + return verifiedTasks, nil +} + +func taskIsPresent(taskName string, tasks []*structs.Task) bool { + for _, task := range tasks { + if task.Name == taskName { + return true + } + } + return false +} + // triggerDiscovery causes a Consul discovery to begin (if one hasn't already) func (c *Client) triggerDiscovery() { select { diff --git a/client/client_test.go b/client/client_test.go index 204c067ec..95bf51069 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1620,3 +1620,71 @@ func TestClient_hasLocalState(t *testing.T) { require.True(t, c.hasLocalState(alloc)) }) } + +func Test_verifiedTasks(t *testing.T) { + t.Parallel() + logger := testlog.HCLogger(t) + + // produce a result and check against expected tasks and/or error output + try := func(t *testing.T, a *structs.Allocation, tasks, expTasks []string, expErr string) { + result, err := verifiedTasks(logger, a, tasks) + if expErr != "" { + require.EqualError(t, err, expErr) + } else { + require.NoError(t, err) + require.Equal(t, expTasks, result) + } + } + + // create an alloc with TaskGroup=g1, tasks configured given g1Tasks + alloc := func(g1Tasks []string) *structs.Allocation { + var tasks []*structs.Task + for _, taskName := range g1Tasks { + tasks = append(tasks, &structs.Task{Name: taskName}) + } + + return &structs.Allocation{ + Job: &structs.Job{ + TaskGroups: []*structs.TaskGroup{ + {Name: "g0", Tasks: []*structs.Task{{Name: "g0t1"}}}, + {Name: "g1", Tasks: tasks}, + }, + }, + TaskGroup: "g1", + } + } + + t.Run("nil alloc", func(t *testing.T) { + tasks := []string{"g1t1"} + try(t, nil, tasks, nil, "nil allocation") + }) + + t.Run("missing task names", func(t *testing.T) { + var tasks []string + tgTasks := []string{"g1t1"} + try(t, alloc(tgTasks), tasks, nil, "missing task names") + }) + + t.Run("missing group", func(t *testing.T) { + tasks := []string{"g1t1"} + a := alloc(tasks) + a.TaskGroup = "other" + try(t, a, tasks, nil, "group name in allocation is not present in job") + }) + + t.Run("nonexistent task", func(t *testing.T) { + tasks := []string{"missing"} + try(t, alloc([]string{"task1"}), tasks, nil, `task "missing" not found in allocation`) + }) + + t.Run("matching task", func(t *testing.T) { + tasks := []string{"g1t1"} + try(t, alloc(tasks), tasks, tasks, "") + }) + + t.Run("matching task subset", func(t *testing.T) { + tasks := []string{"g1t1", "g1t3"} + tgTasks := []string{"g1t1", "g1t2", "g1t3"} + try(t, alloc(tgTasks), tasks, tasks, "") + }) +} diff --git a/client/consul/consul.go b/client/consul/consul.go index 251165ff3..f8348c220 100644 --- a/client/consul/consul.go +++ b/client/consul/consul.go @@ -2,14 +2,37 @@ package consul import ( "github.com/hashicorp/nomad/command/agent/consul" + "github.com/hashicorp/nomad/nomad/structs" ) // ConsulServiceAPI is the interface the Nomad Client uses to register and // remove services and checks from Consul. type ConsulServiceAPI interface { + // RegisterWorkload with Consul. Adds all service entries and checks to Consul. RegisterWorkload(*consul.WorkloadServices) error + + // RemoveWorkload from Consul. Removes all service entries and checks. RemoveWorkload(*consul.WorkloadServices) + + // UpdateWorkload in Consul. Does not alter the service if only checks have + // changed. UpdateWorkload(old, newTask *consul.WorkloadServices) error + + // AllocRegistrations returns the registrations for the given allocation. AllocRegistrations(allocID string) (*consul.AllocRegistration, error) + + // UpdateTTL is used to update the TTL of a check. UpdateTTL(id, output, status string) error } + +// TokenDeriverFunc takes an allocation and a set of tasks and derives a +// service identity token for each. Requests go through nomad server. +type TokenDeriverFunc func(*structs.Allocation, []string) (map[string]string, error) + +// ServiceIdentityAPI is the interface the Nomad Client uses to request Consul +// Service Identity tokens through Nomad Server. +type ServiceIdentityAPI interface { + // DeriveSITokens contacts the nomad server and requests consul service + // identity tokens be generated for tasks in the allocation. + DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) +} diff --git a/client/consul/consul_testing.go b/client/consul/consul_testing.go index 75307eae8..0384e4c95 100644 --- a/client/consul/consul_testing.go +++ b/client/consul/consul_testing.go @@ -6,7 +6,6 @@ import ( "time" log "github.com/hashicorp/go-hclog" - "github.com/hashicorp/nomad/command/agent/consul" testing "github.com/mitchellh/go-testing-interface" ) diff --git a/client/consul/identities.go b/client/consul/identities.go new file mode 100644 index 000000000..e07dfaf33 --- /dev/null +++ b/client/consul/identities.go @@ -0,0 +1,32 @@ +package consul + +import ( + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/nomad/structs" +) + +// Implementation of ServiceIdentityAPI used to interact with Nomad Server from +// Nomad Client for acquiring Consul Service Identity tokens. +// +// This client is split from the other consul client(s) to avoid a circular +// dependency between themselves and client.Client +type identitiesClient struct { + tokenDeriver TokenDeriverFunc + logger hclog.Logger +} + +func NewIdentitiesClient(logger hclog.Logger, tokenDeriver TokenDeriverFunc) *identitiesClient { + return &identitiesClient{ + tokenDeriver: tokenDeriver, + logger: logger, + } +} + +func (c *identitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) { + tokens, err := c.tokenDeriver(alloc, tasks) + if err != nil { + c.logger.Error("error deriving SI token", "error", err, "alloc_id", alloc.ID, "task_names", tasks) + return nil, err + } + return tokens, nil +} diff --git a/client/consul/identities_test.go b/client/consul/identities_test.go new file mode 100644 index 000000000..e56000d4a --- /dev/null +++ b/client/consul/identities_test.go @@ -0,0 +1,31 @@ +package consul + +import ( + "errors" + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +func TestCSI_DeriveTokens(t *testing.T) { + logger := testlog.HCLogger(t) + dFunc := func(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + return map[string]string{"a": "b"}, nil + } + tc := NewIdentitiesClient(logger, dFunc) + tokens, err := tc.DeriveSITokens(nil, nil) + require.NoError(t, err) + require.Equal(t, map[string]string{"a": "b"}, tokens) +} + +func TestCSI_DeriveTokens_error(t *testing.T) { + logger := testlog.HCLogger(t) + dFunc := func(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + return nil, errors.New("some failure") + } + tc := NewIdentitiesClient(logger, dFunc) + _, err := tc.DeriveSITokens(&structs.Allocation{ID: "a1"}, nil) + require.Error(t, err) +} diff --git a/client/consul/identities_testing.go b/client/consul/identities_testing.go new file mode 100644 index 000000000..2d4258d25 --- /dev/null +++ b/client/consul/identities_testing.go @@ -0,0 +1,85 @@ +package consul + +import ( + "fmt" + "sync" + + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/structs" +) + +// MockServiceIdentitiesClient is used for testing the client for managing consul service +// identity tokens. +type MockServiceIdentitiesClient struct { + // deriveTokenErrors maps an allocation ID and tasks to an error when the + // token is derived + deriveTokenErrors map[string]map[string]error + + // DeriveTokenFn allows the caller to control the DeriveToken function. If + // not set an error is returned if found in DeriveTokenErrors and otherwise + // a token is generated and returned + DeriveTokenFn TokenDeriverFunc + + // lock around everything + lock sync.Mutex +} + +var _ ServiceIdentityAPI = (*MockServiceIdentitiesClient)(nil) + +// NewMockServiceIdentitiesClient returns a MockServiceIdentitiesClient for testing. +func NewMockServiceIdentitiesClient() *MockServiceIdentitiesClient { + return &MockServiceIdentitiesClient{ + deriveTokenErrors: make(map[string]map[string]error), + } +} + +func (mtc *MockServiceIdentitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) { + mtc.lock.Lock() + defer mtc.lock.Unlock() + + fmt.Println("MockServiceIdentitiesClient.DeriveSITokens running!") + + // if the DeriveTokenFn is explicitly set, use that + if mtc.DeriveTokenFn != nil { + return mtc.DeriveTokenFn(alloc, tasks) + } + + // generate a token for each task, unless the mock has an error ready for + // one or more of the tasks in which case return that + tokens := make(map[string]string, len(tasks)) + for _, task := range tasks { + if m, ok := mtc.deriveTokenErrors[alloc.ID]; ok { + if err, ok := m[task]; ok { + return nil, err + } + } + tokens[task] = uuid.Generate() + } + return tokens, nil +} + +func (mtc *MockServiceIdentitiesClient) SetDeriveTokenError(allocID string, tasks []string, err error) { + mtc.lock.Lock() + defer mtc.lock.Unlock() + + if _, ok := mtc.deriveTokenErrors[allocID]; !ok { + mtc.deriveTokenErrors[allocID] = make(map[string]error, 10) + } + + for _, task := range tasks { + mtc.deriveTokenErrors[allocID][task] = err + } +} + +func (mtc *MockServiceIdentitiesClient) DeriveTokenErrors() map[string]map[string]error { + mtc.lock.Lock() + defer mtc.lock.Unlock() + + m := make(map[string]map[string]error) + for aID, tasks := range mtc.deriveTokenErrors { + for task, err := range tasks { + m[aID][task] = err + } + } + return m +} diff --git a/client/vaultclient/vaultclient_testing.go b/client/vaultclient/vaultclient_testing.go index fea31a31d..08adaa119 100644 --- a/client/vaultclient/vaultclient_testing.go +++ b/client/vaultclient/vaultclient_testing.go @@ -67,6 +67,7 @@ func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, e vc.deriveTokenErrors = make(map[string]map[string]error, 10) } + // todo(shoenig): this seems like a bug if _, ok := vc.renewTokenErrors[allocID]; !ok { vc.deriveTokenErrors[allocID] = make(map[string]error, 10) } @@ -111,8 +112,10 @@ func (vc *MockVaultClient) StopRenewToken(token string) error { return nil } -func (vc *MockVaultClient) Start() {} -func (vc *MockVaultClient) Stop() {} +func (vc *MockVaultClient) Start() {} + +func (vc *MockVaultClient) Stop() {} + func (vc *MockVaultClient) GetConsulACL(string, string) (*vaultapi.Secret, error) { return nil, nil } // StoppedTokens tracks the tokens that have stopped renewing diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 8d6333c21..d71213ce0 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -920,7 +920,28 @@ type DeriveVaultTokenResponse struct { Tasks map[string]string // Error stores any error that occurred. Errors are stored here so we can - // communicate whether it is retriable + // communicate whether it is retryable + Error *RecoverableError + + QueryMeta +} + +// DeriveSITokenRequest is used to request Consul Service Identity tokens from +// the Nomad Server for the named tasks in the given allocation. +type DeriveSITokenRequest struct { + NodeID string + SecretID string + AllocID string + Tasks []string + QueryOptions +} + +type DeriveSITokenResponse struct { + // Tokens maps from Task Name to its associated SI token + Tokens map[string]string + + // Error stores any error that occurred. Errors are stored here so we can + // communicate whether it is retryable Error *RecoverableError QueryMeta