From f4e89e2895b1d6a5aa270ce27483d4bc2cc6d97a Mon Sep 17 00:00:00 2001 From: Piotr Kazmierczak <470696+pkazmierczak@users.noreply.github.com> Date: Fri, 9 Dec 2022 14:46:54 +0100 Subject: [PATCH] acl: make sure there is only one default Auth Method per type (#15504) This PR adds a check that makes sure we don't insert a duplicate default ACL auth method for a given type. --- nomad/acl_endpoint.go | 20 ++++++++++- nomad/acl_endpoint_test.go | 19 ++++++++-- nomad/mock/acl.go | 2 +- nomad/state/state_store_acl_sso.go | 46 +++++++++++++++++++++++-- nomad/state/state_store_acl_sso_test.go | 30 ++++++++++++++-- 5 files changed, 107 insertions(+), 10 deletions(-) diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index e06b414fe..d18ac3128 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -1709,6 +1709,12 @@ func (a *ACL) UpsertAuthMethods( return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method") } + // Snapshot the state so we can make lookups to verify default method + stateSnapshot, err := a.srv.State().Snapshot() + if err != nil { + return err + } + // Validate each auth method, canonicalize, and compute hash for idx, authMethod := range args.AuthMethods { if err := authMethod.Validate( @@ -1716,6 +1722,18 @@ func (a *ACL) UpsertAuthMethods( a.srv.config.ACLTokenMaxExpirationTTL); err != nil { return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth method %d invalid: %v", idx, err) } + + // Are we trying to upsert a default auth method? Check if there isn't + // a default one for that very type already. + if authMethod.Default { + existingMethodsDefaultmethod, _ := stateSnapshot.GetDefaultACLAuthMethodByType(nil, authMethod.Type) + if existingMethodsDefaultmethod != nil { + return structs.NewErrRPCCodedf( + http.StatusBadRequest, + "default method for type %s already exists: %v", authMethod.Type, existingMethodsDefaultmethod.Name, + ) + } + } authMethod.Canonicalize() authMethod.SetHash() } @@ -1733,7 +1751,7 @@ func (a *ACL) UpsertAuthMethods( // Populate the response. We do a lookup against the state to pick up the // proper create / modify times. - stateSnapshot, err := a.srv.State().Snapshot() + stateSnapshot, err = a.srv.State().Snapshot() if err != nil { return err } diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index 8c6efe05b..790d28f1b 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -3022,6 +3022,7 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) { // Create the register request am1 := mock.ACLAuthMethod() + am1.Default = true // make sure it's going to be a default method // Lookup the authMethods req := &structs.ACLAuthMethodUpsertRequest{ @@ -3032,14 +3033,26 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) { }, } var resp structs.ACLAuthMethodUpsertResponse - if err := msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp); err != nil { - t.Fatalf("err: %v", err) - } + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp)) must.NotEq(t, uint64(0), resp.Index) // Check we created the authMethod out, err := s1.fsm.State().GetACLAuthMethodByName(nil, am1.Name) must.Nil(t, err) must.NotNil(t, out) + must.NotEq(t, 0, len(resp.AuthMethods)) must.True(t, am1.Equal(resp.AuthMethods[0])) + + // Try to insert another default authMethod + am2 := mock.ACLAuthMethod() + am2.Default = true + req = &structs.ACLAuthMethodUpsertRequest{ + AuthMethods: []*structs.ACLAuthMethod{am2}, + WriteRequest: structs.WriteRequest{ + Region: "global", + AuthToken: root.SecretID, + }, + } + // We expect this to err since there's already a default method of the same type + must.Error(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp)) } diff --git a/nomad/mock/acl.go b/nomad/mock/acl.go index 5f34f5c86..bdb460490 100644 --- a/nomad/mock/acl.go +++ b/nomad/mock/acl.go @@ -226,7 +226,7 @@ func ACLAuthMethod() *structs.ACLAuthMethod { Type: "OIDC", TokenLocality: "local", MaxTokenTTL: maxTokenTTL, - Default: true, + Default: false, Config: &structs.ACLAuthMethodConfig{ OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", diff --git a/nomad/state/state_store_acl_sso.go b/nomad/state/state_store_acl_sso.go index 8c8016c3e..337df25cb 100644 --- a/nomad/state/state_store_acl_sso.go +++ b/nomad/state/state_store_acl_sso.go @@ -61,9 +61,18 @@ func (s *StateStore) upsertACLAuthMethodTxn(index uint64, txn *txn, method *stru // This validation also happens within the RPC handler, but Raft latency // could mean that by the time the state call is invoked, another Raft - // update has already written a method with the same name. We therefore - // need to check we are not trying to create a method with an existing - // name. + // update has already written a method with the same name or default + // setting. We therefore need to check we are not trying to create a method + // with an existing name or a duplicate default for the same type. + if method.Default { + existingMethodsDefaultmethod, _ := s.GetDefaultACLAuthMethodByType(nil, method.Type) + if existingMethodsDefaultmethod != nil { + return false, fmt.Errorf( + "default ACL auth method for type %s already exists: %v", + method.Type, existingMethodsDefaultmethod.Name, + ) + } + } existingRaw, err := txn.First(TableACLAuthMethods, indexID, method.Name) if err != nil { return false, fmt.Errorf("ACL auth method lookup failed: %v", err) @@ -176,3 +185,34 @@ func (s *StateStore) GetACLAuthMethodByName(ws memdb.WatchSet, authMethod string } return nil, nil } + +// GetDefaultACLAuthMethodByType returns a default ACL Auth Methods for a given +// auth type. Since we only want 1 default auth method per type, this function +// is used during upserts to facilitate that check. +func (s *StateStore) GetDefaultACLAuthMethodByType(ws memdb.WatchSet, methodType string) (*structs.ACLAuthMethod, error) { + txn := s.db.ReadTxn() + + // Walk the entire table to get all ACL auth methods. + iter, err := txn.Get(TableACLAuthMethods, indexID) + if err != nil { + return nil, fmt.Errorf("ACL auth method lookup failed: %v", err) + } + ws.Add(iter.WatchCh()) + + // Filter out non-default methods + filter := memdb.NewFilterIterator(iter, func(raw interface{}) bool { + method, ok := raw.(*structs.ACLAuthMethod) + if !ok { + return true + } + // any non-default method or method of different type than desired gets filtered-out + return !method.Default || method.Type != methodType + }) + + for raw := filter.Next(); raw != nil; raw = filter.Next() { + method := raw.(*structs.ACLAuthMethod) + return method, nil + } + + return nil, nil +} diff --git a/nomad/state/state_store_acl_sso_test.go b/nomad/state/state_store_acl_sso_test.go index 38e9666df..58a762d12 100644 --- a/nomad/state/state_store_acl_sso_test.go +++ b/nomad/state/state_store_acl_sso_test.go @@ -4,11 +4,10 @@ import ( "testing" "github.com/hashicorp/go-memdb" - "github.com/shoenig/test/must" - "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" ) func TestStateStore_UpsertACLAuthMethods(t *testing.T) { @@ -227,3 +226,30 @@ func TestStateStore_GetACLAuthMethodByName(t *testing.T) { must.NoError(t, err) must.Equal(t, mockedACLAuthMethods[1], authMethod) } + +func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) { + ci.Parallel(t) + testState := testStateStore(t) + + // Generate 2 auth methods, make one of them default + am1 := mock.ACLAuthMethod() + am1.Default = true + am2 := mock.ACLAuthMethod() + + // upsert + mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2} + must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods)) + + // Get the default method for OIDC + ws := memdb.NewWatchSet() + defaultOIDCMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "OIDC") + must.NoError(t, err) + + must.True(t, defaultOIDCMethod.Default) + must.Eq(t, am1, defaultOIDCMethod) + + // Get the default method for jwt (should not return anything) + defaultJWTMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "JWT") + must.NoError(t, err) + must.Nil(t, defaultJWTMethod) +}