diff --git a/.semgrep/rpc_endpoint.yml b/.semgrep/rpc_endpoint.yml index 622251c55..9488dc63a 100644 --- a/.semgrep/rpc_endpoint.yml +++ b/.semgrep/rpc_endpoint.yml @@ -78,6 +78,9 @@ rules: - pattern-not: '"ACL.ResolveToken"' - pattern-not: '"ACL.UpsertOneTimeToken"' - pattern-not: '"ACL.ExchangeOneTimeToken"' + - pattern-not: 'structs.ACLListAuthMethodsRPCMethod' + - pattern-not: 'structs.ACLOIDCAuthURLRPCMethod' + - pattern-not: 'structs.ACLOIDCCompleteAuthRPCMethod' - pattern-not: '"CSIPlugin.Get"' - pattern-not: '"CSIPlugin.List"' - pattern-not: '"Status.Leader"' diff --git a/api/acl.go b/api/acl.go index fb80d7d24..f4ebd2ed1 100644 --- a/api/acl.go +++ b/api/acl.go @@ -442,6 +442,38 @@ func (a *ACLBindingRules) Get(bindingRuleID string, q *QueryOptions) (*ACLBindin return &resp, qm, nil } +// ACLOIDC is used to query the ACL OIDC endpoints. +type ACLOIDC struct { + client *Client +} + +// ACLOIDC returns a new handle on the ACL auth-methods API client. +func (c *Client) ACLOIDC() *ACLOIDC { + return &ACLOIDC{client: c} +} + +// GetAuthURL generates the OIDC provider authentication URL. This URL should +// be visited in order to sign in to the provider. +func (a *ACLOIDC) GetAuthURL(req *ACLOIDCAuthURLRequest, q *WriteOptions) (*ACLOIDCAuthURLResponse, *WriteMeta, error) { + var resp ACLOIDCAuthURLResponse + wm, err := a.client.write("/v1/acl/oidc/auth-url", req, &resp, q) + if err != nil { + return nil, nil, err + } + return &resp, wm, nil +} + +// CompleteAuth exchanges the OIDC provider token for a Nomad token with the +// appropriate claims attached. +func (a *ACLOIDC) CompleteAuth(req *ACLOIDCCompleteAuthRequest, q *WriteOptions) (*ACLToken, *WriteMeta, error) { + var resp ACLToken + wm, err := a.client.write("/v1/acl/oidc/complete-auth", req, &resp, q) + if err != nil { + return nil, nil, err + } + return &resp, wm, nil +} + // ACLPolicyListStub is used to for listing ACL policies type ACLPolicyListStub struct { Name string @@ -666,6 +698,7 @@ type ACLAuthMethodConfig struct { OIDCDiscoveryURL string OIDCClientID string OIDCClientSecret string + OIDCScopes []string BoundAudiences []string AllowedRedirectURIs []string DiscoveryCaPem []string @@ -816,3 +849,50 @@ type ACLBindingRuleListStub struct { CreateIndex uint64 ModifyIndex uint64 } + +// ACLOIDCAuthURLRequest is the request to make when starting the OIDC +// authentication login flow. +type ACLOIDCAuthURLRequest struct { + + // AuthMethodName is the OIDC auth-method to use. This is a required + // parameter. + AuthMethodName string + + // RedirectURI is the URL that authorization should redirect to. This is a + // required parameter. + RedirectURI string + + // ClientNonce is a randomly generated string to prevent replay attacks. It + // is up to the client to generate this and Go integrations should use the + // oidc.NewID function within the hashicorp/cap library. + ClientNonce string +} + +// ACLOIDCAuthURLResponse is the response when starting the OIDC authentication +// login flow. +type ACLOIDCAuthURLResponse struct { + + // AuthURL is URL to begin authorization and is where the user logging in + // should go. + AuthURL string +} + +// ACLOIDCCompleteAuthRequest is the request object to begin completing the +// OIDC auth cycle after receiving the callback from the OIDC provider. +type ACLOIDCCompleteAuthRequest struct { + + // AuthMethodName is the name of the auth method being used to login via + // OIDC. This will match AuthUrlArgs.AuthMethodName. This is a required + // parameter. + AuthMethodName string + + // ClientNonce, State, and Code are provided from the parameters given to + // the redirect URL. These are all required parameters. + ClientNonce string + State string + Code string + + // RedirectURI is the URL that authorization should redirect to. This is a + // required parameter. + RedirectURI string +} diff --git a/command/acl_auth_method.go b/command/acl_auth_method.go index 1cfc32282..a5b0a24e6 100644 --- a/command/acl_auth_method.go +++ b/command/acl_auth_method.go @@ -85,6 +85,7 @@ func formatAuthMethodConfig(config *api.ACLAuthMethodConfig) []string { fmt.Sprintf("OIDC Discovery URL|%s", config.OIDCDiscoveryURL), fmt.Sprintf("OIDC Client ID|%s", config.OIDCClientID), fmt.Sprintf("OIDC Client Secret|%s", config.OIDCClientSecret), + fmt.Sprintf("OIDC Scopes|%s", strings.Join(config.OIDCScopes, ",")), fmt.Sprintf("Bound audiences|%s", strings.Join(config.BoundAudiences, ",")), fmt.Sprintf("Allowed redirects URIs|%s", strings.Join(config.AllowedRedirectURIs, ",")), fmt.Sprintf("Discovery CA pem|%s", strings.Join(config.DiscoveryCaPem, ",")), diff --git a/command/acl_auth_method_list_test.go b/command/acl_auth_method_list_test.go index a4cdd2d1c..9c2966152 100644 --- a/command/acl_auth_method_list_test.go +++ b/command/acl_auth_method_list_test.go @@ -39,15 +39,19 @@ func TestACLAuthMethodListCommand(t *testing.T) { ui := cli.NewMockUi() cmd := &ACLAuthMethodListCommand{Meta: Meta{Ui: ui, flagAddress: url}} - // Attempt to list auth methods without a valid management token + // List with an invalid token works fine invalidToken := mock.ACLToken() code := cmd.Run([]string{"-address=" + url, "-token=" + invalidToken.SecretID}) - must.One(t, code) + must.Zero(t, code) // List with a valid management token code = cmd.Run([]string{"-address=" + url, "-token=" + token.SecretID}) must.Zero(t, code) + // List with no token at all + code = cmd.Run([]string{"-address=" + url}) + must.Zero(t, code) + // Check the output out := ui.OutputWriter.String() must.StrContains(t, out, method.Name) diff --git a/command/agent/acl_endpoint.go b/command/agent/acl_endpoint.go index 04c6c2742..1c3fc5868 100644 --- a/command/agent/acl_endpoint.go +++ b/command/agent/acl_endpoint.go @@ -829,3 +829,48 @@ func (s *HTTPServer) aclBindingRuleUpsertRequest( } return nil, nil } + +// ACLOIDCAuthURLRequest starts the OIDC login workflow. +func (s *HTTPServer) ACLOIDCAuthURLRequest(_ http.ResponseWriter, req *http.Request) (interface{}, error) { + + // The endpoint only supports PUT or POST requests. + if req.Method != http.MethodPost && req.Method != http.MethodPut { + return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod) + } + + var args structs.ACLOIDCAuthURLRequest + s.parseWriteRequest(req, &args.WriteRequest) + + if err := decodeBody(req, &args); err != nil { + return nil, CodedError(http.StatusBadRequest, err.Error()) + } + + var out structs.ACLOIDCAuthURLResponse + if err := s.agent.RPC(structs.ACLOIDCAuthURLRPCMethod, &args, &out); err != nil { + return nil, err + } + return out, nil +} + +// ACLOIDCCompleteAuthRequest completes the OIDC login workflow. +func (s *HTTPServer) ACLOIDCCompleteAuthRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) { + + // The endpoint only supports PUT or POST requests. + if req.Method != http.MethodPost && req.Method != http.MethodPut { + return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod) + } + + var args structs.ACLOIDCCompleteAuthRequest + s.parseWriteRequest(req, &args.WriteRequest) + + if err := decodeBody(req, &args); err != nil { + return nil, CodedError(http.StatusBadRequest, err.Error()) + } + + var out structs.ACLOIDCCompleteAuthResponse + if err := s.agent.RPC(structs.ACLOIDCCompleteAuthRPCMethod, &args, &out); err != nil { + return nil, err + } + setIndex(resp, out.Index) + return out.ACLToken, nil +} diff --git a/command/agent/acl_endpoint_test.go b/command/agent/acl_endpoint_test.go index 59ae857d8..848ccdca6 100644 --- a/command/agent/acl_endpoint_test.go +++ b/command/agent/acl_endpoint_test.go @@ -4,9 +4,11 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "testing" "time" + capOIDC "github.com/hashicorp/cap/oidc" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" @@ -1213,10 +1215,10 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) { must.NoError(t, srv.server.State().UpsertACLAuthMethods( 20, []*structs.ACLAuthMethod{mockACLAuthMethod})) - url := "/v1/acl/auth-method/" + mockACLAuthMethod.Name + authMethodURL := "/v1/acl/auth-method/" + mockACLAuthMethod.Name // Build the HTTP request. - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequest(http.MethodGet, authMethodURL, nil) must.NoError(t, err) respW := httptest.NewRecorder() @@ -1238,10 +1240,10 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) { must.NoError(t, srv.server.State().UpsertACLAuthMethods( 20, []*structs.ACLAuthMethod{mockACLAuthMethod})) - url := "/v1/acl/auth-method/" + mockACLAuthMethod.Name + authMethodURL := "/v1/acl/auth-method/" + mockACLAuthMethod.Name // Build the HTTP request to read the auth-method. - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequest(http.MethodGet, authMethodURL, nil) must.NoError(t, err) respW := httptest.NewRecorder() @@ -1258,7 +1260,7 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) { mockACLAuthMethod.MaxTokenTTL = 3600 * time.Hour mockACLAuthMethod.SetHash() - req, err = http.NewRequest(http.MethodPost, url, encodeReq(mockACLAuthMethod)) + req, err = http.NewRequest(http.MethodPost, authMethodURL, encodeReq(mockACLAuthMethod)) must.NoError(t, err) respW = httptest.NewRecorder() @@ -1270,7 +1272,7 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) { must.NoError(t, err) // Delete the ACL auth-method. - req, err = http.NewRequest(http.MethodDelete, url, nil) + req, err = http.NewRequest(http.MethodDelete, authMethodURL, nil) must.NoError(t, err) respW = httptest.NewRecorder() @@ -1622,3 +1624,221 @@ func TestHTTPServer_ACLBindingRuleSpecificRequest(t *testing.T) { }) } } + +func TestHTTPServer_ACLOIDCAuthURLRequest(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + testFn func(srv *TestAgent) + }{ + { + name: "incorrect method", + testFn: func(testAgent *TestAgent) { + + // Build the HTTP request. + req, err := http.NewRequest(http.MethodConnect, "/v1/acl/oidc/auth-url", nil) + must.NoError(t, err) + respW := httptest.NewRecorder() + + // Send the HTTP request. + obj, err := testAgent.Server.ACLOIDCAuthURLRequest(respW, req) + must.Error(t, err) + must.StrContains(t, err.Error(), "Invalid method") + must.Nil(t, obj) + }, + }, + { + name: "success", + testFn: func(testAgent *TestAgent) { + + // Set up the test OIDC provider. + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider just like real world use. + mockedAuthMethod := mock.ACLAuthMethod() + mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"} + mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr() + mockedAuthMethod.Config.SigningAlgs = []string{"ES256"} + mockedAuthMethod.Config.DiscoveryCaPem = []string{oidcTestProvider.CACert()} + + must.NoError(t, testAgent.server.State().UpsertACLAuthMethods( + 10, []*structs.ACLAuthMethod{mockedAuthMethod})) + + // Generate the request body. + requestBody := structs.ACLOIDCAuthURLRequest{ + AuthMethodName: mockedAuthMethod.Name, + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "fpSPuaodKevKfDU3IeXa", + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + // Build the HTTP request. + req, err := http.NewRequest(http.MethodPost, "/v1/acl/oidc/auth-url", encodeReq(&requestBody)) + must.NoError(t, err) + respW := httptest.NewRecorder() + + // Send the HTTP request. + obj, err := testAgent.Server.ACLOIDCAuthURLRequest(respW, req) + must.NoError(t, err) + + // The response URL comes encoded, so decode this and check we have each + // component we expect. + escapedURL, err := url.PathUnescape(obj.(structs.ACLOIDCAuthURLResponse).AuthURL) + must.NoError(t, err) + must.StrContains(t, escapedURL, "/authorize?client_id=mock") + must.StrContains(t, escapedURL, "&nonce=fpSPuaodKevKfDU3IeXa") + must.StrContains(t, escapedURL, "&redirect_uri=http://127.0.0.1:4649/oidc/callback") + must.StrContains(t, escapedURL, "&response_type=code") + must.StrContains(t, escapedURL, "&scope=openid") + must.StrContains(t, escapedURL, "&state=st_") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + httpACLTest(t, nil, tc.testFn) + }) + } +} + +func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + testFn func(srv *TestAgent) + }{ + { + name: "incorrect method", + testFn: func(testAgent *TestAgent) { + + // Build the HTTP request. + req, err := http.NewRequest(http.MethodConnect, "/v1/acl/oidc/complete-auth", nil) + must.NoError(t, err) + respW := httptest.NewRecorder() + + // Send the HTTP request. + obj, err := testAgent.Server.ACLOIDCCompleteAuthRequest(respW, req) + must.Error(t, err) + must.StrContains(t, err.Error(), "Invalid method") + must.Nil(t, obj) + }, + }, + { + name: "success", + testFn: func(testAgent *TestAgent) { + + // Set up the test OIDC provider. + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + oidcTestProvider.SetAllowedRedirectURIs([]string{"http://127.0.0.1:4649/oidc/callback"}) + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider just like real world use. + mockedAuthMethod := mock.ACLAuthMethod() + mockedAuthMethod.Config.BoundAudiences = []string{"mock"} + mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"} + mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr() + mockedAuthMethod.Config.SigningAlgs = []string{"ES256"} + mockedAuthMethod.Config.DiscoveryCaPem = []string{oidcTestProvider.CACert()} + mockedAuthMethod.Config.ClaimMappings = map[string]string{} + mockedAuthMethod.Config.ListClaimMappings = map[string]string{ + "http://nomad.internal/roles": "roles", + "http://nomad.internal/policies": "policies", + } + + must.NoError(t, testAgent.server.State().UpsertACLAuthMethods( + 10, []*structs.ACLAuthMethod{mockedAuthMethod})) + + // Set our custom data and some expected values, so we can make the RPC and + // use the test provider. + oidcTestProvider.SetExpectedAuthNonce("fpSPuaodKevKfDU3IeXa") + oidcTestProvider.SetExpectedAuthCode("codeABC") + oidcTestProvider.SetCustomAudience("mock") + oidcTestProvider.SetExpectedState("st_someweirdstateid") + oidcTestProvider.SetCustomClaims(map[string]interface{}{ + "azp": "mock", + "http://nomad.internal/policies": []string{"engineering"}, + "http://nomad.internal/roles": []string{"engineering"}, + }) + + // Generate the request body. + requestBody := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: mockedAuthMethod.Name, + ClientNonce: "fpSPuaodKevKfDU3IeXa", + State: "st_someweirdstateid", + Code: "codeABC", + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + WriteRequest: structs.WriteRequest{ + Region: "global", + }, + } + + // Build the HTTP request. + req, err := http.NewRequest(http.MethodPost, "/v1/acl/oidc/complete-auth", encodeReq(&requestBody)) + must.NoError(t, err) + respW := httptest.NewRecorder() + + // Send the HTTP request. + _, err = testAgent.Server.ACLOIDCCompleteAuthRequest(respW, req) + must.ErrorContains(t, err, "no role or policy bindings matched") + + // Upsert an ACL policy and role, so that we can reference this within our + // OIDC claims. + mockACLPolicy := mock.ACLPolicy() + must.NoError(t, testAgent.server.State().UpsertACLPolicies( + structs.MsgTypeTestSetup, 20, []*structs.ACLPolicy{mockACLPolicy})) + + mockACLRole := mock.ACLRole() + mockACLRole.Policies = []*structs.ACLRolePolicyLink{{Name: mockACLPolicy.Name}} + must.NoError(t, testAgent.server.State().UpsertACLRoles( + structs.MsgTypeTestSetup, 30, []*structs.ACLRole{mockACLRole}, true)) + + // Generate and upsert two binding rules, so we can test both ACL Policy + // and Role claim mapping. + mockBindingRule1 := mock.ACLBindingRule() + mockBindingRule1.AuthMethod = mockedAuthMethod.Name + mockBindingRule1.BindType = structs.ACLBindingRuleBindTypePolicy + mockBindingRule1.Selector = "engineering in list.policies" + mockBindingRule1.BindName = mockACLPolicy.Name + + mockBindingRule2 := mock.ACLBindingRule() + mockBindingRule2.AuthMethod = mockedAuthMethod.Name + mockBindingRule2.BindName = mockACLRole.Name + + must.NoError(t, testAgent.server.State().UpsertACLBindingRules( + 40, []*structs.ACLBindingRule{mockBindingRule1, mockBindingRule2}, true)) + + // Build the HTTP request. + req, err = http.NewRequest(http.MethodPost, "/v1/acl/oidc/complete-auth", encodeReq(&requestBody)) + must.NoError(t, err) + respW = httptest.NewRecorder() + + // Send the HTTP request. + obj, err := testAgent.Server.ACLOIDCCompleteAuthRequest(respW, req) + must.NoError(t, err) + + aclTokenResp, ok := obj.(*structs.ACLToken) + must.True(t, ok) + must.NotNil(t, aclTokenResp) + must.Len(t, 1, aclTokenResp.Policies) + must.Eq(t, mockACLPolicy.Name, aclTokenResp.Policies[0]) + must.Len(t, 1, aclTokenResp.Roles) + must.Eq(t, mockACLRole.Name, aclTokenResp.Roles[0].Name) + must.Eq(t, mockACLRole.ID, aclTokenResp.Roles[0].ID) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + httpACLTest(t, nil, tc.testFn) + }) + } +} diff --git a/command/agent/http.go b/command/agent/http.go index 47c970532..da3580e77 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -397,6 +397,10 @@ func (s *HTTPServer) registerHandlers(enableDebug bool) { s.mux.HandleFunc("/v1/acl/binding-rule", s.wrap(s.ACLBindingRuleRequest)) s.mux.HandleFunc("/v1/acl/binding-rule/", s.wrap(s.ACLBindingRuleSpecificRequest)) + // Register out ACL OIDC SSO provider handlers. + s.mux.HandleFunc("/v1/acl/oidc/auth-url", s.wrap(s.ACLOIDCAuthURLRequest)) + s.mux.HandleFunc("/v1/acl/oidc/complete-auth", s.wrap(s.ACLOIDCCompleteAuthRequest)) + s.mux.Handle("/v1/client/fs/", wrapCORS(s.wrap(s.FsRequest))) s.mux.HandleFunc("/v1/client/gc", s.wrap(s.ClientGCRequest)) s.mux.Handle("/v1/client/stats", wrapCORS(s.wrap(s.ClientStatsRequest))) diff --git a/command/commands.go b/command/commands.go index 8ffa7d125..1eac75833 100644 --- a/command/commands.go +++ b/command/commands.go @@ -510,6 +510,11 @@ func Commands(metaPtr *Meta, agentUi cli.Ui) map[string]cli.CommandFactory { Meta: meta, }, nil }, + "login": func() (cli.Command, error) { + return &LoginCommand{ + Meta: meta, + }, nil + }, "logs": func() (cli.Command, error) { return &AllocLogsCommand{ Meta: meta, diff --git a/command/login.go b/command/login.go new file mode 100644 index 000000000..15910c763 --- /dev/null +++ b/command/login.go @@ -0,0 +1,284 @@ +package command + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + + "github.com/hashicorp/cap/util" + "github.com/hashicorp/nomad/api" + "github.com/hashicorp/nomad/lib/auth/oidc" + "github.com/mitchellh/cli" + "github.com/posener/complete" +) + +// Ensure LoginCommand satisfies the cli.Command interface. +var _ cli.Command = &LoginCommand{} + +// LoginCommand implements cli.Command. +type LoginCommand struct { + Meta + + authMethodType string + authMethodName string + callbackAddr string + + template string + json bool +} + +// Help satisfies the cli.Command Help function. +func (l *LoginCommand) Help() string { + helpText := ` +Usage: nomad login [options] + + The login command will exchange the provided third party credentials with the + requested auth method for a newly minted Nomad ACL token. + +General Options: + + ` + generalOptionsUsage(usageOptsNoNamespace) + ` + +Login Options: + + -method + The name of the ACL auth method to login to. If the cluster administrator + has configured a default, this flag is optional. + + -type + Type of the auth method to login to. Defaults to "OIDC". + + -oidc-callback-addr + The address to use for the local OIDC callback server. This should be given + in the form of : and defaults to "127.0.0.1:4649". + + -json + Output the ACL token in JSON format. + + -t + Format and display the ACL token using a Go template. +` + return strings.TrimSpace(helpText) +} + +// Synopsis satisfies the cli.Command Synopsis function. +func (l *LoginCommand) Synopsis() string { + return "Login to Nomad using an auth method" +} + +func (l *LoginCommand) AutocompleteFlags() complete.Flags { + return mergeAutocompleteFlags(l.Meta.AutocompleteFlags(FlagSetClient), + complete.Flags{ + "-method": complete.PredictAnything, + "-type": complete.PredictSet("OIDC"), + "-oidc-callback-addr": complete.PredictAnything, + "-json": complete.PredictNothing, + "-t": complete.PredictAnything, + }) +} + +// Name returns the name of this command. +func (l *LoginCommand) Name() string { return "login" } + +// Run satisfies the cli.Command Run function. +func (l *LoginCommand) Run(args []string) int { + + flags := l.Meta.FlagSet(l.Name(), FlagSetClient) + flags.Usage = func() { l.Ui.Output(l.Help()) } + flags.StringVar(&l.authMethodName, "method", "", "") + flags.StringVar(&l.authMethodType, "type", "OIDC", "") + flags.StringVar(&l.callbackAddr, "oidc-callback-addr", "127.0.0.1:4649", "") + flags.BoolVar(&l.json, "json", false, "") + flags.StringVar(&l.template, "t", "", "") + if err := flags.Parse(args); err != nil { + return 1 + } + args = flags.Args() + + if len(args) != 0 { + l.Ui.Error("This command takes no arguments") + l.Ui.Error(commandErrorText(l)) + return 1 + } + + // Auth method types are particular with their naming, so ensure we forgive + // any case mistakes here from the user. + sanitizedMethodType := strings.ToUpper(l.authMethodType) + + // Ensure we sanitize the method type so we do not pedantically return an + // error when the caller uses "oidc" rather than "OIDC". The flag default + // means an empty type is only possible is the caller specifies this + // explicitly. + switch sanitizedMethodType { + case "": + l.Ui.Error("Please supply an authentication type") + return 1 + case api.ACLAuthMethodTypeOIDC: + default: + l.Ui.Error(fmt.Sprintf("Unsupported authentication type %q", sanitizedMethodType)) + return 1 + } + + client, err := l.Meta.Client() + if err != nil { + l.Ui.Error(fmt.Sprintf("Error initializing client: %s", err)) + return 1 + } + + // If the caller did not supply and auth method name, attempt to lookup the + // default. This ensures a nice UX as clusters are expected to only have + // one method, and this avoids having to type the name during each login. + if l.authMethodName == "" { + + authMethodList, _, err := client.ACLAuthMethods().List(nil) + if err != nil { + l.Ui.Error(fmt.Sprintf("Error listing ACL auth methods: %s", err)) + return 1 + } + + for _, authMethod := range authMethodList { + if authMethod.Default { + l.authMethodName = authMethod.Name + } + } + + if l.authMethodName == "" { + l.Ui.Error("Must specify an auth method name, no default found") + return 1 + } + } + + // Each login type should implement a function which matches this signature + // for the specific login implementation. This allows the command to have + // reusable and generic handling of errors and outputs. + var authFn func(context.Context, *api.Client) (*api.ACLToken, error) + + switch sanitizedMethodType { + case api.ACLAuthMethodTypeOIDC: + authFn = l.loginOIDC + default: + l.Ui.Error(fmt.Sprintf("Unsupported authentication type %q", sanitizedMethodType)) + return 1 + } + + ctx, cancel := contextWithInterrupt() + defer cancel() + + token, err := authFn(ctx, client) + if err != nil { + l.Ui.Error(fmt.Sprintf("Error performing login: %v", err)) + return 1 + } + + if l.json || l.template != "" { + out, err := Format(l.json, l.template, token) + if err != nil { + l.Ui.Error(err.Error()) + return 1 + } + l.Ui.Output(out) + return 0 + } + + l.Ui.Output(fmt.Sprintf("Successfully logged in via %s and %s\n", sanitizedMethodType, l.authMethodName)) + outputACLToken(l.Ui, token) + return 0 +} + +func (l *LoginCommand) loginOIDC(ctx context.Context, client *api.Client) (*api.ACLToken, error) { + + callbackServer, err := oidc.NewCallbackServer(l.callbackAddr) + if err != nil { + return nil, err + } + defer callbackServer.Close() + + getAuthArgs := api.ACLOIDCAuthURLRequest{ + AuthMethodName: l.authMethodName, + RedirectURI: callbackServer.RedirectURI(), + ClientNonce: callbackServer.Nonce(), + } + + getAuthURLResp, _, err := client.ACLOIDC().GetAuthURL(&getAuthArgs, nil) + if err != nil { + return nil, err + } + + // Open the auth URL in the user browser or ask them to visit it. + // We purposely use fmt here and NOT c.ui because the ui will truncate + // our URL (a known bug). + if err := util.OpenURL(getAuthURLResp.AuthURL); err != nil { + l.Ui.Error(fmt.Sprintf("Error opening OIDC provider URL: %v\n", err)) + l.Ui.Output(fmt.Sprintf(strings.TrimSpace(oidcErrorVisitURLMsg)+"\n\n", getAuthURLResp.AuthURL)) + } + + // Wait. The login process can end to one of the following reasons: + // - the user interrupts the login process via CTRL-C + // - the login process returns an error via the callback server + // - the login process is successful as returned by the callback server + var req *api.ACLOIDCCompleteAuthRequest + select { + case <-ctx.Done(): + _ = callbackServer.Close() + return nil, ctx.Err() + case err := <-callbackServer.ErrorCh(): + return nil, err + case req = <-callbackServer.SuccessCh(): + } + + cbArgs := api.ACLOIDCCompleteAuthRequest{ + AuthMethodName: l.authMethodName, + RedirectURI: callbackServer.RedirectURI(), + ClientNonce: callbackServer.Nonce(), + Code: req.Code, + State: req.State, + } + + token, _, err := client.ACLOIDC().CompleteAuth(&cbArgs, nil) + return token, err +} + +const ( + // oidcErrorVisitURLMsg is a message to show users when opening the OIDC + // provider URL automatically fails. This type of message is otherwise not + // needed, as it just clutters the console without providing value. + oidcErrorVisitURLMsg = ` +Automatic opening of the OIDC provider for login has failed. To complete the +authentication, please visit your provider using the URL below: + +%s +` +) + +// contextWithInterrupt returns a context and cancel function that adheres to +// expected behaviour and also includes cancellation when the user interrupts +// the login process via CTRL-C. +func contextWithInterrupt() (context.Context, func()) { + + // Create the cancellable context that we'll use when we receive an + // interrupt. + ctx, cancel := context.WithCancel(context.Background()) + + // Create the signal channel and cancel the context when we get a signal. + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + + // Start a routine which waits for the signals. + go func() { + select { + case <-ch: + cancel() + case <-ctx.Done(): + return + } + }() + + // Return the context and a closer that cancels the context and also + // stops any signals from coming to our channel. + return ctx, func() { + signal.Stop(ch) + cancel() + } +} diff --git a/command/login_test.go b/command/login_test.go new file mode 100644 index 000000000..98b7af6bd --- /dev/null +++ b/command/login_test.go @@ -0,0 +1,57 @@ +package command + +import ( + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/command/agent" + "github.com/hashicorp/nomad/testutil" + "github.com/mitchellh/cli" + "github.com/shoenig/test/must" +) + +func TestLoginCommand_Run(t *testing.T) { + ci.Parallel(t) + + // Build a test server with ACLs enabled. + srv, _, agentURL := testServer(t, false, func(c *agent.Config) { + c.ACL.Enabled = true + }) + defer srv.Shutdown() + + // Wait for the server to start fully. + testutil.WaitForLeader(t, srv.Agent.RPC) + + ui := cli.NewMockUi() + cmd := &LoginCommand{ + Meta: Meta{ + Ui: ui, + flagAddress: agentURL, + }, + } + + // Test the basic validation on the command. + must.Eq(t, 1, cmd.Run([]string{"-address=" + agentURL, "this-command-does-not-take-args"})) + must.StrContains(t, ui.ErrorWriter.String(), "This command takes no arguments") + + ui.OutputWriter.Reset() + ui.ErrorWriter.Reset() + + // Attempt to call it with an unsupported method type. + must.Eq(t, 1, cmd.Run([]string{"-address=" + agentURL, "-type=SAML"})) + must.StrContains(t, ui.ErrorWriter.String(), `Unsupported authentication type "SAML"`) + + ui.OutputWriter.Reset() + ui.ErrorWriter.Reset() + + // Use a valid method type but with incorrect casing so we can ensure this + // is handled. + must.Eq(t, 1, cmd.Run([]string{"-address=" + agentURL, "-type=oIdC"})) + must.StrContains(t, ui.ErrorWriter.String(), "Must specify an auth method name, no default found") + + ui.OutputWriter.Reset() + ui.ErrorWriter.Reset() + + // TODO(jrasell) find a way to test the full login flow from the CLI + // perspective. +} diff --git a/go.mod b/go.mod index b81b96280..b96b843f3 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/gosuri/uilive v0.0.4 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 + github.com/hashicorp/cap v0.2.0 github.com/hashicorp/consul-template v0.29.6-0.20221026140134-90370e07bf62 github.com/hashicorp/consul/api v1.18.0 github.com/hashicorp/consul/sdk v0.13.0 @@ -71,6 +72,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/hcl v1.0.1-vault-3 github.com/hashicorp/hcl/v2 v2.9.2-0.20220525143345-ab3cae0737bc + github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40 github.com/hashicorp/logutils v1.0.0 github.com/hashicorp/memberlist v0.5.0 github.com/hashicorp/net-rpc-msgpackrpc v0.0.0-20151116020338-a14192a58a69 @@ -177,6 +179,7 @@ require ( github.com/containerd/cgroups v1.0.3 // indirect github.com/containerd/console v1.0.3 // indirect github.com/containerd/containerd v1.6.12 // indirect + github.com/coreos/go-oidc/v3 v3.1.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/cyphar/filepath-securejoin v0.2.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -231,7 +234,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect - github.com/mitchellh/pointerstructure v1.2.1 // indirect + github.com/mitchellh/pointerstructure v1.2.1 github.com/morikuni/aec v1.0.0 // indirect github.com/mrunalp/fileutils v0.5.0 // indirect github.com/muesli/reflow v0.3.0 diff --git a/go.sum b/go.sum index fc498612c..122cb1778 100644 --- a/go.sum +++ b/go.sum @@ -348,6 +348,8 @@ github.com/coreos/go-iptables v0.5.0/go.mod h1:/mVI274lEDI2ns62jHCDnCyBF9Iwsmeka github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= +github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw= +github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20161114122254-48702e0da86b/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -654,6 +656,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/cap v0.2.0 h1:Cgr1iDczX17y0PNF5VG+bWTtDiimYL8F18izMPbWNy4= +github.com/hashicorp/cap v0.2.0/go.mod h1:zb3VvIFA0lM2lbmO69NjowV9dJzJnZS89TaM9blXPJA= github.com/hashicorp/consul-template v0.29.6-0.20221026140134-90370e07bf62 h1:72EUkkdM0uFQZVHpx69lM0bBqRhmtqsCV3Up48dfw2w= github.com/hashicorp/consul-template v0.29.6-0.20221026140134-90370e07bf62/go.mod h1:oznME/M/L6XDklrE62H9R1Rp+WYtxrISywtwXpA+bgU= github.com/hashicorp/consul/api v1.18.0 h1:R7PPNzTCeN6VuQNDwwhZWJvzCtGSrNpJqfb22h3yH9g= @@ -690,6 +694,7 @@ github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrj github.com/hashicorp/go-hclog v0.12.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v0.16.2/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v1.0.0/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-hclog v1.3.1 h1:vDwF1DFNZhntP4DAjuTpOw3uEgMUpXh1pB5fW9DqHpo= github.com/hashicorp/go-hclog v1.3.1/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= @@ -768,6 +773,8 @@ github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee h1:8B4HqvMUtYSjsGk github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee/go.mod h1:gwlu9+/P9MmKtYrMsHeFRZPXj2CTPm11TDnMeaRHS7g= github.com/hashicorp/hcl/v2 v2.9.2-0.20220525143345-ab3cae0737bc h1:32lGaCPq5JPYNgFFTjl/cTIar9UWWxCbimCs5G2hMHg= github.com/hashicorp/hcl/v2 v2.9.2-0.20220525143345-ab3cae0737bc/go.mod h1:odKNpEeZv3COD+++SQcPyACuKOlM5eBoQlzRyN5utIQ= +github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40 h1:ExwaL+hUy1ys2AWDbsbh/lxQS2EVCYxuj0LoyLTdB3Y= +github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40/go.mod h1:n2TSygSNwsLJ76m8qFXTSc7beTb+auJxYdqrnoqwZWE= github.com/hashicorp/logutils v1.0.0 h1:dLEQVugN8vlakKOUE3ihGLTZJRB4j+M2cdTm/ORI65Y= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= github.com/hashicorp/mdns v1.0.1/go.mod h1:4gW7WsVCke5TE7EPeYliwHlRUyBtfCwuFwuMg2DmyNY= @@ -1284,6 +1291,7 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1: github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/yhat/scrape v0.0.0-20161128144610-24b7890b0945/go.mod h1:4vRFPPNYllgCacoj+0FoKOjTW68rUhEfqPLiEJaK2w8= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -1352,6 +1360,7 @@ golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= @@ -1433,6 +1442,7 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= @@ -1481,6 +1491,7 @@ golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211005180243-6b3c2da341f1/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b h1:clP8eMhB30EHdc0bd2Twtq6kgU7yl5ub2cQLSdrv1Dg= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1919,6 +1930,7 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/testing/apitests/acl_test.go b/internal/testing/apitests/acl_test.go new file mode 100644 index 000000000..ce5e0d146 --- /dev/null +++ b/internal/testing/apitests/acl_test.go @@ -0,0 +1,181 @@ +package apitests + +import ( + "net/url" + "testing" + "time" + + capOIDC "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/nomad/api" + "github.com/hashicorp/nomad/ci" + "github.com/shoenig/test/must" +) + +func TestACLOIDC_GetAuthURL(t *testing.T) { + ci.Parallel(t) + + testClient, testServer, _ := makeACLClient(t, nil, nil) + defer testServer.Stop() + + // Set up the test OIDC provider. + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + oidcTestProvider.SetAllowedRedirectURIs([]string{"http://127.0.0.1:4649/oidc/callback"}) + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider just like real world use. + mockedAuthMethod := api.ACLAuthMethod{ + Name: "api-test-auth-method", + Type: api.ACLAuthMethodTypeOIDC, + TokenLocality: api.ACLAuthMethodTokenLocalityGlobal, + MaxTokenTTL: 10 * time.Hour, + Default: true, + Config: &api.ACLAuthMethodConfig{ + OIDCDiscoveryURL: oidcTestProvider.Addr(), + OIDCClientID: "mock", + OIDCClientSecret: "verysecretsecret", + BoundAudiences: []string{"mock"}, + AllowedRedirectURIs: []string{"http://127.0.0.1:4649/oidc/callback"}, + DiscoveryCaPem: []string{oidcTestProvider.CACert()}, + SigningAlgs: []string{"ES256"}, + ClaimMappings: map[string]string{"foo": "bar"}, + ListClaimMappings: map[string]string{"foo": "bar"}, + }, + } + + createdAuthMethod, writeMeta, err := testClient.ACLAuthMethods().Create(&mockedAuthMethod, nil) + must.NoError(t, err) + must.NotNil(t, createdAuthMethod) + assertWriteMeta(t, writeMeta) + + // Generate and make the request. + authURLRequest := api.ACLOIDCAuthURLRequest{ + AuthMethodName: createdAuthMethod.Name, + RedirectURI: createdAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "fpSPuaodKevKfDU3IeXb", + } + + authURLResp, _, err := testClient.ACLOIDC().GetAuthURL(&authURLRequest, nil) + must.NoError(t, err) + + // The response URL comes encoded, so decode this and check we have each + // component we expect. + escapedURL, err := url.PathUnescape(authURLResp.AuthURL) + must.NoError(t, err) + must.StrContains(t, escapedURL, "/authorize?client_id=mock") + must.StrContains(t, escapedURL, "&nonce=fpSPuaodKevKfDU3IeXb") + must.StrContains(t, escapedURL, "&redirect_uri=http://127.0.0.1:4649/oidc/callback") + must.StrContains(t, escapedURL, "&response_type=code") + must.StrContains(t, escapedURL, "&scope=openid") + must.StrContains(t, escapedURL, "&state=st_") +} + +func TestACLOIDC_CompleteAuth(t *testing.T) { + ci.Parallel(t) + + testClient, testServer, _ := makeACLClient(t, nil, nil) + defer testServer.Stop() + + // Set up the test OIDC provider. + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + oidcTestProvider.SetAllowedRedirectURIs([]string{"http://127.0.0.1:4649/oidc/callback"}) + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider just like real world use. + mockedAuthMethod := api.ACLAuthMethod{ + Name: "api-test-auth-method", + Type: api.ACLAuthMethodTypeOIDC, + TokenLocality: api.ACLAuthMethodTokenLocalityGlobal, + MaxTokenTTL: 10 * time.Hour, + Default: true, + Config: &api.ACLAuthMethodConfig{ + OIDCDiscoveryURL: oidcTestProvider.Addr(), + OIDCClientID: "mock", + OIDCClientSecret: "verysecretsecret", + BoundAudiences: []string{"mock"}, + AllowedRedirectURIs: []string{"http://127.0.0.1:4649/oidc/callback"}, + DiscoveryCaPem: []string{oidcTestProvider.CACert()}, + SigningAlgs: []string{"ES256"}, + ClaimMappings: map[string]string{}, + ListClaimMappings: map[string]string{ + "http://nomad.internal/roles": "roles", + "http://nomad.internal/policies": "policies", + }, + }, + } + + createdAuthMethod, writeMeta, err := testClient.ACLAuthMethods().Create(&mockedAuthMethod, nil) + must.NoError(t, err) + must.NotNil(t, createdAuthMethod) + assertWriteMeta(t, writeMeta) + + // Set our custom data and some expected values, so we can make the call + // and use the test provider. + oidcTestProvider.SetExpectedAuthNonce("fpSPuaodKevKfDU3IeXb") + oidcTestProvider.SetExpectedAuthCode("codeABC") + oidcTestProvider.SetCustomAudience("mock") + oidcTestProvider.SetExpectedState("st_someweirdstateid") + oidcTestProvider.SetCustomClaims(map[string]interface{}{ + "azp": "mock", + "http://nomad.internal/policies": []string{"engineering"}, + "http://nomad.internal/roles": []string{"engineering"}, + }) + + // Upsert an ACL policy and role, so that we can reference this within our + // OIDC claims. + mockedACLPolicy := api.ACLPolicy{ + Name: "api-oidc-login-test", + Rules: `namespace "default" { policy = "write"}`, + } + _, err = testClient.ACLPolicies().Upsert(&mockedACLPolicy, nil) + must.NoError(t, err) + + mockedACLRole := api.ACLRole{ + Name: "api-oidc-login-test", + Policies: []*api.ACLRolePolicyLink{{Name: mockedACLPolicy.Name}}, + } + createRoleResp, _, err := testClient.ACLRoles().Create(&mockedACLRole, nil) + must.NoError(t, err) + must.NotNil(t, createRoleResp) + + // Generate and upsert two binding rules, so we can test both ACL Policy + // and Role claim mapping. + mockedBindingRule1 := api.ACLBindingRule{ + AuthMethod: mockedAuthMethod.Name, + Selector: "engineering in list.policies", + BindType: api.ACLBindingRuleBindTypePolicy, + BindName: mockedACLPolicy.Name, + } + createBindingRole1Resp, _, err := testClient.ACLBindingRules().Create(&mockedBindingRule1, nil) + must.NoError(t, err) + must.NotNil(t, createBindingRole1Resp) + + mockedBindingRule2 := api.ACLBindingRule{ + AuthMethod: mockedAuthMethod.Name, + Selector: "engineering in list.roles", + BindType: api.ACLBindingRuleBindTypeRole, + BindName: mockedACLRole.Name, + } + createBindingRole2Resp, _, err := testClient.ACLBindingRules().Create(&mockedBindingRule2, nil) + must.NoError(t, err) + must.NotNil(t, createBindingRole2Resp) + + // Generate and make the request. + authURLRequest := api.ACLOIDCCompleteAuthRequest{ + AuthMethodName: createdAuthMethod.Name, + RedirectURI: createdAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "fpSPuaodKevKfDU3IeXb", + State: "st_someweirdstateid", + Code: "codeABC", + } + + completeAuthResp, _, err := testClient.ACLOIDC().CompleteAuth(&authURLRequest, nil) + must.NoError(t, err) + must.NotNil(t, completeAuthResp) + must.Len(t, 1, completeAuthResp.Policies) + must.Eq(t, mockedACLPolicy.Name, completeAuthResp.Policies[0]) + must.Len(t, 1, completeAuthResp.Roles) + must.Eq(t, mockedACLRole.Name, completeAuthResp.Roles[0].Name) + must.Eq(t, createRoleResp.ID, completeAuthResp.Roles[0].ID) +} diff --git a/lib/auth/oidc/binder.go b/lib/auth/oidc/binder.go new file mode 100644 index 000000000..3e1f74738 --- /dev/null +++ b/lib/auth/oidc/binder.go @@ -0,0 +1,203 @@ +package oidc + +import ( + "fmt" + "strings" + + "github.com/hashicorp/go-bexpr" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/hil" + "github.com/hashicorp/hil/ast" + + "github.com/hashicorp/nomad/nomad/structs" +) + +// Binder is responsible for collecting the ACL roles and policies to be +// assigned to a token generated as a result of "logging in" via an auth method. +// +// It does so by applying the auth method's configured binding rules. +type Binder struct { + store BinderStateStore +} + +// NewBinder creates a Binder with the given state store. +func NewBinder(store BinderStateStore) *Binder { + return &Binder{store} +} + +// BinderStateStore is the subset of state store methods used by the binder. +type BinderStateStore interface { + GetACLBindingRulesByAuthMethod(ws memdb.WatchSet, authMethod string) (memdb.ResultIterator, error) + GetACLRoleByName(ws memdb.WatchSet, roleName string) (*structs.ACLRole, error) + ACLPolicyByName(ws memdb.WatchSet, name string) (*structs.ACLPolicy, error) +} + +// Bindings contains the ACL roles and policies to be assigned to the created +// token. +type Bindings struct { + Roles []*structs.ACLTokenRoleLink + Policies []string +} + +// None indicates that the resulting bindings would not give the created token +// access to any resources. +func (b *Bindings) None() bool { + if b == nil { + return true + } + + return len(b.Policies) == 0 && len(b.Roles) == 0 +} + +// Bind collects the ACL roles and policies to be assigned to the created token. +func (b *Binder) Bind(authMethod *structs.ACLAuthMethod, identity *Identity) (*Bindings, error) { + var ( + bindings Bindings + err error + ) + + // Load the auth method's binding rules. + rulesIterator, err := b.store.GetACLBindingRulesByAuthMethod(nil, authMethod.Name) + if err != nil { + return nil, err + } + + // Find the rules with selectors that match the identity's fields. + matchingRules := []*structs.ACLBindingRule{} + for { + raw := rulesIterator.Next() + if raw == nil { + break + } + rule := raw.(*structs.ACLBindingRule) + if doesSelectorMatch(rule.Selector, identity.Claims) { + matchingRules = append(matchingRules, rule) + } + } + if len(matchingRules) == 0 { + return &bindings, nil + } + + // Compute role or policy names by interpolating the identity's claim + // mappings into the rule BindName templates. + for _, rule := range matchingRules { + bindName, valid, err := computeBindName(rule.BindType, rule.BindName, identity.ClaimMappings) + switch { + case err != nil: + return nil, fmt.Errorf("cannot compute %q bind name for bind target: %w", rule.BindType, err) + case !valid: + return nil, fmt.Errorf("computed %q bind name for bind target is invalid: %q", rule.BindType, bindName) + } + + switch rule.BindType { + case structs.ACLBindingRuleBindTypeRole: + role, err := b.store.GetACLRoleByName(nil, bindName) + if err != nil { + return nil, err + } + + if role != nil { + bindings.Roles = append(bindings.Roles, &structs.ACLTokenRoleLink{ + ID: role.ID, + }) + } + case structs.ACLBindingRuleBindTypePolicy: + policy, err := b.store.ACLPolicyByName(nil, bindName) + if err != nil { + return nil, err + } + + if policy != nil { + bindings.Policies = append(bindings.Policies, policy.Name) + } + } + } + + return &bindings, nil +} + +// computeBindName processes the HIL for the provided bind type+name using the +// projected variables. +// +// - If the HIL is invalid ("", false, AN_ERROR) is returned. +// - If the computed name is not valid for the type ("INVALID_NAME", false, nil) is returned. +// - If the computed name is valid for the type ("VALID_NAME", true, nil) is returned. +func computeBindName(bindType, bindName string, claimMappings map[string]string) (string, bool, error) { + bindName, err := interpolateHIL(bindName, claimMappings, true) + if err != nil { + return "", false, err + } + + var valid bool + switch bindType { + case structs.ACLBindingRuleBindTypePolicy: + valid = structs.ValidPolicyName.MatchString(bindName) + case structs.ACLBindingRuleBindTypeRole: + valid = structs.ValidACLRoleName.MatchString(bindName) + default: + return "", false, fmt.Errorf("unknown binding rule bind type: %s", bindType) + } + + return bindName, valid, nil +} + +// doesSelectorMatch checks that a single selector matches the provided vars. +func doesSelectorMatch(selector string, selectableVars interface{}) bool { + if selector == "" { + return true // catch-all + } + + eval, err := bexpr.CreateEvaluator(selector) + if err != nil { + return false // fails to match if selector is invalid + } + + result, err := eval.Evaluate(selectableVars) + if err != nil { + return false // fails to match if evaluation fails + } + + return result +} + +// interpolateHIL processes the string as if it were HIL and interpolates only +// the provided string->string map as possible variables. +func interpolateHIL(s string, vars map[string]string, lowercase bool) (string, error) { + if !strings.Contains(s, "${") { + // Skip going to the trouble of parsing something that has no HIL. + return s, nil + } + + tree, err := hil.Parse(s) + if err != nil { + return "", err + } + + vm := make(map[string]ast.Variable) + for k, v := range vars { + if lowercase { + v = strings.ToLower(v) + } + vm[k] = ast.Variable{ + Type: ast.TypeString, + Value: v, + } + } + + config := &hil.EvalConfig{ + GlobalScope: &ast.BasicScope{ + VarMap: vm, + }, + } + + result, err := hil.Eval(tree, config) + if err != nil { + return "", err + } + + if result.Type != hil.TypeString { + return "", fmt.Errorf("generated unexpected hil type: %s", result.Type) + } + + return result.Value.(string), nil +} diff --git a/lib/auth/oidc/binder_test.go b/lib/auth/oidc/binder_test.go new file mode 100644 index 000000000..953959710 --- /dev/null +++ b/lib/auth/oidc/binder_test.go @@ -0,0 +1,181 @@ +package oidc + +import ( + "testing" + + "github.com/shoenig/test/must" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +func TestBinder_Bind(t *testing.T) { + ci.Parallel(t) + + testStore := state.TestStateStore(t) + testBind := NewBinder(testStore) + + // create an authMethod method and insert into the state store + authMethod := mock.ACLAuthMethod() + must.NoError(t, testStore.UpsertACLAuthMethods(0, []*structs.ACLAuthMethod{authMethod})) + + // create some roles and insert into the state store + targetRole := &structs.ACLRole{ + ID: uuid.Generate(), + Name: "vim-role", + } + otherRole := &structs.ACLRole{ + ID: uuid.Generate(), + Name: "frontend-engineers", + } + must.NoError(t, testStore.UpsertACLRoles( + structs.MsgTypeTestSetup, 0, []*structs.ACLRole{targetRole, otherRole}, true, + )) + + // create binding rules and insert into the state store + bindingRules := []*structs.ACLBindingRule{ + { + ID: uuid.Generate(), + Selector: "role==engineer", + BindType: structs.ACLBindingRuleBindTypeRole, + BindName: "${editor}-role", + AuthMethod: authMethod.Name, + }, + { + ID: uuid.Generate(), + Selector: "role==engineer", + BindType: structs.ACLBindingRuleBindTypeRole, + BindName: "this-role-does-not-exist", + AuthMethod: authMethod.Name, + }, + { + ID: uuid.Generate(), + Selector: "language==js", + BindType: structs.ACLBindingRuleBindTypeRole, + BindName: otherRole.Name, + AuthMethod: authMethod.Name, + }, + } + must.NoError(t, testStore.UpsertACLBindingRules(0, bindingRules, true)) + + tests := []struct { + name string + authMethod *structs.ACLAuthMethod + identity *Identity + want *Bindings + wantErr bool + }{ + { + "empty identity", + authMethod, + &Identity{}, + &Bindings{}, + false, + }, + { + "role", + authMethod, + &Identity{ + Claims: map[string]string{ + "role": "engineer", + "language": "go", + }, + ClaimMappings: map[string]string{ + "editor": "vim", + }, + }, + &Bindings{Roles: []*structs.ACLTokenRoleLink{{ID: targetRole.ID}}}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := testBind.Bind(tt.authMethod, tt.identity) + if tt.wantErr { + must.Error(t, err) + } else { + must.NoError(t, err) + } + must.Eq(t, got, tt.want) + }) + } +} + +func Test_computeBindName(t *testing.T) { + ci.Parallel(t) + tests := []struct { + name string + bindType string + bindName string + claimMappings map[string]string + wantName string + wantTrue bool + wantErr bool + }{ + { + "valid bind name and type", + structs.ACLBindingRuleBindTypeRole, + "cluster-admin", + map[string]string{"cluster-admin": "root"}, + "cluster-admin", + true, + false, + }, + { + "invalid type", + "amazing", + "cluster-admin", + map[string]string{"cluster-admin": "root"}, + "", + false, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := computeBindName(tt.bindType, tt.bindName, tt.claimMappings) + if tt.wantErr { + must.NotNil(t, err) + } + must.Eq(t, got, tt.wantName) + must.Eq(t, got1, tt.wantTrue) + }) + } +} + +func Test_doesSelectorMatch(t *testing.T) { + ci.Parallel(t) + tests := []struct { + name string + selector string + selectableVars interface{} + want bool + }{ + { + "catch-all", + "", + nil, + true, + }, + { + "valid selector but no selectable vars", + "nomad_engineering_team in Groups", + "", + false, + }, + { + "valid selector and successful evaluation", + "nomad_engineering_team in Groups", + map[string][]string{"Groups": {"nomad_sales_team", "nomad_engineering_team"}}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + must.Eq(t, doesSelectorMatch(tt.selector, tt.selectableVars), tt.want) + }) + } +} diff --git a/lib/auth/oidc/claims.go b/lib/auth/oidc/claims.go new file mode 100644 index 000000000..db2753e96 --- /dev/null +++ b/lib/auth/oidc/claims.go @@ -0,0 +1,232 @@ +package oidc + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/mitchellh/pointerstructure" + + "github.com/hashicorp/nomad/nomad/structs" +) + +// SelectorData returns the data for go-bexpr for selector evaluation. +func SelectorData( + am *structs.ACLAuthMethod, idClaims, userClaims map[string]interface{}) (*structs.ACLAuthClaims, error) { + + // Ensure the issuer and subscriber data does not get overwritten. + if len(userClaims) > 0 { + + iss, issOk := idClaims["iss"] + sub, subOk := idClaims["sub"] + + for k, v := range userClaims { + idClaims[k] = v + } + + if issOk { + idClaims["iss"] = iss + } + if subOk { + idClaims["sub"] = sub + } + } + + return extractClaims(am, idClaims) +} + +// extractClaims takes the claim mapping configuration of the OIDC auth method, +// extracts the claims, and returns a map of data that can be used with +// go-bexpr. +func extractClaims( + am *structs.ACLAuthMethod, all map[string]interface{}) (*structs.ACLAuthClaims, error) { + + values, err := extractMappings(all, am.Config.ClaimMappings) + if err != nil { + return nil, err + } + + list, err := extractListMappings(all, am.Config.ListClaimMappings) + if err != nil { + return nil, err + } + + return &structs.ACLAuthClaims{ + Value: values, + List: list, + }, nil +} + +// extractMappings extracts the string value mappings. +func extractMappings( + all map[string]interface{}, mapping map[string]string) (map[string]string, error) { + + result := make(map[string]string) + for source, target := range mapping { + rawValue := getClaim(all, source) + if rawValue == nil { + continue + } + + strValue, ok := stringifyClaimValue(rawValue) + if !ok { + return nil, fmt.Errorf("error converting claim '%s' to string from unknown type %T", + source, rawValue) + } + + result[target] = strValue + } + + return result, nil +} + +// extractListMappings builds a metadata map of string list values from a set +// of claims and claims mappings. The referenced claims must be strings and +// the claims mappings must be of the structure: +// +// { +// "/some/claim/pointer": "metadata_key1", +// "another_claim": "metadata_key2", +// ... +// } +func extractListMappings( + all map[string]interface{}, mappings map[string]string) (map[string][]string, error) { + + result := make(map[string][]string) + for source, target := range mappings { + rawValue := getClaim(all, source) + if rawValue == nil { + continue + } + + rawList, ok := normalizeList(rawValue) + if !ok { + return nil, fmt.Errorf("%q list claim could not be converted to string list", source) + } + + list := make([]string, 0, len(rawList)) + for _, raw := range rawList { + value, ok := stringifyClaimValue(raw) + if !ok { + return nil, fmt.Errorf("value %v in %q list claim could not be parsed as string", + raw, source) + } + + if value == "" { + continue + } + list = append(list, value) + } + + result[target] = list + } + + return result, nil +} + +// getClaim returns a claim value from allClaims given a provided claim string. +// If this string is a valid JSONPointer, it will be interpreted as such to +// locate the claim. Otherwise, the claim string will be used directly. +// +// There is no fixup done to the returned data type here. That happens a layer +// up in the caller. +func getClaim(all map[string]interface{}, claim string) interface{} { + if !strings.HasPrefix(claim, "/") { + return all[claim] + } + + val, err := pointerstructure.Get(all, claim) + if err != nil { + // We silently drop the error since keys that are invalid + // just have no values. + return nil + } + + return val +} + +// stringifyClaimValue will try to convert the provided raw value into a +// faithful string representation of that value per these rules: +// +// - strings => unchanged +// - bool => "true" / "false" +// - json.Number => String() +// - float32/64 => truncated to int64 and then formatted as an ascii string +// - intXX/uintXX => casted to int64 and then formatted as an ascii string +// +// If successful the string value and true are returned. otherwise an empty +// string and false are returned. +func stringifyClaimValue(rawValue interface{}) (string, bool) { + switch v := rawValue.(type) { + case string: + return v, true + case bool: + return strconv.FormatBool(v), true + case json.Number: + return v.String(), true + case float64: + // The claims unmarshalled by go-oidc don't use UseNumber, so + // they'll come in as float64 instead of an integer or json.Number. + return strconv.FormatInt(int64(v), 10), true + + // The numerical type cases following here are only here for the sake + // of numerical type completion. Everything is truncated to an integer + // before being stringified. + case float32: + return strconv.FormatInt(int64(v), 10), true + case int8: + return strconv.FormatInt(int64(v), 10), true + case int16: + return strconv.FormatInt(int64(v), 10), true + case int32: + return strconv.FormatInt(int64(v), 10), true + case int64: + return strconv.FormatInt(v, 10), true + case int: + return strconv.FormatInt(int64(v), 10), true + case uint8: + return strconv.FormatInt(int64(v), 10), true + case uint16: + return strconv.FormatInt(int64(v), 10), true + case uint32: + return strconv.FormatInt(int64(v), 10), true + case uint64: + return strconv.FormatInt(int64(v), 10), true + case uint: + return strconv.FormatInt(int64(v), 10), true + default: + return "", false + } +} + +// normalizeList takes an item or a slice and returns a slice. This is useful +// when providers are expected to return a list (typically of strings) but +// reduce it to a non-slice type when the list count is 1. +// +// There is no fixup done to elements of the returned slice here. That happens +// a layer up in the caller. +func normalizeList(raw interface{}) ([]interface{}, bool) { + switch v := raw.(type) { + case []interface{}: + return v, true + case string, // note: this list should be the same as stringifyClaimValue + bool, + json.Number, + float64, + float32, + int8, + int16, + int32, + int64, + int, + uint8, + uint16, + uint32, + uint64, + uint: + return []interface{}{v}, true + default: + return nil, false + } +} diff --git a/lib/auth/oidc/claims_test.go b/lib/auth/oidc/claims_test.go new file mode 100644 index 000000000..5f9655c00 --- /dev/null +++ b/lib/auth/oidc/claims_test.go @@ -0,0 +1,86 @@ +package oidc + +import ( + "testing" + + "github.com/shoenig/test/must" + + "github.com/hashicorp/nomad/nomad/structs" +) + +func TestSelectorData(t *testing.T) { + cases := []struct { + Name string + Mapping map[string]string + ListMapping map[string]string + Data map[string]interface{} + Expected *structs.ACLAuthClaims + }{ + { + "no mappings", + nil, + nil, + map[string]interface{}{"iss": "https://hashicorp.com"}, + &structs.ACLAuthClaims{ + Value: map[string]string{}, + List: map[string][]string{}, + }, + }, + + { + "key", + map[string]string{"iss": "issuer"}, + nil, + map[string]interface{}{"iss": "https://hashicorp.com"}, + &structs.ACLAuthClaims{ + Value: map[string]string{"issuer": "https://hashicorp.com"}, + List: map[string][]string{}, + }, + }, + + { + "key doesn't exist", + map[string]string{"iss": "issuer"}, + nil, + map[string]interface{}{"nope": "https://hashicorp.com"}, + &structs.ACLAuthClaims{ + Value: map[string]string{}, + List: map[string][]string{}, + }, + }, + + { + "list", + nil, + map[string]string{"groups": "g"}, + map[string]interface{}{ + "groups": []interface{}{ + "A", 42, false, + }, + }, + &structs.ACLAuthClaims{ + Value: map[string]string{}, + List: map[string][]string{ + "g": {"A", "42", "false"}, + }, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.Name, func(t *testing.T) { + + am := &structs.ACLAuthMethod{ + Config: &structs.ACLAuthMethodConfig{ + ClaimMappings: tt.Mapping, + ListClaimMappings: tt.ListMapping, + }, + } + + // Get real selector data + actual, err := SelectorData(am, tt.Data, nil) + must.NoError(t, err) + must.Eq(t, actual, tt.Expected) + }) + } +} diff --git a/lib/auth/oidc/identity.go b/lib/auth/oidc/identity.go new file mode 100644 index 000000000..e39856a88 --- /dev/null +++ b/lib/auth/oidc/identity.go @@ -0,0 +1,36 @@ +package oidc + +import ( + "github.com/hashicorp/nomad/nomad/structs" +) + +type Identity struct { + // Claims is the format of this Identity suitable for selection + // with a binding rule. + Claims interface{} + + // ClaimMappings is the format of this Identity suitable for interpolation in a + // bind name within a binding rule. + ClaimMappings map[string]string +} + +// NewIdentity builds a new Identity that can be used to generate bindings via +// Bind for ACL token creation. +func NewIdentity( + authMethodConfig *structs.ACLAuthMethodConfig, authClaims *structs.ACLAuthClaims) *Identity { + + claimMappings := make(map[string]string) + + // Populate claimMappings vars with empty values so HIL works. + for _, k := range authMethodConfig.ClaimMappings { + claimMappings["value."+k] = "" + } + for k, val := range authClaims.Value { + claimMappings["value."+k] = val + } + + return &Identity{ + Claims: authClaims, + ClaimMappings: claimMappings, + } +} diff --git a/lib/auth/oidc/identity_test.go b/lib/auth/oidc/identity_test.go new file mode 100644 index 000000000..c38a52fbb --- /dev/null +++ b/lib/auth/oidc/identity_test.go @@ -0,0 +1,64 @@ +package oidc + +import ( + "github.com/shoenig/test/must" + "testing" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/nomad/structs" +) + +func Test_NewIdentity(t *testing.T) { + ci.Parallel(t) + + testCases := []struct { + name string + inputAuthMethodConfig *structs.ACLAuthMethodConfig + inputAuthClaims *structs.ACLAuthClaims + expectedOutputIdentity *Identity + }{ + { + name: "identity with claims", + inputAuthMethodConfig: &structs.ACLAuthMethodConfig{ + ClaimMappings: map[string]string{"http://nomad.internal/username": "username"}, + ListClaimMappings: map[string]string{"http://nomad.internal/roles": "roles"}, + }, + inputAuthClaims: &structs.ACLAuthClaims{ + Value: map[string]string{"username": "jrasell"}, + List: map[string][]string{"roles": {"engineering"}}, + }, + expectedOutputIdentity: &Identity{ + Claims: &structs.ACLAuthClaims{ + Value: map[string]string{"username": "jrasell"}, + List: map[string][]string{"roles": {"engineering"}}, + }, + ClaimMappings: map[string]string{"value.username": "jrasell"}, + }, + }, + { + name: "identity without claims", + inputAuthMethodConfig: &structs.ACLAuthMethodConfig{ + ClaimMappings: map[string]string{"http://nomad.internal/username": "username"}, + ListClaimMappings: map[string]string{"http://nomad.internal/roles": "roles"}, + }, + inputAuthClaims: &structs.ACLAuthClaims{ + Value: map[string]string{"username": ""}, + List: map[string][]string{"roles": {""}}, + }, + expectedOutputIdentity: &Identity{ + Claims: &structs.ACLAuthClaims{ + Value: map[string]string{"username": ""}, + List: map[string][]string{"roles": {""}}, + }, + ClaimMappings: map[string]string{"value.username": ""}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actualOutput := NewIdentity(tc.inputAuthMethodConfig, tc.inputAuthClaims) + must.Eq(t, tc.expectedOutputIdentity, actualOutput) + }) + } +} diff --git a/lib/auth/oidc/provider.go b/lib/auth/oidc/provider.go new file mode 100644 index 000000000..7a086af1a --- /dev/null +++ b/lib/auth/oidc/provider.go @@ -0,0 +1,185 @@ +package oidc + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/nomad/nomad/structs" +) + +// providerConfig returns the OIDC provider configuration for an OIDC +// auth-method. +func providerConfig(authMethod *structs.ACLAuthMethod) (*oidc.Config, error) { + var algs []oidc.Alg + if len(authMethod.Config.SigningAlgs) > 0 { + for _, alg := range authMethod.Config.SigningAlgs { + algs = append(algs, oidc.Alg(alg)) + } + } else { + algs = []oidc.Alg{oidc.RS256} + } + + return oidc.NewConfig( + authMethod.Config.OIDCDiscoveryURL, + authMethod.Config.OIDCClientID, + oidc.ClientSecret(authMethod.Config.OIDCClientSecret), + algs, + authMethod.Config.AllowedRedirectURIs, + oidc.WithAudiences(authMethod.Config.BoundAudiences...), + oidc.WithProviderCA(strings.Join(authMethod.Config.DiscoveryCaPem, "\n")), + ) +} + +// ProviderCache is a cache for OIDC providers. OIDC providers are something +// you don't want to recreate per-request since they make HTTP requests +// when they're constructed. +// +// The ProviderCache purges a provider under two scenarios: (1) the +// provider config is updated, and it is different and (2) after a set +// amount of time (see cacheExpiry for value) in case the remote provider +// configuration changed. +type ProviderCache struct { + providers map[string]*oidc.Provider + mu sync.RWMutex + + // cancel is used to trigger cancellation of any routines when the cache + // has been informed its parent process is exiting. + cancel context.CancelFunc +} + +// NewProviderCache should be used to initialize a provider cache. This +// will start up background resources to manage the cache. +func NewProviderCache() *ProviderCache { + + // Create a context, so a server that is shutting down can correctly + // shut down the cache loop and OIDC provider background processes. + ctx, cancel := context.WithCancel(context.Background()) + + result := &ProviderCache{ + providers: map[string]*oidc.Provider{}, + cancel: cancel, + } + + // Start the cleanup timer + go result.runCleanupLoop(ctx) + + return result +} + +// Get returns the OIDC provider for the given auth method configuration. +// This will initialize the provider if it isn't already in the cache or +// if the configuration changed. +func (c *ProviderCache) Get(authMethod *structs.ACLAuthMethod) (*oidc.Provider, error) { + + // No matter what we'll use the config of the arg method since we'll + // use it to compare to existing (if exists) or initialize a new provider. + oidcCfg, err := providerConfig(authMethod) + if err != nil { + return nil, err + } + + // Get any current provider for the named auth-method. + var ( + current *oidc.Provider + ok bool + ) + + c.mu.RLock() + current, ok = c.providers[authMethod.Name] + c.mu.RUnlock() + + // If we have a current value, we want to compare hashes to detect changes. + if ok { + currentHash, err := current.ConfigHash() + if err != nil { + return nil, err + } + + newHash, err := oidcCfg.Hash() + if err != nil { + return nil, err + } + + // If the hashes match, this is can be classed as a cache hit. + if currentHash == newHash { + return current, nil + } + } + + // If we made it here, the provider isn't in the cache OR the config + // changed. We therefore, need to initialize a new provider. + newProvider, err := oidc.NewProvider(oidcCfg) + if err != nil { + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + + // If we have an old provider, clean up resources. + if current != nil { + current.Done() + } + + c.providers[authMethod.Name] = newProvider + + return newProvider, nil +} + +// Delete force deletes a single auth method from the cache by name. +func (c *ProviderCache) Delete(name string) { + c.mu.Lock() + defer c.mu.Unlock() + + p, ok := c.providers[name] + if ok { + p.Done() + delete(c.providers, name) + } +} + +// Shutdown stops any long-lived cache process and informs each OIDC provider +// that they are done. This should be called whenever the Nomad server is +// shutting down. +func (c *ProviderCache) Shutdown() { + c.cancel() + c.clear() +} + +// runCleanupLoop runs an infinite loop that clears the cache every cacheExpiry +// duration. This ensures that we force refresh our provider info periodically +// in case anything changes. +func (c *ProviderCache) runCleanupLoop(ctx context.Context) { + + ticker := time.NewTicker(cacheExpiry) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + + // We could be more clever and do a per-entry expiry but Nomad won't + // have more than one ot two auth methods configured, therefore it's + // not worth the added complexity. + case <-ticker.C: + c.clear() + } + } +} + +// clear is called to delete all the providers in the cache. +func (c *ProviderCache) clear() { + c.mu.Lock() + defer c.mu.Unlock() + for _, p := range c.providers { + p.Done() + } + c.providers = map[string]*oidc.Provider{} +} + +// cacheExpiry is the duration after which the provider cache is reset. +const cacheExpiry = 6 * time.Hour diff --git a/lib/auth/oidc/provider_test.go b/lib/auth/oidc/provider_test.go new file mode 100644 index 000000000..2d8c90820 --- /dev/null +++ b/lib/auth/oidc/provider_test.go @@ -0,0 +1,66 @@ +package oidc + +import ( + "testing" + "time" + + "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" +) + +func TestProviderCache(t *testing.T) { + + // Instantiate a new cache. + testCache := NewProviderCache() + defer testCache.Shutdown() + + // Create our OIDC test provider. + oidcTestProvider := oidc.StartTestProvider(t) + oidcTestProvider.SetClientCreds("bob", "ssshhhh") + _, _, tpAlg, _ := oidcTestProvider.SigningKeys() + + // Create a mocked auth-method; avoiding the mock as the hashicorp/cap lib + // performs validation on certain fields. + authMethod := structs.ACLAuthMethod{ + Name: "test-oidc-auth-method", + Type: "OIDC", + TokenLocality: "global", + MaxTokenTTL: 100 * time.Hour, + Default: true, + Config: &structs.ACLAuthMethodConfig{ + OIDCDiscoveryURL: oidcTestProvider.Addr(), + OIDCClientID: "alice", + OIDCClientSecret: "ssshhhh", + AllowedRedirectURIs: []string{"http://example.com"}, + DiscoveryCaPem: []string{oidcTestProvider.CACert()}, + SigningAlgs: []string{string(tpAlg)}, + }, + } + authMethod.SetHash() + + // Perform a lookup against the cache. This should generate a new provider + // for our auth-method. + oidcProvider1, err := testCache.Get(&authMethod) + must.NoError(t, err) + must.NotNil(t, oidcProvider1) + + // Perform another lookup, checking that the returned pointer value is the + // same. + oidcProvider2, err := testCache.Get(&authMethod) + must.NoError(t, err) + must.EqOp(t, oidcProvider1, oidcProvider2) + + // Update an aspect on the auth-method config and then perform a lookup. + // This should return a non-cached provider. + authMethod.Config.AllowedRedirectURIs = []string{"http://example.com/foo/bar/baz/haz"} + oidcProvider3, err := testCache.Get(&authMethod) + must.NoError(t, err) + must.NotEqOp(t, oidcProvider2, oidcProvider3) + + // Ensure the cache only contains a single entry to show we successfully + // replaced the stale entry. + testCache.mu.RLock() + must.MapLen(t, 1, testCache.providers) + testCache.mu.RUnlock() +} diff --git a/lib/auth/oidc/server.go b/lib/auth/oidc/server.go new file mode 100644 index 000000000..6d2ed8843 --- /dev/null +++ b/lib/auth/oidc/server.go @@ -0,0 +1,244 @@ +package oidc + +import ( + "fmt" + "net" + "net/http" + + "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/nomad/api" +) + +// CallbackServer is started with NewCallbackServer and creates an HTTP +// server for handling loopback OIDC auth redirects. +type CallbackServer struct { + ln net.Listener + url string + clientNonce string + errCh chan error + successCh chan *api.ACLOIDCCompleteAuthRequest +} + +// NewCallbackServer creates and starts a new local HTTP server for +// OIDC authentication to redirect to. This is used to capture the +// necessary information to complete the authentication. +func NewCallbackServer(addr string) (*CallbackServer, error) { + // Generate our nonce + nonce, err := oidc.NewID() + if err != nil { + return nil, err + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + // Initialize our callback server + srv := &CallbackServer{ + url: fmt.Sprintf("http://%s/oidc/callback", ln.Addr().String()), + ln: ln, + clientNonce: nonce, + errCh: make(chan error, 5), + successCh: make(chan *api.ACLOIDCCompleteAuthRequest, 5), + } + + // Register our HTTP route and start the server + mux := http.NewServeMux() + mux.Handle("/oidc/callback", srv) + go func() { + httpServer := &http.Server{Handler: mux} + if err := httpServer.Serve(ln); err != nil { + srv.errCh <- err + } + }() + + return srv, nil +} + +// Close cleans up and shuts down the server. On close, errors may be +// sent to ErrorCh and should be ignored. +func (s *CallbackServer) Close() error { return s.ln.Close() } + +// RedirectURI is the redirect URI that should be provided for the auth. +func (s *CallbackServer) RedirectURI() string { return s.url } + +// Nonce returns a generated nonce that can be used for the request. +func (s *CallbackServer) Nonce() string { return s.clientNonce } + +// ErrorCh returns a channel where any errors are sent. Errors may be +// sent after Close and should be disregarded. +func (s *CallbackServer) ErrorCh() <-chan error { return s.errCh } + +// SuccessCh returns a channel that gets sent a partially completed +// request to complete the OIDC auth with the Nomad server. +func (s *CallbackServer) SuccessCh() <-chan *api.ACLOIDCCompleteAuthRequest { return s.successCh } + +// ServeHTTP implements http.Handler and handles the callback request. This +// isn't usually used directly; use the server address instead. +func (s *CallbackServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + q := req.URL.Query() + + // Build our result + result := &api.ACLOIDCCompleteAuthRequest{ + State: q.Get("state"), + ClientNonce: s.clientNonce, + Code: q.Get("code"), + } + + // Send our result. We don't block here because the channel should be + // buffered, and otherwise we're done. + select { + case s.successCh <- result: + default: + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(serverSuccessHTMLResponse)) +} + +// serverSuccessHTMLResponse is the HTML response the OIDC callback server uses +// when the user has successfully logged in via the OIDC provider. +const serverSuccessHTMLResponse = ` + + + + + + OIDC Authentication Succeeded + + + +
+
+ + + + + + + + + + +
+ +
+
+ Signed in via your OIDC provider +
+

+ You can now close this window and start using Nomad. +

+
+
+
+

Not sure how to get started?

+

+ Check out beginner and advanced guides on HashiCorp Nomad at the HashiCorp Learn site or read more in the official documentation. +

+ + + + + + + Get started with Nomad + + + + + + + + View the official Nomad documentation + +
+
+ + +` diff --git a/lib/auth/oidc/server_test.go b/lib/auth/oidc/server_test.go new file mode 100644 index 000000000..dcbf61d34 --- /dev/null +++ b/lib/auth/oidc/server_test.go @@ -0,0 +1,20 @@ +package oidc + +import ( + "testing" + + "github.com/shoenig/test/must" +) + +func TestCallbackServer(t *testing.T) { + + testCallbackServer, err := NewCallbackServer("127.0.0.1:4649") + must.NoError(t, err) + must.NotNil(t, testCallbackServer) + + defer func() { + must.NoError(t, testCallbackServer.Close()) + }() + must.StrNotEqFold(t, "", testCallbackServer.Nonce()) + must.StrNotEqFold(t, "", testCallbackServer.RedirectURI()) +} diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index b28e31b69..9307b4b5b 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -1,6 +1,8 @@ package nomad import ( + "context" + "errors" "fmt" "io/ioutil" "net/http" @@ -10,6 +12,7 @@ import ( "time" "github.com/armon/go-metrics" + capOIDC "github.com/hashicorp/cap/oidc" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" "github.com/hashicorp/go-set" @@ -17,6 +20,7 @@ import ( policy "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/lib/auth/oidc" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/state/paginator" "github.com/hashicorp/nomad/nomad/structs" @@ -31,6 +35,15 @@ const ( // aclBootstrapReset is the file name to create in the data dir. It's only contents // should be the reset index aclBootstrapReset = "acl-bootstrap-reset" + + // aclOIDCAuthURLRequestExpiryTime is the deadline used when generating an + // OIDC provider authentication URL. This is used for HTTP requests to + // external APIs. + aclOIDCAuthURLRequestExpiryTime = 60 * time.Second + + // aclOIDCCallbackRequestExpiryTime is the deadline used when obtaining an + // OIDC provider token. This is used for HTTP requests to external APIs. + aclOIDCCallbackRequestExpiryTime = 60 * time.Second ) // ACL endpoint is used for manipulating ACL tokens and policies @@ -38,10 +51,20 @@ type ACL struct { srv *Server ctx *RPCContext logger hclog.Logger + + // oidcProviderCache is a cache of OIDC providers as defined by the + // hashicorp/cap library. When performing an OIDC login flow, this cache + // should be used to obtain a provider from an auth-method. + oidcProviderCache *oidc.ProviderCache } func NewACLEndpoint(srv *Server, ctx *RPCContext) *ACL { - return &ACL{srv: srv, ctx: ctx, logger: srv.logger.Named("acl")} + return &ACL{ + srv: srv, + ctx: ctx, + logger: srv.logger.Named("acl"), + oidcProviderCache: srv.oidcProviderCache, + } } // UpsertPolicies is used to create or update a set of policies @@ -1843,14 +1866,6 @@ func (a *ACL) ListAuthMethods( } defer metrics.MeasureSince([]string{"nomad", "acl", "list_auth_methods"}, time.Now()) - // Resolve the token and ensure it has some form of permissions. - acl, err := a.srv.ResolveToken(args.AuthToken) - if err != nil { - return err - } else if acl == nil { - return structs.ErrPermissionDenied - } - // Set up and return the blocking query. return a.srv.blockingRPC(&blockingOptions{ queryOpts: &args.QueryOptions, @@ -2346,3 +2361,266 @@ func (a *ACL) GetBindingRule( }, }) } + +// OIDCAuthURL starts the OIDC login workflow. The response URL should be used +// by the caller to authenticate the user. Once this has been completed, +// OIDCCompleteAuth can be used for the remainder of the workflow. +func (a *ACL) OIDCAuthURL(args *structs.ACLOIDCAuthURLRequest, reply *structs.ACLOIDCAuthURLResponse) error { + + // The OIDC flow can only be used when the Nomad cluster has ACL enabled. + if !a.srv.config.ACLEnabled { + return aclDisabled + } + + // Perform the initial forwarding within the region. This ensures we + // respect stale queries. + if done, err := a.srv.forward(structs.ACLOIDCAuthURLRPCMethod, args, args, reply); done { + return err + } + + // There is not a perfect place to run this defer since we potentially + // forward twice. It is likely there will be two distinct patterns to this + // timing in clusters that utilise a mixture of local and global with + // methods. + defer metrics.MeasureSince([]string{"nomad", "acl", "oidc_auth_url"}, time.Now()) + + // Validate the request arguments to ensure it contains all the data it + // needs. Whether the data provided is correct will be handled by the OIDC + // provider. + if err := args.Validate(); err != nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "invalid OIDC auth-url request: %v", err) + } + + // Grab a snapshot of the state, so we can query it safely. + stateSnapshot, err := a.srv.fsm.State().Snapshot() + if err != nil { + return err + } + + // Lookup the auth method from state, so we have the entire object + // available to us. It's important to check for nil on the auth method + // object, as it is possible the request was made with an incorrectly named + // auth method. + authMethod, err := stateSnapshot.GetACLAuthMethodByName(nil, args.AuthMethodName) + if err != nil { + return err + } + if authMethod == nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth-method %q not found", args.AuthMethodName) + } + + // If the authentication method generates global ACL tokens, we need to + // forward the request onto the authoritative regional leader. + if authMethod.TokenLocalityIsGlobal() { + args.Region = a.srv.config.AuthoritativeRegion + + if done, err := a.srv.forward(structs.ACLOIDCAuthURLRPCMethod, args, args, reply); done { + return err + } + } + + // Generate our OIDC request. + oidcReqOpts := []capOIDC.Option{ + capOIDC.WithNonce(args.ClientNonce), + } + + if len(authMethod.Config.OIDCScopes) > 0 { + oidcReqOpts = append(oidcReqOpts, capOIDC.WithScopes(authMethod.Config.OIDCScopes...)) + } + + oidcReq, err := capOIDC.NewRequest( + aclOIDCAuthURLRequestExpiryTime, + args.RedirectURI, + oidcReqOpts..., + ) + if err != nil { + return fmt.Errorf("failed to generate OIDC request: %v", err) + } + + // Use the cache to provide us with an OIDC provider for the auth method + // that was resolved from state. + oidcProvider, err := a.oidcProviderCache.Get(authMethod) + if err != nil { + return fmt.Errorf("failed to generate OIDC provider: %v", err) + } + + // Generate a context. This argument is required by the OIDC provider lib, + // but is not used in any way. This therefore acts for future proofing, if + // the provider lib uses the context. + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(aclOIDCAuthURLRequestExpiryTime)) + defer cancel() + + // Generate the URL, handling any error along with the URL. + authURL, err := oidcProvider.AuthURL(ctx, oidcReq) + if err != nil { + return fmt.Errorf("failed to generate auth URL: %v", err) + } + + reply.AuthURL = authURL + return nil +} + +// OIDCCompleteAuth complete the OIDC login workflow. It will exchange the OIDC +// provider token for a Nomad ACL token, using the configured ACL role and +// policy claims to provide authorization. +func (a *ACL) OIDCCompleteAuth( + args *structs.ACLOIDCCompleteAuthRequest, reply *structs.ACLOIDCCompleteAuthResponse) error { + + // The OIDC flow can only be used when the Nomad cluster has ACL enabled. + if !a.srv.config.ACLEnabled { + return aclDisabled + } + + // Perform the initial forwarding within the region. This ensures we + // respect stale queries. + if done, err := a.srv.forward(structs.ACLOIDCCompleteAuthRPCMethod, args, args, reply); done { + return err + } + + // There is not a perfect place to run this defer since we potentially + // forward twice. It is likely there will be two distinct patterns to this + // timing in clusters that utilise a mixture of local and global with + // methods. + defer metrics.MeasureSince([]string{"nomad", "acl", "oidc_complete_auth"}, time.Now()) + + // Validate the request arguments to ensure it contains all the data it + // needs. Whether the data provided is correct will be handled by the OIDC + // provider. + if err := args.Validate(); err != nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "invalid OIDC complete-auth request: %v", err) + } + + // Grab a snapshot of the state, so we can query it safely. + stateSnapshot, err := a.srv.fsm.State().Snapshot() + if err != nil { + return err + } + + // Lookup the auth method from state, so we have the entire object + // available to us. It's important to check for nil on the auth method + // object, as it is possible the request was made with an incorrectly named + // auth method. + authMethod, err := stateSnapshot.GetACLAuthMethodByName(nil, args.AuthMethodName) + if err != nil { + return err + } + if authMethod == nil { + return structs.NewErrRPCCodedf(http.StatusBadRequest, "auth-method %q not found", args.AuthMethodName) + } + + // If the authentication method generates global ACL tokens, we need to + // forward the request onto the authoritative regional leader. + if authMethod.TokenLocalityIsGlobal() { + args.Region = a.srv.config.AuthoritativeRegion + + if done, err := a.srv.forward(structs.ACLOIDCCompleteAuthRPCMethod, args, args, reply); done { + return err + } + } + + // Use the cache to provide us with an OIDC provider for the auth method + // that was resolved from state. + oidcProvider, err := a.oidcProviderCache.Get(authMethod) + if err != nil { + return fmt.Errorf("failed to generate OIDC provider: %v", err) + } + + // Build our OIDC request options and request object. + oidcReqOpts := []capOIDC.Option{ + capOIDC.WithNonce(args.ClientNonce), + capOIDC.WithState(args.State), + } + + if len(authMethod.Config.OIDCScopes) > 0 { + oidcReqOpts = append(oidcReqOpts, capOIDC.WithScopes(authMethod.Config.OIDCScopes...)) + } + if len(authMethod.Config.BoundAudiences) > 0 { + oidcReqOpts = append(oidcReqOpts, capOIDC.WithAudiences(authMethod.Config.BoundAudiences...)) + } + + oidcReq, err := capOIDC.NewRequest(aclOIDCCallbackRequestExpiryTime, args.RedirectURI, oidcReqOpts...) + if err != nil { + return fmt.Errorf("failed to generate OIDC request: %v", err) + } + + // Generate a context with a deadline. This is passed to the OIDC provider + // and used when making remote HTTP requests. + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(aclOIDCCallbackRequestExpiryTime)) + defer cancel() + + // Exchange the state and code for an OIDC provider token. + oidcToken, err := oidcProvider.Exchange(ctx, oidcReq, args.State, args.Code) + if err != nil { + return fmt.Errorf("failed to exchange token with provider: %v", err) + } + if !oidcToken.Valid() { + return errors.New("exchanged token is not valid; potentially expired or empty") + } + + var idTokenClaims map[string]interface{} + if err := oidcToken.IDToken().Claims(&idTokenClaims); err != nil { + return fmt.Errorf("failed to retrieve the ID token claims: %v", err) + } + + var userClaims map[string]interface{} + if userTokenSource := oidcToken.StaticTokenSource(); userTokenSource != nil { + if err := oidcProvider.UserInfo(ctx, userTokenSource, idTokenClaims["sub"].(string), &userClaims); err != nil { + return fmt.Errorf("failed to retrieve the user info claims: %v", err) + } + } + + // Generate the data used by the go-bexpr selector that is an internal + // representation of the claims that can be understood by Nomad. + oidcInternalClaims, err := oidc.SelectorData(authMethod, idTokenClaims, userClaims) + if err != nil { + return err + } + + // Create a new binder object based on the current state snapshot to + // provide consistency within the RPC handler. + oidcBinder := oidc.NewBinder(stateSnapshot) + + // Generate the role and policy bindings that will be assigned to the ACL + // token. Ensure we have at least 1 role or policy, otherwise the RPC will + // fail anyway. + tokenBindings, err := oidcBinder.Bind(authMethod, oidc.NewIdentity(authMethod.Config, oidcInternalClaims)) + if err != nil { + return err + } + if tokenBindings.None() { + return structs.NewErrRPCCoded(http.StatusBadRequest, "no role or policy bindings matched") + } + + // Build our token RPC request. The RPC handler includes a lot of specific + // logic, so we do not want to call Raft directly or copy that here. In the + // future we should try and extract out the logic into an interface, or at + // least a separate function. + tokenUpsertRequest := structs.ACLTokenUpsertRequest{ + Tokens: []*structs.ACLToken{ + { + Name: "OIDC-" + authMethod.Name, + Type: structs.ACLClientToken, + Policies: tokenBindings.Policies, + Roles: tokenBindings.Roles, + Global: authMethod.TokenLocalityIsGlobal(), + ExpirationTTL: authMethod.MaxTokenTTL, + }, + }, + WriteRequest: structs.WriteRequest{ + Region: a.srv.Region(), + AuthToken: a.srv.getLeaderAcl(), + }, + } + + var tokenUpsertReply structs.ACLTokenUpsertResponse + + if err := a.srv.RPC(structs.ACLUpsertTokensRPCMethod, &tokenUpsertRequest, &tokenUpsertReply); err != nil { + return err + } + + // The way the UpsertTokens RPC currently works, if we get no error, then + // we will have exactly the same number of tokens returned as we sent. It + // is therefore safe to assume we have 1 token. + reply.ACLToken = tokenUpsertReply.Tokens[0] + return nil +} diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index 1c961204c..4121135a7 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -3,11 +3,13 @@ package nomad import ( "fmt" "io/ioutil" + "net/url" "path/filepath" "strings" "testing" "time" + capOIDC "github.com/hashicorp/cap/oidc" "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" "github.com/hashicorp/nomad/ci" @@ -3478,3 +3480,225 @@ func TestACL_GetBindingRule(t *testing.T) { must.Eq(t, aclBindingRules[0].ID, result.reply.ACLBindingRule.ID) must.Greater(t, aclBindingRuleResp4.Index, result.reply.Index) } + +func TestACL_OIDCAuthURL(t *testing.T) { + t.Parallel() + + testServer, _, testServerCleanupFn := TestACLServer(t, nil) + defer testServerCleanupFn() + codec := rpcClient(t, testServer) + testutil.WaitForLeader(t, testServer.RPC) + + // Set up the test OIDC provider. + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + oidcTestProvider.SetClientCreds("bob", "topsecretcredthing") + + // Send an empty request to ensure the RPC handler runs the validation + // func. + authURLReq1 := structs.ACLOIDCAuthURLRequest{ + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var authURLResp1 structs.ACLOIDCAuthURLResponse + err := msgpackrpc.CallWithCodec(codec, structs.ACLOIDCAuthURLRPCMethod, &authURLReq1, &authURLResp1) + must.Error(t, err) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "invalid OIDC auth-url request") + + // Send a valid request that contains an auth method name that does not + // exist within state. + authURLReq2 := structs.ACLOIDCAuthURLRequest{ + AuthMethodName: "test-oidc-auth-method", + RedirectURI: "http://127.0.0.1:4649/oidc/callback", + ClientNonce: "fsSPuaodKevKfDU3IeXa", + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var authURLResp2 structs.ACLOIDCAuthURLResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCAuthURLRPCMethod, &authURLReq2, &authURLResp2) + must.Error(t, err) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "auth-method \"test-oidc-auth-method\" not found") + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider just like real world use. + mockedAuthMethod := mock.ACLAuthMethod() + mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"} + mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr() + mockedAuthMethod.Config.SigningAlgs = []string{"ES256"} + mockedAuthMethod.Config.DiscoveryCaPem = []string{oidcTestProvider.CACert()} + + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(10, []*structs.ACLAuthMethod{mockedAuthMethod})) + + // Make a new request, which contains all valid data and therefore should + // succeed. + authURLReq3 := structs.ACLOIDCAuthURLRequest{ + AuthMethodName: mockedAuthMethod.Name, + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "fsSPuaodKevKfDU3IeXa", + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var authURLResp3 structs.ACLOIDCAuthURLResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCAuthURLRPCMethod, &authURLReq3, &authURLResp3) + must.NoError(t, err) + + // The response URL comes encoded, so decode this and check we have each + // component we expect. + escapedURL, err := url.PathUnescape(authURLResp3.AuthURL) + must.NoError(t, err) + must.StrContains(t, escapedURL, "/authorize?client_id=mock") + must.StrContains(t, escapedURL, "&nonce=fsSPuaodKevKfDU3IeXa") + must.StrContains(t, escapedURL, "&redirect_uri=http://127.0.0.1:4649/oidc/callback") + must.StrContains(t, escapedURL, "&response_type=code") + must.StrContains(t, escapedURL, "&scope=openid") + must.StrContains(t, escapedURL, "&state=st_") +} + +func TestACL_OIDCCompleteAuth(t *testing.T) { + t.Parallel() + + testServer, _, testServerCleanupFn := TestACLServer(t, nil) + defer testServerCleanupFn() + codec := rpcClient(t, testServer) + testutil.WaitForLeader(t, testServer.RPC) + + oidcTestProvider := capOIDC.StartTestProvider(t) + defer oidcTestProvider.Stop() + oidcTestProvider.SetAllowedRedirectURIs([]string{"http://127.0.0.1:4649/oidc/callback"}) + + // Send an empty request to ensure the RPC handler runs the validation + // func. + completeAuthReq1 := structs.ACLOIDCCompleteAuthRequest{ + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var completeAuthResp1 structs.ACLOIDCCompleteAuthResponse + err := msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq1, &completeAuthResp1) + must.Error(t, err) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "invalid OIDC complete-auth request") + + // Send a request that passes initial validation. The auth method does not + // exist meaning it will fail. + completeAuthReq2 := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: "test-oidc-auth-method", + ClientNonce: "fsSPuaodKevKfDU3IeXa", + State: "st_", + Code: "idontknowthisyet", + RedirectURI: "http://127.0.0.1:4649/oidc/callback", + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var completeAuthResp2 structs.ACLOIDCCompleteAuthResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq2, &completeAuthResp2) + must.Error(t, err) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "auth-method \"test-oidc-auth-method\" not found") + + // Generate and upsert an ACL auth method for use. Certain values must be + // taken from the cap OIDC provider and these are validated. Others must + // match data we use later, such as the claims. + mockedAuthMethod := mock.ACLAuthMethod() + mockedAuthMethod.Config.BoundAudiences = []string{"mock"} + mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"} + mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr() + mockedAuthMethod.Config.SigningAlgs = []string{"ES256"} + mockedAuthMethod.Config.DiscoveryCaPem = []string{oidcTestProvider.CACert()} + mockedAuthMethod.Config.ClaimMappings = map[string]string{} + mockedAuthMethod.Config.ListClaimMappings = map[string]string{ + "http://nomad.internal/roles": "roles", + "http://nomad.internal/policies": "policies", + } + + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(10, []*structs.ACLAuthMethod{mockedAuthMethod})) + + // Set our custom data and some expected values, so we can make the RPC and + // use the test provider. + oidcTestProvider.SetExpectedAuthNonce("fsSPuaodKevKfDU3IeXa") + oidcTestProvider.SetExpectedAuthCode("codeABC") + oidcTestProvider.SetCustomAudience("mock") + oidcTestProvider.SetExpectedState("st_someweirdstateid") + oidcTestProvider.SetCustomClaims(map[string]interface{}{ + "azp": "mock", + "http://nomad.internal/policies": []string{"engineering"}, + "http://nomad.internal/roles": []string{"engineering"}, + }) + + // We should now be able to authenticate, however, we do not have any rule + // bindings that will match. + completeAuthReq3 := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: mockedAuthMethod.Name, + ClientNonce: "fsSPuaodKevKfDU3IeXa", + State: "st_", + Code: "codeABC", + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var completeAuthResp3 structs.ACLOIDCCompleteAuthResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq3, &completeAuthResp3) + must.Error(t, err) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "no role or policy bindings matched") + + // Upsert an ACL policy and role, so that we can reference this within our + // OIDC claims. + mockACLPolicy := mock.ACLPolicy() + must.NoError(t, testServer.fsm.State().UpsertACLPolicies( + structs.MsgTypeTestSetup, 20, []*structs.ACLPolicy{mockACLPolicy})) + + mockACLRole := mock.ACLRole() + mockACLRole.Policies = []*structs.ACLRolePolicyLink{{Name: mockACLPolicy.Name}} + must.NoError(t, testServer.fsm.State().UpsertACLRoles( + structs.MsgTypeTestSetup, 30, []*structs.ACLRole{mockACLRole}, true)) + + // Generate and upsert two binding rules, so we can test both ACL Policy + // and Role claim mapping. + mockBindingRule1 := mock.ACLBindingRule() + mockBindingRule1.AuthMethod = mockedAuthMethod.Name + mockBindingRule1.BindType = structs.ACLBindingRuleBindTypePolicy + mockBindingRule1.Selector = "engineering in list.policies" + mockBindingRule1.BindName = mockACLPolicy.Name + + mockBindingRule2 := mock.ACLBindingRule() + mockBindingRule2.AuthMethod = mockedAuthMethod.Name + mockBindingRule2.BindName = mockACLRole.Name + + must.NoError(t, testServer.fsm.State().UpsertACLBindingRules( + 40, []*structs.ACLBindingRule{mockBindingRule1, mockBindingRule2}, true)) + + completeAuthReq4 := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: mockedAuthMethod.Name, + ClientNonce: "fsSPuaodKevKfDU3IeXa", + State: "st_someweirdstateid", + Code: "codeABC", + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + + var completeAuthResp4 structs.ACLOIDCCompleteAuthResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq4, &completeAuthResp4) + must.NoError(t, err) + must.NotNil(t, completeAuthResp4.ACLToken) + must.Len(t, 1, completeAuthResp4.ACLToken.Policies) + must.Eq(t, mockACLPolicy.Name, completeAuthResp4.ACLToken.Policies[0]) + must.Len(t, 1, completeAuthResp4.ACLToken.Roles) + must.Eq(t, mockACLRole.Name, completeAuthResp4.ACLToken.Roles[0].Name) + must.Eq(t, mockACLRole.ID, completeAuthResp4.ACLToken.Roles[0].ID) +} diff --git a/nomad/mock/acl.go b/nomad/mock/acl.go index 8cf9e6185..15786bfe6 100644 --- a/nomad/mock/acl.go +++ b/nomad/mock/acl.go @@ -231,6 +231,7 @@ func ACLAuthMethod() *structs.ACLAuthMethod { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCScopes: []string{"groups"}, BoundAudiences: []string{"audience1", "audience2"}, AllowedRedirectURIs: []string{"foo", "bar"}, DiscoveryCaPem: []string{"foo"}, diff --git a/nomad/server.go b/nomad/server.go index a796a3931..74f4770ce 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -36,6 +36,7 @@ import ( "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/helper/stats" "github.com/hashicorp/nomad/helper/tlsutil" + "github.com/hashicorp/nomad/lib/auth/oidc" "github.com/hashicorp/nomad/nomad/deploymentwatcher" "github.com/hashicorp/nomad/nomad/drainer" "github.com/hashicorp/nomad/nomad/state" @@ -255,6 +256,11 @@ type Server struct { // aclCache is used to maintain the parsed ACL objects aclCache *lru.TwoQueueCache + // oidcProviderCache maintains a cache of OIDC providers. This is useful as + // the provider performs background HTTP requests. When the Nomad server is + // shutting down, the oidcProviderCache.Shutdown() function must be called. + oidcProviderCache *oidc.ProviderCache + // leaderAcl is the management ACL token that is valid when resolved by the // current leader. leaderAcl string @@ -414,6 +420,11 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr } s.encrypter = encrypter + // Set up the OIDC provider cache. This is needed by the setupRPC, but must + // be done separately so that the server can stop all background processes + // when it shuts down itself. + s.oidcProviderCache = oidc.NewProviderCache() + // Initialize the RPC layer if err := s.setupRPC(tlsWrap); err != nil { s.Shutdown() @@ -720,6 +731,12 @@ func (s *Server) Shutdown() error { // Stop being able to set Configuration Entries s.consulConfigEntries.Stop() + // Shutdown the OIDC provider cache which contains background resources and + // processes. + if s.oidcProviderCache != nil { + s.oidcProviderCache.Shutdown() + } + return nil } diff --git a/nomad/structs/acl.go b/nomad/structs/acl.go index f0eeb962d..34ea1c9e5 100644 --- a/nomad/structs/acl.go +++ b/nomad/structs/acl.go @@ -149,6 +149,22 @@ const ( // Args: ACLBindingRuleRequest // Reply: ACLBindingRuleResponse ACLGetBindingRuleRPCMethod = "ACL.GetBindingRule" + + // ACLOIDCAuthURLRPCMethod is the RPC method for starting the OIDC login + // workflow. It generates the OIDC provider URL which will be used for user + // authentication. + // + // Args: ACLOIDCAuthURLRequest + // Reply: ACLOIDCAuthURLResponse + ACLOIDCAuthURLRPCMethod = "ACL.OIDCAuthURL" + + // ACLOIDCCompleteAuthRPCMethod is the RPC method for completing the OIDC + // login workflow. It exchanges the OIDC provider token for a Nomad ACL + // token with roles as defined within the remote provider. + // + // Args: ACLOIDCCompleteAuthRequest + // Reply: ACLOIDCCompleteAuthResponse + ACLOIDCCompleteAuthRPCMethod = "ACL.OIDCCompleteAuth" ) const ( @@ -168,11 +184,11 @@ const ( ) var ( - // validACLRoleName is used to validate an ACL role name. - validACLRoleName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") + // ValidACLRoleName is used to validate an ACL role name. + ValidACLRoleName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") // validACLAuthMethodName is used to validate an ACL auth method name. - validACLAuthMethod = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") + ValidACLAuthMethod = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") ) // ACLTokenRoleLink is used to link an ACL token to an ACL role. The ACL token @@ -406,7 +422,7 @@ func (a *ACLRole) Validate() error { var mErr multierror.Error - if !validACLRoleName.MatchString(a.Name) { + if !ValidACLRoleName.MatchString(a.Name) { mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid name '%s'", a.Name)) } @@ -777,7 +793,7 @@ func (a *ACLAuthMethod) Merge(b *ACLAuthMethod) { func (a *ACLAuthMethod) Validate(minTTL, maxTTL time.Duration) error { var mErr multierror.Error - if !validACLAuthMethod.MatchString(a.Name) { + if !ValidACLAuthMethod.MatchString(a.Name) { mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid name '%s'", a.Name)) } @@ -800,11 +816,16 @@ func (a *ACLAuthMethod) Validate(minTTL, maxTTL time.Duration) error { return mErr.ErrorOrNil() } +// TokenLocalityIsGlobal returns whether the auth method creates global ACL +// tokens or not. +func (a *ACLAuthMethod) TokenLocalityIsGlobal() bool { return a.TokenLocality == "global" } + // ACLAuthMethodConfig is used to store configuration of an auth method type ACLAuthMethodConfig struct { OIDCDiscoveryURL string OIDCClientID string OIDCClientSecret string + OIDCScopes []string BoundAudiences []string AllowedRedirectURIs []string DiscoveryCaPem []string @@ -821,6 +842,7 @@ func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { c := new(ACLAuthMethodConfig) *c = *a + c.OIDCScopes = slices.Clone(a.OIDCScopes) c.BoundAudiences = slices.Clone(a.BoundAudiences) c.AllowedRedirectURIs = slices.Clone(a.AllowedRedirectURIs) c.DiscoveryCaPem = slices.Clone(a.DiscoveryCaPem) @@ -829,6 +851,14 @@ func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { return c } +// ACLAuthClaims is the claim mapping of the OIDC auth method in a format that +// can be used with go-bexpr. This structure is used during rule binding +// evaluation. +type ACLAuthClaims struct { + Value map[string]string `bexpr:"value"` + List map[string][]string `bexpr:"list"` +} + // ACLAuthMethodStub is used for listing ACL auth methods type ACLAuthMethodStub struct { Name string @@ -916,7 +946,7 @@ type ACLWhoAmIResponse struct { // ACL Roles and Policies. type ACLBindingRule struct { - // ID is an internally generated UUID for this role and is controlled by + // ID is an internally generated UUID for this rule and is controlled by // Nomad. ID string @@ -1194,3 +1224,107 @@ type ACLBindingRuleResponse struct { ACLBindingRule *ACLBindingRule QueryMeta } + +// ACLOIDCAuthURLRequest is the request to make when starting the OIDC +// authentication login flow. +type ACLOIDCAuthURLRequest struct { + + // AuthMethodName is the OIDC auth-method to use. This is a required + // parameter. + AuthMethodName string + + // RedirectURI is the URL that authorization should redirect to. This is a + // required parameter. + RedirectURI string + + // ClientNonce is a randomly generated string to prevent replay attacks. It + // is up to the client to generate this and Go integrations should use the + // oidc.NewID function within the hashicorp/cap library. This must then be + // passed back to ACLOIDCCompleteAuthRequest. This is a required parameter. + ClientNonce string + + // WriteRequest is used due to the requirement by the RPC forwarding + // mechanism. This request doesn't write anything to Nomad's internal + // state. + WriteRequest +} + +// Validate ensures the request object contains all the required fields in +// order to start the OIDC authentication flow. +func (a *ACLOIDCAuthURLRequest) Validate() error { + + var mErr multierror.Error + + if a.AuthMethodName == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing auth method name")) + } + if a.ClientNonce == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing client nonce")) + } + if a.RedirectURI == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing redirect URI")) + } + return mErr.ErrorOrNil() +} + +// ACLOIDCAuthURLResponse is the response when starting the OIDC authentication +// login flow. +type ACLOIDCAuthURLResponse struct { + + // AuthURL is URL to begin authorization and is where the user logging in + // should go. + AuthURL string +} + +// ACLOIDCCompleteAuthRequest is the request object to begin completing the +// OIDC auth cycle after receiving the callback from the OIDC provider. +type ACLOIDCCompleteAuthRequest struct { + + // AuthMethodName is the name of the auth method being used to login via + // OIDC. This will match ACLOIDCAuthURLRequest.AuthMethodName. This is a + // required parameter. + AuthMethodName string + + // ClientNonce, State, and Code are provided from the parameters given to + // the redirect URL. These are all required parameters. + ClientNonce string + State string + Code string + + // RedirectURI is the URL that authorization should redirect to. This is a + // required parameter. + RedirectURI string + + WriteRequest +} + +// Validate ensures the request object contains all the required fields in +// order to complete the OIDC authentication flow. +func (a *ACLOIDCCompleteAuthRequest) Validate() error { + + var mErr multierror.Error + + if a.AuthMethodName == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing auth method name")) + } + if a.ClientNonce == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing client nonce")) + } + if a.State == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing state")) + } + if a.Code == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing code")) + } + if a.RedirectURI == "" { + mErr.Errors = append(mErr.Errors, errors.New("missing redirect URI")) + } + return mErr.ErrorOrNil() +} + +// ACLOIDCCompleteAuthResponse is the response when the OIDC auth flow has been +// completed successfully. +type ACLOIDCCompleteAuthResponse struct { + ACLToken *ACLToken + WriteMeta +} diff --git a/nomad/structs/acl_test.go b/nomad/structs/acl_test.go index b80184ec2..82b29b5df 100644 --- a/nomad/structs/acl_test.go +++ b/nomad/structs/acl_test.go @@ -1085,6 +1085,7 @@ func TestACLAuthMethodConfig_Copy(t *testing.T) { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCScopes: []string{"groups"}, BoundAudiences: []string{"audience1", "audience2"}, AllowedRedirectURIs: []string{"foo", "bar"}, DiscoveryCaPem: []string{"foo"}, @@ -1136,6 +1137,16 @@ func TestACLAuthMethod_Canonicalize(t *testing.T) { } } +func TestACLAuthMethod_TokenLocalityIsGlobal(t *testing.T) { + ci.Parallel(t) + + globalAuthMethod := &ACLAuthMethod{TokenLocality: "global"} + must.True(t, globalAuthMethod.TokenLocalityIsGlobal()) + + localAuthMethod := &ACLAuthMethod{TokenLocality: "local"} + must.False(t, localAuthMethod.TokenLocalityIsGlobal()) +} + func TestACLBindingRule_Canonicalize(t *testing.T) { ci.Parallel(t) @@ -1375,3 +1386,41 @@ func Test_ACLBindingRuleRequest(t *testing.T) { req := ACLBindingRuleRequest{} require.True(t, req.IsRead()) } + +func TestACLOIDCAuthURLRequest(t *testing.T) { + ci.Parallel(t) + + req := &ACLOIDCAuthURLRequest{} + must.False(t, req.IsRead()) +} + +func TestACLOIDCAuthURLRequest_Validate(t *testing.T) { + ci.Parallel(t) + + testRequest := &ACLOIDCAuthURLRequest{} + err := testRequest.Validate() + must.Error(t, err) + must.StrContains(t, err.Error(), "missing auth method name") + must.StrContains(t, err.Error(), "missing client nonce") + must.StrContains(t, err.Error(), "missing redirect URI") +} + +func TestACLOIDCCompleteAuthRequest(t *testing.T) { + ci.Parallel(t) + + req := &ACLOIDCCompleteAuthRequest{} + must.False(t, req.IsRead()) +} + +func TestACLOIDCCompleteAuthRequest_Validate(t *testing.T) { + ci.Parallel(t) + + testRequest := &ACLOIDCCompleteAuthRequest{} + err := testRequest.Validate() + must.Error(t, err) + must.StrContains(t, err.Error(), "missing auth method name") + must.StrContains(t, err.Error(), "missing client nonce") + must.StrContains(t, err.Error(), "missing state") + must.StrContains(t, err.Error(), "missing code") + must.StrContains(t, err.Error(), "missing redirect URI") +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index ca53fcb56..542112777 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -50,8 +50,8 @@ import ( ) var ( - // validPolicyName is used to validate a policy name - validPolicyName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") + // ValidPolicyName is used to validate a policy name + ValidPolicyName = regexp.MustCompile("^[a-zA-Z0-9-]{1,128}$") // b32 is a lowercase base32 encoding for use in URL friendly service hashes b32 = base32.NewEncoding(strings.ToLower("abcdefghijklmnopqrstuvwxyz234567")) @@ -11993,7 +11993,7 @@ func (a *ACLPolicy) Stub() *ACLPolicyListStub { func (a *ACLPolicy) Validate() error { var mErr multierror.Error - if !validPolicyName.MatchString(a.Name) { + if !ValidPolicyName.MatchString(a.Name) { err := fmt.Errorf("invalid name '%s'", a.Name) mErr.Errors = append(mErr.Errors, err) } diff --git a/ui/app/adapters/auth-method.js b/ui/app/adapters/auth-method.js index 42926c52c..586eb51ed 100644 --- a/ui/app/adapters/auth-method.js +++ b/ui/app/adapters/auth-method.js @@ -17,7 +17,7 @@ export default class AuthMethodAdapter extends ApplicationAdapter { /** * @typedef {Object} ACLOIDCAuthURLParams - * @property {string} AuthMethod + * @property {string} AuthMethodName * @property {string} RedirectUri * @property {string} ClientNonce * @property {Object[]} Meta // NOTE: unsure if array of objects or kv pairs @@ -27,11 +27,11 @@ export default class AuthMethodAdapter extends ApplicationAdapter { * @param {ACLOIDCAuthURLParams} params * @returns */ - getAuthURL({ AuthMethod, RedirectUri, ClientNonce, Meta }) { + getAuthURL({ AuthMethodName, RedirectUri, ClientNonce, Meta }) { const url = `/${this.namespace}/oidc/auth-url`; return this.ajax(url, 'POST', { data: { - AuthMethod, + AuthMethodName, RedirectUri, ClientNonce, Meta, diff --git a/ui/app/controllers/settings/tokens.js b/ui/app/controllers/settings/tokens.js index 1a5548bb1..c019087f8 100644 --- a/ui/app/controllers/settings/tokens.js +++ b/ui/app/controllers/settings/tokens.js @@ -90,13 +90,20 @@ export default class Tokens extends Controller { window.localStorage.setItem('nomadOIDCNonce', nonce); window.localStorage.setItem('nomadOIDCAuthMethod', provider); + let redirectURL; + if (Ember.testing) { + redirectURL = this.router.currentURL; + } else { + redirectURL = new URL(window.location.toString()); + redirectURL.search = ''; + redirectURL = redirectURL.href; + } + method .getAuthURL({ - AuthMethod: provider, + AuthMethodName: provider, ClientNonce: nonce, - RedirectUri: Ember.testing - ? this.router.currentURL - : window.location.toString(), + RedirectUri: redirectURL, }) .then(({ AuthURL }) => { if (Ember.testing) { @@ -111,7 +118,7 @@ export default class Tokens extends Controller { @tracked state = null; get isValidatingToken() { - if (this.code && this.state === 'success') { + if (this.code && this.state) { this.validateSSO(); return true; } else { @@ -120,25 +127,41 @@ export default class Tokens extends Controller { } async validateSSO() { + let redirectURL; + if (Ember.testing) { + redirectURL = this.router.currentURL; + } else { + redirectURL = new URL(window.location.toString()); + redirectURL.search = ''; + redirectURL = redirectURL.href; + } + const res = await this.token.authorizedRequest( '/v1/acl/oidc/complete-auth', { method: 'POST', body: JSON.stringify({ - AuthMethod: window.localStorage.getItem('nomadOIDCAuthMethod'), + AuthMethodName: window.localStorage.getItem('nomadOIDCAuthMethod'), ClientNonce: window.localStorage.getItem('nomadOIDCNonce'), Code: this.code, State: this.state, + RedirectURI: redirectURL, }), } ); if (res.ok) { const data = await res.json(); - this.token.set('secret', data.ACLToken); - this.verifyToken(); + this.clearTokenProperties(); + this.token.set('secret', data.SecretID); this.state = null; this.code = null; + + // Refetch the token and associated policies + this.get('token.fetchSelfTokenAndPolicies').perform().catch(); + + this.signInStatus = 'success'; + this.token.set('tokenNotFound', false); } else { this.state = 'failure'; this.code = null; diff --git a/ui/app/templates/settings/tokens.hbs b/ui/app/templates/settings/tokens.hbs index 33918cf82..dd8dc64c0 100644 --- a/ui/app/templates/settings/tokens.hbs +++ b/ui/app/templates/settings/tokens.hbs @@ -82,7 +82,7 @@ class="button is-primary" onclick={{action "redirectToSSO" method}} type="button" - >Sign in with with {{method.name}} + >Sign in with {{method.name}} {{/each}} diff --git a/ui/mirage/config.js b/ui/mirage/config.js index e81967249..86b7820f9 100644 --- a/ui/mirage/config.js +++ b/ui/mirage/config.js @@ -443,7 +443,7 @@ export default function () { return JSON.stringify(findLeader(schema)); }); - this.get('/acl/tokens', function ({tokens}, req) { + this.get('/acl/tokens', function ({ tokens }, req) { return this.serialize(tokens.all()); }); @@ -548,9 +548,14 @@ export default function () { this.delete('/acl/policy/:id', function (schema, request) { const { id } = request.params; - schema.tokens.all().models.filter(token => token.policyIds.includes(id)).forEach(token => { - token.update({ policyIds: token.policyIds.filter(pid => pid !== id) }); - }); + schema.tokens + .all() + .models.filter((token) => token.policyIds.includes(id)) + .forEach((token) => { + token.update({ + policyIds: token.policyIds.filter((pid) => pid !== id), + }); + }); server.db.policies.remove(id); return ''; }); @@ -566,7 +571,6 @@ export default function () { description: Description, rules: Rules, }); - }); this.get('/regions', function ({ regions }) { @@ -979,26 +983,37 @@ export default function () { return schema.authMethods.all(); }); this.post('/acl/oidc/auth-url', (schema, req) => { - const {AuthMethod, ClientNonce, RedirectUri, Meta} = JSON.parse(req.requestBody); - return new Response(200, {}, { - AuthURL: `/ui/oidc-mock?auth_method=${AuthMethod}&client_nonce=${ClientNonce}&redirect_uri=${RedirectUri}&meta=${Meta}` - }); + const { AuthMethodName, ClientNonce, RedirectUri, Meta } = JSON.parse( + req.requestBody + ); + return new Response( + 200, + {}, + { + AuthURL: `/ui/oidc-mock?auth_method=${AuthMethodName}&client_nonce=${ClientNonce}&redirect_uri=${RedirectUri}&meta=${Meta}`, + } + ); }); // Simulate an OIDC callback by assuming the code passed is the secret of an existing token, and return that token. - this.post('/acl/oidc/complete-auth', function (schema, req) { - const code = JSON.parse(req.requestBody).Code; - const token = schema.tokens.findBy({ - id: code - }); - - return new Response(200, {}, { - ACLToken: token.secretId - }); - }, {timing: 1000}); - - + this.post( + '/acl/oidc/complete-auth', + function (schema, req) { + const code = JSON.parse(req.requestBody).Code; + const token = schema.tokens.findBy({ + id: code, + }); + return new Response( + 200, + {}, + { + SecretID: token.secretId, + } + ); + }, + { timing: 1000 } + ); //#endregion SSO }