encrypter: Remove tracking of cancelation for decrypt tasks. (#25795)

New wrapped keys were added to the encrypter and tracked using
their keyID with the context cancelation function. This tracking
was performed primarily so the FSM could load its known key
objects and logs with entries for the same ID superseding existing
decryption tasks. This is a hard to reason about approach and in
theory can cause timing problems in conjunction with the locking.

The new approach still tracks decryption tasks but does not store
the cancelation context. This context is now controlled within a
single function in an attempt to provide a clearer workflow. In
the event two calls for the same key are made in close succession
meaning there is no entry in the keyring for the key yet, all
tasks will be launched. The first-past-the-post will write the
cipher to encrypter state, the second task will complete but not
write the cipher.
This commit is contained in:
James Rasell
2025-05-07 14:35:24 +01:00
committed by GitHub
parent cb09696b1c
commit 296d03d9dd
3 changed files with 397 additions and 78 deletions

3
.changelog/25795.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
encrypter: Refactor startup decryption task handling to avoid timing problems with task addition on FSM restore
```

View File

@@ -58,9 +58,22 @@ type Encrypter struct {
// issuer is the OIDC Issuer to use for workload identities if configured
issuer string
keyring map[string]*cipherSet
decryptTasks map[string]context.CancelFunc
lock sync.RWMutex
// keyring stores the cipher material indexed by the key ID. keyringLock
// must be used in a fine-grained manner to access this to ensure safety and
// provide availability and performance optimizations.
keyring map[string]*cipherSet
keyringLock sync.RWMutex
// decryptTasks tracks the currently running decryption tasks. A task
// represents a key having one or more decryptWrappedKeyTask functions
// running. decryptTasksLock must be used when accessing this map and is
// distinct to keyringLock due to their responsibilities.
//
// The nature and design of the encrypter as well as other Nomad systems
// means we can have more than 1 task attempting to decrypt the same key. In
// this case, we adopt first-past-the-post.
decryptTasks map[string]struct{}
decryptTasksLock sync.RWMutex
}
// cipherSet contains the key material for variable encryption and workload
@@ -86,7 +99,7 @@ func NewEncrypter(srv *Server, keystorePath string) (*Encrypter, error) {
keyring: make(map[string]*cipherSet),
issuer: srv.GetConfig().OIDCIssuer,
providerConfigs: map[string]*structs.KEKProviderConfig{},
decryptTasks: map[string]context.CancelFunc{},
decryptTasks: map[string]struct{}{},
}
providerConfigs, err := getProviderConfigs(srv)
@@ -162,9 +175,9 @@ func (e *Encrypter) loadKeystore() error {
return nil
}
e.lock.RLock()
e.keyringLock.RLock()
_, ok := e.keyring[id]
e.lock.RUnlock()
e.keyringLock.RUnlock()
if ok {
return nil // already loaded this key from another file
}
@@ -193,8 +206,8 @@ func (e *Encrypter) loadKeystore() error {
// IsReady blocks until all decrypt tasks are complete, or the context expires.
func (e *Encrypter) IsReady(ctx context.Context) error {
err := helper.WithBackoffFunc(ctx, time.Millisecond*100, time.Second, func() error {
e.lock.RLock()
defer e.lock.RUnlock()
e.decryptTasksLock.RLock()
defer e.decryptTasksLock.RUnlock()
if len(e.decryptTasks) != 0 {
keyIDs := []string{}
for keyID := range e.decryptTasks {
@@ -364,8 +377,8 @@ func (e *Encrypter) AddUnwrappedKey(rootKey *structs.UnwrappedRootKey, isUpgrade
}
// AddWrappedKey creates decryption tasks for keys we've previously stored in
// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys,
// but it returns an error for ease of testing.
// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys
// and RootKeyMeta. It returns an error for ease of testing.
func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.RootKey) error {
// If the passed root key does not contain any wrapped keys, it has no
@@ -382,35 +395,44 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
logger := e.log.With("key_id", wrappedKeys.KeyID)
e.lock.Lock()
e.keyringLock.Lock()
_, err := e.cipherSetByIDLocked(wrappedKeys.KeyID)
if err == nil {
// key material for each key ID is immutable so nothing to do, but we
// can remove any running decrypt tasks as we no longer need these to finish
// before considering the encrypter ready for this cipher.
if _, err := e.cipherSetByIDLocked(wrappedKeys.KeyID); err == nil {
e.decryptTasksLock.Lock()
delete(e.decryptTasks, wrappedKeys.KeyID)
e.decryptTasksLock.Unlock()
// key material for each key ID is immutable so nothing to do, but we
// can cancel and remove any running decrypt tasks
if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok {
cancel()
delete(e.decryptTasks, wrappedKeys.KeyID)
}
e.lock.Unlock()
e.keyringLock.Unlock()
return nil
}
if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok {
// stop any previous tasks for this same key ID under the assumption
// they're broken or being superseded, but don't remove the CancelFunc
// from the map yet so that other callers don't think we can continue
cancel()
}
e.lock.Unlock()
completeCtx, cancel := context.WithCancel(ctx)
e.keyringLock.Unlock()
var mErr *multierror.Error
// Generate a context and cancel function for this key. It will be used by
// all tasks, so the first-past-the-post cipher winner can cancel the work
// of the other tasks in a timely manner.
completeCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Track the total number of decrypt tasks we have started for this key. It
// helps us cancel the context if we did not successfully launch a task.
decryptTasks := 0
// Use a channel to receive the cipherSet from the decrypter goroutines. It
// allows us to fan-out decryption tasks for HA in Nomad Enterprise.
cipherSetCh := make(chan *cipherSet)
// We will use the key ID to track the decrypt tasks for this key. Doing
// this here means we can do this once per function call.
e.decryptTasksLock.Lock()
e.decryptTasks[wrappedKeys.KeyID] = struct{}{}
e.decryptTasksLock.Unlock()
for _, wrappedKey := range wrappedKeys.WrappedKeys {
providerID := wrappedKey.ProviderID
if providerID == "" {
@@ -434,24 +456,64 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
}
// fan-out decryption tasks for HA in Nomad Enterprise. we can use the
// key whenever any one provider returns a successful decryption
go e.decryptWrappedKeyTask(completeCtx, cancel, wrapper, provider, wrappedKeys.Meta(), wrappedKey)
// key whenever any one provider returns a successful decryption.
go e.decryptWrappedKeyTask(completeCtx, wrapper, wrappedKeys.Meta(), wrappedKey, cipherSetCh)
decryptTasks++
}
e.lock.Lock()
defer e.lock.Unlock()
if err := mErr.ErrorOrNil(); err != nil {
e.decryptTasks[wrappedKeys.KeyID] = cancel
err = mErr.ErrorOrNil()
if err != nil {
// If we have no tasks running, we can log an error for the operator and
// exit.
//
// It is likely any decryption configuration for the key is incorrect
// and follow-up attempts from other Raft/FMS calls for this key will
// also fail. We should not, however, continue with the server startup
// without this key, and therefore we do not delete any added tracking.
if decryptTasks == 0 {
cancel()
logger.Error("root key cannot be decrypted", "error", err)
return err
}
logger.Error("root key cannot be decrypted", "error", err)
return err
// If we have at least one task running, we can log a warning for the
// operator but continue to wait for the other tasks to complete.
logger.Warn("root key cannot be decrypted by some KMS providers",
"error", err)
}
// The routine will now wait until the server tells us to stop or until we
// successfully decrypt the key. The context's cancellation function has a
// deferred call above, so we don't need to worry about it here unless we
// want to preemptively cancel in-flight tasks.
select {
case <-completeCtx.Done():
// In this event, the server is shutting down and the agent process will
// exit. This means the decrypter state will be lost, so there is no
// need to tidy the decryption task state.
return completeCtx.Err()
case generatedCipher := <-cipherSetCh:
// By reaching this point, we have a decrypted cipher. No errors can
// occur from here on, so start by telling any other running work to
// stop.
cancel()
// Write the cipher to the encrypter keyring. We could just overwrite an
// existing entry, but there is no harm in checking this first as we
// need a write lock anyway.
e.keyringLock.Lock()
if _, ok := e.keyring[wrappedKeys.KeyID]; !ok {
e.keyring[wrappedKeys.KeyID] = generatedCipher
}
e.keyringLock.Unlock()
// We can now remove the decrypt task from the tracking map indicating
// this key is ready for use.
e.decryptTasksLock.Lock()
delete(e.decryptTasks, wrappedKeys.KeyID)
e.decryptTasksLock.Unlock()
}
return nil
@@ -459,8 +521,13 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
// decryptWrappedKeyTask attempts to decrypt a wrapped key. It blocks until
// successful or until the context is canceled (another task completes or the
// server shuts down). The error returned is only for testing and diagnostics.
func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.CancelFunc, wrapper kms.Wrapper, provider *structs.KEKProviderConfig, meta *structs.RootKeyMeta, wrappedKey *structs.WrappedKey) error {
// server shuts down) and the resulting cipher will be sent back via the respCh
// channel.
//
// The error returned is only for testing and diagnostics.
func (e *Encrypter) decryptWrappedKeyTask(
ctx context.Context, wrapper kms.Wrapper, meta *structs.RootKeyMeta,
wrappedKey *structs.WrappedKey, respCh chan *cipherSet) error {
var key []byte
var rsaKey []byte
@@ -508,8 +575,10 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca
RSAKey: rsaKey,
}
var generatedCipher *cipherSet
err = helper.WithBackoffFunc(ctx, minBackoff, maxBackoff, func() error {
err := e.addCipher(rootKey)
generatedCipher, err = e.generateCipher(rootKey)
if err != nil {
err := fmt.Errorf("could not add cipher: %w", err)
e.log.Error(err.Error(), "key_id", meta.KeyID)
@@ -521,40 +590,60 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca
return err
}
e.lock.Lock()
defer e.lock.Unlock()
cancel()
delete(e.decryptTasks, meta.KeyID)
// Send the cipher to the response channel or exit if the context is
// canceled.
//
// The context is canceled when the server is shutting down or when another
// task decrypting the same key completes.
select {
case <-ctx.Done():
return ctx.Err()
case respCh <- generatedCipher:
}
return nil
}
// addCipher creates a new cipherSet for the key and stores them in the keyring
func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error {
if rootKey == nil || rootKey.Meta == nil {
return fmt.Errorf("missing metadata")
generatedCipher, err := e.generateCipher(rootKey)
if err != nil {
return err
}
var aead cipher.AEAD
e.keyringLock.Lock()
defer e.keyringLock.Unlock()
e.keyring[rootKey.Meta.KeyID] = generatedCipher
return nil
}
func (e *Encrypter) generateCipher(rootKey *structs.UnwrappedRootKey) (*cipherSet, error) {
if rootKey == nil || rootKey.Meta == nil {
return nil, fmt.Errorf("missing metadata")
}
var aeadCipher cipher.AEAD
switch rootKey.Meta.Algorithm {
case structs.EncryptionAlgorithmAES256GCM:
block, err := aes.NewCipher(rootKey.Key)
if err != nil {
return fmt.Errorf("could not create cipher: %v", err)
return nil, fmt.Errorf("could not create cipher: %v", err)
}
aead, err = cipher.NewGCM(block)
aeadCipher, err = cipher.NewGCM(block)
if err != nil {
return fmt.Errorf("could not create cipher: %v", err)
return nil, fmt.Errorf("could not create cipher: %v", err)
}
default:
return fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm)
return nil, fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm)
}
ed25519Key := ed25519.NewKeyFromSeed(rootKey.Key)
cs := cipherSet{
rootKey: rootKey,
cipher: aead,
cipher: aeadCipher,
eddsaPrivateKey: ed25519Key,
}
@@ -563,17 +652,14 @@ func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error {
if len(rootKey.RSAKey) > 0 {
rsaKey, err := x509.ParsePKCS1PrivateKey(rootKey.RSAKey)
if err != nil {
return fmt.Errorf("error parsing rsa key: %w", err)
return nil, fmt.Errorf("error parsing rsa key: %w", err)
}
cs.rsaPrivateKey = rsaKey
cs.rsaPKCS1PublicKey = x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey)
}
e.lock.Lock()
defer e.lock.Unlock()
e.keyring[rootKey.Meta.KeyID] = &cs
return nil
return &cs, nil
}
// waitForKey retrieves the key material by ID from the keyring, retrying with
@@ -583,8 +669,8 @@ func (e *Encrypter) waitForKey(ctx context.Context, keyID string) (*cipherSet, e
err := helper.WithBackoffFunc(ctx, 50*time.Millisecond, 100*time.Millisecond,
func() error {
e.lock.RLock()
defer e.lock.RUnlock()
e.keyringLock.RLock()
defer e.keyringLock.RUnlock()
var err error
ks, err = e.cipherSetByIDLocked(keyID)
if err != nil {
@@ -612,8 +698,8 @@ func (e *Encrypter) GetActiveKey() (*rsa.PrivateKey, string, error) {
// GetKey retrieves the key material by ID from the keyring.
func (e *Encrypter) GetKey(keyID string) (*structs.UnwrappedRootKey, error) {
e.lock.Lock()
defer e.lock.Unlock()
e.keyringLock.Lock()
defer e.keyringLock.Unlock()
ks, err := e.cipherSetByIDLocked(keyID)
if err != nil {
@@ -659,8 +745,8 @@ func (e *Encrypter) cipherSetByIDLocked(keyID string) (*cipherSet, error) {
// RemoveKey removes a key by ID from the keyring
func (e *Encrypter) RemoveKey(keyID string) error {
e.lock.Lock()
defer e.lock.Unlock()
e.keyringLock.Lock()
defer e.keyringLock.Unlock()
delete(e.keyring, keyID)
return nil
}
@@ -883,8 +969,8 @@ func (e *Encrypter) waitForPublicKey(keyID string) (*structs.KeyringPublicKey, e
// GetPublicKey returns the public signing key for the requested key id or an
// error if the key could not be found.
func (e *Encrypter) GetPublicKey(keyID string) (*structs.KeyringPublicKey, error) {
e.lock.RLock()
defer e.lock.RUnlock()
e.keyringLock.RLock()
defer e.keyringLock.RUnlock()
ks, err := e.cipherSetByIDLocked(keyID)
if err != nil {

View File

@@ -13,6 +13,7 @@ import (
"net/rpc"
"os"
"path/filepath"
"sync"
"testing"
"time"
@@ -20,6 +21,7 @@ import (
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/uuid"
@@ -232,13 +234,13 @@ func TestEncrypter_Restore(t *testing.T) {
}
// Ensure all rotated keys are correct
srv.encrypter.lock.Lock()
srv.encrypter.keyringLock.Lock()
test.MapLen(t, 5, srv.encrypter.keyring)
for _, keyset := range srv.encrypter.keyring {
test.Len(t, 32, keyset.rootKey.Key)
test.Greater(t, 0, len(keyset.rootKey.RSAKey))
}
srv.encrypter.lock.Unlock()
srv.encrypter.keyringLock.Unlock()
shutdown()
@@ -261,13 +263,13 @@ func TestEncrypter_Restore(t *testing.T) {
return len(listResp.Keys) == 5 // 4 new + the bootstrap key
}, time.Second*5, time.Second, "expected keyring to be restored")
srv.encrypter.lock.Lock()
srv.encrypter.keyringLock.Lock()
test.MapLen(t, 5, srv.encrypter.keyring)
for _, keyset := range srv.encrypter.keyring {
test.Len(t, 32, keyset.rootKey.Key)
test.Greater(t, 0, len(keyset.rootKey.RSAKey))
}
srv.encrypter.lock.Unlock()
srv.encrypter.keyringLock.Unlock()
for _, keyMeta := range listResp.Keys {
@@ -836,7 +838,7 @@ func TestEncrypter_TransitConfigFallback(t *testing.T) {
must.Eq(t, expect, providers[2].Config, must.Sprint("expected fallback to env"))
}
func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
func TestEncrypter_AddWrappedKey_zeroDecryptTaskError(t *testing.T) {
ci.Parallel(t)
srv := &Server{
@@ -844,12 +846,148 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
config: &Config{},
}
tmpDir := t.TempDir()
encrypter, err := NewEncrypter(srv, t.TempDir())
must.NoError(t, err)
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
must.NoError(t, err)
encrypter, err := NewEncrypter(srv, tmpDir)
wrappedKey, err := encrypter.wrapRootKey(key, false)
must.NoError(t, err)
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(timeoutCancel)
must.Error(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
must.MapLen(t, 1, encrypter.decryptTasks)
must.MapEmpty(t, encrypter.keyring)
}
func TestEncrypter_AddWrappedKey_sameKeyTwice(t *testing.T) {
ci.Parallel(t)
srv := &Server{
logger: testlog.HCLogger(t),
config: &Config{},
}
encrypter, err := NewEncrypter(srv, t.TempDir())
must.NoError(t, err)
// Create a valid and correctly formatted key and wrap it.
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
must.NoError(t, err)
wrappedKey, err := encrypter.wrapRootKey(key, true)
must.NoError(t, err)
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(timeoutCancel)
// Add the wrapped key to the encrypter and assert that the key is added to
// the keyring and no decryption tasks are queued.
must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
must.MapEmpty(t, encrypter.decryptTasks)
must.NoError(t, encrypter.IsReady(timeoutCtx))
must.MapLen(t, 1, encrypter.keyring)
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), 2*time.Second)
t.Cleanup(timeoutCancel)
// Add the same key again and assert that the key is not added to the
// keyring and no decryption tasks are queued.
must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
must.MapEmpty(t, encrypter.decryptTasks)
must.NoError(t, encrypter.IsReady(timeoutCtx))
must.MapLen(t, 1, encrypter.keyring)
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
}
func TestEncrypter_AddWrappedKey_sameKeyConcurrent(t *testing.T) {
ci.Parallel(t)
srv := &Server{
logger: testlog.HCLogger(t),
config: &Config{},
}
encrypter, err := NewEncrypter(srv, t.TempDir())
must.NoError(t, err)
// Create a valid and correctly formatted key and wrap it.
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
must.NoError(t, err)
wrappedKey, err := encrypter.wrapRootKey(key, true)
must.NoError(t, err)
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(timeoutCancel)
// Define the number of concurrent calls to AddWrappedKey. Changing this
// value should not affect the correctness of the test.
concurrentNum := 10
// Create a channel to receive the responses from the concurrent calls to
// AddWrappedKey. The channel is buffered to ensure that the launched
// routines can send to it without blocking.
respCh := make(chan error, concurrentNum)
// Create a channel to control when the concurrent calls to AddWrappedKey
// are triggered. When the channel is closed, all waiting routines will
// unblock within 0.001 ms of each other.
startCh := make(chan struct{})
// Launch the concurrent calls to AddWrappedKey and wait till they have all
// triggered and responded before moving on. The timeout ensures this test
// won't deadlock or hang indefinitely.
var wg sync.WaitGroup
wg.Add(concurrentNum)
for i := 0; i < concurrentNum; i++ {
go func() {
<-startCh
respCh <- encrypter.AddWrappedKey(timeoutCtx, wrappedKey)
wg.Done()
}()
}
close(startCh)
wg.Wait()
// Gather the responses and ensure the encrypter state is as we expect.
var respNum int
for {
select {
case resp := <-respCh:
must.NoError(t, resp)
if respNum++; respNum == concurrentNum {
must.NoError(t, encrypter.IsReady(timeoutCtx))
must.MapEmpty(t, encrypter.decryptTasks)
must.MapLen(t, 1, encrypter.keyring)
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
return
}
case <-timeoutCtx.Done():
must.NoError(t, timeoutCtx.Err())
}
}
}
func TestEncrypter_decryptWrappedKeyTask_successful(t *testing.T) {
ci.Parallel(t)
srv := &Server{
logger: testlog.HCLogger(t),
config: &Config{},
}
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
must.NoError(t, err)
encrypter, err := NewEncrypter(srv, t.TempDir())
must.NoError(t, err)
wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{})
@@ -868,11 +1006,103 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
must.NoError(t, err)
must.NotNil(t, KMSWrapper)
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = encrypter.decryptWrappedKeyTask(ctx, cancel, KMSWrapper, provider, key.Meta, wrappedKey)
respCh := make(chan *cipherSet)
go encrypter.decryptWrappedKeyTask(ctx, KMSWrapper, key.Meta, wrappedKey, respCh)
select {
case <-ctx.Done():
t.Fatal("timed out waiting for decryptWrappedKeyTask to complete")
case cipherResp := <-respCh:
must.NotNil(t, cipherResp)
}
}
func TestEncrypter_decryptWrappedKeyTask_contextCancel(t *testing.T) {
ci.Parallel(t)
srv := &Server{
logger: testlog.HCLogger(t),
config: &Config{},
}
encrypter, err := NewEncrypter(srv, t.TempDir())
must.NoError(t, err)
// Create a valid and correctly formatted key and wrap it.
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
must.NoError(t, err)
wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{})
must.NotNil(t, wrappedKey)
must.NoError(t, err)
// Prepare the KMS wrapper and the response channel, so we can call
// decryptWrappedKeyTask. Use a buffered channel, so the decrypt task does
// not block on a send.
provider, ok := encrypter.providerConfigs[string(structs.KEKProviderAEAD)]
must.True(t, ok)
must.NotNil(t, provider)
kmsWrapper, err := encrypter.newKMSWrapper(provider, key.Meta.KeyID, wrappedKey.KeyEncryptionKey)
must.NoError(t, err)
must.NotNil(t, kmsWrapper)
respCh := make(chan *cipherSet, 1)
// Generate a context and immediately cancel it.
ctx, cancel := context.WithCancel(context.Background())
cancel()
// Ensure we receive an error indicating we hit the context done case and
// check no cipher response was sent.
err = encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh)
must.ErrorContains(t, err, "operation cancelled")
must.Eq(t, 0, len(respCh))
// Recreate the response channel so that it is no longer buffered. The
// decrypt task should now block on attempting to send to it.
respCh = make(chan *cipherSet)
// Generate a new context and an error channel so we can gather the response
// of decryptWrappedKeyTask running inside a goroutine.
ctx, cancel = context.WithCancel(context.Background())
errorCh := make(chan error, 1)
// Launch the decryptWrappedKeyTask routine.
go func() {
err := encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh)
errorCh <- err
}()
// Roughly ensure the decrypt task is running for enough time to get past
// the cipher generation. This is so that when we cancel the context, we
// have passed the helper.Backoff functions, which are also designed to exit
// and return if the context is canceled. As Tim correctly pointed out; this
// "is about giving this test a fighting chance to be testing the thing we
// think it is".
//
// Canceling the context should cause the routine to exit and send an error
// which we can check to ensure we correctly unblock.
timer, timerStop := helper.NewSafeTimer(500 * time.Millisecond)
defer timerStop()
<-timer.C
cancel()
timer, timerStop = helper.NewSafeTimer(200 * time.Millisecond)
defer timerStop()
select {
case <-timer.C:
t.Fatal("timed out waiting for decryptWrappedKeyTask to send its error")
case err := <-errorCh:
must.ErrorContains(t, err, "context canceled")
}
}
func TestEncrypter_AddWrappedKey_noWrappedKeys(t *testing.T) {