acl: make sure there is only one default Auth Method per type (#15504)

This PR adds a check that makes sure we don't insert a duplicate default ACL auth method for a given type.
This commit is contained in:
Piotr Kazmierczak
2022-12-09 14:46:54 +01:00
committed by GitHub
parent 605597ffd0
commit f4e89e2895
5 changed files with 107 additions and 10 deletions

View File

@@ -1709,6 +1709,12 @@ func (a *ACL) UpsertAuthMethods(
return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method")
}
// Snapshot the state so we can make lookups to verify default method
stateSnapshot, err := a.srv.State().Snapshot()
if err != nil {
return err
}
// Validate each auth method, canonicalize, and compute hash
for idx, authMethod := range args.AuthMethods {
if err := authMethod.Validate(
@@ -1716,6 +1722,18 @@ func (a *ACL) UpsertAuthMethods(
a.srv.config.ACLTokenMaxExpirationTTL); err != nil {
return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth method %d invalid: %v", idx, err)
}
// Are we trying to upsert a default auth method? Check if there isn't
// a default one for that very type already.
if authMethod.Default {
existingMethodsDefaultmethod, _ := stateSnapshot.GetDefaultACLAuthMethodByType(nil, authMethod.Type)
if existingMethodsDefaultmethod != nil {
return structs.NewErrRPCCodedf(
http.StatusBadRequest,
"default method for type %s already exists: %v", authMethod.Type, existingMethodsDefaultmethod.Name,
)
}
}
authMethod.Canonicalize()
authMethod.SetHash()
}
@@ -1733,7 +1751,7 @@ func (a *ACL) UpsertAuthMethods(
// Populate the response. We do a lookup against the state to pick up the
// proper create / modify times.
stateSnapshot, err := a.srv.State().Snapshot()
stateSnapshot, err = a.srv.State().Snapshot()
if err != nil {
return err
}

View File

@@ -3022,6 +3022,7 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) {
// Create the register request
am1 := mock.ACLAuthMethod()
am1.Default = true // make sure it's going to be a default method
// Lookup the authMethods
req := &structs.ACLAuthMethodUpsertRequest{
@@ -3032,14 +3033,26 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) {
},
}
var resp structs.ACLAuthMethodUpsertResponse
if err := msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp); err != nil {
t.Fatalf("err: %v", err)
}
must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp))
must.NotEq(t, uint64(0), resp.Index)
// Check we created the authMethod
out, err := s1.fsm.State().GetACLAuthMethodByName(nil, am1.Name)
must.Nil(t, err)
must.NotNil(t, out)
must.NotEq(t, 0, len(resp.AuthMethods))
must.True(t, am1.Equal(resp.AuthMethods[0]))
// Try to insert another default authMethod
am2 := mock.ACLAuthMethod()
am2.Default = true
req = &structs.ACLAuthMethodUpsertRequest{
AuthMethods: []*structs.ACLAuthMethod{am2},
WriteRequest: structs.WriteRequest{
Region: "global",
AuthToken: root.SecretID,
},
}
// We expect this to err since there's already a default method of the same type
must.Error(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp))
}

View File

@@ -226,7 +226,7 @@ func ACLAuthMethod() *structs.ACLAuthMethod {
Type: "OIDC",
TokenLocality: "local",
MaxTokenTTL: maxTokenTTL,
Default: true,
Default: false,
Config: &structs.ACLAuthMethodConfig{
OIDCDiscoveryURL: "http://example.com",
OIDCClientID: "mock",

View File

@@ -61,9 +61,18 @@ func (s *StateStore) upsertACLAuthMethodTxn(index uint64, txn *txn, method *stru
// This validation also happens within the RPC handler, but Raft latency
// could mean that by the time the state call is invoked, another Raft
// update has already written a method with the same name. We therefore
// need to check we are not trying to create a method with an existing
// name.
// update has already written a method with the same name or default
// setting. We therefore need to check we are not trying to create a method
// with an existing name or a duplicate default for the same type.
if method.Default {
existingMethodsDefaultmethod, _ := s.GetDefaultACLAuthMethodByType(nil, method.Type)
if existingMethodsDefaultmethod != nil {
return false, fmt.Errorf(
"default ACL auth method for type %s already exists: %v",
method.Type, existingMethodsDefaultmethod.Name,
)
}
}
existingRaw, err := txn.First(TableACLAuthMethods, indexID, method.Name)
if err != nil {
return false, fmt.Errorf("ACL auth method lookup failed: %v", err)
@@ -176,3 +185,34 @@ func (s *StateStore) GetACLAuthMethodByName(ws memdb.WatchSet, authMethod string
}
return nil, nil
}
// GetDefaultACLAuthMethodByType returns a default ACL Auth Methods for a given
// auth type. Since we only want 1 default auth method per type, this function
// is used during upserts to facilitate that check.
func (s *StateStore) GetDefaultACLAuthMethodByType(ws memdb.WatchSet, methodType string) (*structs.ACLAuthMethod, error) {
txn := s.db.ReadTxn()
// Walk the entire table to get all ACL auth methods.
iter, err := txn.Get(TableACLAuthMethods, indexID)
if err != nil {
return nil, fmt.Errorf("ACL auth method lookup failed: %v", err)
}
ws.Add(iter.WatchCh())
// Filter out non-default methods
filter := memdb.NewFilterIterator(iter, func(raw interface{}) bool {
method, ok := raw.(*structs.ACLAuthMethod)
if !ok {
return true
}
// any non-default method or method of different type than desired gets filtered-out
return !method.Default || method.Type != methodType
})
for raw := filter.Next(); raw != nil; raw = filter.Next() {
method := raw.(*structs.ACLAuthMethod)
return method, nil
}
return nil, nil
}

View File

@@ -4,11 +4,10 @@ import (
"testing"
"github.com/hashicorp/go-memdb"
"github.com/shoenig/test/must"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/shoenig/test/must"
)
func TestStateStore_UpsertACLAuthMethods(t *testing.T) {
@@ -227,3 +226,30 @@ func TestStateStore_GetACLAuthMethodByName(t *testing.T) {
must.NoError(t, err)
must.Equal(t, mockedACLAuthMethods[1], authMethod)
}
func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) {
ci.Parallel(t)
testState := testStateStore(t)
// Generate 2 auth methods, make one of them default
am1 := mock.ACLAuthMethod()
am1.Default = true
am2 := mock.ACLAuthMethod()
// upsert
mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2}
must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
// Get the default method for OIDC
ws := memdb.NewWatchSet()
defaultOIDCMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "OIDC")
must.NoError(t, err)
must.True(t, defaultOIDCMethod.Default)
must.Eq(t, am1, defaultOIDCMethod)
// Get the default method for jwt (should not return anything)
defaultJWTMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "JWT")
must.NoError(t, err)
must.Nil(t, defaultJWTMethod)
}