From 0abadb6804dfbc193ccf72701cbaa145eba7603d Mon Sep 17 00:00:00 2001 From: Piotr Kazmierczak <470696+pkazmierczak@users.noreply.github.com> Date: Thu, 26 Jan 2023 14:17:11 +0100 Subject: [PATCH] acl: make auth method default across all types (#15869) --- command/login.go | 29 ++++++++++++++++--------- command/login_test.go | 23 +++++++++++++++++++- nomad/acl_endpoint.go | 8 +++---- nomad/state/state_store_acl_sso.go | 18 +++++++-------- nomad/state/state_store_acl_sso_test.go | 14 +++++------- website/content/docs/commands/login.mdx | 5 +++-- 6 files changed, 61 insertions(+), 36 deletions(-) diff --git a/command/login.go b/command/login.go index a2e361504..831ae78bf 100644 --- a/command/login.go +++ b/command/login.go @@ -48,7 +48,8 @@ Login Options: has configured a default, this flag is optional. -type - Type of the auth method to login to. Defaults to "OIDC". + Type of the auth method to login to. If the cluster administrator has + configured a default, this flag is optional. -oidc-callback-addr The address to use for the local OIDC callback server. This should be given @@ -88,7 +89,7 @@ func (l *LoginCommand) Run(args []string) int { flags := l.Meta.FlagSet(l.Name(), FlagSetClient) flags.Usage = func() { l.Ui.Output(l.Help()) } flags.StringVar(&l.authMethodName, "method", "", "") - flags.StringVar(&l.authMethodType, "type", "OIDC", "") + flags.StringVar(&l.authMethodType, "type", "", "") flags.StringVar(&l.callbackAddr, "oidc-callback-addr", "localhost:4649", "") flags.BoolVar(&l.json, "json", false, "") flags.StringVar(&l.template, "t", "", "") @@ -112,9 +113,6 @@ func (l *LoginCommand) Run(args []string) int { // means an empty type is only possible is the caller specifies this // explicitly. switch sanitizedMethodType { - case "": - l.Ui.Error("Please supply an authentication type") - return 1 case api.ACLAuthMethodTypeOIDC: default: l.Ui.Error(fmt.Sprintf("Unsupported authentication type %q", sanitizedMethodType)) @@ -127,9 +125,10 @@ func (l *LoginCommand) Run(args []string) int { return 1 } - // If the caller did not supply and auth method name, attempt to lookup the - // default. This ensures a nice UX as clusters are expected to only have - // one method, and this avoids having to type the name during each login. + // If the caller did not supply an auth method name or type, attempt to + // lookup the default. This ensures a nice UX as clusters are expected to + // only have one method, and this avoids having to type the name during + // each login. if l.authMethodName == "" { authMethodList, _, err := client.ACLAuthMethods().List(nil) @@ -141,11 +140,21 @@ func (l *LoginCommand) Run(args []string) int { for _, authMethod := range authMethodList { if authMethod.Default { l.authMethodName = authMethod.Name + if l.authMethodType == "" { + l.authMethodType = authMethod.Type + } + if l.authMethodType != authMethod.Type { + l.Ui.Error(fmt.Sprintf( + "Specified type: %s does not match the type of the default method: %s", + l.authMethodType, authMethod.Type, + )) + return 1 + } } } - if l.authMethodName == "" { - l.Ui.Error("Must specify an auth method name, no default found") + if l.authMethodName == "" || l.authMethodType == "" { + l.Ui.Error("Must specify an auth method name and type, no default found") return 1 } } diff --git a/command/login_test.go b/command/login_test.go index 98b7af6bd..a7348a55f 100644 --- a/command/login_test.go +++ b/command/login_test.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/command/agent" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/mitchellh/cli" "github.com/shoenig/test/must" @@ -47,7 +48,27 @@ func TestLoginCommand_Run(t *testing.T) { // Use a valid method type but with incorrect casing so we can ensure this // is handled. must.Eq(t, 1, cmd.Run([]string{"-address=" + agentURL, "-type=oIdC"})) - must.StrContains(t, ui.ErrorWriter.String(), "Must specify an auth method name, no default found") + must.StrContains(t, ui.ErrorWriter.String(), "Must specify an auth method name and type, no default found") + + ui.OutputWriter.Reset() + ui.ErrorWriter.Reset() + + // Store a default auth method + state := srv.Agent.Server().State() + method := &structs.ACLAuthMethod{ + Name: "test-auth-method", + Default: true, + Type: "JWT", + Config: &structs.ACLAuthMethodConfig{ + OIDCDiscoveryURL: "http://example.com", + }, + } + method.SetHash() + must.NoError(t, state.UpsertACLAuthMethods(1000, []*structs.ACLAuthMethod{method})) + + // Specify an incorrect type of default method + must.Eq(t, 1, cmd.Run([]string{"-address=" + agentURL, "-type=OIDC"})) + must.StrContains(t, ui.ErrorWriter.String(), "Specified type: OIDC does not match the type of the default method: JWT") ui.OutputWriter.Reset() ui.ErrorWriter.Reset() diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index 58c6792c4..27c87248a 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -1824,13 +1824,13 @@ func (a *ACL) UpsertAuthMethods( } // Are we trying to upsert a default auth method? Check if there isn't - // a default one for that very type already. + // a default one already. if authMethod.Default { - existingMethodsDefaultmethod, _ := stateSnapshot.GetDefaultACLAuthMethodByType(nil, authMethod.Type) - if existingMethodsDefaultmethod != nil && existingMethodsDefaultmethod.Name != authMethod.Name { + existingMethodsDefaultMethod, _ := stateSnapshot.GetDefaultACLAuthMethod(nil) + if existingMethodsDefaultMethod != nil && existingMethodsDefaultMethod.Name != authMethod.Name { return structs.NewErrRPCCodedf( http.StatusBadRequest, - "default method for type %s already exists: %v", authMethod.Type, existingMethodsDefaultmethod.Name, + "default method already exists: %v", existingMethodsDefaultMethod.Name, ) } } diff --git a/nomad/state/state_store_acl_sso.go b/nomad/state/state_store_acl_sso.go index 5ce3278e5..f27ac63ce 100644 --- a/nomad/state/state_store_acl_sso.go +++ b/nomad/state/state_store_acl_sso.go @@ -65,11 +65,10 @@ func (s *StateStore) upsertACLAuthMethodTxn(index uint64, txn *txn, method *stru // setting. We therefore need to check we are not trying to create a method // with an existing name or a duplicate default for the same type. if method.Default { - existingMethodsDefaultmethod, _ := s.GetDefaultACLAuthMethodByType(nil, method.Type) - if existingMethodsDefaultmethod != nil && existingMethodsDefaultmethod.Name != method.Name { + existingMethodsDefaultMethod, _ := s.GetDefaultACLAuthMethod(nil) + if existingMethodsDefaultMethod != nil && existingMethodsDefaultMethod.Name != method.Name { return false, fmt.Errorf( - "default ACL auth method for type %s already exists: %v", - method.Type, existingMethodsDefaultmethod.Name, + "default ACL auth method already exists: %v", existingMethodsDefaultMethod.Name, ) } } @@ -186,10 +185,9 @@ func (s *StateStore) GetACLAuthMethodByName(ws memdb.WatchSet, authMethod string return nil, nil } -// GetDefaultACLAuthMethodByType returns a default ACL Auth Methods for a given -// auth type. Since we only want 1 default auth method per type, this function -// is used during upserts to facilitate that check. -func (s *StateStore) GetDefaultACLAuthMethodByType(ws memdb.WatchSet, methodType string) (*structs.ACLAuthMethod, error) { +// GetDefaultACLAuthMethod returns a default ACL Auth Method. This function is +// used during upserts to facilitate a check that there's only 1 default Auth Method. +func (s *StateStore) GetDefaultACLAuthMethod(ws memdb.WatchSet) (*structs.ACLAuthMethod, error) { txn := s.db.ReadTxn() // Walk the entire table to get all ACL auth methods. @@ -205,8 +203,8 @@ func (s *StateStore) GetDefaultACLAuthMethodByType(ws memdb.WatchSet, methodType if !ok { return true } - // any non-default method or method of different type than desired gets filtered-out - return !method.Default || method.Type != methodType + // any non-default method gets filtered-out + return !method.Default }) for raw := filter.Next(); raw != nil; raw = filter.Next() { diff --git a/nomad/state/state_store_acl_sso_test.go b/nomad/state/state_store_acl_sso_test.go index 58a762d12..2d41e4661 100644 --- a/nomad/state/state_store_acl_sso_test.go +++ b/nomad/state/state_store_acl_sso_test.go @@ -227,7 +227,7 @@ func TestStateStore_GetACLAuthMethodByName(t *testing.T) { must.Equal(t, mockedACLAuthMethods[1], authMethod) } -func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) { +func TestStateStore_GetDefaultACLAuthMethod(t *testing.T) { ci.Parallel(t) testState := testStateStore(t) @@ -240,16 +240,12 @@ func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) { mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2} must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods)) - // Get the default method for OIDC + // Get the default method ws := memdb.NewWatchSet() - defaultOIDCMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "OIDC") + defaultACLAuthMethod, err := testState.GetDefaultACLAuthMethod(ws) must.NoError(t, err) - must.True(t, defaultOIDCMethod.Default) - must.Eq(t, am1, defaultOIDCMethod) + must.True(t, defaultACLAuthMethod.Default) + must.Eq(t, am1, defaultACLAuthMethod) - // Get the default method for jwt (should not return anything) - defaultJWTMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "JWT") - must.NoError(t, err) - must.Nil(t, defaultJWTMethod) } diff --git a/website/content/docs/commands/login.mdx b/website/content/docs/commands/login.mdx index 6c1238a24..5f83c5a8c 100644 --- a/website/content/docs/commands/login.mdx +++ b/website/content/docs/commands/login.mdx @@ -28,8 +28,9 @@ requested auth method for a newly minted Nomad ACL token. - `-method`: The name of the ACL auth method to log in via. If the cluster administrator has configured a default, this flag is optional. -- `-type`: Type of the auth method to log in via. Defaults to, and currently - only supports, "OIDC". +- `-type`: Type of the auth method to log in via. If the cluster administrator + has configured a default, this flag is optional. Currently only supports + "OIDC". - `-oidc-callback-addr`: The address to use for the local OIDC callback server. This should be given in the form of `:` and defaults to