diff --git a/.changelog/25795.txt b/.changelog/25795.txt new file mode 100644 index 000000000..7183b6aa1 --- /dev/null +++ b/.changelog/25795.txt @@ -0,0 +1,3 @@ +```release-note:bug +encrypter: Refactor startup decryption task handling to avoid timing problems with task addition on FSM restore +``` diff --git a/nomad/encrypter.go b/nomad/encrypter.go index a470e4955..4b96f8246 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -58,9 +58,22 @@ type Encrypter struct { // issuer is the OIDC Issuer to use for workload identities if configured issuer string - keyring map[string]*cipherSet - decryptTasks map[string]context.CancelFunc - lock sync.RWMutex + // keyring stores the cipher material indexed by the key ID. keyringLock + // must be used in a fine-grained manner to access this to ensure safety and + // provide availability and performance optimizations. + keyring map[string]*cipherSet + keyringLock sync.RWMutex + + // decryptTasks tracks the currently running decryption tasks. A task + // represents a key having one or more decryptWrappedKeyTask functions + // running. decryptTasksLock must be used when accessing this map and is + // distinct to keyringLock due to their responsibilities. + // + // The nature and design of the encrypter as well as other Nomad systems + // means we can have more than 1 task attempting to decrypt the same key. In + // this case, we adopt first-past-the-post. + decryptTasks map[string]struct{} + decryptTasksLock sync.RWMutex } // cipherSet contains the key material for variable encryption and workload @@ -86,7 +99,7 @@ func NewEncrypter(srv *Server, keystorePath string) (*Encrypter, error) { keyring: make(map[string]*cipherSet), issuer: srv.GetConfig().OIDCIssuer, providerConfigs: map[string]*structs.KEKProviderConfig{}, - decryptTasks: map[string]context.CancelFunc{}, + decryptTasks: map[string]struct{}{}, } providerConfigs, err := getProviderConfigs(srv) @@ -162,9 +175,9 @@ func (e *Encrypter) loadKeystore() error { return nil } - e.lock.RLock() + e.keyringLock.RLock() _, ok := e.keyring[id] - e.lock.RUnlock() + e.keyringLock.RUnlock() if ok { return nil // already loaded this key from another file } @@ -193,8 +206,8 @@ func (e *Encrypter) loadKeystore() error { // IsReady blocks until all decrypt tasks are complete, or the context expires. func (e *Encrypter) IsReady(ctx context.Context) error { err := helper.WithBackoffFunc(ctx, time.Millisecond*100, time.Second, func() error { - e.lock.RLock() - defer e.lock.RUnlock() + e.decryptTasksLock.RLock() + defer e.decryptTasksLock.RUnlock() if len(e.decryptTasks) != 0 { keyIDs := []string{} for keyID := range e.decryptTasks { @@ -364,8 +377,8 @@ func (e *Encrypter) AddUnwrappedKey(rootKey *structs.UnwrappedRootKey, isUpgrade } // AddWrappedKey creates decryption tasks for keys we've previously stored in -// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys, -// but it returns an error for ease of testing. +// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys +// and RootKeyMeta. It returns an error for ease of testing. func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.RootKey) error { // If the passed root key does not contain any wrapped keys, it has no @@ -382,35 +395,44 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root logger := e.log.With("key_id", wrappedKeys.KeyID) - e.lock.Lock() + e.keyringLock.Lock() - _, err := e.cipherSetByIDLocked(wrappedKeys.KeyID) - if err == nil { + // key material for each key ID is immutable so nothing to do, but we + // can remove any running decrypt tasks as we no longer need these to finish + // before considering the encrypter ready for this cipher. + if _, err := e.cipherSetByIDLocked(wrappedKeys.KeyID); err == nil { + e.decryptTasksLock.Lock() + delete(e.decryptTasks, wrappedKeys.KeyID) + e.decryptTasksLock.Unlock() - // key material for each key ID is immutable so nothing to do, but we - // can cancel and remove any running decrypt tasks - if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok { - cancel() - delete(e.decryptTasks, wrappedKeys.KeyID) - } - e.lock.Unlock() + e.keyringLock.Unlock() return nil } - if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok { - // stop any previous tasks for this same key ID under the assumption - // they're broken or being superseded, but don't remove the CancelFunc - // from the map yet so that other callers don't think we can continue - cancel() - } - - e.lock.Unlock() - - completeCtx, cancel := context.WithCancel(ctx) + e.keyringLock.Unlock() var mErr *multierror.Error + // Generate a context and cancel function for this key. It will be used by + // all tasks, so the first-past-the-post cipher winner can cancel the work + // of the other tasks in a timely manner. + completeCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Track the total number of decrypt tasks we have started for this key. It + // helps us cancel the context if we did not successfully launch a task. decryptTasks := 0 + + // Use a channel to receive the cipherSet from the decrypter goroutines. It + // allows us to fan-out decryption tasks for HA in Nomad Enterprise. + cipherSetCh := make(chan *cipherSet) + + // We will use the key ID to track the decrypt tasks for this key. Doing + // this here means we can do this once per function call. + e.decryptTasksLock.Lock() + e.decryptTasks[wrappedKeys.KeyID] = struct{}{} + e.decryptTasksLock.Unlock() + for _, wrappedKey := range wrappedKeys.WrappedKeys { providerID := wrappedKey.ProviderID if providerID == "" { @@ -434,24 +456,64 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root } // fan-out decryption tasks for HA in Nomad Enterprise. we can use the - // key whenever any one provider returns a successful decryption - go e.decryptWrappedKeyTask(completeCtx, cancel, wrapper, provider, wrappedKeys.Meta(), wrappedKey) + // key whenever any one provider returns a successful decryption. + go e.decryptWrappedKeyTask(completeCtx, wrapper, wrappedKeys.Meta(), wrappedKey, cipherSetCh) decryptTasks++ } - e.lock.Lock() - defer e.lock.Unlock() + if err := mErr.ErrorOrNil(); err != nil { - e.decryptTasks[wrappedKeys.KeyID] = cancel - - err = mErr.ErrorOrNil() - if err != nil { + // If we have no tasks running, we can log an error for the operator and + // exit. + // + // It is likely any decryption configuration for the key is incorrect + // and follow-up attempts from other Raft/FMS calls for this key will + // also fail. We should not, however, continue with the server startup + // without this key, and therefore we do not delete any added tracking. if decryptTasks == 0 { - cancel() + logger.Error("root key cannot be decrypted", "error", err) + return err } - logger.Error("root key cannot be decrypted", "error", err) - return err + // If we have at least one task running, we can log a warning for the + // operator but continue to wait for the other tasks to complete. + logger.Warn("root key cannot be decrypted by some KMS providers", + "error", err) + } + + // The routine will now wait until the server tells us to stop or until we + // successfully decrypt the key. The context's cancellation function has a + // deferred call above, so we don't need to worry about it here unless we + // want to preemptively cancel in-flight tasks. + select { + case <-completeCtx.Done(): + + // In this event, the server is shutting down and the agent process will + // exit. This means the decrypter state will be lost, so there is no + // need to tidy the decryption task state. + return completeCtx.Err() + + case generatedCipher := <-cipherSetCh: + + // By reaching this point, we have a decrypted cipher. No errors can + // occur from here on, so start by telling any other running work to + // stop. + cancel() + + // Write the cipher to the encrypter keyring. We could just overwrite an + // existing entry, but there is no harm in checking this first as we + // need a write lock anyway. + e.keyringLock.Lock() + if _, ok := e.keyring[wrappedKeys.KeyID]; !ok { + e.keyring[wrappedKeys.KeyID] = generatedCipher + } + e.keyringLock.Unlock() + + // We can now remove the decrypt task from the tracking map indicating + // this key is ready for use. + e.decryptTasksLock.Lock() + delete(e.decryptTasks, wrappedKeys.KeyID) + e.decryptTasksLock.Unlock() } return nil @@ -459,8 +521,13 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root // decryptWrappedKeyTask attempts to decrypt a wrapped key. It blocks until // successful or until the context is canceled (another task completes or the -// server shuts down). The error returned is only for testing and diagnostics. -func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.CancelFunc, wrapper kms.Wrapper, provider *structs.KEKProviderConfig, meta *structs.RootKeyMeta, wrappedKey *structs.WrappedKey) error { +// server shuts down) and the resulting cipher will be sent back via the respCh +// channel. +// +// The error returned is only for testing and diagnostics. +func (e *Encrypter) decryptWrappedKeyTask( + ctx context.Context, wrapper kms.Wrapper, meta *structs.RootKeyMeta, + wrappedKey *structs.WrappedKey, respCh chan *cipherSet) error { var key []byte var rsaKey []byte @@ -508,8 +575,10 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca RSAKey: rsaKey, } + var generatedCipher *cipherSet + err = helper.WithBackoffFunc(ctx, minBackoff, maxBackoff, func() error { - err := e.addCipher(rootKey) + generatedCipher, err = e.generateCipher(rootKey) if err != nil { err := fmt.Errorf("could not add cipher: %w", err) e.log.Error(err.Error(), "key_id", meta.KeyID) @@ -521,40 +590,60 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca return err } - e.lock.Lock() - defer e.lock.Unlock() - cancel() - delete(e.decryptTasks, meta.KeyID) + // Send the cipher to the response channel or exit if the context is + // canceled. + // + // The context is canceled when the server is shutting down or when another + // task decrypting the same key completes. + select { + case <-ctx.Done(): + return ctx.Err() + case respCh <- generatedCipher: + } + return nil } // addCipher creates a new cipherSet for the key and stores them in the keyring func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error { - if rootKey == nil || rootKey.Meta == nil { - return fmt.Errorf("missing metadata") + generatedCipher, err := e.generateCipher(rootKey) + if err != nil { + return err } - var aead cipher.AEAD + + e.keyringLock.Lock() + defer e.keyringLock.Unlock() + e.keyring[rootKey.Meta.KeyID] = generatedCipher + return nil +} + +func (e *Encrypter) generateCipher(rootKey *structs.UnwrappedRootKey) (*cipherSet, error) { + + if rootKey == nil || rootKey.Meta == nil { + return nil, fmt.Errorf("missing metadata") + } + var aeadCipher cipher.AEAD switch rootKey.Meta.Algorithm { case structs.EncryptionAlgorithmAES256GCM: block, err := aes.NewCipher(rootKey.Key) if err != nil { - return fmt.Errorf("could not create cipher: %v", err) + return nil, fmt.Errorf("could not create cipher: %v", err) } - aead, err = cipher.NewGCM(block) + aeadCipher, err = cipher.NewGCM(block) if err != nil { - return fmt.Errorf("could not create cipher: %v", err) + return nil, fmt.Errorf("could not create cipher: %v", err) } default: - return fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm) + return nil, fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm) } ed25519Key := ed25519.NewKeyFromSeed(rootKey.Key) cs := cipherSet{ rootKey: rootKey, - cipher: aead, + cipher: aeadCipher, eddsaPrivateKey: ed25519Key, } @@ -563,17 +652,14 @@ func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error { if len(rootKey.RSAKey) > 0 { rsaKey, err := x509.ParsePKCS1PrivateKey(rootKey.RSAKey) if err != nil { - return fmt.Errorf("error parsing rsa key: %w", err) + return nil, fmt.Errorf("error parsing rsa key: %w", err) } cs.rsaPrivateKey = rsaKey cs.rsaPKCS1PublicKey = x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey) } - e.lock.Lock() - defer e.lock.Unlock() - e.keyring[rootKey.Meta.KeyID] = &cs - return nil + return &cs, nil } // waitForKey retrieves the key material by ID from the keyring, retrying with @@ -583,8 +669,8 @@ func (e *Encrypter) waitForKey(ctx context.Context, keyID string) (*cipherSet, e err := helper.WithBackoffFunc(ctx, 50*time.Millisecond, 100*time.Millisecond, func() error { - e.lock.RLock() - defer e.lock.RUnlock() + e.keyringLock.RLock() + defer e.keyringLock.RUnlock() var err error ks, err = e.cipherSetByIDLocked(keyID) if err != nil { @@ -612,8 +698,8 @@ func (e *Encrypter) GetActiveKey() (*rsa.PrivateKey, string, error) { // GetKey retrieves the key material by ID from the keyring. func (e *Encrypter) GetKey(keyID string) (*structs.UnwrappedRootKey, error) { - e.lock.Lock() - defer e.lock.Unlock() + e.keyringLock.Lock() + defer e.keyringLock.Unlock() ks, err := e.cipherSetByIDLocked(keyID) if err != nil { @@ -659,8 +745,8 @@ func (e *Encrypter) cipherSetByIDLocked(keyID string) (*cipherSet, error) { // RemoveKey removes a key by ID from the keyring func (e *Encrypter) RemoveKey(keyID string) error { - e.lock.Lock() - defer e.lock.Unlock() + e.keyringLock.Lock() + defer e.keyringLock.Unlock() delete(e.keyring, keyID) return nil } @@ -883,8 +969,8 @@ func (e *Encrypter) waitForPublicKey(keyID string) (*structs.KeyringPublicKey, e // GetPublicKey returns the public signing key for the requested key id or an // error if the key could not be found. func (e *Encrypter) GetPublicKey(keyID string) (*structs.KeyringPublicKey, error) { - e.lock.RLock() - defer e.lock.RUnlock() + e.keyringLock.RLock() + defer e.keyringLock.RUnlock() ks, err := e.cipherSetByIDLocked(keyID) if err != nil { diff --git a/nomad/encrypter_test.go b/nomad/encrypter_test.go index 8b182254a..e313e764b 100644 --- a/nomad/encrypter_test.go +++ b/nomad/encrypter_test.go @@ -13,6 +13,7 @@ import ( "net/rpc" "os" "path/filepath" + "sync" "testing" "time" @@ -20,6 +21,7 @@ import ( wrapping "github.com/hashicorp/go-kms-wrapping/v2" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/helper/uuid" @@ -232,13 +234,13 @@ func TestEncrypter_Restore(t *testing.T) { } // Ensure all rotated keys are correct - srv.encrypter.lock.Lock() + srv.encrypter.keyringLock.Lock() test.MapLen(t, 5, srv.encrypter.keyring) for _, keyset := range srv.encrypter.keyring { test.Len(t, 32, keyset.rootKey.Key) test.Greater(t, 0, len(keyset.rootKey.RSAKey)) } - srv.encrypter.lock.Unlock() + srv.encrypter.keyringLock.Unlock() shutdown() @@ -261,13 +263,13 @@ func TestEncrypter_Restore(t *testing.T) { return len(listResp.Keys) == 5 // 4 new + the bootstrap key }, time.Second*5, time.Second, "expected keyring to be restored") - srv.encrypter.lock.Lock() + srv.encrypter.keyringLock.Lock() test.MapLen(t, 5, srv.encrypter.keyring) for _, keyset := range srv.encrypter.keyring { test.Len(t, 32, keyset.rootKey.Key) test.Greater(t, 0, len(keyset.rootKey.RSAKey)) } - srv.encrypter.lock.Unlock() + srv.encrypter.keyringLock.Unlock() for _, keyMeta := range listResp.Keys { @@ -836,7 +838,7 @@ func TestEncrypter_TransitConfigFallback(t *testing.T) { must.Eq(t, expect, providers[2].Config, must.Sprint("expected fallback to env")) } -func TestEncrypter_decryptWrappedKeyTask(t *testing.T) { +func TestEncrypter_AddWrappedKey_zeroDecryptTaskError(t *testing.T) { ci.Parallel(t) srv := &Server{ @@ -844,12 +846,148 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) { config: &Config{}, } - tmpDir := t.TempDir() + encrypter, err := NewEncrypter(srv, t.TempDir()) + must.NoError(t, err) key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM) must.NoError(t, err) - encrypter, err := NewEncrypter(srv, tmpDir) + wrappedKey, err := encrypter.wrapRootKey(key, false) + must.NoError(t, err) + + timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(timeoutCancel) + + must.Error(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey)) + must.MapLen(t, 1, encrypter.decryptTasks) + must.MapEmpty(t, encrypter.keyring) +} + +func TestEncrypter_AddWrappedKey_sameKeyTwice(t *testing.T) { + ci.Parallel(t) + + srv := &Server{ + logger: testlog.HCLogger(t), + config: &Config{}, + } + + encrypter, err := NewEncrypter(srv, t.TempDir()) + must.NoError(t, err) + + // Create a valid and correctly formatted key and wrap it. + key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM) + must.NoError(t, err) + + wrappedKey, err := encrypter.wrapRootKey(key, true) + must.NoError(t, err) + + timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(timeoutCancel) + + // Add the wrapped key to the encrypter and assert that the key is added to + // the keyring and no decryption tasks are queued. + must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey)) + must.MapEmpty(t, encrypter.decryptTasks) + must.NoError(t, encrypter.IsReady(timeoutCtx)) + must.MapLen(t, 1, encrypter.keyring) + must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID) + + timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(timeoutCancel) + + // Add the same key again and assert that the key is not added to the + // keyring and no decryption tasks are queued. + must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey)) + must.MapEmpty(t, encrypter.decryptTasks) + must.NoError(t, encrypter.IsReady(timeoutCtx)) + must.MapLen(t, 1, encrypter.keyring) + must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID) +} + +func TestEncrypter_AddWrappedKey_sameKeyConcurrent(t *testing.T) { + ci.Parallel(t) + + srv := &Server{ + logger: testlog.HCLogger(t), + config: &Config{}, + } + + encrypter, err := NewEncrypter(srv, t.TempDir()) + must.NoError(t, err) + + // Create a valid and correctly formatted key and wrap it. + key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM) + must.NoError(t, err) + + wrappedKey, err := encrypter.wrapRootKey(key, true) + must.NoError(t, err) + + timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(timeoutCancel) + + // Define the number of concurrent calls to AddWrappedKey. Changing this + // value should not affect the correctness of the test. + concurrentNum := 10 + + // Create a channel to receive the responses from the concurrent calls to + // AddWrappedKey. The channel is buffered to ensure that the launched + // routines can send to it without blocking. + respCh := make(chan error, concurrentNum) + + // Create a channel to control when the concurrent calls to AddWrappedKey + // are triggered. When the channel is closed, all waiting routines will + // unblock within 0.001 ms of each other. + startCh := make(chan struct{}) + + // Launch the concurrent calls to AddWrappedKey and wait till they have all + // triggered and responded before moving on. The timeout ensures this test + // won't deadlock or hang indefinitely. + var wg sync.WaitGroup + wg.Add(concurrentNum) + + for i := 0; i < concurrentNum; i++ { + go func() { + <-startCh + respCh <- encrypter.AddWrappedKey(timeoutCtx, wrappedKey) + wg.Done() + }() + } + + close(startCh) + wg.Wait() + + // Gather the responses and ensure the encrypter state is as we expect. + var respNum int + + for { + select { + case resp := <-respCh: + must.NoError(t, resp) + if respNum++; respNum == concurrentNum { + must.NoError(t, encrypter.IsReady(timeoutCtx)) + must.MapEmpty(t, encrypter.decryptTasks) + must.MapLen(t, 1, encrypter.keyring) + must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID) + return + } + case <-timeoutCtx.Done(): + must.NoError(t, timeoutCtx.Err()) + } + } +} + +func TestEncrypter_decryptWrappedKeyTask_successful(t *testing.T) { + ci.Parallel(t) + + srv := &Server{ + logger: testlog.HCLogger(t), + config: &Config{}, + } + + key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM) + must.NoError(t, err) + + encrypter, err := NewEncrypter(srv, t.TempDir()) must.NoError(t, err) wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{}) @@ -868,11 +1006,103 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) { must.NoError(t, err) must.NotNil(t, KMSWrapper) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - err = encrypter.decryptWrappedKeyTask(ctx, cancel, KMSWrapper, provider, key.Meta, wrappedKey) + respCh := make(chan *cipherSet) + + go encrypter.decryptWrappedKeyTask(ctx, KMSWrapper, key.Meta, wrappedKey, respCh) + + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for decryptWrappedKeyTask to complete") + case cipherResp := <-respCh: + must.NotNil(t, cipherResp) + } +} + +func TestEncrypter_decryptWrappedKeyTask_contextCancel(t *testing.T) { + ci.Parallel(t) + + srv := &Server{ + logger: testlog.HCLogger(t), + config: &Config{}, + } + + encrypter, err := NewEncrypter(srv, t.TempDir()) must.NoError(t, err) + + // Create a valid and correctly formatted key and wrap it. + key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM) + must.NoError(t, err) + + wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{}) + must.NotNil(t, wrappedKey) + must.NoError(t, err) + + // Prepare the KMS wrapper and the response channel, so we can call + // decryptWrappedKeyTask. Use a buffered channel, so the decrypt task does + // not block on a send. + provider, ok := encrypter.providerConfigs[string(structs.KEKProviderAEAD)] + must.True(t, ok) + must.NotNil(t, provider) + + kmsWrapper, err := encrypter.newKMSWrapper(provider, key.Meta.KeyID, wrappedKey.KeyEncryptionKey) + must.NoError(t, err) + must.NotNil(t, kmsWrapper) + + respCh := make(chan *cipherSet, 1) + + // Generate a context and immediately cancel it. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Ensure we receive an error indicating we hit the context done case and + // check no cipher response was sent. + err = encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh) + must.ErrorContains(t, err, "operation cancelled") + must.Eq(t, 0, len(respCh)) + + // Recreate the response channel so that it is no longer buffered. The + // decrypt task should now block on attempting to send to it. + respCh = make(chan *cipherSet) + + // Generate a new context and an error channel so we can gather the response + // of decryptWrappedKeyTask running inside a goroutine. + ctx, cancel = context.WithCancel(context.Background()) + + errorCh := make(chan error, 1) + + // Launch the decryptWrappedKeyTask routine. + go func() { + err := encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh) + errorCh <- err + }() + + // Roughly ensure the decrypt task is running for enough time to get past + // the cipher generation. This is so that when we cancel the context, we + // have passed the helper.Backoff functions, which are also designed to exit + // and return if the context is canceled. As Tim correctly pointed out; this + // "is about giving this test a fighting chance to be testing the thing we + // think it is". + // + // Canceling the context should cause the routine to exit and send an error + // which we can check to ensure we correctly unblock. + timer, timerStop := helper.NewSafeTimer(500 * time.Millisecond) + defer timerStop() + + <-timer.C + cancel() + + timer, timerStop = helper.NewSafeTimer(200 * time.Millisecond) + defer timerStop() + + select { + case <-timer.C: + t.Fatal("timed out waiting for decryptWrappedKeyTask to send its error") + case err := <-errorCh: + must.ErrorContains(t, err, "context canceled") + } } func TestEncrypter_AddWrappedKey_noWrappedKeys(t *testing.T) {