Files
nomad/client/vaultclient/vaultclient_test.go
Tim Gross 18fdda6242 vault: fix namespace reset for clients with unset namespace (#23491)
The Vault "logical" API doesn't allow configuring the namespace on a per-request
basis. Instead, it's set on the client. Our `vaultclient` wrapper locks access
to the API client and sets the namespace (and token, if applicable) for each
request, and then resets the namespace and unlocks the API client.

The logic for resetting the namespace incorrectly assumed that if the Vault
configuration didn't set the namespace that it was canonicalized to the
non-empty string `"default"`. This results in the API client's namespace getting
"stuck" whenever a job uses a non-default namespace if the configuration value
is empty. Update the logic to always go back to the configuration, rather than
accepting the "previous" namespace from the caller.

This changeset also removes some long-dead code in the Vault client wrapper.

Fixes: https://github.com/hashicorp/nomad/issues/22230
Ref: https://hashicorp.atlassian.net/browse/NET-10207
2024-07-03 10:13:20 -04:00

710 lines
18 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package vaultclient
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"time"
josejwt "github.com/go-jose/go-jose/v3/jwt"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/widmgr"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/useragent"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
structsc "github.com/hashicorp/nomad/nomad/structs/config"
"github.com/hashicorp/nomad/testutil"
"github.com/hashicorp/vault/api"
vaultapi "github.com/hashicorp/vault/api"
"github.com/shoenig/test/must"
)
const (
jwtAuthMountPathTest = "jwt_test"
jwtAuthConfigTemplate = `
{
"jwks_url": "<<.JWKSURL>>",
"jwt_supported_algs": ["RS256", "EdDSA"],
"default_role": "nomad-workloads"
}
`
widVaultPolicyTemplate = `
path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}/*" {
capabilities = ["read"]
}
path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}" {
capabilities = ["read"]
}
path "secret/metadata/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/*" {
capabilities = ["list"]
}
path "secret/metadata/*" {
capabilities = ["list"]
}
`
widVaultRole = `
{
"role_type": "jwt",
"bound_audiences": "vault.io",
"user_claim": "/nomad_job_id",
"user_claim_json_pointer": true,
"claim_mappings": {
"nomad_namespace": "nomad_namespace",
"nomad_job_id": "nomad_job_id"
},
"token_ttl": "30m",
"token_type": "service",
"token_period": "72h",
"token_policies": ["nomad-workloads"]
}
`
)
func renderVaultTemplate(tmplStr string, data any) ([]byte, error) {
var buf bytes.Buffer
tmpl, err := template.New("policy").
Delims("<<", ">>").
Parse(tmplStr)
if err != nil {
return nil, fmt.Errorf("failed to parse policy template: %w", err)
}
err = tmpl.Execute(&buf, data)
if err != nil {
return nil, fmt.Errorf("failed to render policy template: %w", err)
}
return buf.Bytes(), nil
}
func setupVaultForWorkloadIdentity(v *testutil.TestVault, jwksURL string) error {
logical := v.Client.Logical()
sys := v.Client.Sys()
ctx := context.Background()
// Enable JWT auth method.
err := sys.EnableAuthWithOptions(jwtAuthMountPathTest, &api.MountInput{
Type: "jwt",
})
if err != nil {
return fmt.Errorf("failed to enable JWT auth method: %w", err)
}
secret, err := logical.Read(fmt.Sprintf("sys/auth/%s", jwtAuthMountPathTest))
jwtAuthAccessor := secret.Data["accessor"].(string)
// Write JWT auth method config.
jwtAuthConfigData := struct {
JWKSURL string
}{
JWKSURL: jwksURL,
}
jwtAuthConfig, err := renderVaultTemplate(jwtAuthConfigTemplate, jwtAuthConfigData)
if err != nil {
return err
}
_, err = logical.WriteBytesWithContext(ctx, fmt.Sprintf("auth/%s/config", jwtAuthMountPathTest), jwtAuthConfig)
if err != nil {
return fmt.Errorf("failed to write JWT auth method config: %w", err)
}
// Write Nomad workload policy.
data := struct {
JWTAuthAccessorID string
}{
JWTAuthAccessorID: jwtAuthAccessor,
}
policy, err := renderVaultTemplate(widVaultPolicyTemplate, data)
if err != nil {
return err
}
encoded := base64.StdEncoding.EncodeToString(policy)
policyReqBody := fmt.Sprintf(`{"policy": "%s"}`, encoded)
policyPath := "sys/policies/acl/nomad-workloads"
_, err = logical.WriteBytesWithContext(ctx, policyPath, []byte(policyReqBody))
if err != nil {
return fmt.Errorf("failed to write policy: %w", err)
}
// Write Nomad workload role.
rolePath := fmt.Sprintf("auth/%s/role/nomad-workloads", jwtAuthMountPathTest)
_, err = logical.WriteBytesWithContext(ctx, rolePath, []byte(widVaultRole))
if err != nil {
return fmt.Errorf("failed to write role: %w", err)
}
return nil
}
func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
ci.Parallel(t)
// Create signer and signed identities.
alloc := mock.MinAlloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Identities = []*structs.WorkloadIdentity{
{
Name: "vault_default",
Audience: []string{"vault.io"},
TTL: time.Second,
},
}
signer := widmgr.NewMockWIDSigner(task.Identities)
signedWIDs, err := signer.SignIdentities(1, []*structs.WorkloadIdentityRequest{
{
AllocID: alloc.ID,
WIHandle: structs.WIHandle{
IdentityName: task.Identities[0].Name,
WorkloadIdentifier: task.Name,
WorkloadType: structs.WorkloadTypeTask,
},
},
})
must.NoError(t, err)
must.Len(t, 1, signedWIDs)
// Setup test JWKS server.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
out, err := json.Marshal(signer.JSONWebKeySet())
if err != nil {
t.Errorf("failed to generate JWKS json response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprintln(w, string(out))
}))
defer ts.Close()
// Start and configure Vault cluster for JWT authentication.
v := testutil.NewTestVault(t)
defer v.Stop()
err = setupVaultForWorkloadIdentity(v, ts.URL)
must.NoError(t, err)
// Start Vault client.
logger := testlog.HCLogger(t)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.JWTAuthBackendPath = jwtAuthMountPathTest
c, err := NewVaultClient(v.Config, logger, nil)
must.NoError(t, err)
c.Start()
defer c.Stop()
// Derive Vault token using signed JWT.
jwtStr := signedWIDs[0].JWT
token, renewable, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Namespace: "default",
})
must.NoError(t, err)
must.NotEq(t, "", token)
must.True(t, renewable)
// Verify token has expected properties.
v.Client.SetToken(token)
s, err := v.Client.Logical().Read("auth/token/lookup-self")
must.NoError(t, err)
jwt, err := josejwt.ParseSigned(jwtStr)
must.NoError(t, err)
claims := make(map[string]any)
err = jwt.UnsafeClaimsWithoutVerification(&claims)
must.NoError(t, err)
must.Eq(t, "service", s.Data["type"].(string))
must.True(t, s.Data["renewable"].(bool))
must.SliceContains(t, s.Data["policies"].([]any), "nomad-workloads")
must.MapEq(t, map[string]any{
"nomad_namespace": claims["nomad_namespace"],
"nomad_job_id": claims["nomad_job_id"],
"role": "nomad-workloads",
}, s.Data["meta"].(map[string]any))
// Verify token has the expected permissions.
pathAllowed := fmt.Sprintf("secret/data/%s/%s/a", claims["nomad_namespace"], claims["nomad_job_id"])
pathDenied := "secret/data/denied"
s, err = v.Client.Logical().Write("sys/capabilities-self", map[string]any{
"paths": []string{pathAllowed, pathDenied},
})
must.NoError(t, err)
must.Eq(t, []any{"read"}, (s.Data[pathAllowed]).([]any))
must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any))
// Derive Vault token with non-existing role.
token, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Role: "test",
Namespace: "default",
})
must.ErrorContains(t, err, `role "test" could not be found`)
}
func TestVaultClient_TokenRenewals(t *testing.T) {
ci.Parallel(t)
v := testutil.NewTestVault(t)
defer v.Stop()
logger := testlog.HCLogger(t)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "4s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault Vault: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while to ensure that the renewal loop is active
time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
tcr := &vaultapi.TokenCreateRequest{
Policies: []string{"foo", "bar"},
TTL: "2s",
DisplayName: "derived-for-task",
Renewable: new(bool),
}
*tcr.Renewable = true
num := 5
tokens := make([]string, num)
for i := 0; i < num; i++ {
c.client.SetToken(v.Config.Token)
if err := c.client.SetAddress(v.Config.Addr); err != nil {
t.Fatal(err)
}
secret, err := c.client.Auth().Token().Create(tcr)
if err != nil {
t.Fatalf("failed to create vault token: %v", err)
}
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatal("failed to derive a wrapped vault token")
}
tokens[i] = secret.Auth.ClientToken
errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
go func(errCh <-chan error) {
for {
select {
case err := <-errCh:
must.NoError(t, err, must.Sprintf("unexpected error while renewing vault token"))
}
}
}(errCh)
}
c.lock.Lock()
length := c.heap.Length()
c.lock.Unlock()
if length != num {
t.Fatalf("bad: Heap length: expected: %d, actual: %d", num, length)
}
time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
for i := 0; i < num; i++ {
if err := c.StopRenewToken(tokens[i]); err != nil {
must.NoError(t, err)
}
}
c.lock.Lock()
length = c.heap.Length()
c.lock.Unlock()
if length != 0 {
t.Fatalf("bad: Heap length: expected: 0, actual: %d", length)
}
}
// TestVaultClient_NamespaceSupport tests that the Vault namespace Config, if present, will result in the
// namespace header being set on the created Vault Vault.
func TestVaultClient_NamespaceSupport(t *testing.T) {
ci.Parallel(t)
tr := true
testNs := "test-namespace"
logger := testlog.HCLogger(t)
conf := structsc.DefaultVaultConfig()
conf.Enabled = &tr
conf.Token = "testvaulttoken"
conf.Namespace = testNs
c, err := NewVaultClient(conf, logger, nil)
must.NoError(t, err)
must.Eq(t, testNs, c.client.Headers().Get(structs.VaultNamespaceHeaderName))
}
func TestVaultClient_Heap(t *testing.T) {
ci.Parallel(t)
tr := true
conf := structsc.DefaultVaultConfig()
conf.Enabled = &tr
conf.Token = "testvaulttoken"
conf.TaskTokenTTL = "10s"
logger := testlog.HCLogger(t)
c, err := NewVaultClient(conf, logger, nil)
if err != nil {
t.Fatal(err)
}
if c == nil {
t.Fatal("failed to create vault Vault")
}
now := time.Now()
renewalReq1 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id1",
increment: 10,
}
if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id1") {
t.Fatalf("id1 should have been tracked")
}
renewalReq2 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id2",
increment: 10,
}
if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id2") {
t.Fatalf("id2 should have been tracked")
}
renewalReq3 := &vaultClientRenewalRequest{
errCh: make(chan error, 1),
id: "id3",
increment: 10,
}
if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
t.Fatal(err)
}
if !c.isTracked("id3") {
t.Fatalf("id3 should have been tracked")
}
// Reading elements should yield id2, id1 and id3 in order
req, _ := c.nextRenewal()
if req != renewalReq2 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
}
if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
t.Fatal(err)
}
req, _ = c.nextRenewal()
if req != renewalReq1 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
}
if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
t.Fatal(err)
}
req, _ = c.nextRenewal()
if req != renewalReq3 {
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
}
if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id1"); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id2"); err != nil {
t.Fatal(err)
}
if err := c.StopRenewToken("id3"); err != nil {
t.Fatal(err)
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
if c.isTracked("id1") {
t.Fatalf("id1 should not have been tracked")
}
}
func TestVaultClient_RenewNonRenewableLease(t *testing.T) {
ci.Parallel(t)
v := testutil.NewTestVault(t)
defer v.Stop()
logger := testlog.HCLogger(t)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "4s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault Vault: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while to ensure that the renewal loop is active
time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
tcr := &vaultapi.TokenCreateRequest{
Policies: []string{"foo", "bar"},
TTL: "2s",
DisplayName: "derived-for-task",
Renewable: new(bool),
}
c.client.SetToken(v.Config.Token)
if err := c.client.SetAddress(v.Config.Addr); err != nil {
t.Fatal(err)
}
secret, err := c.client.Auth().Token().Create(tcr)
if err != nil {
t.Fatalf("failed to create vault token: %v", err)
}
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
t.Fatal("failed to derive a wrapped vault token")
}
_, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration)
if err == nil {
t.Fatalf("expected error, got nil")
} else if !strings.Contains(err.Error(), "lease is not renewable") {
t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err)
}
}
func TestVaultClient_RenewNonexistentLease(t *testing.T) {
ci.Parallel(t)
v := testutil.NewTestVault(t)
defer v.Stop()
logger := testlog.HCLogger(t)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
v.Config.TaskTokenTTL = "4s"
c, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault Vault: %v", err)
}
c.Start()
defer c.Stop()
// Sleep a little while to ensure that the renewal loop is active
time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
c.client.SetToken(v.Config.Token)
if err := c.client.SetAddress(v.Config.Addr); err != nil {
t.Fatal(err)
}
_, err = c.RenewToken(c.client.Token(), 10)
if err == nil {
t.Fatalf("expected error, got nil")
// The Vault error message changed between 0.10.2 and 1.0.1
} else if !strings.Contains(err.Error(), "lease not found") && !strings.Contains(err.Error(), "lease is not renewable") {
t.Fatalf("expected \"%s\" or \"%s\" in error message, got \"%v\"", "lease not found", "lease is not renewable", err.Error())
}
}
// TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal
// time is jittered.
func TestVaultClient_RenewalTime_Long(t *testing.T) {
ci.Parallel(t)
// highRoller is a randIntn func that always returns the max value
highRoller := func(n int) int {
return n - 1
}
// lowRoller is a randIntn func that always returns the min value (0)
lowRoller := func(int) int {
return 0
}
must.Eq(t, 39*time.Second, renewalTime(highRoller, 60))
must.Eq(t, 20*time.Second, renewalTime(lowRoller, 60))
must.Eq(t, 309*time.Second, renewalTime(highRoller, 600))
must.Eq(t, 290*time.Second, renewalTime(lowRoller, 600))
const days3 = 60 * 60 * 24 * 3
must.Eq(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3))
must.Eq(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3))
}
// TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal
// time is lease/2.
func TestVaultClient_RenewalTime_Short(t *testing.T) {
ci.Parallel(t)
dice := func(int) int {
t.Error("dice should not have been called")
panic("unreachable")
}
must.Eq(t, 29*time.Second, renewalTime(dice, 58))
must.Eq(t, 15*time.Second, renewalTime(dice, 30))
must.Eq(t, 1*time.Second, renewalTime(dice, 2))
}
func TestVaultClient_SetUserAgent(t *testing.T) {
ci.Parallel(t)
conf := structsc.DefaultVaultConfig()
conf.Enabled = pointer.Of(true)
logger := testlog.HCLogger(t)
c, err := NewVaultClient(conf, logger, nil)
must.NoError(t, err)
ua := c.client.Headers().Get("User-Agent")
must.Eq(t, useragent.String(), ua)
}
func TestVaultClient_RenewalConcurrent(t *testing.T) {
ci.Parallel(t)
// Create test server to mock the Vault API.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := vaultapi.Secret{
RequestID: uuid.Generate(),
LeaseID: uuid.Generate(),
Renewable: true,
Data: map[string]any{},
Auth: &vaultapi.SecretAuth{
ClientToken: uuid.Generate(),
Accessor: uuid.Generate(),
LeaseDuration: 300,
},
}
out, err := json.Marshal(resp)
if err != nil {
t.Errorf("failed to generate JWKS json response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprintln(w, string(out))
}))
defer ts.Close()
// Start Vault client.
conf := structsc.DefaultVaultConfig()
conf.Addr = ts.URL
conf.Enabled = pointer.Of(true)
vc, err := NewVaultClient(conf, testlog.HCLogger(t), nil)
must.NoError(t, err)
vc.Start()
// Renew token multiple times in parallel.
requests := 100
resultCh := make(chan any)
for i := 0; i < requests; i++ {
go func() {
_, err := vc.RenewToken("token", 30)
resultCh <- err
}()
}
// Collect results with timeout.
timer, stop := helper.NewSafeTimer(3 * time.Second)
defer stop()
for i := 0; i < requests; i++ {
select {
case got := <-resultCh:
must.Nil(t, got, must.Sprintf("token renewal error: %v", got))
case <-timer.C:
t.Fatal("timeout waiting for token renewal")
}
}
}
func TestVaultClient_NamespaceReset(t *testing.T) {
// Mock Vault API that always returns an error
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintln(w, "error")
}))
defer ts.Close()
conf := structsc.DefaultVaultConfig()
conf.Addr = ts.URL
conf.Enabled = pointer.Of(true)
for _, ns := range []string{"", "foo"} {
conf.Namespace = ns
vc, err := NewVaultClient(conf, testlog.HCLogger(t), nil)
must.NoError(t, err)
vc.Start()
_, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: "bogus",
Namespace: "bar",
})
must.Error(t, err)
must.Eq(t, ns, vc.client.Namespace(),
must.Sprintf("expected %q, not %q", ns, vc.client.Namespace()))
}
}