Merge branch 'main' into f-NMD-763-identity

This commit is contained in:
James Rasell
2025-06-24 08:42:33 +01:00
83 changed files with 3203 additions and 2441 deletions

View File

@@ -729,14 +729,35 @@ func (ar *allocRunner) killTasks() map[string]*structs.TaskState {
// run alloc prekill hooks
ar.preKillHooks()
// generate task event for given task runner
taskEventFn := func(tr *taskrunner.TaskRunner) (te *structs.TaskEvent) {
te = structs.NewTaskEvent(structs.TaskKilling).
SetKillTimeout(tr.Task().KillTimeout, ar.clientConfig.MaxKillTimeout)
// if the task is not set failed, the task has not finished,
// the job type is batch, and the allocation is being migrated
// then mark the task as failed. this ensures the task is recreated
// if no eligible nodes are immediately available.
if !tr.TaskState().Failed &&
tr.TaskState().FinishedAt.IsZero() &&
ar.alloc.Job.Type == structs.JobTypeBatch &&
ar.alloc.DesiredTransition.Migrate != nil &&
*ar.alloc.DesiredTransition.Migrate {
ar.logger.Trace("marking migrating batch job task failed on kill", "task_name", tr.Task().Name)
te.SetFailsTask()
}
return
}
// Kill leader first, synchronously
for name, tr := range ar.tasks {
if !tr.IsLeader() {
continue
}
taskEvent := structs.NewTaskEvent(structs.TaskKilling)
taskEvent.SetKillTimeout(tr.Task().KillTimeout, ar.clientConfig.MaxKillTimeout)
taskEvent := taskEventFn(tr)
err := tr.Kill(context.TODO(), taskEvent)
if err != nil && err != taskrunner.ErrTaskNotRunning {
ar.logger.Warn("error stopping leader task", "error", err, "task_name", name)
@@ -758,8 +779,8 @@ func (ar *allocRunner) killTasks() map[string]*structs.TaskState {
wg.Add(1)
go func(name string, tr *taskrunner.TaskRunner) {
defer wg.Done()
taskEvent := structs.NewTaskEvent(structs.TaskKilling)
taskEvent.SetKillTimeout(tr.Task().KillTimeout, ar.clientConfig.MaxKillTimeout)
taskEvent := taskEventFn(tr)
err := tr.Kill(context.TODO(), taskEvent)
if err != nil && err != taskrunner.ErrTaskNotRunning {
ar.logger.Warn("error stopping task", "error", err, "task_name", name)
@@ -782,8 +803,8 @@ func (ar *allocRunner) killTasks() map[string]*structs.TaskState {
wg.Add(1)
go func(name string, tr *taskrunner.TaskRunner) {
defer wg.Done()
taskEvent := structs.NewTaskEvent(structs.TaskKilling)
taskEvent.SetKillTimeout(tr.Task().KillTimeout, ar.clientConfig.MaxKillTimeout)
taskEvent := taskEventFn(tr)
err := tr.Kill(context.TODO(), taskEvent)
if err != nil && err != taskrunner.ErrTaskNotRunning {
ar.logger.Warn("error stopping sidecar task", "error", err, "task_name", name)

View File

@@ -1804,6 +1804,160 @@ func TestAllocRunner_HandlesArtifactFailure(t *testing.T) {
require.True(t, state.TaskStates["bad"].Failed)
}
// Test that alloc runner kills tasks in task group when stopping and
// fails tasks when job is batch job type and migrating
func TestAllocRunner_Migrate_Batch_KillTG(t *testing.T) {
ci.Parallel(t)
alloc := mock.BatchAlloc()
tr := alloc.AllocatedResources.Tasks[alloc.Job.TaskGroups[0].Tasks[0].Name]
alloc.Job.TaskGroups[0].RestartPolicy.Attempts = 0
alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy.Attempts = 0
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config["run_for"] = "10s"
alloc.AllocatedResources.Tasks[task.Name] = tr
task2 := alloc.Job.TaskGroups[0].Tasks[0].Copy()
task2.Name = "task 2"
task2.Driver = "mock_driver"
task2.Config["run_for"] = "1ms"
alloc.Job.TaskGroups[0].Tasks = append(alloc.Job.TaskGroups[0].Tasks, task2)
alloc.AllocatedResources.Tasks[task2.Name] = tr
conf, cleanup := testAllocRunnerConfig(t, alloc)
defer cleanup()
ar, err := NewAllocRunner(conf)
must.NoError(t, err)
defer destroy(ar)
go ar.Run()
upd := conf.StateUpdater.(*MockStateUpdater)
// Wait for running
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}
if last.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusRunning)
}
return true, nil
}, func(err error) {
must.NoError(t, err)
})
// Wait for completed task
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}
if last.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusRunning)
}
// task should not have finished yet, task2 should be finished
if !last.TaskStates[task.Name].FinishedAt.IsZero() {
return false, fmt.Errorf("task should not be finished")
}
if last.TaskStates[task2.Name].FinishedAt.IsZero() {
return false, fmt.Errorf("task should be finished")
}
return true, nil
}, func(err error) {
must.NoError(t, err)
})
update := ar.Alloc().Copy()
migrate := true
update.DesiredTransition.Migrate = &migrate
update.DesiredStatus = structs.AllocDesiredStatusStop
ar.Update(update)
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}
if last.ClientStatus != structs.AllocClientStatusFailed {
return false, fmt.Errorf("got client status %q; want %q", last.ClientStatus, structs.AllocClientStatusFailed)
}
// task should be failed since it was killed, task2 should not
// be failed since it was already completed
if !last.TaskStates[task.Name].Failed {
return false, fmt.Errorf("task should be failed")
}
if last.TaskStates[task2.Name].Failed {
return false, fmt.Errorf("task should not be failed")
}
return true, nil
}, func(err error) {
must.NoError(t, err)
})
}
// Test that alloc runner kills tasks in task group when stopping and
// does not fail tasks when job is batch job type and not migrating
func TestAllocRunner_Batch_KillTG(t *testing.T) {
ci.Parallel(t)
alloc := mock.BatchAlloc()
tr := alloc.AllocatedResources.Tasks[alloc.Job.TaskGroups[0].Tasks[0].Name]
alloc.Job.TaskGroups[0].RestartPolicy.Attempts = 0
alloc.Job.TaskGroups[0].Tasks[0].RestartPolicy.Attempts = 0
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config["run_for"] = "10s"
alloc.AllocatedResources.Tasks[task.Name] = tr
conf, cleanup := testAllocRunnerConfig(t, alloc)
defer cleanup()
ar, err := NewAllocRunner(conf)
must.NoError(t, err)
defer destroy(ar)
go ar.Run()
upd := conf.StateUpdater.(*MockStateUpdater)
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}
if last.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusRunning)
}
return true, nil
}, func(err error) {
must.NoError(t, err)
})
update := ar.Alloc().Copy()
update.DesiredStatus = structs.AllocDesiredStatusStop
ar.Update(update)
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}
if last.ClientStatus != structs.AllocClientStatusComplete {
return false, fmt.Errorf("got client status %q; want %q", last.ClientStatus, structs.AllocClientStatusComplete)
}
return true, nil
}, func(err error) {
must.NoError(t, err)
})
}
// Test that alloc runner kills tasks in task group when another task fails
func TestAllocRunner_TaskFailed_KillTG(t *testing.T) {
ci.Parallel(t)

View File

@@ -35,8 +35,8 @@ func TestTaskRunner_DisableFileForVaultToken_UpgradePath(t *testing.T) {
// Setup a test Vault client.
token := "1234"
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return token, true, nil
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)

View File

@@ -1462,9 +1462,9 @@ func TestTaskRunner_BlockForVaultToken(t *testing.T) {
// Control when we get a Vault token
token := "1234"
waitCh := make(chan struct{})
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
<-waitCh
return token, true, nil
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
@@ -1571,8 +1571,8 @@ func TestTaskRunner_DisableFileForVaultToken(t *testing.T) {
// Setup a test Vault client
token := "1234"
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return token, true, nil
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return token, true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)
@@ -1639,13 +1639,13 @@ func TestTaskRunner_DeriveToken_Retry(t *testing.T) {
// Fail on the first attempt to derive a vault token
token := "1234"
count := 0
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
if count > 0 {
return token, true, nil
return token, true, 30, nil
}
count++
return "", false, structs.NewRecoverableError(fmt.Errorf("want a retry"), true)
return "", false, 0, structs.NewRecoverableError(fmt.Errorf("want a retry"), true)
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)
@@ -1741,8 +1741,8 @@ func TestTaskRunner_DeriveToken_Unrecoverable(t *testing.T) {
must.NoError(t, err)
vc.(*vaultclient.MockVaultClient).SetDeriveTokenWithJWTFn(
func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
return "", false, errors.New("unrecoverable")
func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "", false, 0, errors.New("unrecoverable")
},
)
@@ -2076,9 +2076,9 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) {
// Control when we get a Vault token
waitCh := make(chan struct{}, 1)
defer close(waitCh)
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, error) {
handler := func(ctx context.Context, req vaultclient.JWTLoginRequest) (string, bool, int, error) {
<-waitCh
return "1234", true, nil
return "1234", true, 30, nil
}
vc, err := vaultclient.NewMockVaultClient(structs.VaultDefaultCluster)
must.NoError(t, err)

View File

@@ -238,6 +238,7 @@ func (h *vaultHook) run(token string) {
// updatedToken lets us store state between loops. If true, a new token
// has been retrieved and we need to apply the Vault change mode
var updatedToken bool
leaseDuration := 30
OUTER:
for {
@@ -255,7 +256,7 @@ OUTER:
if token == "" {
// Get a token
var exit bool
token, exit = h.deriveVaultToken()
token, leaseDuration, exit = h.deriveVaultToken()
if exit {
// Exit the manager
return
@@ -289,7 +290,10 @@ OUTER:
//
// If Vault is having availability issues or is overloaded, a large
// number of initial token renews can exacerbate the problem.
renewCh, err := h.client.RenewToken(token, 30)
if leaseDuration == 0 {
leaseDuration = 30
}
renewCh, err := h.client.RenewToken(token, leaseDuration)
// An error returned means the token is not being renewed
if err != nil {
@@ -358,13 +362,17 @@ OUTER:
// deriveVaultToken derives the Vault token using exponential backoffs. It
// returns the Vault token and whether the manager should exit.
func (h *vaultHook) deriveVaultToken() (string, bool) {
func (h *vaultHook) deriveVaultToken() (string, int, bool) {
var attempts uint64
var backoff time.Duration
timer, stopTimer := helper.NewSafeTimer(0)
defer stopTimer()
for {
token, err := h.deriveVaultTokenJWT()
token, lease, err := h.deriveVaultTokenJWT()
if err == nil {
return token, false
return token, lease, false
}
// Check if we can't recover from the error
@@ -374,11 +382,12 @@ func (h *vaultHook) deriveVaultToken() (string, bool) {
structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err)))
return "", true
return "", 0, true
}
// Handle the retry case
backoff = helper.Backoff(vaultBackoffBaseline, vaultBackoffLimit, attempts)
timer.Reset(backoff)
attempts++
h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff)
@@ -386,14 +395,14 @@ func (h *vaultHook) deriveVaultToken() (string, bool) {
// Wait till retrying
select {
case <-h.ctx.Done():
return "", true
case <-time.After(backoff):
return "", 0, true
case <-timer.C:
}
}
}
// deriveVaultTokenJWT returns a Vault ACL token using JWT auth login.
func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
func (h *vaultHook) deriveVaultTokenJWT() (string, int, error) {
// Retrieve signed identity.
signed, err := h.widmgr.Get(structs.WIHandle{
IdentityName: h.widName,
@@ -401,13 +410,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
WorkloadType: structs.WorkloadTypeTask,
})
if err != nil {
return "", structs.NewRecoverableError(
return "", 0, structs.NewRecoverableError(
fmt.Errorf("failed to retrieve signed workload identity: %w", err),
true,
)
}
if signed == nil {
return "", structs.NewRecoverableError(
return "", 0, structs.NewRecoverableError(
errors.New("no signed workload identity available"),
false,
)
@@ -419,13 +428,13 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
}
// Derive Vault token with signed identity.
token, renewable, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
token, renewable, leaseDuration, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
JWT: signed.JWT,
Role: role,
Namespace: h.vaultBlock.Namespace,
})
if err != nil {
return "", structs.WrapRecoverable(
return "", 0, structs.WrapRecoverable(
fmt.Sprintf("failed to derive Vault token for identity %s: %v", h.widName, err),
err,
)
@@ -437,7 +446,7 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
h.allowTokenExpiration = true
}
return token, nil
return token, leaseDuration, nil
}
// writeToken writes the given token to disk

View File

@@ -460,10 +460,10 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Set unrecoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
// Cancel the context to simulate the task being killed.
cancel()
return "", false, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
return "", false, 0, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
})
err := hook.Prestart(ctx, req, &resp)
@@ -509,16 +509,16 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Set recoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "", false, structs.NewRecoverableError(errors.New("recoverable test error"), true)
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "", false, 0, structs.NewRecoverableError(errors.New("recoverable test error"), true)
})
go func() {
// Wait a bit for the first error then fix token renewal.
time.Sleep(time.Second)
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
}()
@@ -555,8 +555,8 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
// Derive predictable token and fail renew request.
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, int, error) {
return "secret", true, 30, nil
})
mockVaultClient.SetRenewTokenError("secret", errors.New("test error"))

View File

@@ -50,8 +50,9 @@ type VaultClient interface {
Stop()
// DeriveTokenWithJWT returns a Vault ACL token using the JWT login
// endpoint, along with whether or not the token is renewable.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, error)
// endpoint, along with whether or not the token is renewable and its lease
// duration.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, int, error)
// RenewToken renews a token with the given increment and adds it to
// the min-heap for periodic renewal.
@@ -237,12 +238,12 @@ func (c *vaultClient) unlockAndUnset() {
}
// DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint.
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) {
if !c.config.IsEnabled() {
return "", false, fmt.Errorf("vault client not enabled")
return "", false, 0, fmt.Errorf("vault client not enabled")
}
if !c.isRunning() {
return "", false, fmt.Errorf("vault client is not running")
return "", false, 0, fmt.Errorf("vault client is not running")
}
c.lock.Lock()
@@ -263,20 +264,20 @@ func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginReques
},
)
if err != nil {
return "", false, fmt.Errorf("failed to login with JWT: %v", err)
return "", false, 0, fmt.Errorf("failed to login with JWT: %v", err)
}
if s == nil {
return "", false, errors.New("JWT login returned an empty secret")
return "", false, 0, errors.New("JWT login returned an empty secret")
}
if s.Auth == nil {
return "", false, errors.New("JWT login did not return a token")
return "", false, 0, errors.New("JWT login did not return a token")
}
for _, w := range s.Warnings {
c.logger.Warn("JWT login warning", "warning", w)
}
return s.Auth.ClientToken, s.Auth.Renewable, nil
return s.Auth.ClientToken, s.Auth.Renewable, s.Auth.LeaseDuration, nil
}
// RenewToken renews the supplied token for a given duration (in seconds) and
@@ -368,6 +369,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
} else {
// Don't set this if renewal fails
leaseDuration = renewResp.Auth.LeaseDuration
req.increment = leaseDuration
}
// Reset the token in the API client before returning

View File

@@ -9,6 +9,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
@@ -218,13 +219,14 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
// Derive Vault token using signed JWT.
jwtStr := signedWIDs[0].JWT
token, renewable, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, renewable, leaseDuration, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Namespace: "default",
})
must.NoError(t, err)
must.NotEq(t, "", token)
must.True(t, renewable)
must.Eq(t, 72*60*60, leaseDuration) // token_period from role
// Verify token has expected properties.
v.Client.SetToken(token)
@@ -259,7 +261,7 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any))
// Derive Vault token with non-existing role.
token, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, _, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Role: "test",
Namespace: "default",
@@ -448,8 +450,14 @@ func TestVaultClient_SetUserAgent(t *testing.T) {
func TestVaultClient_RenewalConcurrent(t *testing.T) {
ci.Parallel(t)
// collects renewal requests that the mock Vault API gets
requestCh := make(chan string, 10)
// Create test server to mock the Vault API.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := io.ReadAll(r.Body)
requestCh <- string(b)
resp := vaultapi.Secret{
RequestID: uuid.Generate(),
LeaseID: uuid.Generate(),
@@ -458,7 +466,7 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
Auth: &vaultapi.SecretAuth{
ClientToken: uuid.Generate(),
Accessor: uuid.Generate(),
LeaseDuration: 300,
LeaseDuration: 1, // force a fast renewal
},
}
@@ -482,9 +490,9 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
vc.Start()
// Renew token multiple times in parallel.
requests := 100
expectedRenewals := 100
resultCh := make(chan any)
for i := 0; i < requests; i++ {
for range expectedRenewals {
go func() {
_, err := vc.RenewToken("token", 30)
resultCh <- err
@@ -494,12 +502,28 @@ func TestVaultClient_RenewalConcurrent(t *testing.T) {
// Collect results with timeout.
timer, stop := helper.NewSafeTimer(3 * time.Second)
defer stop()
for i := 0; i < requests; i++ {
sawInitial := 0
sawRenew := 0
for {
select {
case got := <-requestCh:
switch got {
case `{"increment":1}`:
sawRenew++
case `{"increment":30}`:
sawInitial++
default:
t.Fatalf("unexpected request body: %q", got)
}
if sawInitial == expectedRenewals && sawRenew >= expectedRenewals {
return
}
case got := <-resultCh:
must.Nil(t, got, must.Sprintf("token renewal error: %v", got))
case <-timer.C:
t.Fatal("timeout waiting for token renewal")
t.Fatalf("timeout waiting for expected token renewals (initial: %d renewed: %d)",
sawInitial, sawRenew)
}
}
}
@@ -524,7 +548,7 @@ func TestVaultClient_NamespaceReset(t *testing.T) {
must.NoError(t, err)
vc.Start()
_, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
_, _, _, err = vc.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: "bogus",
Namespace: "bar",
})

View File

@@ -35,20 +35,22 @@ type MockVaultClient struct {
// deriveTokenWithJWTFn allows the caller to control the DeriveTokenWithJWT
// function.
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, error)
deriveTokenWithJWTFn func(context.Context, JWTLoginRequest) (string, bool, int, error)
// renewable determines if the tokens returned should be marked as renewable
renewable bool
duration int
mu sync.Mutex
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient(_ string) (VaultClient, error) {
return &MockVaultClient{renewable: true}, nil
return &MockVaultClient{renewable: true, duration: 30}, nil
}
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, int, error) {
vc.mu.Lock()
defer vc.mu.Unlock()
@@ -65,7 +67,7 @@ func (vc *MockVaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginR
token = fmt.Sprintf("%s-%s", token, req.Role)
}
vc.jwtTokens[req.JWT] = token
return token, vc.renewable, nil
return token, vc.renewable, vc.duration, nil
}
func (vc *MockVaultClient) SetDeriveTokenError(allocID string, tasks []string, err error) {
@@ -161,7 +163,7 @@ func (vc *MockVaultClient) RenewTokenErrCh(token string) chan error {
}
// SetDeriveTokenWithJWTFn sets the function used to derive tokens using JWT.
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, error)) {
func (vc *MockVaultClient) SetDeriveTokenWithJWTFn(f func(context.Context, JWTLoginRequest) (string, bool, int, error)) {
vc.mu.Lock()
defer vc.mu.Unlock()
vc.deriveTokenWithJWTFn = f