diff --git a/helper/funcs.go b/helper/funcs.go index b37450f18..612dfa7a4 100644 --- a/helper/funcs.go +++ b/helper/funcs.go @@ -485,3 +485,13 @@ func WithLock(lock sync.Locker, f func()) { defer lock.Unlock() f() } + +// Merge takes two variables and returns variable b in case a has zero value. +// For pointer values please use pointer.Merge. +func Merge[T comparable](a, b T) T { + var zero T + if a == zero { + return b + } + return a +} diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index ef6ed6e0e..e1e057142 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -1716,7 +1716,13 @@ func (a *ACL) UpsertAuthMethods( } // Validate each auth method, canonicalize, and compute hash + // merge methods in case we're doing an update for idx, authMethod := range args.AuthMethods { + // if there's an existing method with the same name, we treat this as + // an update + existingMethod, _ := stateSnapshot.GetACLAuthMethodByName(nil, authMethod.Name) + authMethod.Merge(existingMethod) + if err := authMethod.Validate( a.srv.config.ACLTokenMinExpirationTTL, a.srv.config.ACLTokenMaxExpirationTTL); err != nil { @@ -2051,10 +2057,6 @@ func (a *ACL) UpsertBindingRules( // Validate each binding rules and compute the hash. for idx, bindingRule := range args.ACLBindingRules { - if err := bindingRule.Validate(); err != nil { - return structs.NewErrRPCCodedf(http.StatusBadRequest, "binding rule %d invalid: %v", idx, err) - } - // If the caller has passed a rule ID, this call is considered an // update to an existing rule. We should therefore ensure it is found // within state. @@ -2068,6 +2070,22 @@ func (a *ACL) UpsertBindingRules( return structs.NewErrRPCCodedf( http.StatusBadRequest, "cannot find binding rule %s", bindingRule.ID) } + + // merge + bindingRule.Merge(existingBindingRule) + + // Auth methods cannot be changed + if bindingRule.AuthMethod != existingBindingRule.AuthMethod { + return structs.NewErrRPCCoded( + http.StatusBadRequest, "cannot update auth method for binding rule, create a new rule instead", + ) + } + bindingRule.AuthMethod = existingBindingRule.AuthMethod + } + + // Validate only if it's not an update + if err := bindingRule.Validate(); err != nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "binding rule %d invalid: %v", idx, err) } // Ensure the auth method linked to exists within state. diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index c578fc25d..b13a19ddd 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -3056,6 +3056,18 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) { } // We expect this to err since there's already a default method of the same type must.Error(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp)) + + // Update token locality + am3 := &structs.ACLAuthMethod{Name: am1.Name, TokenLocality: "global"} + req = &structs.ACLAuthMethodUpsertRequest{ + AuthMethods: []*structs.ACLAuthMethod{am3}, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: root.SecretID, + }, + } + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp)) + must.Eq(t, resp.AuthMethods[0].TokenLocality, am3.TokenLocality) } func TestACL_UpsertBindingRules(t *testing.T) { diff --git a/nomad/structs/acl.go b/nomad/structs/acl.go index 2ab629c23..f0eeb962d 100644 --- a/nomad/structs/acl.go +++ b/nomad/structs/acl.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-set" + "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/uuid" "golang.org/x/crypto/blake2b" @@ -758,6 +759,18 @@ func (a *ACLAuthMethod) Canonicalize() { a.ModifyTime = t } +// Merge merges auth method a with method b. It sets all required empty fields +// of method a to corresponding values of method b, except for "default" and +// "name." +func (a *ACLAuthMethod) Merge(b *ACLAuthMethod) { + if b != nil { + a.Type = helper.Merge(a.Type, b.Type) + a.TokenLocality = helper.Merge(a.TokenLocality, b.TokenLocality) + a.MaxTokenTTL = helper.Merge(a.MaxTokenTTL, b.MaxTokenTTL) + a.Config = helper.Merge(a.Config, b.Config) + } +} + // Validate returns an error is the ACLAuthMethod is invalid. // // TODO revisit possible other validity conditions in the future @@ -992,8 +1005,6 @@ func (a *ACLBindingRule) Validate() error { mErr.Errors = append(mErr.Errors, fmt.Errorf("description longer than %d", maxACLRoleDescriptionLength)) } - // Be specific about the error as returning an error that includes an empty - // quote ("") can be a little confusing. if a.BindType == "" { mErr.Errors = append(mErr.Errors, errors.New("bind type is missing")) } else { @@ -1007,6 +1018,14 @@ func (a *ACLBindingRule) Validate() error { return mErr.ErrorOrNil() } +// Merge merges binding rule a with b. It sets all required empty fields of rule +// a to corresponding values of rule b, except for "ID" which must be provided. +func (a *ACLBindingRule) Merge(b *ACLBindingRule) { + a.BindName = helper.Merge(a.BindName, b.BindName) + a.BindType = helper.Merge(a.BindType, b.BindType) + a.AuthMethod = helper.Merge(a.AuthMethod, b.AuthMethod) +} + // SetHash is used to compute and set the hash of the ACL binding rule. This // should be called every and each time a user specified field on the method is // changed before updating the Nomad state store. diff --git a/nomad/structs/acl_test.go b/nomad/structs/acl_test.go index 7700519ec..b80184ec2 100644 --- a/nomad/structs/acl_test.go +++ b/nomad/structs/acl_test.go @@ -713,7 +713,7 @@ func TestACLRole_Copy(t *testing.T) { { name: "general 1", inputACLRole: &ACLRole{ - Name: fmt.Sprintf("acl-role"), + Name: "acl-role", Description: "mocked-test-acl-role", Policies: []*ACLRolePolicyLink{ {Name: "mocked-test-policy-1"}, @@ -1039,6 +1039,45 @@ func TestACLAuthMethod_Validate(t *testing.T) { } } +func TestACLAuthMethod_Merge(t *testing.T) { + ci.Parallel(t) + + name := fmt.Sprintf("acl-auth-method-%s", uuid.Short()) + + maxTokenTTL, _ := time.ParseDuration("3600s") + am1 := &ACLAuthMethod{ + Name: name, + TokenLocality: "global", + } + am2 := &ACLAuthMethod{ + Name: name, + Type: "OIDC", + TokenLocality: "locality", + MaxTokenTTL: maxTokenTTL, + Default: true, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + BoundAudiences: []string{"audience1", "audience2"}, + AllowedRedirectURIs: []string{"foo", "bar"}, + DiscoveryCaPem: []string{"foo"}, + SigningAlgs: []string{"bar"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + }, + CreateTime: time.Now().UTC(), + CreateIndex: 10, + ModifyIndex: 10, + } + + am1.Merge(am2) + must.Eq(t, am1.TokenLocality, "global") + minTTL, _ := time.ParseDuration("10s") + maxTTL, _ := time.ParseDuration("10h") + must.NoError(t, am1.Validate(minTTL, maxTTL)) +} + func TestACLAuthMethodConfig_Copy(t *testing.T) { ci.Parallel(t) @@ -1163,6 +1202,31 @@ func TestACLBindingRule_Validate(t *testing.T) { must.StrContains(t, err.Error(), `unsupported bind type: "service"`) } +func TestACLBindingRule_Merge(t *testing.T) { + ci.Parallel(t) + + id := uuid.Short() + br := &ACLBindingRule{ + ID: id, + Description: "old description", + AuthMethod: "example-acl-auth-method", + BindType: "rule", + BindName: "bind name", + CreateTime: time.Now().UTC(), + CreateIndex: 10, + ModifyIndex: 10, + } + + // make a description update + br_description_update := &ACLBindingRule{ + ID: id, + Description: "new description", + } + br_description_update.Merge(br) + must.Eq(t, br_description_update.Description, "new description") + must.Eq(t, br_description_update.BindType, "rule") +} + func TestACLBindingRule_SetHash(t *testing.T) { ci.Parallel(t)