From 4c8dcca59abc51d16bb74e2b464a950cc8dc59ea Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Fri, 19 Aug 2016 19:55:06 -0700 Subject: [PATCH] fixes --- nomad/node_endpoint.go | 27 +++++++++++++++++---------- nomad/vault.go | 2 +- nomad/vault_test.go | 5 +++++ testutil/vault.go | 2 +- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index da98a6916..7b6178947 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -982,23 +982,29 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, results := make(map[string]*vapi.Secret, len(args.Tasks)) for i := 0; i < handlers; i++ { g.Go(func() error { - task, ok := <-input - if !ok { - return nil - } + for { + select { + case task, ok := <-input: + if !ok { + return nil + } - secret, err := n.srv.vault.CreateToken(ctx, alloc, task) - if err != nil { - return fmt.Errorf("failed to create token for task %q: %v", task, err) - } + secret, err := n.srv.vault.CreateToken(ctx, alloc, task) + if err != nil { + return fmt.Errorf("failed to create token for task %q: %v", task, err) + } - results[task] = secret - return nil + results[task] = secret + case <-ctx.Done(): + return nil + } + } }) } // Send the input go func() { + defer close(input) for _, task := range args.Tasks { select { case <-ctx.Done(): @@ -1006,6 +1012,7 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest, case input <- task: } } + }() // Wait for everything to complete or for an error diff --git a/nomad/vault.go b/nomad/vault.go index 95e580182..1866505b4 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -421,7 +421,7 @@ func (v *vaultClient) Stop() { v.l.Lock() defer v.l.Unlock() - if !v.renewalRunning || !v.establishingConn { + if !v.renewalRunning && !v.establishingConn { return } diff --git a/nomad/vault_test.go b/nomad/vault_test.go index b6004a483..4d12fd6a7 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -35,6 +35,7 @@ func TestVaultClient_BadConfig(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } + defer client.Stop() if client.ConnectionEstablished() { t.Fatalf("bad") @@ -184,6 +185,7 @@ func TestVaultClient_LookupToken_Invalid(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() _, err = client.LookupToken(context.Background(), "foo") if err == nil || !strings.Contains(err.Error(), "disabled") { @@ -222,6 +224,7 @@ func TestVaultClient_LookupToken(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() waitForConnection(client, t) @@ -281,6 +284,7 @@ func TestVaultClient_LookupToken_RateLimit(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() client.setLimit(rate.Limit(1.0)) waitForConnection(client, t) @@ -334,6 +338,7 @@ func TestVaultClient_CreateToken_Root(t *testing.T) { if err != nil { t.Fatalf("failed to build vault client: %v", err) } + defer client.Stop() waitForConnection(client, t) diff --git a/testutil/vault.go b/testutil/vault.go index bb560cecd..1f449c73c 100644 --- a/testutil/vault.go +++ b/testutil/vault.go @@ -119,6 +119,6 @@ func (tv *TestVault) waitForAPI() { // getPort returns the next available port to bind Vault against func getPort() uint64 { p := vaultStartPort + vaultPortOffset - offset += 1 + vaultPortOffset += 1 return p }