mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
When we renew Vault tokens, we use the lease duration to determine how often to renew. But we also set an `increment` value which is never updated from the initial 30s. For periodic tokens this is not a problem because the `increment` field is ignored on renewal. But for non-periodic tokens this prevents the token TTL from being properly incremented. This behavior has been in place since the initial Vault client implementation in #1606 but before the switch to workload identity most (all?) tokens being created were periodic tokens so this was never detected. Fix this bug by updating the request's `increment` field to the lease duration on each renewal. Also switch out a `time.After` call in backoff of the derive token caller with a safe timer so that we don't have to spawn a new goroutine per loop, and have tighter control over when that's GC'd. Ref: https://github.com/hashicorp/nomad/pull/1606 Ref: https://github.com/hashicorp/nomad/issues/25812
560 lines
14 KiB
Go
560 lines
14 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package vaultclient
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"text/template"
|
|
"time"
|
|
|
|
josejwt "github.com/go-jose/go-jose/v3/jwt"
|
|
"github.com/hashicorp/nomad/ci"
|
|
"github.com/hashicorp/nomad/client/widmgr"
|
|
"github.com/hashicorp/nomad/helper"
|
|
"github.com/hashicorp/nomad/helper/pointer"
|
|
"github.com/hashicorp/nomad/helper/testlog"
|
|
"github.com/hashicorp/nomad/helper/useragent"
|
|
"github.com/hashicorp/nomad/helper/uuid"
|
|
"github.com/hashicorp/nomad/nomad/mock"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
structsc "github.com/hashicorp/nomad/nomad/structs/config"
|
|
"github.com/hashicorp/nomad/testutil"
|
|
"github.com/hashicorp/vault/api"
|
|
vaultapi "github.com/hashicorp/vault/api"
|
|
"github.com/shoenig/test/must"
|
|
)
|
|
|
|
const (
|
|
jwtAuthMountPathTest = "jwt_test"
|
|
|
|
jwtAuthConfigTemplate = `
|
|
{
|
|
"jwks_url": "<<.JWKSURL>>",
|
|
"jwt_supported_algs": ["RS256", "EdDSA"],
|
|
"default_role": "nomad-workloads"
|
|
}
|
|
`
|
|
|
|
widVaultPolicyTemplate = `
|
|
path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}/*" {
|
|
capabilities = ["read"]
|
|
}
|
|
|
|
path "secret/data/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_job_id}}" {
|
|
capabilities = ["read"]
|
|
}
|
|
|
|
path "secret/metadata/{{identity.entity.aliases.<<.JWTAuthAccessorID>>.metadata.nomad_namespace}}/*" {
|
|
capabilities = ["list"]
|
|
}
|
|
|
|
path "secret/metadata/*" {
|
|
capabilities = ["list"]
|
|
}
|
|
`
|
|
|
|
widVaultRole = `
|
|
{
|
|
"role_type": "jwt",
|
|
"bound_audiences": "vault.io",
|
|
"user_claim": "/nomad_job_id",
|
|
"user_claim_json_pointer": true,
|
|
"claim_mappings": {
|
|
"nomad_namespace": "nomad_namespace",
|
|
"nomad_job_id": "nomad_job_id"
|
|
},
|
|
"token_ttl": "30m",
|
|
"token_type": "service",
|
|
"token_period": "72h",
|
|
"token_policies": ["nomad-workloads"]
|
|
}
|
|
`
|
|
)
|
|
|
|
func renderVaultTemplate(tmplStr string, data any) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
tmpl, err := template.New("policy").
|
|
Delims("<<", ">>").
|
|
Parse(tmplStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse policy template: %w", err)
|
|
}
|
|
|
|
err = tmpl.Execute(&buf, data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to render policy template: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func setupVaultForWorkloadIdentity(v *testutil.TestVault, jwksURL string) error {
|
|
logical := v.Client.Logical()
|
|
sys := v.Client.Sys()
|
|
ctx := context.Background()
|
|
|
|
// Enable JWT auth method.
|
|
err := sys.EnableAuthWithOptions(jwtAuthMountPathTest, &api.MountInput{
|
|
Type: "jwt",
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to enable JWT auth method: %w", err)
|
|
}
|
|
|
|
secret, err := logical.Read(fmt.Sprintf("sys/auth/%s", jwtAuthMountPathTest))
|
|
jwtAuthAccessor := secret.Data["accessor"].(string)
|
|
|
|
// Write JWT auth method config.
|
|
jwtAuthConfigData := struct {
|
|
JWKSURL string
|
|
}{
|
|
JWKSURL: jwksURL,
|
|
}
|
|
jwtAuthConfig, err := renderVaultTemplate(jwtAuthConfigTemplate, jwtAuthConfigData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = logical.WriteBytesWithContext(ctx, fmt.Sprintf("auth/%s/config", jwtAuthMountPathTest), jwtAuthConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write JWT auth method config: %w", err)
|
|
}
|
|
|
|
// Write Nomad workload policy.
|
|
data := struct {
|
|
JWTAuthAccessorID string
|
|
}{
|
|
JWTAuthAccessorID: jwtAuthAccessor,
|
|
}
|
|
policy, err := renderVaultTemplate(widVaultPolicyTemplate, data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
encoded := base64.StdEncoding.EncodeToString(policy)
|
|
policyReqBody := fmt.Sprintf(`{"policy": "%s"}`, encoded)
|
|
|
|
policyPath := "sys/policies/acl/nomad-workloads"
|
|
_, err = logical.WriteBytesWithContext(ctx, policyPath, []byte(policyReqBody))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write policy: %w", err)
|
|
}
|
|
|
|
// Write Nomad workload role.
|
|
rolePath := fmt.Sprintf("auth/%s/role/nomad-workloads", jwtAuthMountPathTest)
|
|
_, err = logical.WriteBytesWithContext(ctx, rolePath, []byte(widVaultRole))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write role: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
// Create signer and signed identities.
|
|
alloc := mock.MinAlloc()
|
|
task := alloc.Job.TaskGroups[0].Tasks[0]
|
|
task.Identities = []*structs.WorkloadIdentity{
|
|
{
|
|
Name: "vault_default",
|
|
Audience: []string{"vault.io"},
|
|
TTL: time.Second,
|
|
},
|
|
}
|
|
|
|
signer := widmgr.NewMockWIDSigner(task.Identities)
|
|
signedWIDs, err := signer.SignIdentities(1, []*structs.WorkloadIdentityRequest{
|
|
{
|
|
AllocID: alloc.ID,
|
|
WIHandle: structs.WIHandle{
|
|
IdentityName: task.Identities[0].Name,
|
|
WorkloadIdentifier: task.Name,
|
|
WorkloadType: structs.WorkloadTypeTask,
|
|
},
|
|
},
|
|
})
|
|
must.NoError(t, err)
|
|
must.Len(t, 1, signedWIDs)
|
|
|
|
// Setup test JWKS server.
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
out, err := json.Marshal(signer.JSONWebKeySet())
|
|
if err != nil {
|
|
t.Errorf("failed to generate JWKS json response: %v", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
fmt.Fprintln(w, string(out))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
// Start and configure Vault cluster for JWT authentication.
|
|
v := testutil.NewTestVault(t)
|
|
defer v.Stop()
|
|
|
|
err = setupVaultForWorkloadIdentity(v, ts.URL)
|
|
must.NoError(t, err)
|
|
|
|
// Start Vault client.
|
|
logger := testlog.HCLogger(t)
|
|
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
|
|
v.Config.JWTAuthBackendPath = jwtAuthMountPathTest
|
|
|
|
c, err := NewVaultClient(v.Config, logger)
|
|
must.NoError(t, err)
|
|
|
|
c.Start()
|
|
defer c.Stop()
|
|
|
|
// Derive Vault token using signed JWT.
|
|
jwtStr := signedWIDs[0].JWT
|
|
token, renewable, leaseDuration, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
|
|
JWT: jwtStr,
|
|
Namespace: "default",
|
|
})
|
|
must.NoError(t, err)
|
|
must.NotEq(t, "", token)
|
|
must.True(t, renewable)
|
|
must.Eq(t, 72*60*60, leaseDuration) // token_period from role
|
|
|
|
// Verify token has expected properties.
|
|
v.Client.SetToken(token)
|
|
s, err := v.Client.Logical().Read("auth/token/lookup-self")
|
|
must.NoError(t, err)
|
|
|
|
jwt, err := josejwt.ParseSigned(jwtStr)
|
|
must.NoError(t, err)
|
|
|
|
claims := make(map[string]any)
|
|
err = jwt.UnsafeClaimsWithoutVerification(&claims)
|
|
must.NoError(t, err)
|
|
|
|
must.Eq(t, "service", s.Data["type"].(string))
|
|
must.True(t, s.Data["renewable"].(bool))
|
|
must.SliceContains(t, s.Data["policies"].([]any), "nomad-workloads")
|
|
must.MapEq(t, map[string]any{
|
|
"nomad_namespace": claims["nomad_namespace"],
|
|
"nomad_job_id": claims["nomad_job_id"],
|
|
"role": "nomad-workloads",
|
|
}, s.Data["meta"].(map[string]any))
|
|
|
|
// Verify token has the expected permissions.
|
|
pathAllowed := fmt.Sprintf("secret/data/%s/%s/a", claims["nomad_namespace"], claims["nomad_job_id"])
|
|
pathDenied := "secret/data/denied"
|
|
|
|
s, err = v.Client.Logical().Write("sys/capabilities-self", map[string]any{
|
|
"paths": []string{pathAllowed, pathDenied},
|
|
})
|
|
must.NoError(t, err)
|
|
must.Eq(t, []any{"read"}, (s.Data[pathAllowed]).([]any))
|
|
must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any))
|
|
|
|
// Derive Vault token with non-existing role.
|
|
token, _, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
|
|
JWT: jwtStr,
|
|
Role: "test",
|
|
Namespace: "default",
|
|
})
|
|
must.ErrorContains(t, err, `role "test" could not be found`)
|
|
}
|
|
|
|
// TestVaultClient_NamespaceSupport tests that the Vault namespace Config, if
|
|
// present, will result in the namespace header being set on the created Vault
|
|
// client.
|
|
func TestVaultClient_NamespaceSupport(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
tr := true
|
|
testNs := "test-namespace"
|
|
|
|
logger := testlog.HCLogger(t)
|
|
|
|
conf := structsc.DefaultVaultConfig()
|
|
conf.Enabled = &tr
|
|
conf.Namespace = testNs
|
|
c, err := NewVaultClient(conf, logger)
|
|
must.NoError(t, err)
|
|
must.Eq(t, testNs, c.client.Headers().Get(structs.VaultNamespaceHeaderName))
|
|
}
|
|
|
|
func TestVaultClient_Heap(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
tr := true
|
|
conf := structsc.DefaultVaultConfig()
|
|
conf.Enabled = &tr
|
|
|
|
logger := testlog.HCLogger(t)
|
|
c, err := NewVaultClient(conf, logger)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if c == nil {
|
|
t.Fatal("failed to create vault Vault")
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
renewalReq1 := &vaultClientRenewalRequest{
|
|
errCh: make(chan error, 1),
|
|
id: "id1",
|
|
increment: 10,
|
|
}
|
|
if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !c.isTracked("id1") {
|
|
t.Fatalf("id1 should have been tracked")
|
|
}
|
|
|
|
renewalReq2 := &vaultClientRenewalRequest{
|
|
errCh: make(chan error, 1),
|
|
id: "id2",
|
|
increment: 10,
|
|
}
|
|
if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !c.isTracked("id2") {
|
|
t.Fatalf("id2 should have been tracked")
|
|
}
|
|
|
|
renewalReq3 := &vaultClientRenewalRequest{
|
|
errCh: make(chan error, 1),
|
|
id: "id3",
|
|
increment: 10,
|
|
}
|
|
if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !c.isTracked("id3") {
|
|
t.Fatalf("id3 should have been tracked")
|
|
}
|
|
|
|
// Reading elements should yield id2, id1 and id3 in order
|
|
req, _ := c.nextRenewal()
|
|
if req != renewalReq2 {
|
|
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
|
|
}
|
|
if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req, _ = c.nextRenewal()
|
|
if req != renewalReq1 {
|
|
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
|
|
}
|
|
if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req, _ = c.nextRenewal()
|
|
if req != renewalReq3 {
|
|
t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
|
|
}
|
|
if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := c.StopRenewToken("id1"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := c.StopRenewToken("id2"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := c.StopRenewToken("id3"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if c.isTracked("id1") {
|
|
t.Fatalf("id1 should not have been tracked")
|
|
}
|
|
|
|
if c.isTracked("id1") {
|
|
t.Fatalf("id1 should not have been tracked")
|
|
}
|
|
|
|
if c.isTracked("id1") {
|
|
t.Fatalf("id1 should not have been tracked")
|
|
}
|
|
|
|
}
|
|
|
|
// TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal
|
|
// time is jittered.
|
|
func TestVaultClient_RenewalTime_Long(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
// highRoller is a randIntn func that always returns the max value
|
|
highRoller := func(n int) int {
|
|
return n - 1
|
|
}
|
|
|
|
// lowRoller is a randIntn func that always returns the min value (0)
|
|
lowRoller := func(int) int {
|
|
return 0
|
|
}
|
|
|
|
must.Eq(t, 39*time.Second, renewalTime(highRoller, 60))
|
|
must.Eq(t, 20*time.Second, renewalTime(lowRoller, 60))
|
|
|
|
must.Eq(t, 309*time.Second, renewalTime(highRoller, 600))
|
|
must.Eq(t, 290*time.Second, renewalTime(lowRoller, 600))
|
|
|
|
const days3 = 60 * 60 * 24 * 3
|
|
must.Eq(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3))
|
|
must.Eq(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3))
|
|
}
|
|
|
|
// TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal
|
|
// time is lease/2.
|
|
func TestVaultClient_RenewalTime_Short(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
dice := func(int) int {
|
|
t.Error("dice should not have been called")
|
|
panic("unreachable")
|
|
}
|
|
|
|
must.Eq(t, 29*time.Second, renewalTime(dice, 58))
|
|
must.Eq(t, 15*time.Second, renewalTime(dice, 30))
|
|
must.Eq(t, 1*time.Second, renewalTime(dice, 2))
|
|
}
|
|
|
|
func TestVaultClient_SetUserAgent(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
conf := structsc.DefaultVaultConfig()
|
|
conf.Enabled = pointer.Of(true)
|
|
logger := testlog.HCLogger(t)
|
|
c, err := NewVaultClient(conf, logger)
|
|
must.NoError(t, err)
|
|
|
|
ua := c.client.Headers().Get("User-Agent")
|
|
must.Eq(t, useragent.String(), ua)
|
|
}
|
|
|
|
func TestVaultClient_RenewalConcurrent(t *testing.T) {
|
|
ci.Parallel(t)
|
|
|
|
// collects renewal requests that the mock Vault API gets
|
|
requestCh := make(chan string, 10)
|
|
|
|
// Create test server to mock the Vault API.
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
b, _ := io.ReadAll(r.Body)
|
|
requestCh <- string(b)
|
|
|
|
resp := vaultapi.Secret{
|
|
RequestID: uuid.Generate(),
|
|
LeaseID: uuid.Generate(),
|
|
Renewable: true,
|
|
Data: map[string]any{},
|
|
Auth: &vaultapi.SecretAuth{
|
|
ClientToken: uuid.Generate(),
|
|
Accessor: uuid.Generate(),
|
|
LeaseDuration: 1, // force a fast renewal
|
|
},
|
|
}
|
|
|
|
out, err := json.Marshal(resp)
|
|
if err != nil {
|
|
t.Errorf("failed to generate JWKS json response: %v", err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
fmt.Fprintln(w, string(out))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
// Start Vault client.
|
|
conf := structsc.DefaultVaultConfig()
|
|
conf.Addr = ts.URL
|
|
conf.Enabled = pointer.Of(true)
|
|
|
|
vc, err := NewVaultClient(conf, testlog.HCLogger(t))
|
|
must.NoError(t, err)
|
|
vc.Start()
|
|
|
|
// Renew token multiple times in parallel.
|
|
expectedRenewals := 100
|
|
resultCh := make(chan any)
|
|
for range expectedRenewals {
|
|
go func() {
|
|
_, err := vc.RenewToken("token", 30)
|
|
resultCh <- err
|
|
}()
|
|
}
|
|
|
|
// Collect results with timeout.
|
|
timer, stop := helper.NewSafeTimer(3 * time.Second)
|
|
defer stop()
|
|
|
|
sawInitial := 0
|
|
sawRenew := 0
|
|
for {
|
|
select {
|
|
case got := <-requestCh:
|
|
switch got {
|
|
case `{"increment":1}`:
|
|
sawRenew++
|
|
case `{"increment":30}`:
|
|
sawInitial++
|
|
default:
|
|
t.Fatalf("unexpected request body: %q", got)
|
|
}
|
|
if sawInitial == expectedRenewals && sawRenew >= expectedRenewals {
|
|
return
|
|
}
|
|
case got := <-resultCh:
|
|
must.Nil(t, got, must.Sprintf("token renewal error: %v", got))
|
|
case <-timer.C:
|
|
t.Fatalf("timeout waiting for expected token renewals (initial: %d renewed: %d)",
|
|
sawInitial, sawRenew)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestVaultClient_NamespaceReset(t *testing.T) {
|
|
|
|
// Mock Vault API that always returns an error
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
fmt.Fprintln(w, "error")
|
|
}))
|
|
defer ts.Close()
|
|
|
|
conf := structsc.DefaultVaultConfig()
|
|
conf.Addr = ts.URL
|
|
conf.Enabled = pointer.Of(true)
|
|
|
|
for _, ns := range []string{"", "foo"} {
|
|
conf.Namespace = ns
|
|
|
|
vc, err := NewVaultClient(conf, testlog.HCLogger(t))
|
|
must.NoError(t, err)
|
|
vc.Start()
|
|
|
|
_, _, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
|
|
JWT: "bogus",
|
|
Namespace: "bar",
|
|
})
|
|
must.Error(t, err)
|
|
must.Eq(t, ns, vc.client.Namespace(),
|
|
must.Sprintf("expected %q, not %q", ns, vc.client.Namespace()))
|
|
}
|
|
}
|