mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
encrypter: Remove tracking of cancelation for decrypt tasks. (#25795)
New wrapped keys were added to the encrypter and tracked using their keyID with the context cancelation function. This tracking was performed primarily so the FSM could load its known key objects and logs with entries for the same ID superseding existing decryption tasks. This is a hard to reason about approach and in theory can cause timing problems in conjunction with the locking. The new approach still tracks decryption tasks but does not store the cancelation context. This context is now controlled within a single function in an attempt to provide a clearer workflow. In the event two calls for the same key are made in close succession meaning there is no entry in the keyring for the key yet, all tasks will be launched. The first-past-the-post will write the cipher to encrypter state, the second task will complete but not write the cipher.
This commit is contained in:
3
.changelog/25795.txt
Normal file
3
.changelog/25795.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:bug
|
||||
encrypter: Refactor startup decryption task handling to avoid timing problems with task addition on FSM restore
|
||||
```
|
||||
@@ -58,9 +58,22 @@ type Encrypter struct {
|
||||
// issuer is the OIDC Issuer to use for workload identities if configured
|
||||
issuer string
|
||||
|
||||
keyring map[string]*cipherSet
|
||||
decryptTasks map[string]context.CancelFunc
|
||||
lock sync.RWMutex
|
||||
// keyring stores the cipher material indexed by the key ID. keyringLock
|
||||
// must be used in a fine-grained manner to access this to ensure safety and
|
||||
// provide availability and performance optimizations.
|
||||
keyring map[string]*cipherSet
|
||||
keyringLock sync.RWMutex
|
||||
|
||||
// decryptTasks tracks the currently running decryption tasks. A task
|
||||
// represents a key having one or more decryptWrappedKeyTask functions
|
||||
// running. decryptTasksLock must be used when accessing this map and is
|
||||
// distinct to keyringLock due to their responsibilities.
|
||||
//
|
||||
// The nature and design of the encrypter as well as other Nomad systems
|
||||
// means we can have more than 1 task attempting to decrypt the same key. In
|
||||
// this case, we adopt first-past-the-post.
|
||||
decryptTasks map[string]struct{}
|
||||
decryptTasksLock sync.RWMutex
|
||||
}
|
||||
|
||||
// cipherSet contains the key material for variable encryption and workload
|
||||
@@ -86,7 +99,7 @@ func NewEncrypter(srv *Server, keystorePath string) (*Encrypter, error) {
|
||||
keyring: make(map[string]*cipherSet),
|
||||
issuer: srv.GetConfig().OIDCIssuer,
|
||||
providerConfigs: map[string]*structs.KEKProviderConfig{},
|
||||
decryptTasks: map[string]context.CancelFunc{},
|
||||
decryptTasks: map[string]struct{}{},
|
||||
}
|
||||
|
||||
providerConfigs, err := getProviderConfigs(srv)
|
||||
@@ -162,9 +175,9 @@ func (e *Encrypter) loadKeystore() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
e.lock.RLock()
|
||||
e.keyringLock.RLock()
|
||||
_, ok := e.keyring[id]
|
||||
e.lock.RUnlock()
|
||||
e.keyringLock.RUnlock()
|
||||
if ok {
|
||||
return nil // already loaded this key from another file
|
||||
}
|
||||
@@ -193,8 +206,8 @@ func (e *Encrypter) loadKeystore() error {
|
||||
// IsReady blocks until all decrypt tasks are complete, or the context expires.
|
||||
func (e *Encrypter) IsReady(ctx context.Context) error {
|
||||
err := helper.WithBackoffFunc(ctx, time.Millisecond*100, time.Second, func() error {
|
||||
e.lock.RLock()
|
||||
defer e.lock.RUnlock()
|
||||
e.decryptTasksLock.RLock()
|
||||
defer e.decryptTasksLock.RUnlock()
|
||||
if len(e.decryptTasks) != 0 {
|
||||
keyIDs := []string{}
|
||||
for keyID := range e.decryptTasks {
|
||||
@@ -364,8 +377,8 @@ func (e *Encrypter) AddUnwrappedKey(rootKey *structs.UnwrappedRootKey, isUpgrade
|
||||
}
|
||||
|
||||
// AddWrappedKey creates decryption tasks for keys we've previously stored in
|
||||
// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys,
|
||||
// but it returns an error for ease of testing.
|
||||
// Raft. It's only called as a goroutine by the FSM Apply for WrappedRootKeys
|
||||
// and RootKeyMeta. It returns an error for ease of testing.
|
||||
func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.RootKey) error {
|
||||
|
||||
// If the passed root key does not contain any wrapped keys, it has no
|
||||
@@ -382,35 +395,44 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
|
||||
|
||||
logger := e.log.With("key_id", wrappedKeys.KeyID)
|
||||
|
||||
e.lock.Lock()
|
||||
e.keyringLock.Lock()
|
||||
|
||||
_, err := e.cipherSetByIDLocked(wrappedKeys.KeyID)
|
||||
if err == nil {
|
||||
// key material for each key ID is immutable so nothing to do, but we
|
||||
// can remove any running decrypt tasks as we no longer need these to finish
|
||||
// before considering the encrypter ready for this cipher.
|
||||
if _, err := e.cipherSetByIDLocked(wrappedKeys.KeyID); err == nil {
|
||||
e.decryptTasksLock.Lock()
|
||||
delete(e.decryptTasks, wrappedKeys.KeyID)
|
||||
e.decryptTasksLock.Unlock()
|
||||
|
||||
// key material for each key ID is immutable so nothing to do, but we
|
||||
// can cancel and remove any running decrypt tasks
|
||||
if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok {
|
||||
cancel()
|
||||
delete(e.decryptTasks, wrappedKeys.KeyID)
|
||||
}
|
||||
e.lock.Unlock()
|
||||
e.keyringLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if cancel, ok := e.decryptTasks[wrappedKeys.KeyID]; ok {
|
||||
// stop any previous tasks for this same key ID under the assumption
|
||||
// they're broken or being superseded, but don't remove the CancelFunc
|
||||
// from the map yet so that other callers don't think we can continue
|
||||
cancel()
|
||||
}
|
||||
|
||||
e.lock.Unlock()
|
||||
|
||||
completeCtx, cancel := context.WithCancel(ctx)
|
||||
e.keyringLock.Unlock()
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
// Generate a context and cancel function for this key. It will be used by
|
||||
// all tasks, so the first-past-the-post cipher winner can cancel the work
|
||||
// of the other tasks in a timely manner.
|
||||
completeCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Track the total number of decrypt tasks we have started for this key. It
|
||||
// helps us cancel the context if we did not successfully launch a task.
|
||||
decryptTasks := 0
|
||||
|
||||
// Use a channel to receive the cipherSet from the decrypter goroutines. It
|
||||
// allows us to fan-out decryption tasks for HA in Nomad Enterprise.
|
||||
cipherSetCh := make(chan *cipherSet)
|
||||
|
||||
// We will use the key ID to track the decrypt tasks for this key. Doing
|
||||
// this here means we can do this once per function call.
|
||||
e.decryptTasksLock.Lock()
|
||||
e.decryptTasks[wrappedKeys.KeyID] = struct{}{}
|
||||
e.decryptTasksLock.Unlock()
|
||||
|
||||
for _, wrappedKey := range wrappedKeys.WrappedKeys {
|
||||
providerID := wrappedKey.ProviderID
|
||||
if providerID == "" {
|
||||
@@ -434,24 +456,64 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
|
||||
}
|
||||
|
||||
// fan-out decryption tasks for HA in Nomad Enterprise. we can use the
|
||||
// key whenever any one provider returns a successful decryption
|
||||
go e.decryptWrappedKeyTask(completeCtx, cancel, wrapper, provider, wrappedKeys.Meta(), wrappedKey)
|
||||
// key whenever any one provider returns a successful decryption.
|
||||
go e.decryptWrappedKeyTask(completeCtx, wrapper, wrappedKeys.Meta(), wrappedKey, cipherSetCh)
|
||||
decryptTasks++
|
||||
}
|
||||
|
||||
e.lock.Lock()
|
||||
defer e.lock.Unlock()
|
||||
if err := mErr.ErrorOrNil(); err != nil {
|
||||
|
||||
e.decryptTasks[wrappedKeys.KeyID] = cancel
|
||||
|
||||
err = mErr.ErrorOrNil()
|
||||
if err != nil {
|
||||
// If we have no tasks running, we can log an error for the operator and
|
||||
// exit.
|
||||
//
|
||||
// It is likely any decryption configuration for the key is incorrect
|
||||
// and follow-up attempts from other Raft/FMS calls for this key will
|
||||
// also fail. We should not, however, continue with the server startup
|
||||
// without this key, and therefore we do not delete any added tracking.
|
||||
if decryptTasks == 0 {
|
||||
cancel()
|
||||
logger.Error("root key cannot be decrypted", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Error("root key cannot be decrypted", "error", err)
|
||||
return err
|
||||
// If we have at least one task running, we can log a warning for the
|
||||
// operator but continue to wait for the other tasks to complete.
|
||||
logger.Warn("root key cannot be decrypted by some KMS providers",
|
||||
"error", err)
|
||||
}
|
||||
|
||||
// The routine will now wait until the server tells us to stop or until we
|
||||
// successfully decrypt the key. The context's cancellation function has a
|
||||
// deferred call above, so we don't need to worry about it here unless we
|
||||
// want to preemptively cancel in-flight tasks.
|
||||
select {
|
||||
case <-completeCtx.Done():
|
||||
|
||||
// In this event, the server is shutting down and the agent process will
|
||||
// exit. This means the decrypter state will be lost, so there is no
|
||||
// need to tidy the decryption task state.
|
||||
return completeCtx.Err()
|
||||
|
||||
case generatedCipher := <-cipherSetCh:
|
||||
|
||||
// By reaching this point, we have a decrypted cipher. No errors can
|
||||
// occur from here on, so start by telling any other running work to
|
||||
// stop.
|
||||
cancel()
|
||||
|
||||
// Write the cipher to the encrypter keyring. We could just overwrite an
|
||||
// existing entry, but there is no harm in checking this first as we
|
||||
// need a write lock anyway.
|
||||
e.keyringLock.Lock()
|
||||
if _, ok := e.keyring[wrappedKeys.KeyID]; !ok {
|
||||
e.keyring[wrappedKeys.KeyID] = generatedCipher
|
||||
}
|
||||
e.keyringLock.Unlock()
|
||||
|
||||
// We can now remove the decrypt task from the tracking map indicating
|
||||
// this key is ready for use.
|
||||
e.decryptTasksLock.Lock()
|
||||
delete(e.decryptTasks, wrappedKeys.KeyID)
|
||||
e.decryptTasksLock.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -459,8 +521,13 @@ func (e *Encrypter) AddWrappedKey(ctx context.Context, wrappedKeys *structs.Root
|
||||
|
||||
// decryptWrappedKeyTask attempts to decrypt a wrapped key. It blocks until
|
||||
// successful or until the context is canceled (another task completes or the
|
||||
// server shuts down). The error returned is only for testing and diagnostics.
|
||||
func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.CancelFunc, wrapper kms.Wrapper, provider *structs.KEKProviderConfig, meta *structs.RootKeyMeta, wrappedKey *structs.WrappedKey) error {
|
||||
// server shuts down) and the resulting cipher will be sent back via the respCh
|
||||
// channel.
|
||||
//
|
||||
// The error returned is only for testing and diagnostics.
|
||||
func (e *Encrypter) decryptWrappedKeyTask(
|
||||
ctx context.Context, wrapper kms.Wrapper, meta *structs.RootKeyMeta,
|
||||
wrappedKey *structs.WrappedKey, respCh chan *cipherSet) error {
|
||||
|
||||
var key []byte
|
||||
var rsaKey []byte
|
||||
@@ -508,8 +575,10 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca
|
||||
RSAKey: rsaKey,
|
||||
}
|
||||
|
||||
var generatedCipher *cipherSet
|
||||
|
||||
err = helper.WithBackoffFunc(ctx, minBackoff, maxBackoff, func() error {
|
||||
err := e.addCipher(rootKey)
|
||||
generatedCipher, err = e.generateCipher(rootKey)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("could not add cipher: %w", err)
|
||||
e.log.Error(err.Error(), "key_id", meta.KeyID)
|
||||
@@ -521,40 +590,60 @@ func (e *Encrypter) decryptWrappedKeyTask(ctx context.Context, cancel context.Ca
|
||||
return err
|
||||
}
|
||||
|
||||
e.lock.Lock()
|
||||
defer e.lock.Unlock()
|
||||
cancel()
|
||||
delete(e.decryptTasks, meta.KeyID)
|
||||
// Send the cipher to the response channel or exit if the context is
|
||||
// canceled.
|
||||
//
|
||||
// The context is canceled when the server is shutting down or when another
|
||||
// task decrypting the same key completes.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case respCh <- generatedCipher:
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addCipher creates a new cipherSet for the key and stores them in the keyring
|
||||
func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error {
|
||||
|
||||
if rootKey == nil || rootKey.Meta == nil {
|
||||
return fmt.Errorf("missing metadata")
|
||||
generatedCipher, err := e.generateCipher(rootKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var aead cipher.AEAD
|
||||
|
||||
e.keyringLock.Lock()
|
||||
defer e.keyringLock.Unlock()
|
||||
e.keyring[rootKey.Meta.KeyID] = generatedCipher
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Encrypter) generateCipher(rootKey *structs.UnwrappedRootKey) (*cipherSet, error) {
|
||||
|
||||
if rootKey == nil || rootKey.Meta == nil {
|
||||
return nil, fmt.Errorf("missing metadata")
|
||||
}
|
||||
var aeadCipher cipher.AEAD
|
||||
|
||||
switch rootKey.Meta.Algorithm {
|
||||
case structs.EncryptionAlgorithmAES256GCM:
|
||||
block, err := aes.NewCipher(rootKey.Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create cipher: %v", err)
|
||||
return nil, fmt.Errorf("could not create cipher: %v", err)
|
||||
}
|
||||
aead, err = cipher.NewGCM(block)
|
||||
aeadCipher, err = cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create cipher: %v", err)
|
||||
return nil, fmt.Errorf("could not create cipher: %v", err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm)
|
||||
return nil, fmt.Errorf("invalid algorithm %s", rootKey.Meta.Algorithm)
|
||||
}
|
||||
|
||||
ed25519Key := ed25519.NewKeyFromSeed(rootKey.Key)
|
||||
|
||||
cs := cipherSet{
|
||||
rootKey: rootKey,
|
||||
cipher: aead,
|
||||
cipher: aeadCipher,
|
||||
eddsaPrivateKey: ed25519Key,
|
||||
}
|
||||
|
||||
@@ -563,17 +652,14 @@ func (e *Encrypter) addCipher(rootKey *structs.UnwrappedRootKey) error {
|
||||
if len(rootKey.RSAKey) > 0 {
|
||||
rsaKey, err := x509.ParsePKCS1PrivateKey(rootKey.RSAKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing rsa key: %w", err)
|
||||
return nil, fmt.Errorf("error parsing rsa key: %w", err)
|
||||
}
|
||||
|
||||
cs.rsaPrivateKey = rsaKey
|
||||
cs.rsaPKCS1PublicKey = x509.MarshalPKCS1PublicKey(&rsaKey.PublicKey)
|
||||
}
|
||||
|
||||
e.lock.Lock()
|
||||
defer e.lock.Unlock()
|
||||
e.keyring[rootKey.Meta.KeyID] = &cs
|
||||
return nil
|
||||
return &cs, nil
|
||||
}
|
||||
|
||||
// waitForKey retrieves the key material by ID from the keyring, retrying with
|
||||
@@ -583,8 +669,8 @@ func (e *Encrypter) waitForKey(ctx context.Context, keyID string) (*cipherSet, e
|
||||
|
||||
err := helper.WithBackoffFunc(ctx, 50*time.Millisecond, 100*time.Millisecond,
|
||||
func() error {
|
||||
e.lock.RLock()
|
||||
defer e.lock.RUnlock()
|
||||
e.keyringLock.RLock()
|
||||
defer e.keyringLock.RUnlock()
|
||||
var err error
|
||||
ks, err = e.cipherSetByIDLocked(keyID)
|
||||
if err != nil {
|
||||
@@ -612,8 +698,8 @@ func (e *Encrypter) GetActiveKey() (*rsa.PrivateKey, string, error) {
|
||||
|
||||
// GetKey retrieves the key material by ID from the keyring.
|
||||
func (e *Encrypter) GetKey(keyID string) (*structs.UnwrappedRootKey, error) {
|
||||
e.lock.Lock()
|
||||
defer e.lock.Unlock()
|
||||
e.keyringLock.Lock()
|
||||
defer e.keyringLock.Unlock()
|
||||
|
||||
ks, err := e.cipherSetByIDLocked(keyID)
|
||||
if err != nil {
|
||||
@@ -659,8 +745,8 @@ func (e *Encrypter) cipherSetByIDLocked(keyID string) (*cipherSet, error) {
|
||||
|
||||
// RemoveKey removes a key by ID from the keyring
|
||||
func (e *Encrypter) RemoveKey(keyID string) error {
|
||||
e.lock.Lock()
|
||||
defer e.lock.Unlock()
|
||||
e.keyringLock.Lock()
|
||||
defer e.keyringLock.Unlock()
|
||||
delete(e.keyring, keyID)
|
||||
return nil
|
||||
}
|
||||
@@ -883,8 +969,8 @@ func (e *Encrypter) waitForPublicKey(keyID string) (*structs.KeyringPublicKey, e
|
||||
// GetPublicKey returns the public signing key for the requested key id or an
|
||||
// error if the key could not be found.
|
||||
func (e *Encrypter) GetPublicKey(keyID string) (*structs.KeyringPublicKey, error) {
|
||||
e.lock.RLock()
|
||||
defer e.lock.RUnlock()
|
||||
e.keyringLock.RLock()
|
||||
defer e.keyringLock.RUnlock()
|
||||
|
||||
ks, err := e.cipherSetByIDLocked(keyID)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"net/rpc"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
|
||||
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
|
||||
"github.com/hashicorp/nomad/ci"
|
||||
"github.com/hashicorp/nomad/helper"
|
||||
"github.com/hashicorp/nomad/helper/pointer"
|
||||
"github.com/hashicorp/nomad/helper/testlog"
|
||||
"github.com/hashicorp/nomad/helper/uuid"
|
||||
@@ -232,13 +234,13 @@ func TestEncrypter_Restore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Ensure all rotated keys are correct
|
||||
srv.encrypter.lock.Lock()
|
||||
srv.encrypter.keyringLock.Lock()
|
||||
test.MapLen(t, 5, srv.encrypter.keyring)
|
||||
for _, keyset := range srv.encrypter.keyring {
|
||||
test.Len(t, 32, keyset.rootKey.Key)
|
||||
test.Greater(t, 0, len(keyset.rootKey.RSAKey))
|
||||
}
|
||||
srv.encrypter.lock.Unlock()
|
||||
srv.encrypter.keyringLock.Unlock()
|
||||
|
||||
shutdown()
|
||||
|
||||
@@ -261,13 +263,13 @@ func TestEncrypter_Restore(t *testing.T) {
|
||||
return len(listResp.Keys) == 5 // 4 new + the bootstrap key
|
||||
}, time.Second*5, time.Second, "expected keyring to be restored")
|
||||
|
||||
srv.encrypter.lock.Lock()
|
||||
srv.encrypter.keyringLock.Lock()
|
||||
test.MapLen(t, 5, srv.encrypter.keyring)
|
||||
for _, keyset := range srv.encrypter.keyring {
|
||||
test.Len(t, 32, keyset.rootKey.Key)
|
||||
test.Greater(t, 0, len(keyset.rootKey.RSAKey))
|
||||
}
|
||||
srv.encrypter.lock.Unlock()
|
||||
srv.encrypter.keyringLock.Unlock()
|
||||
|
||||
for _, keyMeta := range listResp.Keys {
|
||||
|
||||
@@ -836,7 +838,7 @@ func TestEncrypter_TransitConfigFallback(t *testing.T) {
|
||||
must.Eq(t, expect, providers[2].Config, must.Sprint("expected fallback to env"))
|
||||
}
|
||||
|
||||
func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
|
||||
func TestEncrypter_AddWrappedKey_zeroDecryptTaskError(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
srv := &Server{
|
||||
@@ -844,12 +846,148 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
|
||||
config: &Config{},
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
encrypter, err := NewEncrypter(srv, t.TempDir())
|
||||
must.NoError(t, err)
|
||||
|
||||
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
|
||||
must.NoError(t, err)
|
||||
|
||||
encrypter, err := NewEncrypter(srv, tmpDir)
|
||||
wrappedKey, err := encrypter.wrapRootKey(key, false)
|
||||
must.NoError(t, err)
|
||||
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
t.Cleanup(timeoutCancel)
|
||||
|
||||
must.Error(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
|
||||
must.MapLen(t, 1, encrypter.decryptTasks)
|
||||
must.MapEmpty(t, encrypter.keyring)
|
||||
}
|
||||
|
||||
func TestEncrypter_AddWrappedKey_sameKeyTwice(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
srv := &Server{
|
||||
logger: testlog.HCLogger(t),
|
||||
config: &Config{},
|
||||
}
|
||||
|
||||
encrypter, err := NewEncrypter(srv, t.TempDir())
|
||||
must.NoError(t, err)
|
||||
|
||||
// Create a valid and correctly formatted key and wrap it.
|
||||
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
|
||||
must.NoError(t, err)
|
||||
|
||||
wrappedKey, err := encrypter.wrapRootKey(key, true)
|
||||
must.NoError(t, err)
|
||||
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
t.Cleanup(timeoutCancel)
|
||||
|
||||
// Add the wrapped key to the encrypter and assert that the key is added to
|
||||
// the keyring and no decryption tasks are queued.
|
||||
must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
|
||||
must.MapEmpty(t, encrypter.decryptTasks)
|
||||
must.NoError(t, encrypter.IsReady(timeoutCtx))
|
||||
must.MapLen(t, 1, encrypter.keyring)
|
||||
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
|
||||
|
||||
timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
t.Cleanup(timeoutCancel)
|
||||
|
||||
// Add the same key again and assert that the key is not added to the
|
||||
// keyring and no decryption tasks are queued.
|
||||
must.NoError(t, encrypter.AddWrappedKey(timeoutCtx, wrappedKey))
|
||||
must.MapEmpty(t, encrypter.decryptTasks)
|
||||
must.NoError(t, encrypter.IsReady(timeoutCtx))
|
||||
must.MapLen(t, 1, encrypter.keyring)
|
||||
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
|
||||
}
|
||||
|
||||
func TestEncrypter_AddWrappedKey_sameKeyConcurrent(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
srv := &Server{
|
||||
logger: testlog.HCLogger(t),
|
||||
config: &Config{},
|
||||
}
|
||||
|
||||
encrypter, err := NewEncrypter(srv, t.TempDir())
|
||||
must.NoError(t, err)
|
||||
|
||||
// Create a valid and correctly formatted key and wrap it.
|
||||
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
|
||||
must.NoError(t, err)
|
||||
|
||||
wrappedKey, err := encrypter.wrapRootKey(key, true)
|
||||
must.NoError(t, err)
|
||||
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(timeoutCancel)
|
||||
|
||||
// Define the number of concurrent calls to AddWrappedKey. Changing this
|
||||
// value should not affect the correctness of the test.
|
||||
concurrentNum := 10
|
||||
|
||||
// Create a channel to receive the responses from the concurrent calls to
|
||||
// AddWrappedKey. The channel is buffered to ensure that the launched
|
||||
// routines can send to it without blocking.
|
||||
respCh := make(chan error, concurrentNum)
|
||||
|
||||
// Create a channel to control when the concurrent calls to AddWrappedKey
|
||||
// are triggered. When the channel is closed, all waiting routines will
|
||||
// unblock within 0.001 ms of each other.
|
||||
startCh := make(chan struct{})
|
||||
|
||||
// Launch the concurrent calls to AddWrappedKey and wait till they have all
|
||||
// triggered and responded before moving on. The timeout ensures this test
|
||||
// won't deadlock or hang indefinitely.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentNum)
|
||||
|
||||
for i := 0; i < concurrentNum; i++ {
|
||||
go func() {
|
||||
<-startCh
|
||||
respCh <- encrypter.AddWrappedKey(timeoutCtx, wrappedKey)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
close(startCh)
|
||||
wg.Wait()
|
||||
|
||||
// Gather the responses and ensure the encrypter state is as we expect.
|
||||
var respNum int
|
||||
|
||||
for {
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
must.NoError(t, resp)
|
||||
if respNum++; respNum == concurrentNum {
|
||||
must.NoError(t, encrypter.IsReady(timeoutCtx))
|
||||
must.MapEmpty(t, encrypter.decryptTasks)
|
||||
must.MapLen(t, 1, encrypter.keyring)
|
||||
must.MapContainsKey(t, encrypter.keyring, key.Meta.KeyID)
|
||||
return
|
||||
}
|
||||
case <-timeoutCtx.Done():
|
||||
must.NoError(t, timeoutCtx.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypter_decryptWrappedKeyTask_successful(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
srv := &Server{
|
||||
logger: testlog.HCLogger(t),
|
||||
config: &Config{},
|
||||
}
|
||||
|
||||
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
|
||||
must.NoError(t, err)
|
||||
|
||||
encrypter, err := NewEncrypter(srv, t.TempDir())
|
||||
must.NoError(t, err)
|
||||
|
||||
wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{})
|
||||
@@ -868,11 +1006,103 @@ func TestEncrypter_decryptWrappedKeyTask(t *testing.T) {
|
||||
must.NoError(t, err)
|
||||
must.NotNil(t, KMSWrapper)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = encrypter.decryptWrappedKeyTask(ctx, cancel, KMSWrapper, provider, key.Meta, wrappedKey)
|
||||
respCh := make(chan *cipherSet)
|
||||
|
||||
go encrypter.decryptWrappedKeyTask(ctx, KMSWrapper, key.Meta, wrappedKey, respCh)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for decryptWrappedKeyTask to complete")
|
||||
case cipherResp := <-respCh:
|
||||
must.NotNil(t, cipherResp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypter_decryptWrappedKeyTask_contextCancel(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
srv := &Server{
|
||||
logger: testlog.HCLogger(t),
|
||||
config: &Config{},
|
||||
}
|
||||
|
||||
encrypter, err := NewEncrypter(srv, t.TempDir())
|
||||
must.NoError(t, err)
|
||||
|
||||
// Create a valid and correctly formatted key and wrap it.
|
||||
key, err := structs.NewUnwrappedRootKey(structs.EncryptionAlgorithmAES256GCM)
|
||||
must.NoError(t, err)
|
||||
|
||||
wrappedKey, err := encrypter.encryptDEK(key, &structs.KEKProviderConfig{})
|
||||
must.NotNil(t, wrappedKey)
|
||||
must.NoError(t, err)
|
||||
|
||||
// Prepare the KMS wrapper and the response channel, so we can call
|
||||
// decryptWrappedKeyTask. Use a buffered channel, so the decrypt task does
|
||||
// not block on a send.
|
||||
provider, ok := encrypter.providerConfigs[string(structs.KEKProviderAEAD)]
|
||||
must.True(t, ok)
|
||||
must.NotNil(t, provider)
|
||||
|
||||
kmsWrapper, err := encrypter.newKMSWrapper(provider, key.Meta.KeyID, wrappedKey.KeyEncryptionKey)
|
||||
must.NoError(t, err)
|
||||
must.NotNil(t, kmsWrapper)
|
||||
|
||||
respCh := make(chan *cipherSet, 1)
|
||||
|
||||
// Generate a context and immediately cancel it.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// Ensure we receive an error indicating we hit the context done case and
|
||||
// check no cipher response was sent.
|
||||
err = encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh)
|
||||
must.ErrorContains(t, err, "operation cancelled")
|
||||
must.Eq(t, 0, len(respCh))
|
||||
|
||||
// Recreate the response channel so that it is no longer buffered. The
|
||||
// decrypt task should now block on attempting to send to it.
|
||||
respCh = make(chan *cipherSet)
|
||||
|
||||
// Generate a new context and an error channel so we can gather the response
|
||||
// of decryptWrappedKeyTask running inside a goroutine.
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
errorCh := make(chan error, 1)
|
||||
|
||||
// Launch the decryptWrappedKeyTask routine.
|
||||
go func() {
|
||||
err := encrypter.decryptWrappedKeyTask(ctx, kmsWrapper, key.Meta, wrappedKey, respCh)
|
||||
errorCh <- err
|
||||
}()
|
||||
|
||||
// Roughly ensure the decrypt task is running for enough time to get past
|
||||
// the cipher generation. This is so that when we cancel the context, we
|
||||
// have passed the helper.Backoff functions, which are also designed to exit
|
||||
// and return if the context is canceled. As Tim correctly pointed out; this
|
||||
// "is about giving this test a fighting chance to be testing the thing we
|
||||
// think it is".
|
||||
//
|
||||
// Canceling the context should cause the routine to exit and send an error
|
||||
// which we can check to ensure we correctly unblock.
|
||||
timer, timerStop := helper.NewSafeTimer(500 * time.Millisecond)
|
||||
defer timerStop()
|
||||
|
||||
<-timer.C
|
||||
cancel()
|
||||
|
||||
timer, timerStop = helper.NewSafeTimer(200 * time.Millisecond)
|
||||
defer timerStop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("timed out waiting for decryptWrappedKeyTask to send its error")
|
||||
case err := <-errorCh:
|
||||
must.ErrorContains(t, err, "context canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypter_AddWrappedKey_noWrappedKeys(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user