mirror of
https://github.com/kemko/nomad.git
synced 2026-01-08 03:15:42 +03:00
acl: make sure there is only one default Auth Method per type (#15504)
This PR adds a check that makes sure we don't insert a duplicate default ACL auth method for a given type.
This commit is contained in:
committed by
GitHub
parent
605597ffd0
commit
f4e89e2895
@@ -1709,6 +1709,12 @@ func (a *ACL) UpsertAuthMethods(
|
|||||||
return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method")
|
return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one auth method")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Snapshot the state so we can make lookups to verify default method
|
||||||
|
stateSnapshot, err := a.srv.State().Snapshot()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Validate each auth method, canonicalize, and compute hash
|
// Validate each auth method, canonicalize, and compute hash
|
||||||
for idx, authMethod := range args.AuthMethods {
|
for idx, authMethod := range args.AuthMethods {
|
||||||
if err := authMethod.Validate(
|
if err := authMethod.Validate(
|
||||||
@@ -1716,6 +1722,18 @@ func (a *ACL) UpsertAuthMethods(
|
|||||||
a.srv.config.ACLTokenMaxExpirationTTL); err != nil {
|
a.srv.config.ACLTokenMaxExpirationTTL); err != nil {
|
||||||
return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth method %d invalid: %v", idx, err)
|
return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth method %d invalid: %v", idx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Are we trying to upsert a default auth method? Check if there isn't
|
||||||
|
// a default one for that very type already.
|
||||||
|
if authMethod.Default {
|
||||||
|
existingMethodsDefaultmethod, _ := stateSnapshot.GetDefaultACLAuthMethodByType(nil, authMethod.Type)
|
||||||
|
if existingMethodsDefaultmethod != nil {
|
||||||
|
return structs.NewErrRPCCodedf(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
"default method for type %s already exists: %v", authMethod.Type, existingMethodsDefaultmethod.Name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
authMethod.Canonicalize()
|
authMethod.Canonicalize()
|
||||||
authMethod.SetHash()
|
authMethod.SetHash()
|
||||||
}
|
}
|
||||||
@@ -1733,7 +1751,7 @@ func (a *ACL) UpsertAuthMethods(
|
|||||||
|
|
||||||
// Populate the response. We do a lookup against the state to pick up the
|
// Populate the response. We do a lookup against the state to pick up the
|
||||||
// proper create / modify times.
|
// proper create / modify times.
|
||||||
stateSnapshot, err := a.srv.State().Snapshot()
|
stateSnapshot, err = a.srv.State().Snapshot()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3022,6 +3022,7 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) {
|
|||||||
|
|
||||||
// Create the register request
|
// Create the register request
|
||||||
am1 := mock.ACLAuthMethod()
|
am1 := mock.ACLAuthMethod()
|
||||||
|
am1.Default = true // make sure it's going to be a default method
|
||||||
|
|
||||||
// Lookup the authMethods
|
// Lookup the authMethods
|
||||||
req := &structs.ACLAuthMethodUpsertRequest{
|
req := &structs.ACLAuthMethodUpsertRequest{
|
||||||
@@ -3032,14 +3033,26 @@ func TestACLEndpoint_UpsertACLAuthMethods(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
var resp structs.ACLAuthMethodUpsertResponse
|
var resp structs.ACLAuthMethodUpsertResponse
|
||||||
if err := msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp); err != nil {
|
must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp))
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
must.NotEq(t, uint64(0), resp.Index)
|
must.NotEq(t, uint64(0), resp.Index)
|
||||||
|
|
||||||
// Check we created the authMethod
|
// Check we created the authMethod
|
||||||
out, err := s1.fsm.State().GetACLAuthMethodByName(nil, am1.Name)
|
out, err := s1.fsm.State().GetACLAuthMethodByName(nil, am1.Name)
|
||||||
must.Nil(t, err)
|
must.Nil(t, err)
|
||||||
must.NotNil(t, out)
|
must.NotNil(t, out)
|
||||||
|
must.NotEq(t, 0, len(resp.AuthMethods))
|
||||||
must.True(t, am1.Equal(resp.AuthMethods[0]))
|
must.True(t, am1.Equal(resp.AuthMethods[0]))
|
||||||
|
|
||||||
|
// Try to insert another default authMethod
|
||||||
|
am2 := mock.ACLAuthMethod()
|
||||||
|
am2.Default = true
|
||||||
|
req = &structs.ACLAuthMethodUpsertRequest{
|
||||||
|
AuthMethods: []*structs.ACLAuthMethod{am2},
|
||||||
|
WriteRequest: structs.WriteRequest{
|
||||||
|
Region: "global",
|
||||||
|
AuthToken: root.SecretID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// We expect this to err since there's already a default method of the same type
|
||||||
|
must.Error(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertAuthMethodsRPCMethod, req, &resp))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ func ACLAuthMethod() *structs.ACLAuthMethod {
|
|||||||
Type: "OIDC",
|
Type: "OIDC",
|
||||||
TokenLocality: "local",
|
TokenLocality: "local",
|
||||||
MaxTokenTTL: maxTokenTTL,
|
MaxTokenTTL: maxTokenTTL,
|
||||||
Default: true,
|
Default: false,
|
||||||
Config: &structs.ACLAuthMethodConfig{
|
Config: &structs.ACLAuthMethodConfig{
|
||||||
OIDCDiscoveryURL: "http://example.com",
|
OIDCDiscoveryURL: "http://example.com",
|
||||||
OIDCClientID: "mock",
|
OIDCClientID: "mock",
|
||||||
|
|||||||
@@ -61,9 +61,18 @@ func (s *StateStore) upsertACLAuthMethodTxn(index uint64, txn *txn, method *stru
|
|||||||
|
|
||||||
// This validation also happens within the RPC handler, but Raft latency
|
// This validation also happens within the RPC handler, but Raft latency
|
||||||
// could mean that by the time the state call is invoked, another Raft
|
// could mean that by the time the state call is invoked, another Raft
|
||||||
// update has already written a method with the same name. We therefore
|
// update has already written a method with the same name or default
|
||||||
// need to check we are not trying to create a method with an existing
|
// setting. We therefore need to check we are not trying to create a method
|
||||||
// name.
|
// with an existing name or a duplicate default for the same type.
|
||||||
|
if method.Default {
|
||||||
|
existingMethodsDefaultmethod, _ := s.GetDefaultACLAuthMethodByType(nil, method.Type)
|
||||||
|
if existingMethodsDefaultmethod != nil {
|
||||||
|
return false, fmt.Errorf(
|
||||||
|
"default ACL auth method for type %s already exists: %v",
|
||||||
|
method.Type, existingMethodsDefaultmethod.Name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
existingRaw, err := txn.First(TableACLAuthMethods, indexID, method.Name)
|
existingRaw, err := txn.First(TableACLAuthMethods, indexID, method.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("ACL auth method lookup failed: %v", err)
|
return false, fmt.Errorf("ACL auth method lookup failed: %v", err)
|
||||||
@@ -176,3 +185,34 @@ func (s *StateStore) GetACLAuthMethodByName(ws memdb.WatchSet, authMethod string
|
|||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDefaultACLAuthMethodByType returns a default ACL Auth Methods for a given
|
||||||
|
// auth type. Since we only want 1 default auth method per type, this function
|
||||||
|
// is used during upserts to facilitate that check.
|
||||||
|
func (s *StateStore) GetDefaultACLAuthMethodByType(ws memdb.WatchSet, methodType string) (*structs.ACLAuthMethod, error) {
|
||||||
|
txn := s.db.ReadTxn()
|
||||||
|
|
||||||
|
// Walk the entire table to get all ACL auth methods.
|
||||||
|
iter, err := txn.Get(TableACLAuthMethods, indexID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ACL auth method lookup failed: %v", err)
|
||||||
|
}
|
||||||
|
ws.Add(iter.WatchCh())
|
||||||
|
|
||||||
|
// Filter out non-default methods
|
||||||
|
filter := memdb.NewFilterIterator(iter, func(raw interface{}) bool {
|
||||||
|
method, ok := raw.(*structs.ACLAuthMethod)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// any non-default method or method of different type than desired gets filtered-out
|
||||||
|
return !method.Default || method.Type != methodType
|
||||||
|
})
|
||||||
|
|
||||||
|
for raw := filter.Next(); raw != nil; raw = filter.Next() {
|
||||||
|
method := raw.(*structs.ACLAuthMethod)
|
||||||
|
return method, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/go-memdb"
|
"github.com/hashicorp/go-memdb"
|
||||||
"github.com/shoenig/test/must"
|
|
||||||
|
|
||||||
"github.com/hashicorp/nomad/ci"
|
"github.com/hashicorp/nomad/ci"
|
||||||
"github.com/hashicorp/nomad/nomad/mock"
|
"github.com/hashicorp/nomad/nomad/mock"
|
||||||
"github.com/hashicorp/nomad/nomad/structs"
|
"github.com/hashicorp/nomad/nomad/structs"
|
||||||
|
"github.com/shoenig/test/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStateStore_UpsertACLAuthMethods(t *testing.T) {
|
func TestStateStore_UpsertACLAuthMethods(t *testing.T) {
|
||||||
@@ -227,3 +226,30 @@ func TestStateStore_GetACLAuthMethodByName(t *testing.T) {
|
|||||||
must.NoError(t, err)
|
must.NoError(t, err)
|
||||||
must.Equal(t, mockedACLAuthMethods[1], authMethod)
|
must.Equal(t, mockedACLAuthMethods[1], authMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) {
|
||||||
|
ci.Parallel(t)
|
||||||
|
testState := testStateStore(t)
|
||||||
|
|
||||||
|
// Generate 2 auth methods, make one of them default
|
||||||
|
am1 := mock.ACLAuthMethod()
|
||||||
|
am1.Default = true
|
||||||
|
am2 := mock.ACLAuthMethod()
|
||||||
|
|
||||||
|
// upsert
|
||||||
|
mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2}
|
||||||
|
must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
|
||||||
|
|
||||||
|
// Get the default method for OIDC
|
||||||
|
ws := memdb.NewWatchSet()
|
||||||
|
defaultOIDCMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "OIDC")
|
||||||
|
must.NoError(t, err)
|
||||||
|
|
||||||
|
must.True(t, defaultOIDCMethod.Default)
|
||||||
|
must.Eq(t, am1, defaultOIDCMethod)
|
||||||
|
|
||||||
|
// Get the default method for jwt (should not return anything)
|
||||||
|
defaultJWTMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "JWT")
|
||||||
|
must.NoError(t, err)
|
||||||
|
must.Nil(t, defaultJWTMethod)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user