diff --git a/.changelog/25231.txt b/.changelog/25231.txt new file mode 100644 index 000000000..e60e78122 --- /dev/null +++ b/.changelog/25231.txt @@ -0,0 +1,3 @@ +```release-note:improvement +oidc: Add private key JWT / client assertion option to auth flow, and default enable PKCE +``` diff --git a/api/acl.go b/api/acl.go index 88a64a107..dcedee07a 100644 --- a/api/acl.go +++ b/api/acl.go @@ -826,6 +826,10 @@ type ACLAuthMethodConfig struct { OIDCClientID string // The OAuth Client Secret configured with the OIDC provider OIDCClientSecret string + // Optionally send a signed JWT ("private key jwt") as a client assertion + OIDCClientAssertion *OIDCClientAssertion + // Disable S256 PKCE challenge verification + OIDCDisablePKCE *bool // Disable claims from the OIDC UserInfo endpoint OIDCDisableUserInfo bool // List of OIDC scopes @@ -950,6 +954,121 @@ func (c *ACLAuthMethodConfig) UnmarshalJSON(data []byte) error { return nil } +// OIDCClientAssertionKeySource specifies what key material should be used +// to sign an OIDCClientAssertion. +type OIDCClientAssertionKeySource string + +const ( + // OIDCKeySourceNomad signs the OIDCClientAssertion JWT with Nomad's + // internal private key. Its public key is exposed at /.well-known/jwks.json + OIDCKeySourceNomad OIDCClientAssertionKeySource = "nomad" + // OIDCKeySourcePrivateKey signs the OIDCClientAssertion JWT with + // key material defined in OIDCClientAssertion.PrivateKey + OIDCKeySourcePrivateKey OIDCClientAssertionKeySource = "private_key" + // OIDCKeySourceClientSecret signs the OIDCClientAssertion JWT with + // ACLAuthMethod.ClientSecret + OIDCKeySourceClientSecret OIDCClientAssertionKeySource = "client_secret" +) + +// OIDCClientAssertion (a.k.a private_key_jwt) is used to send +// a client_assertion along with an OIDC token request. +// See also: structs.OIDCClientAssertion +type OIDCClientAssertion struct { + // KeySource is where to get the private key to sign the JWT. + // It is the one field that *must* be set to enable client assertions. + // Available sources: + // * "nomad" = Use current active key in Nomad's keyring + // * "private_key" = Use key material in the PrivateKey field of this struct + // * "client_secret" = Use the OIDCClientSecret inherited from the parent + // ACLAuthMethodConfig struct + KeySource OIDCClientAssertionKeySource + + // Audience is/are who will be processing the assertion. + // Defaults to the parent ACLAuthMethodConfig's OIDCDiscoveryURL + Audience []string + + // PrivateKey contains external key material provided by users. + // KeySource must be "private_key" to enable this. + PrivateKey *OIDCClientAssertionKey + + // KeyAlgorithm is the key's algorithm. + // Its default values are based on the KeySource: + // * nomad = "RS256" -- pulled from the keyring + // * private_key = "RS256" + // * client_secret = "HS256" + // Only RSA algorithms are supported for nomad and private_key. + KeyAlgorithm string + + // ExtraHeaders are added to the JWT headers, alongside "kid" and "type" + // Setting the "kid" header here is not allowed; use PrivateKey.KeyID. + ExtraHeaders map[string]string +} + +// OIDCClientAssertionKeyIDHeader is the header that the OIDC provider will use +// to look up the certificate or public key that it needs to verify the +// private key JWT signature. +type OIDCClientAssertionKeyIDHeader string + +const ( + OIDCClientAssertionHeaderKid OIDCClientAssertionKeyIDHeader = "kid" + OIDCClientAssertionHeaderX5t OIDCClientAssertionKeyIDHeader = "x5t" + OIDCClientAssertionHeaderX5tS256 OIDCClientAssertionKeyIDHeader = "x5t#S256" +) + +// OIDCClientAssertionKey contains key material provided by users for Nomad +// to use to sign the private key JWT. +// +// PemKey or PemKeyFile must contain an RSA private key in PEM format. +// +// PemCert, PemCertFile may contain an x509 certificate created with +// the Key, used to derive the KeyID. Alternatively, KeyID may be set manually. +// +// PemKeyFile and PemCertFile, if set, must be present on disk on any Nomad +// servers that may become cluster leaders. +type OIDCClientAssertionKey struct { + // PemKey is the private key, in pem format. It is used to sign the JWT. + // Mutually exclusive with PemKeyFile. + PemKey string + // PemKeyFile is the path to a private key on server disk, in pem format. + // It is used to sign the JWT. + // Mutually exclusive with PemKey. + PemKeyFile string + + // KeyIDHeader is which header to set for they provider to identify the + // public key to use to verify the signed JWT. Its default values vary + // based on which of the other required fields is set: + // KeyID: "kid" + // PemCert: "x5t#S256" + // PemCertFile: "x5t#S256" + // + // Valid values are: "kid", "x5t", "x5t#S256" + // If "x5t" is selected, Nomad uses sha1 to derive the x5t header + // from the provided certificate. + // + // Refer to the RFC for more information on JWT key headers: + // "kid": https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.4 + // "x5t": https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.7 + // "x5t#S256": https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.8 + // + // If you need to set some other header not supported here, + // you may use OIDCClientAssertion.ExtraHeaders. + KeyIDHeader OIDCClientAssertionKeyIDHeader + // KeyID may be set manually and becomes the "kid" header. + // Mutually exclusive with PemCert and PemCertFile. + // Allowed KeyIDHeader values: "kid" (the default) + KeyID string + // PemCert is a certificate, signed by the private key or a CA, + // in pem format. It is used to derive an x5t-style KeyID. + // Mutually exclusive with PemCertFile and KeyID. + // Allowed KeyIDHeader values: "x5t", "x5t#S256" (default "x5t#S256") + PemCert string + // PemCertFile is a certificate, signed by the private key or a CA, + // on server disk, in pem format. It is used to derive an x5t-style KeyID. + // Mutually exclusive with PemCert and KeyID. + // Allowed KeyIDHeader values: "x5t", "x5t#S256" (default "x5t#S256") + PemCertFile string +} + // ACLAuthMethodListStub is the stub object returned when performing a listing // of ACL auth-methods. It is intentionally minimal due to the unauthenticated // nature of the list endpoint. diff --git a/api/acl_test.go b/api/acl_test.go index 27f058dbe..636ebe01f 100644 --- a/api/acl_test.go +++ b/api/acl_test.go @@ -604,6 +604,11 @@ func TestACLAuthMethods(t *testing.T) { TokenLocality: ACLAuthMethodTokenLocalityLocal, MaxTokenTTL: 15 * time.Minute, Default: true, + Config: &ACLAuthMethodConfig{ + BoundAudiences: []string{"test-aud"}, + OIDCDiscoveryURL: "https://example.com", + OIDCClientID: "test-client-id", + }, } _, writeMeta, err := testClient.ACLAuthMethods().Create(&authMethod, nil) must.NoError(t, err) @@ -664,6 +669,11 @@ func TestACLBindingRules(t *testing.T) { TokenLocality: ACLAuthMethodTokenLocalityGlobal, MaxTokenTTL: 10 * time.Hour, Default: true, + Config: &ACLAuthMethodConfig{ + BoundAudiences: []string{"test-aud"}, + OIDCDiscoveryURL: "https://example.com", + OIDCClientID: "test-client-id", + }, } _, _, err := testClient.ACLAuthMethods().Create(&aclAuthMethod, nil) must.NoError(t, err) diff --git a/command/acl_auth_method_create_test.go b/command/acl_auth_method_create_test.go index b7a46808a..0b5e69c59 100644 --- a/command/acl_auth_method_create_test.go +++ b/command/acl_auth_method_create_test.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/command/agent" "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test" "github.com/shoenig/test/must" ) @@ -67,9 +68,10 @@ func TestACLAuthMethodCreateCommand_Run(t *testing.T) { args := []string{ "-address=" + url, "-token=" + rootACLToken.SecretID, "-name=acl-auth-method-cli-test", "-type=OIDC", "-token-locality=global", "-default=true", "-max-token-ttl=3600s", - "-config={\"OIDCDiscoveryURL\":\"http://example.com\", \"ExpirationLeeway\": \"1h\"}", + `-config={"OIDCDiscoveryURL":"http://example.com", "OIDCClientID": "example-id", "BoundAudiences": ["example-aud"], "ExpirationLeeway": "1h"}`, } - must.Eq(t, 0, cmd.Run(args)) + test.Eq(t, 0, cmd.Run(args)) + test.Eq(t, "", ui.ErrorWriter.String()) s := ui.OutputWriter.String() must.StrContains(t, s, "acl-auth-method-cli-test") @@ -81,7 +83,11 @@ func TestACLAuthMethodCreateCommand_Run(t *testing.T) { defer os.Remove(configFile.Name()) must.Nil(t, err) - conf := map[string]interface{}{"OIDCDiscoveryURL": "http://example.com"} + conf := map[string]interface{}{ + "OIDCDiscoveryURL": "http://example.com", + "OIDCClientID": "example-id", + "BoundAudiences": []string{"example-aud"}, + } jsonData, err := json.Marshal(conf) must.Nil(t, err) @@ -93,7 +99,8 @@ func TestACLAuthMethodCreateCommand_Run(t *testing.T) { "-type=OIDC", "-token-locality=global", "-default=false", "-max-token-ttl=3600s", fmt.Sprintf("-config=@%s", configFile.Name()), } - must.Eq(t, 0, cmd.Run(args)) + test.Eq(t, 0, cmd.Run(args)) + test.Eq(t, "", ui.ErrorWriter.String()) s = ui.OutputWriter.String() must.StrContains(t, s, "acl-auth-method-cli-test") diff --git a/command/acl_auth_method_update_test.go b/command/acl_auth_method_update_test.go index 2fc9ebd40..9ec070926 100644 --- a/command/acl_auth_method_update_test.go +++ b/command/acl_auth_method_update_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/nomad/command/agent" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" + "github.com/shoenig/test" "github.com/shoenig/test/must" ) @@ -64,6 +65,8 @@ func TestACLAuthMethodUpdateCommand_Run(t *testing.T) { TokenLocality: "local", Config: &structs.ACLAuthMethodConfig{ OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "example-id", + BoundAudiences: []string{"example-aud"}, }, } method.SetHash() @@ -80,7 +83,8 @@ func TestACLAuthMethodUpdateCommand_Run(t *testing.T) { // Update the token locality code = cmd.Run([]string{ "-address=" + url, "-token=" + rootACLToken.SecretID, "-token-locality=global", method.Name}) - must.Zero(t, code) + test.Zero(t, code) + test.Eq(t, "", ui.ErrorWriter.String()) s := ui.OutputWriter.String() must.StrContains(t, s, method.Name) @@ -92,7 +96,11 @@ func TestACLAuthMethodUpdateCommand_Run(t *testing.T) { defer os.Remove(configFile.Name()) must.Nil(t, err) - conf := map[string]interface{}{"OIDCDiscoveryURL": "http://example.com"} + conf := map[string]interface{}{ + "OIDCDiscoveryURL": "http://example.com", + "OIDCClientID": "example-id", + "BoundAudiences": []string{"example-aud"}, + } jsonData, err := json.Marshal(conf) must.Nil(t, err) @@ -105,7 +113,8 @@ func TestACLAuthMethodUpdateCommand_Run(t *testing.T) { fmt.Sprintf("-config=@%s", configFile.Name()), method.Name, }) - must.Zero(t, code) + test.Zero(t, code) + test.Eq(t, "", ui.ErrorWriter.String()) s = ui.OutputWriter.String() must.StrContains(t, s, method.Name) diff --git a/command/agent/acl_endpoint_test.go b/command/agent/acl_endpoint_test.go index 7123a0784..4e80c79eb 100644 --- a/command/agent/acl_endpoint_test.go +++ b/command/agent/acl_endpoint_test.go @@ -1823,7 +1823,6 @@ func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) { oidcTestProvider.SetExpectedAuthNonce("fpSPuaodKevKfDU3IeXa") oidcTestProvider.SetExpectedAuthCode("codeABC") oidcTestProvider.SetCustomAudience("mock") - oidcTestProvider.SetExpectedState("st_someweirdstateid") oidcTestProvider.SetCustomClaims(map[string]interface{}{ "azp": "mock", "http://nomad.internal/policies": []string{"engineering"}, @@ -1834,7 +1833,7 @@ func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) { requestBody := structs.ACLOIDCCompleteAuthRequest{ AuthMethodName: mockedAuthMethod.Name, ClientNonce: "fpSPuaodKevKfDU3IeXa", - State: "st_someweirdstateid", + State: "overwrite me", Code: "codeABC", RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], WriteRequest: structs.WriteRequest{ @@ -1842,6 +1841,10 @@ func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) { }, } + // Request Auth URL first, as a user would. This primes the + // request cache on the server, returns the expected state. + requestBody.State = requestAuthState(t, testAgent.Server, mockedAuthMethod, requestBody.ClientNonce) + // Build the HTTP request. req, err := http.NewRequest(http.MethodPost, "/v1/acl/oidc/complete-auth", encodeReq(&requestBody)) must.NoError(t, err) @@ -1877,6 +1880,10 @@ func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) { must.NoError(t, testAgent.server.State().UpsertACLBindingRules( 40, []*structs.ACLBindingRule{mockBindingRule1, mockBindingRule2}, true)) + // Request Auth URL first, as a user would. This primes the + // request cache on the server, returns the expected state. + requestBody.State = requestAuthState(t, testAgent.Server, mockedAuthMethod, requestBody.ClientNonce) + // Build the HTTP request. req, err = http.NewRequest(http.MethodPost, "/v1/acl/oidc/complete-auth", encodeReq(&requestBody)) must.NoError(t, err) @@ -2036,3 +2043,22 @@ func TestHTTPServer_ACLLoginRequest(t *testing.T) { }) } } + +// requestAuthState hits the oidc/auth-url endpoint, as a user would during +// normal login, before a subsequent call to oidc/complete-auth. Returns the +// "state" generated by the server, to be used in ACLOIDCCompleteAuthRequest. +func requestAuthState(t *testing.T, server *HTTPServer, authMethod *structs.ACLAuthMethod, nonce string) string { + t.Helper() + urlReq, err := http.NewRequest(http.MethodPost, "/v1/acl/oidc/auth-url", encodeReq(&structs.ACLOIDCAuthURLRequest{ + AuthMethodName: authMethod.Name, + RedirectURI: authMethod.Config.AllowedRedirectURIs[0], + ClientNonce: nonce, + WriteRequest: structs.WriteRequest{Region: "global"}, + })) + must.NoError(t, err) + authURLResp, err := server.ACLOIDCAuthURLRequest(httptest.NewRecorder(), urlReq) + must.NoError(t, err) + u, err := url.Parse(authURLResp.(structs.ACLOIDCAuthURLResponse).AuthURL) + must.NoError(t, err) + return u.Query().Get("state") +} diff --git a/internal/testing/apitests/acl_test.go b/internal/testing/apitests/acl_test.go index b4969e3cf..a84aaf716 100644 --- a/internal/testing/apitests/acl_test.go +++ b/internal/testing/apitests/acl_test.go @@ -11,6 +11,7 @@ import ( capOIDC "github.com/hashicorp/cap/oidc" "github.com/hashicorp/nomad/api" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/pointer" "github.com/shoenig/test/must" ) @@ -94,9 +95,12 @@ func TestACLOIDC_CompleteAuth(t *testing.T) { MaxTokenTTL: 10 * time.Hour, Default: true, Config: &api.ACLAuthMethodConfig{ - OIDCDiscoveryURL: oidcTestProvider.Addr(), - OIDCClientID: "mock", - OIDCClientSecret: "verysecretsecret", + OIDCDiscoveryURL: oidcTestProvider.Addr(), + OIDCClientID: "mock", + OIDCClientSecret: "verysecretsecret", + // PKCE is hard to test at this level, because the verifier only + // exists on the server. this functionality is covered elsewhere. + OIDCDisablePKCE: pointer.Of(true), OIDCDisableUserInfo: false, BoundAudiences: []string{"mock"}, AllowedRedirectURIs: []string{"http://127.0.0.1:4649/oidc/callback"}, @@ -120,7 +124,6 @@ func TestACLOIDC_CompleteAuth(t *testing.T) { oidcTestProvider.SetExpectedAuthNonce("fpSPuaodKevKfDU3IeXb") oidcTestProvider.SetExpectedAuthCode("codeABC") oidcTestProvider.SetCustomAudience("mock") - oidcTestProvider.SetExpectedState("st_someweirdstateid") oidcTestProvider.SetCustomClaims(map[string]interface{}{ "azp": "mock", "http://nomad.internal/policies": []string{"engineering"}, @@ -166,12 +169,24 @@ func TestACLOIDC_CompleteAuth(t *testing.T) { must.NoError(t, err) must.NotNil(t, createBindingRole2Resp) + // Request Auth URL first, as a user would. This primes the request cache + // on the server, and the response includes the expected state. + authURLresp, _, err := testClient.ACLAuth().GetAuthURL(&api.ACLOIDCAuthURLRequest{ + AuthMethodName: createdAuthMethod.Name, + RedirectURI: createdAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "fpSPuaodKevKfDU3IeXb", + }, nil) + must.NoError(t, err) + u, err := url.Parse(authURLresp.AuthURL) + must.NoError(t, err) + state := u.Query().Get("state") + // Generate and make the request. authURLRequest := api.ACLOIDCCompleteAuthRequest{ AuthMethodName: createdAuthMethod.Name, RedirectURI: createdAuthMethod.Config.AllowedRedirectURIs[0], ClientNonce: "fpSPuaodKevKfDU3IeXb", - State: "st_someweirdstateid", + State: state, Code: "codeABC", } diff --git a/lib/auth/oidc/client_assertion.go b/lib/auth/oidc/client_assertion.go new file mode 100644 index 000000000..7b2054aa4 --- /dev/null +++ b/lib/auth/oidc/client_assertion.go @@ -0,0 +1,178 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package oidc + +import ( + "crypto/rsa" + // sha1 is used to derive an "x5t" jwt header from an x509 certificate, + // per the OIDC JWS spec: + // https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.7 + // sha1 is not a security risk here, but it is less reliable than sha256 + // (for "x5t#S256" headers) in terms of possible value collisions, + // so "x5t" must be set explicitly by the user in their auth method config + // if their provider does not allow "x5t#S256" (the default). + // None of this applies if the user sets the KeyID ("kid" header) manually. + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "errors" + "fmt" + "hash" + "os" + "time" + + gojwt "github.com/golang-jwt/jwt/v5" + cass "github.com/hashicorp/cap/oidc/clientassertion" + "github.com/hashicorp/nomad/nomad/structs" +) + +func BuildClientAssertionJWT(config *structs.ACLAuthMethodConfig, nomadKey *rsa.PrivateKey, nomadKID string) (*cass.JWT, error) { + // should already be validated by caller, but just in case. + if config == nil || config.OIDCClientAssertion == nil { + return nil, errors.New("no auth method config or client assertion") + } + + // this is all we use config for + clientID := config.OIDCClientID + // client assertion-specific info is in here + as := config.OIDCClientAssertion + + // this should have also happened long before, but again, just in case. + if err := as.Validate(); err != nil { + return nil, err + } + + opts := []cass.Option{ + cass.WithHeaders(as.ExtraHeaders), + } + + switch as.KeySource { + + case structs.OIDCKeySourceClientSecret: + algo := cass.HSAlgorithm(as.KeyAlgorithm) + return cass.NewJWTWithHMAC(clientID, as.Audience, algo, as.ClientSecret, opts...) + + case structs.OIDCKeySourceNomad: + opts = append(opts, cass.WithKeyID(nomadKID)) + return cass.NewJWTWithRSAKey(clientID, as.Audience, cass.RS256, nomadKey, opts...) + + case structs.OIDCKeySourcePrivateKey: + algo := cass.RSAlgorithm(as.KeyAlgorithm) + rsaKey, err := getCassPrivateKey(as.PrivateKey) + if err != nil { + return nil, err + } + + if as.PrivateKey.KeyID != "" { + // if the user provides a verbatim KeyID, set it as "kid" header + opts = append(opts, + cass.WithKeyID(as.PrivateKey.KeyID), + ) + } else { + // otherwise, derive it from the cert + cert, err := getCassCert(as.PrivateKey) + if err != nil { + return nil, err + } + keyID, err := hashKeyID(cert, as.PrivateKey.KeyIDHeader) + if err != nil { + return nil, err + } + opts = append(opts, cass.WithHeaders(map[string]string{ + string(as.PrivateKey.KeyIDHeader): keyID, + })) + } + return cass.NewJWTWithRSAKey(clientID, as.Audience, algo, rsaKey, opts...) + + default: // this shouldn't happen, but just in case + return nil, fmt.Errorf("unknown OIDC KeySource %q", as.KeySource) + } +} + +// getCassPrivateKey parses the structs.OIDCClientAssertionKey PemKeyFile +// or PemKey, depending on which is set. +func getCassPrivateKey(k *structs.OIDCClientAssertionKey) (key *rsa.PrivateKey, err error) { + var bts []byte + var source string // for informative error messages + + // pem file on disk + if k.PemKeyFile != "" { + source = "PemKeyFile" + bts, err = os.ReadFile(k.PemKeyFile) + if err != nil { + return nil, fmt.Errorf("error reading %s: %w", source, err) + } + } + // or pem string + if k.PemKey != "" { + source = "PemKey" + bts = []byte(k.PemKey) + } + + key, err = gojwt.ParseRSAPrivateKeyFromPEM(bts) + if err != nil { + return nil, fmt.Errorf("error parsing %s: %w", source, err) + } + if err := key.Validate(); err != nil { + return nil, fmt.Errorf("error validating %s: %w", source, err) + } + return key, nil +} + +// getCassCert parses the structs.OIDCClientAssertionKey PemCertFile +// or PemCert, depending on which is set. +func getCassCert(k *structs.OIDCClientAssertionKey) (*x509.Certificate, error) { + var bts []byte + var err error + var source string // for informative error messages + + // pem file on disk + if k.PemCertFile != "" { + source = "PemCertFile" + bts, err = os.ReadFile(k.PemCertFile) + if err != nil { + return nil, fmt.Errorf("error reading %s: %w", source, err) + } + } + // or pem string + if k.PemCert != "" { + source = "PemCert" + bts = []byte(k.PemCert) + } + + block, _ := pem.Decode(bts) + if block == nil { + return nil, fmt.Errorf("failed to decode %s PEM block", source) + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse %s bytes: %w", source, err) + } + now := time.Now() + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return nil, errors.New("cert expired or not yet valid") + } + return cert, nil +} + +// hashKeyID derives a "certificate thumbprint" that the OIDC provider uses +// to find the certificate to verify the private key JWT signature. +// https://datatracker.ietf.org/doc/html/rfc7515#section-4.1.7 +func hashKeyID(cert *x509.Certificate, header structs.OIDCClientAssertionKeyIDHeader) (string, error) { + var hasher hash.Hash + switch header { + case structs.OIDCClientAssertionHeaderX5t: + hasher = sha1.New() + case structs.OIDCClientAssertionHeaderX5tS256: + hasher = sha256.New() + default: + // this should be validated long before here, at upsert + return "", fmt.Errorf(`%w; must be one of: "x5t", "x5t#S256"`, structs.ErrInvalidKeyIDHeader) + } + hasher.Write(cert.Raw) + hashed := hasher.Sum(nil) + return base64.RawURLEncoding.EncodeToString(hashed), nil +} diff --git a/lib/auth/oidc/request.go b/lib/auth/oidc/request.go new file mode 100644 index 000000000..c9d43bf1a --- /dev/null +++ b/lib/auth/oidc/request.go @@ -0,0 +1,83 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package oidc + +import ( + "context" + "errors" + "sync" + "time" + + //"github.com/coreos/go-oidc/v3/oidc" + "github.com/hashicorp/cap/oidc" +) + +var ErrNonceReuse = errors.New("nonce reuse detected") + +// expiringRequest ensures that OIDC requests that are only partially fulfilled +// do not get stuck in memory forever. +type expiringRequest struct { + // req is what we actually care about + req *oidc.Req + // ctx lets us clean up stale requests automatically + ctx context.Context + cancel context.CancelFunc +} + +// NewRequestCache creates a cache for OIDC requests. +func NewRequestCache() *RequestCache { + return &RequestCache{ + m: sync.Map{}, + // the JWT expiration time in cap library is 5 minutes, + // so auto-delete from our request cache after 6. + timeout: 6 * time.Minute, + } +} + +type RequestCache struct { + m sync.Map + timeout time.Duration +} + +// Store saves the request, to be Loaded later with its Nonce. +// If LoadAndDelete is not called, the stale request will be auto-deleted. +func (rc *RequestCache) Store(ctx context.Context, req *oidc.Req) error { + ctx, cancel := context.WithTimeout(ctx, rc.timeout) + er := &expiringRequest{ + req: req, + ctx: ctx, + cancel: cancel, + } + if _, loaded := rc.m.LoadOrStore(req.Nonce(), er); loaded { + // we already had a request for this nonce, which should never happen, + // so cancel the new request and error to notify caller of a bug. + cancel() + return ErrNonceReuse + } + // auto-delete after timeout or context canceled + go func() { + <-ctx.Done() + rc.m.Delete(req.Nonce()) + }() + return nil +} + +func (rc *RequestCache) Load(nonce string) *oidc.Req { + if er, ok := rc.m.Load(nonce); ok { + return er.(*expiringRequest).req + } + return nil +} + +func (rc *RequestCache) LoadAndDelete(nonce string) *oidc.Req { + if er, loaded := rc.m.LoadAndDelete(nonce); loaded { + // there is a tiny race condition here. if by massive coincidence, + // or a bug, the same nonce makes its way in here, this cancel() + // triggers a map Delete() up in Store(), which could delete a request + // out from under a subsequent Store() + er.(*expiringRequest).cancel() + return er.(*expiringRequest).req + } + return nil +} diff --git a/lib/auth/oidc/request_test.go b/lib/auth/oidc/request_test.go new file mode 100644 index 000000000..81780724c --- /dev/null +++ b/lib/auth/oidc/request_test.go @@ -0,0 +1,96 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package oidc + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/cap/oidc" + "github.com/shoenig/test/must" + "github.com/shoenig/test/wait" +) + +func TestRequestCache(t *testing.T) { + // using a top-level cache and running each sub-test in parallel exercises + // a little bit of thread safety. + rc := NewRequestCache() + + t.Run("reuse nonce", func(t *testing.T) { + t.Parallel() + req := getRequest(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + must.NoError(t, rc.Store(ctx, req)) + must.ErrorIs(t, rc.Store(ctx, req), ErrNonceReuse) + }) + + t.Run("cancel parent ctx", func(t *testing.T) { + t.Parallel() + req := getRequest(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + must.NoError(t, rc.Store(ctx, req)) + must.Eq(t, req, rc.Load(req.Nonce())) + + cancel() // triggers delete + waitUntilGone(t, rc, req.Nonce()) + }) + + t.Run("load and delete", func(t *testing.T) { + t.Parallel() + req := getRequest(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + must.NoError(t, rc.Store(ctx, req)) + must.Eq(t, req, rc.Load(req.Nonce())) + + must.Eq(t, req, rc.LoadAndDelete(req.Nonce())) // triggers delete + waitUntilGone(t, rc, req.Nonce()) + must.Nil(t, rc.LoadAndDelete(req.Nonce())) + }) + + t.Run("timeout", func(t *testing.T) { + // this test needs its own cache to reduce the timeout + // without affecting any other tests. + rc := NewRequestCache() + rc.timeout = time.Millisecond + + req := getRequest(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + must.NoError(t, rc.Store(ctx, req)) + + // timeout triggers delete behind the scenes + waitUntilGone(t, rc, req.Nonce()) + }) +} + +func getRequest(t *testing.T) *oidc.Req { + t.Helper() + nonce := t.Name() + req, err := oidc.NewRequest(time.Minute, "test-redirect-url", + oidc.WithNonce(nonce)) + must.NoError(t, err) + return req +} + +func waitUntilGone(t *testing.T, rc *RequestCache, nonce string) { + t.Helper() + must.Wait(t, + wait.InitialSuccess( + wait.Timeout(100*time.Millisecond), // should be much faster + wait.Gap(10*time.Millisecond), + wait.BoolFunc(func() bool { + return rc.Load(nonce) == nil + }), + ), + must.Sprint("request should have gone away"), + ) +} diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index a80cb5837..af46a50d6 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -15,10 +15,12 @@ import ( "time" capOIDC "github.com/hashicorp/cap/oidc" + cass "github.com/hashicorp/cap/oidc/clientassertion" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-memdb" metrics "github.com/hashicorp/go-metrics/compat" "github.com/hashicorp/go-set/v3" + "github.com/hashicorp/nomad/helper/pointer" policy "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/helper" @@ -69,6 +71,11 @@ type ACL struct { // hashicorp/cap library. When performing an OIDC login flow, this cache // should be used to obtain a provider from an auth-method. oidcProviderCache *oidc.ProviderCache + + // oidcRequestCache stores a cache of OIDC requests, so request state + // (mainly PKCE challenge/verification) can persist between calls to + // OIDCAuthURL and OIDCCompleteAuth. + oidcRequestCache *oidc.RequestCache } func NewACLEndpoint(srv *Server, ctx *RPCContext) *ACL { @@ -77,6 +84,7 @@ func NewACLEndpoint(srv *Server, ctx *RPCContext) *ACL { ctx: ctx, logger: srv.logger.Named("acl"), oidcProviderCache: srv.oidcProviderCache, + oidcRequestCache: srv.oidcRequestCache, } } @@ -1884,6 +1892,8 @@ func (a *ACL) UpsertAuthMethods( existingMethod, _ := stateSnapshot.GetACLAuthMethodByName(nil, authMethod.Name) authMethod.Merge(existingMethod) + authMethod.Canonicalize() + if err := authMethod.Validate( a.srv.config.ACLTokenMinExpirationTTL, a.srv.config.ACLTokenMaxExpirationTTL); err != nil { @@ -1901,7 +1911,21 @@ func (a *ACL) UpsertAuthMethods( ) } } - authMethod.Canonicalize() + + // if PKCE is not explicitly disabled, enable it. + if authMethod.Config.OIDCDisablePKCE == nil { + authMethod.Config.OIDCDisablePKCE = pointer.Of(false) + } + // if there is a client assertion, ensure it is valid. + if authMethod.Config.OIDCClientAssertion.IsSet() { + _, err := a.oidcClientAssertion(authMethod.Config) + if err != nil { + return structs.NewErrRPCCodedf( + http.StatusBadRequest, "invalid OIDCClientAssertion: %s", err, + ) + } + } + authMethod.SetHash() } @@ -2595,21 +2619,15 @@ func (a *ACL) OIDCAuthURL(args *structs.ACLOIDCAuthURLRequest, reply *structs.AC } // Generate our OIDC request. - oidcReqOpts := []capOIDC.Option{ - capOIDC.WithNonce(args.ClientNonce), - } - - if len(authMethod.Config.OIDCScopes) > 0 { - oidcReqOpts = append(oidcReqOpts, capOIDC.WithScopes(authMethod.Config.OIDCScopes...)) - } - - oidcReq, err := capOIDC.NewRequest( - aclOIDCAuthURLRequestExpiryTime, - args.RedirectURI, - oidcReqOpts..., - ) - if err != nil { - return fmt.Errorf("failed to generate OIDC request: %v", err) + oidcReq := a.oidcRequestCache.Load(args.ClientNonce) + if oidcReq == nil { + oidcReq, err = a.oidcRequest(args.ClientNonce, args.RedirectURI, authMethod.Config) + if err != nil { + return err + } + if err = a.oidcRequestCache.Store(a.srv.shutdownCtx, oidcReq); err != nil { + return fmt.Errorf("error storing OIDC request: %w", err) + } } // Use the cache to provide us with an OIDC provider for the auth method @@ -2708,22 +2726,12 @@ func (a *ACL) OIDCCompleteAuth( return fmt.Errorf("failed to generate OIDC provider: %v", err) } - // Build our OIDC request options and request object. - oidcReqOpts := []capOIDC.Option{ - capOIDC.WithNonce(args.ClientNonce), - capOIDC.WithState(args.State), - } - - if len(authMethod.Config.OIDCScopes) > 0 { - oidcReqOpts = append(oidcReqOpts, capOIDC.WithScopes(authMethod.Config.OIDCScopes...)) - } - if len(authMethod.Config.BoundAudiences) > 0 { - oidcReqOpts = append(oidcReqOpts, capOIDC.WithAudiences(authMethod.Config.BoundAudiences...)) - } - - oidcReq, err := capOIDC.NewRequest(aclOIDCCallbackRequestExpiryTime, args.RedirectURI, oidcReqOpts...) - if err != nil { - return fmt.Errorf("failed to generate OIDC request: %v", err) + // Retrieve the request generated in OIDCAuthURL() + oidcReq := a.oidcRequestCache.LoadAndDelete(args.ClientNonce) // I am so done with this NONCENSE + if oidcReq == nil { + // note: this may happen if there is a leader election between getting + // the auth url and completing the login flow here. + return errors.New("no OIDC request found for client nonce") } // Generate a context with a deadline. This is passed to the OIDC provider @@ -3045,3 +3053,73 @@ func formatTokenName(format, authType, authName string, claims map[string]string return tokenName, nil } + +// oidcRequest builds the request to send to the cap library. +// The way the cap lib is structured, you can build the request once, +// and use it for different request types. +func (a *ACL) oidcRequest(nonce, redirect string, config *structs.ACLAuthMethodConfig) (*capOIDC.Req, error) { + opts := []capOIDC.Option{ + capOIDC.WithNonce(nonce), + } + + if len(config.OIDCScopes) > 0 { + opts = append(opts, capOIDC.WithScopes(config.OIDCScopes...)) + } + if len(config.BoundAudiences) > 0 { + opts = append(opts, capOIDC.WithAudiences(config.BoundAudiences...)) + } + + if config.OIDCDisablePKCE != nil && !*config.OIDCDisablePKCE { + verifier, err := capOIDC.NewCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to make pkce verifier: %w", err) + } + opts = append(opts, capOIDC.WithPKCE(verifier)) + } + + if config.OIDCClientAssertion.IsSet() { + j, err := a.oidcClientAssertion(config) + if err != nil { + return nil, err + } + opts = append(opts, capOIDC.WithClientAssertionJWT(j)) + } + + req, err := capOIDC.NewRequest( + aclOIDCAuthURLRequestExpiryTime, + redirect, + opts..., + ) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC request: %v", err) + } + + return req, nil +} + +func (a *ACL) oidcClientAssertion(config *structs.ACLAuthMethodConfig) (*cass.JWT, error) { + // this nomad key will only actually be used if the client assertion config + // KeySource = "nomad", but we get it here to avoid exposing more of the + // codebase to the encrypter. + nomadKey, nomadKID, err := a.srv.encrypter.GetActiveKey() + if err != nil { + return nil, fmt.Errorf("failed to get active nomad key: %w", err) + } + j, err := oidc.BuildClientAssertionJWT(config, nomadKey, nomadKID) + if err != nil { + return nil, fmt.Errorf("failed to build client_assertion jwt: %w", err) + } + if config.VerboseLogging { + // a user initially setting up the auth method, as one might with + // VerboseLogging enabled, may benefit from not having to do a full + // login flow to see the jwt (and any possible Serialize() error). + // we say "example" in the log, because the cap library will run + // Serialize() again internally, so it won't use this same jwt. + token, err := j.Serialize() + if err != nil { + return nil, fmt.Errorf("failed to serialize client_assertion jwt: %w", err) + } + a.logger.Debug("example client_assertion", "oidc_client_id", config.OIDCClientID, "jwt", token) + } + return j, nil +} diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index 3520b4fd0..ddfbffed4 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -5,6 +5,7 @@ package nomad import ( "bytes" + "context" "fmt" "io" "net/url" @@ -20,7 +21,9 @@ import ( "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/lib/auth/oidc" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -3630,6 +3633,84 @@ func TestACL_OIDCAuthURL(t *testing.T) { must.StrContains(t, escapedURL, "&response_type=code") must.StrContains(t, escapedURL, "&scope=openid") must.StrContains(t, escapedURL, "&state=st_") + + t.Run("pkce", func(t *testing.T) { + authMethod := mockedAuthMethod.Copy() + authMethod.Name = mockedAuthMethod.Name + "-pkce" + authMethod.Config.OIDCDisablePKCE = pointer.Of(false) + authMethod.SetHash() + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(20, []*structs.ACLAuthMethod{authMethod})) + + urlReq := structs.ACLOIDCAuthURLRequest{ + AuthMethodName: authMethod.Name, + RedirectURI: authMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "pkce_nonce", + WriteRequest: structs.WriteRequest{ + Region: DefaultRegion, + }, + } + var resp structs.ACLOIDCAuthURLResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCAuthURLRPCMethod, &urlReq, &resp) + must.NoError(t, err) + + // The oidc.Req should have been cached with a PKCE verifier + // for use later in OIDCCompleteAuth + cachedReq := testServer.oidcRequestCache.Load(urlReq.ClientNonce) + must.NotNil(t, cachedReq, must.Sprint("oidc Req should be cached")) + must.NotNil(t, cachedReq.PKCEVerifier(), must.Sprint("cached req should have a PKCE verifier")) + }) + + t.Run("client assertion", func(t *testing.T) { + authMethod := mockedAuthMethod.Copy() + authMethod.Config.VerboseLogging = true + authMethod.Name = mockedAuthMethod.Name + "-client-assertion" + // we'll test a representative input variant of + // oidc.BuildClientAssertionJWT to make sure it gets called. + cassConfig := &structs.OIDCClientAssertion{ + // different key sources are tested in oidc helper lib. + KeySource: "nomad", + Audience: []string{"test-audience"}, + } + authMethod.Config.OIDCClientAssertion = cassConfig + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(30, []*structs.ACLAuthMethod{authMethod})) + + // Ensure keyring is initialized, so we can assert that we pass the + // right key into getClientAssertionJWT + testutil.WaitForKeyring(t, testServer.RPC, DefaultRegion) + nomadKey, _, err := testServer.encrypter.GetActiveKey() + must.NoError(t, err) + + // Make the RPC call + urlReq := structs.ACLOIDCAuthURLRequest{ + AuthMethodName: authMethod.Name, + RedirectURI: authMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "client_assertion_nonce", + WriteRequest: structs.WriteRequest{Region: DefaultRegion}, + } + var urlResp structs.ACLOIDCAuthURLResponse + must.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLOIDCAuthURLRPCMethod, &urlReq, &urlResp)) + + // The oidc.Req should have been cached with a PKCE verifier + // for use later in OIDCCompleteAuth + oidcReq := testServer.oidcRequestCache.Load(urlReq.ClientNonce) + must.NotNil(t, oidcReq) + cachedJWT := oidcReq.ClientAssertionJWT() + must.NotNil(t, cachedJWT) + + // The cap library will run this method internally. + signed, err := cachedJWT.Serialize() + must.NoError(t, err) + + // This just verifies that it was signed by our public key. + // Extra validation of the headers/claims on the JWT are tested + // in our oidc helper lib. + token, err := jwt.Parse(signed, func(tok *jwt.Token) (any, error) { + return &nomadKey.PublicKey, nil + }) + must.NoError(t, err) + must.NotNil(t, token, must.Sprint("nil parsed token")) + must.True(t, token.Valid, must.Sprint("parsed token invalid")) + }) } func TestACL_OIDCCompleteAuth(t *testing.T) { @@ -3706,19 +3787,16 @@ func TestACL_OIDCCompleteAuth(t *testing.T) { oidcTestProvider.SetExpectedAuthNonce("fsSPuaodKevKfDU3IeXa") oidcTestProvider.SetExpectedAuthCode("codeABC") oidcTestProvider.SetCustomAudience("mock") - oidcTestProvider.SetExpectedState("st_someweirdstateid") oidcTestProvider.SetCustomClaims(map[string]interface{}{ "azp": "mock", "http://nomad.internal/policies": []string{"engineering"}, "http://nomad.internal/roles": []string{"engineering"}, }) - // We should now be able to authenticate, however, we do not have any rule - // bindings that will match. completeAuthReq3 := structs.ACLOIDCCompleteAuthRequest{ AuthMethodName: mockedAuthMethod.Name, ClientNonce: "fsSPuaodKevKfDU3IeXa", - State: "st_", + State: "st_someweirdstateid", Code: "codeABC", RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], WriteRequest: structs.WriteRequest{ @@ -3726,9 +3804,21 @@ func TestACL_OIDCCompleteAuth(t *testing.T) { }, } + // Simulate a case where OIDCAuthURL was never called, or expired, + // or a leadership transfer occurred between it and OIDCCompleteAuth. var completeAuthResp3 structs.ACLLoginResponse err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq3, &completeAuthResp3) must.Error(t, err) + must.ErrorContains(t, err, "no OIDC request found for client nonce") + must.False(t, strings.Contains(buf.String(), verboseLoggingMessage)) + + // Pretend that OIDCAuthURL was called as a separate request. + cacheOIDCRequest(t, testServer.oidcRequestCache, completeAuthReq3) + + // We should now be able to authenticate, however, we do not have any rule + // bindings that will match. + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq3, &completeAuthResp3) + must.Error(t, err) must.ErrorContains(t, err, "400") must.ErrorContains(t, err, "no role or policy bindings matched") must.False(t, strings.Contains(buf.String(), verboseLoggingMessage)) @@ -3775,6 +3865,9 @@ func TestACL_OIDCCompleteAuth(t *testing.T) { }, } + // Pretend that OIDCAuthURL was called as a separate request. + cacheOIDCRequest(t, testServer.oidcRequestCache, completeAuthReq4) + var completeAuthResp4 structs.ACLLoginResponse err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq4, &completeAuthResp4) must.NoError(t, err) @@ -3809,13 +3902,112 @@ func TestACL_OIDCCompleteAuth(t *testing.T) { }, } + // Pretend that OIDCAuthURL was called as a separate request. + cacheOIDCRequest(t, testServer.oidcRequestCache, completeAuthReq5) + var completeAuthResp5 structs.ACLLoginResponse err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &completeAuthReq5, &completeAuthResp5) must.NoError(t, err) - must.NotNil(t, completeAuthResp4.ACLToken) + must.NotNil(t, completeAuthResp5.ACLToken) must.Len(t, 0, completeAuthResp5.ACLToken.Policies) must.Len(t, 0, completeAuthResp5.ACLToken.Roles) must.Eq(t, structs.ACLManagementToken, completeAuthResp5.ACLToken.Type) + + // Now that we have a happy setup, test additional features. + // Note: these mutate mockedAuthMethod and oidcTestProvider. + + // PKCE will apply to all subsequent tests + pkceVerifier, err := capOIDC.NewCodeVerifier() + must.NoError(t, err) + // because this does not allow setting it back to `nil` + oidcTestProvider.SetPKCEVerifier(pkceVerifier) + + t.Run("pkce", func(t *testing.T) { + + mockedAuthMethod.Config.OIDCDisablePKCE = pointer.Of(false) + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(60, []*structs.ACLAuthMethod{mockedAuthMethod})) + + req := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: mockedAuthMethod.Name, + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "pkce_nonce", + State: "pkce_state", + Code: "pkce_code", + WriteRequest: structs.WriteRequest{Region: DefaultRegion}, + } + oidcTestProvider.SetExpectedAuthNonce(req.ClientNonce) + oidcTestProvider.SetExpectedState(req.State) + oidcTestProvider.SetExpectedAuthCode(req.Code) + + // Pretend that OIDCAuthURL was called as a separate request. + cacheOIDCRequest(t, testServer.oidcRequestCache, req, + // this is what we are here to test + capOIDC.WithPKCE(pkceVerifier)) + + var resp structs.ACLLoginResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &req, &resp) + must.NoError(t, err) + must.NotNil(t, resp.ACLToken) + must.Eq(t, structs.ACLManagementToken, resp.ACLToken.Type) + }) + + // We've already tested the actual JWT logic in TestACL_OIDCAuthURL, + // so here we can use a bogus serializer for the test provider to check. + mockJWT := &mockSerializer{s: "mock-it-to-me"} + // Client assertions apply to all subsequent tests, + // because this does not allow setting it back to "". + oidcTestProvider.SetClientAssertionJWT(mockJWT.s) + + t.Run("client assertion", func(t *testing.T) { + mockedAuthMethod.Config.OIDCClientAssertion = &structs.OIDCClientAssertion{ + // these fields will not be used, they just need to be valid. + KeySource: "nomad", + Audience: []string{"mock-aud"}, + } + // there's some extra logging if verbose, so toggle it on to make sure + // there's no errors or panics. + mockedAuthMethod.Config.VerboseLogging = true + t.Cleanup(func() { + mockedAuthMethod.Config.VerboseLogging = false + }) + must.NoError(t, testServer.fsm.State().UpsertACLAuthMethods(70, []*structs.ACLAuthMethod{mockedAuthMethod})) + + req := structs.ACLOIDCCompleteAuthRequest{ + AuthMethodName: mockedAuthMethod.Name, + RedirectURI: mockedAuthMethod.Config.AllowedRedirectURIs[0], + ClientNonce: "cass_nonce", + State: "cass_state", + Code: "cass_code", + WriteRequest: structs.WriteRequest{Region: DefaultRegion}, + } + oidcTestProvider.SetExpectedAuthNonce(req.ClientNonce) + oidcTestProvider.SetExpectedState(req.State) + oidcTestProvider.SetExpectedAuthCode(req.Code) + + // Pretend that OIDCAuthURL was called as a separate request. + cacheOIDCRequest(t, testServer.oidcRequestCache, req, + capOIDC.WithPKCE(pkceVerifier), // needed from previous test + // this is what we care about + capOIDC.WithClientAssertionJWT(mockJWT), + ) + + var resp structs.ACLLoginResponse + err = msgpackrpc.CallWithCodec(codec, structs.ACLOIDCCompleteAuthRPCMethod, &req, &resp) + must.NoError(t, err) + must.NotNil(t, resp.ACLToken) + must.Eq(t, structs.ACLManagementToken, resp.ACLToken.Type) + + }) +} + +// mockSerializer implements the capOIDC.JWTSerializer interface, +// which is used to provide a client assertion JWT. +type mockSerializer struct { + s string +} + +func (s *mockSerializer) Serialize() (string, error) { + return s.s, nil } func TestACL_Login(t *testing.T) { @@ -4009,3 +4201,26 @@ func TestACL_Login(t *testing.T) { must.NotNil(t, completeAuthResp6.ACLToken) must.Eq(t, mockedAuthMethod.Type+"-"+mockedAuthMethod.Name+"-"+user, completeAuthResp6.ACLToken.Name) } + +// cacheOIDCRequest primes the oidc.Request cache, as OIDCAuthURL usually would, +// to prepare for a subsequent OIDCCompleteAuth call. +func cacheOIDCRequest(t *testing.T, cache *oidc.RequestCache, req structs.ACLOIDCCompleteAuthRequest, opts ...capOIDC.Option) { + t.Helper() + opts = append(opts, + capOIDC.WithNonce(req.ClientNonce), + capOIDC.WithState(req.State), + capOIDC.WithNow(func() time.Time { + return time.Now().Add(time.Minute) // expire in the future + }), + ) + oidcReq, err := capOIDC.NewRequest( + time.Second, "http://127.0.0.1:4649/oidc/callback", + opts..., + ) + must.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + // make sure the cache is clean first + cache.LoadAndDelete(req.ClientNonce) + must.NoError(t, cache.Store(ctx, oidcReq)) +} diff --git a/nomad/encrypter.go b/nomad/encrypter.go index 0409f5c30..f3cc5e18e 100644 --- a/nomad/encrypter.go +++ b/nomad/encrypter.go @@ -589,6 +589,15 @@ func (e *Encrypter) waitForKey(ctx context.Context, keyID string) (*cipherSet, e return ks, nil } +// GetActiveKey returns the active private key and its kid (key id) +func (e *Encrypter) GetActiveKey() (*rsa.PrivateKey, string, error) { + c, err := e.activeCipherSet() + if err != nil { + return nil, "", err + } + return c.rsaPrivateKey, c.rootKey.Meta.KeyID, nil +} + // GetKey retrieves the key material by ID from the keyring. func (e *Encrypter) GetKey(keyID string) (*structs.UnwrappedRootKey, error) { e.lock.Lock() diff --git a/nomad/mock/acl.go b/nomad/mock/acl.go index 6bec60495..5824d3045 100644 --- a/nomad/mock/acl.go +++ b/nomad/mock/acl.go @@ -16,6 +16,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/hashicorp/nomad/helper/pointer" testing "github.com/mitchellh/go-testing-interface" "github.com/stretchr/testify/assert" @@ -272,9 +273,12 @@ func ACLOIDCAuthMethod() *structs.ACLAuthMethod { MaxTokenTTL: maxTokenTTL, Default: false, Config: &structs.ACLAuthMethodConfig{ - OIDCDiscoveryURL: "http://example.com", - OIDCClientID: "mock", - OIDCClientSecret: "very secret secret", + OIDCDiscoveryURL: "http://example.com", + OIDCClientID: "mock", + OIDCClientSecret: "very secret secret", + // PKCE is hard to test outside the server/RPC layer, + // because the verifier is only accessible there. + OIDCDisablePKCE: pointer.Of(true), OIDCDisableUserInfo: false, OIDCScopes: []string{"groups"}, BoundAudiences: []string{"sales", "engineering"}, diff --git a/nomad/server.go b/nomad/server.go index 12e6cf41d..d415effcd 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -270,6 +270,11 @@ type Server struct { // shutting down, the oidcProviderCache.Shutdown() function must be called. oidcProviderCache *oidc.ProviderCache + // oidcRequestCache stores a cache of OIDC requests, so request state + // (mainly PKCE challenge/verification) can persist between calls to + // OIDCAuthURL and OIDCCompleteAuth. + oidcRequestCache *oidc.RequestCache + // lockTTLTimer and lockDelayTimer are used to track variable lock timers. // These are held in memory on the leader rather than in state to avoid // large amount of Raft writes. @@ -431,6 +436,12 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigFunc // processes when it shuts down itself. s.oidcProviderCache = oidc.NewProviderCache() + // Set up OIDC requests cache for state that persists between calls to + // ACL.OIDCAuthURL and ACL.OIDCCompleteAuth. + // It needs no special handling to handle agent shutdowns (its Store method + // handles this lifecycle). + s.oidcRequestCache = oidc.NewRequestCache() + // Initialize the RPC layer if err := s.setupRPC(tlsWrap); err != nil { s.Shutdown() diff --git a/nomad/structs/acl.go b/nomad/structs/acl.go index 41aea10e7..188f82f89 100644 --- a/nomad/structs/acl.go +++ b/nomad/structs/acl.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "maps" "regexp" "slices" "strconv" @@ -794,6 +795,9 @@ func (a *ACLAuthMethod) SetHash() []byte { _, _ = hash.Write([]byte(a.Config.OIDCDiscoveryURL)) _, _ = hash.Write([]byte(a.Config.OIDCClientID)) _, _ = hash.Write([]byte(a.Config.OIDCClientSecret)) + if a.Config.OIDCDisablePKCE != nil { + _, _ = hash.Write([]byte(strconv.FormatBool(*a.Config.OIDCDisablePKCE))) + } _, _ = hash.Write([]byte(strconv.FormatBool(a.Config.OIDCDisableUserInfo))) _, _ = hash.Write([]byte(strconv.FormatBool(a.Config.VerboseLogging))) _, _ = hash.Write([]byte(a.Config.ExpirationLeeway.String())) @@ -828,6 +832,25 @@ func (a *ACLAuthMethod) SetHash() []byte { _, _ = hash.Write([]byte(k)) _, _ = hash.Write([]byte(v)) } + if a.Config.OIDCClientAssertion != nil { + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.KeySource)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.KeyAlgorithm)) + for _, aud := range a.Config.OIDCClientAssertion.Audience { + _, _ = hash.Write([]byte(aud)) + } + for k, v := range a.Config.OIDCClientAssertion.ExtraHeaders { + _, _ = hash.Write([]byte(k)) + _, _ = hash.Write([]byte(v)) + } + if a.Config.OIDCClientAssertion.PrivateKey != nil { + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.KeyIDHeader)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.KeyID)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.PemKey)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.PemKeyFile)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.PemCert)) + _, _ = hash.Write([]byte(a.Config.OIDCClientAssertion.PrivateKey.PemCertFile)) + } + } } // Finalize the hash. @@ -934,6 +957,8 @@ func (a *ACLAuthMethod) Canonicalize() { if a.TokenNameFormat == "" { a.TokenNameFormat = DefaultACLAuthMethodTokenNameFormat } + + a.Config.Canonicalize() } // Merge merges auth method a with method b. It sets all required empty fields @@ -975,6 +1000,10 @@ func (a *ACLAuthMethod) Validate(minTTL, maxTTL time.Duration) error { a.MaxTokenTTL.String(), minTTL.String(), maxTTL.String())) } + if err := a.Config.Validate(); err != nil { + mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid config: %w", err)) + } + return mErr.ErrorOrNil() } @@ -990,6 +1019,17 @@ func (a *ACLAuthMethod) Sanitize() *ACLAuthMethod { if clean.Config.OIDCClientSecret != "" { clean.Config.OIDCClientSecret = "redacted" } + if clean.Config.OIDCClientAssertion != nil { + // this ClientSecret gets inherited by the above one + if clean.Config.OIDCClientAssertion.ClientSecret != "" { + clean.Config.OIDCClientAssertion.ClientSecret = "redacted" + } + if clean.Config.OIDCClientAssertion.PrivateKey != nil && + clean.Config.OIDCClientAssertion.PrivateKey.PemKey != "" { + clean.Config.OIDCClientAssertion.PrivateKey.PemKey = "redacted" + } + } + return clean } @@ -1017,6 +1057,12 @@ type ACLAuthMethodConfig struct { // The OAuth Client Secret configured with the OIDC provider OIDCClientSecret string + // Optional client assertion ("private key jwt") config + OIDCClientAssertion *OIDCClientAssertion + + // Disable PKCE challenge verification + OIDCDisablePKCE *bool + // Disable claims from the OIDC UserInfo endpoint OIDCDisableUserInfo bool @@ -1065,6 +1111,40 @@ type ACLAuthMethodConfig struct { VerboseLogging bool } +func (a *ACLAuthMethodConfig) Canonicalize() { + if a == nil { + return + } + if a.OIDCClientAssertion != nil { + // client assertions inherit certain values from auth method + if len(a.OIDCClientAssertion.Audience) == 0 { + a.OIDCClientAssertion.Audience = []string{a.OIDCDiscoveryURL} + } + a.OIDCClientAssertion.ClientSecret = a.OIDCClientSecret + a.OIDCClientAssertion.Canonicalize() + } +} + +func (a *ACLAuthMethodConfig) Validate() error { + if a == nil { + return errors.New("missing auth method Config") + } + mErr := &multierror.Error{} + if a.OIDCDiscoveryURL == "" { + mErr = multierror.Append(mErr, errors.New("missing OIDCDiscoveryURL")) + } + if a.OIDCClientID == "" { + mErr = multierror.Append(mErr, errors.New("missing OIDCClientID")) + } + if len(a.BoundAudiences) == 0 || a.BoundAudiences[0] == "" { + mErr = multierror.Append(mErr, errors.New("missing BoundAudiences")) + } + if err := a.OIDCClientAssertion.Validate(); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("invalid client assertion config: %w", err)) + } + return helper.FlattenMultierror(mErr) +} + func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { if a == nil { return nil @@ -1080,6 +1160,7 @@ func (a *ACLAuthMethodConfig) Copy() *ACLAuthMethodConfig { c.AllowedRedirectURIs = slices.Clone(a.AllowedRedirectURIs) c.DiscoveryCaPem = slices.Clone(a.DiscoveryCaPem) c.SigningAlgs = slices.Clone(a.SigningAlgs) + c.OIDCClientAssertion = a.OIDCClientAssertion.Copy() return c } @@ -1171,6 +1252,187 @@ func (a *ACLAuthMethodConfig) UnmarshalJSON(data []byte) (err error) { return nil } +type OIDCClientAssertionKeySource string + +const ( + OIDCKeySourceNomad OIDCClientAssertionKeySource = "nomad" + OIDCKeySourceClientSecret OIDCClientAssertionKeySource = "client_secret" + OIDCKeySourcePrivateKey OIDCClientAssertionKeySource = "private_key" +) + +// OIDCClientAssertion (a.k.a private_key_jwt) is used to send +// a client_assertion along with an OIDC token request. +// See api.OIDCClientAssertion for full field descriptions. +type OIDCClientAssertion struct { + KeySource OIDCClientAssertionKeySource + Audience []string + PrivateKey *OIDCClientAssertionKey + ExtraHeaders map[string]string + KeyAlgorithm string + // ClientSecret here is not part of the public api; it's inherited from the + // parent ACLAuthMethodConfig struct via ACLAuthMethodConfig.Canonicalize. + // It's exported mainly so that it gets saved across msgpack in raft state. + ClientSecret string +} + +func (c *OIDCClientAssertion) Copy() *OIDCClientAssertion { + if c == nil { + return nil + } + n := new(OIDCClientAssertion) + *n = *c + n.Audience = slices.Clone(c.Audience) + n.PrivateKey = c.PrivateKey.Copy() + n.ExtraHeaders = maps.Clone(c.ExtraHeaders) + return n +} + +func (c *OIDCClientAssertion) Canonicalize() { + if c == nil { + return + } + // default KeyAlgorithm to "RS256" for nomad and user keys, "HS256" for client_secret + if c.KeyAlgorithm == "" { + switch c.KeySource { + case OIDCKeySourceClientSecret: + c.KeyAlgorithm = "HS256" + case OIDCKeySourceNomad, OIDCKeySourcePrivateKey: + c.KeyAlgorithm = "RS256" + } + } + c.PrivateKey.Canonicalize() +} + +func (c *OIDCClientAssertion) IsSet() bool { + return c != nil && c.KeySource != "" +} + +func (c *OIDCClientAssertion) Validate() error { + if c == nil { + return nil + } + if len(c.Audience) == 0 || c.Audience[0] == "" { + return errors.New("missing Audience") + } + switch c.KeySource { + case OIDCKeySourceNomad: + case OIDCKeySourcePrivateKey: + if c.PrivateKey == nil { + return errors.New("PrivateKey is required for `private_key` KeySource") + } + if err := c.PrivateKey.Validate(); err != nil { + return fmt.Errorf("invalid PrivateKey: %w", err) + } + case OIDCKeySourceClientSecret: + if c.ClientSecret == "" { + return errors.New("OIDCClientSecret is required for `client_secret` KeySource") + } + default: + return fmt.Errorf("invalid KeySource %q", c.KeySource) + } + return nil +} + +type OIDCClientAssertionKeyIDHeader string + +const ( + OIDCClientAssertionHeaderKid OIDCClientAssertionKeyIDHeader = "kid" + OIDCClientAssertionHeaderX5t OIDCClientAssertionKeyIDHeader = "x5t" + OIDCClientAssertionHeaderX5tS256 OIDCClientAssertionKeyIDHeader = "x5t#S256" +) + +// OIDCClientAssertionKey contains key material provided by users for Nomad +// to use to sign the private key JWT. +// See api.OIDCClientAssertionKey for full field descriptions. +type OIDCClientAssertionKey struct { + PemKey string + PemKeyFile string + + KeyIDHeader OIDCClientAssertionKeyIDHeader + PemCert string + PemCertFile string + KeyID string +} + +func (k *OIDCClientAssertionKey) Copy() *OIDCClientAssertionKey { + if k == nil { + return nil + } + n := new(OIDCClientAssertionKey) + *n = *k + return n +} + +func (k *OIDCClientAssertionKey) Canonicalize() { + if k == nil { + return + } + if k.KeyIDHeader == "" { + if k.KeyID != "" { + k.KeyIDHeader = OIDCClientAssertionHeaderKid + } + if k.PemCert != "" || k.PemCertFile != "" { + k.KeyIDHeader = OIDCClientAssertionHeaderX5tS256 + } + } +} + +var ( + ErrMissingClientAssertionKey = errors.New("missing PemKey or PemKeyFile") + ErrAmbiguousClientAssertionKey = errors.New("require only one of PemKey or PemKeyFile") + ErrMissingClientAssertionKeyID = errors.New("missing PemCert, PemCertFile, or KeyID") + ErrAmbiguousClientAssertionKeyID = errors.New("require only one of PemCert, PemCertFile, or KeyID") + ErrInvalidKeyIDHeader = errors.New("invalid KeyIDHeader") +) + +// Validate ensures that one Key and one Cert or KeyID are provided, +// and that the key ID header is valid for the provided KeyID or cert. +func (k *OIDCClientAssertionKey) Validate() error { + if k == nil { + return nil + } + + // mutually exclusive key fields + // must have key file or base64, but not both + if k.PemKey == "" && k.PemKeyFile == "" { + return ErrMissingClientAssertionKey + } + if k.PemKey != "" && k.PemKeyFile != "" { + return ErrAmbiguousClientAssertionKey + } + + // mutually exclusive cert fields + // must have exactly one of: cert file or base64, or keyid + if k.PemCert == "" && k.PemCertFile == "" && k.KeyID == "" { + return ErrMissingClientAssertionKeyID + } + if k.PemCert != "" && (k.PemCertFile != "" || k.KeyID != "") { + return ErrAmbiguousClientAssertionKeyID + } + if k.PemCertFile != "" && (k.PemCert != "" || k.KeyID != "") { + return ErrAmbiguousClientAssertionKeyID + } + if k.KeyID != "" && (k.PemCert != "" || k.PemCertFile != "") { + return ErrAmbiguousClientAssertionKeyID + } + + // only allow certain key id headers + // only "kid" for KeyID + if k.KeyID != "" && k.KeyIDHeader != OIDCClientAssertionHeaderKid { + return fmt.Errorf("%w; key header for key ID must be %q", + ErrInvalidKeyIDHeader, OIDCClientAssertionHeaderKid) + } + // only "x5t*" for certs + if k.PemCert != "" || k.PemCertFile != "" { + if k.KeyIDHeader != OIDCClientAssertionHeaderX5t && k.KeyIDHeader != OIDCClientAssertionHeaderX5tS256 { + return fmt.Errorf("%w; certificate-derived key header must be one of: %q, %q", + ErrInvalidKeyIDHeader, OIDCClientAssertionHeaderX5tS256, OIDCClientAssertionHeaderX5t) + } + } + + return nil +} + // ACLAuthClaims is the claim mapping of the OIDC auth method in a format that // can be used with go-bexpr. This structure is used during rule binding // evaluation. diff --git a/nomad/structs/acl_test.go b/nomad/structs/acl_test.go index 90bc6b804..ac67e54d3 100644 --- a/nomad/structs/acl_test.go +++ b/nomad/structs/acl_test.go @@ -1139,6 +1139,7 @@ func TestACLAuthMethod_Equal(t *testing.T) { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCClientAssertion: validClientAssertion(), OIDCDisableUserInfo: false, BoundAudiences: []string{"audience1", "audience2"}, AllowedRedirectURIs: []string{"foo", "bar"}, @@ -1192,6 +1193,7 @@ func TestACLAuthMethod_Copy(t *testing.T) { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCClientAssertion: validClientAssertion(), OIDCDisableUserInfo: false, BoundAudiences: []string{"audience1", "audience2"}, AllowedRedirectURIs: []string{"foo", "bar"}, @@ -1219,6 +1221,8 @@ func TestACLAuthMethod_Copy(t *testing.T) { func TestACLAuthMethod_Validate(t *testing.T) { ci.Parallel(t) + minTTL, _ := time.ParseDuration("10s") + maxTTL, _ := time.ParseDuration("10h") goodTTL, _ := time.ParseDuration("3600s") badTTL, _ := time.ParseDuration("3600h") @@ -1235,6 +1239,11 @@ func TestACLAuthMethod_Validate(t *testing.T) { Type: "OIDC", TokenLocality: "local", MaxTokenTTL: goodTTL, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "mock-discovery-url", + OIDCClientID: "mock-client-id", + BoundAudiences: []string{"mock-aud"}, + }, }, false, "", @@ -1246,8 +1255,6 @@ func TestACLAuthMethod_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - minTTL, _ := time.ParseDuration("10s") - maxTTL, _ := time.ParseDuration("10h") got := tt.method.Validate(minTTL, maxTTL) if tt.wantErr { must.Error(t, got, must.Sprintf( @@ -1261,6 +1268,41 @@ func TestACLAuthMethod_Validate(t *testing.T) { } }) } + + // test that Validate calls relevant fields' Validate, and so on + t.Run("validate cascade", func(t *testing.T) { + goodMethod := &ACLAuthMethod{ + // min need to pass top-level validation + Name: "test-name", + Type: "OIDC", + TokenLocality: "local", + MaxTokenTTL: goodTTL, + Config: &ACLAuthMethodConfig{ + OIDCDiscoveryURL: "mock-discovery-url", + OIDCClientID: "mock-client-id", + BoundAudiences: []string{"mock-aud"}, + }, + } + goodMethod.Canonicalize() + err := goodMethod.Validate(minTTL, maxTTL) + must.NoError(t, err) + + deeplyBadClientAssertion := goodMethod.Copy() + deeplyBadClientAssertion.Config = &ACLAuthMethodConfig{ + OIDCClientAssertion: &OIDCClientAssertion{ + // need these to pass + Audience: []string{"test-audience"}, + KeySource: OIDCKeySourcePrivateKey, + // then fail validation nested way down in here + PrivateKey: &OIDCClientAssertionKey{ + // cannot set both of these + PemKey: "test-b64", + PemKeyFile: "test-file", + }}, + } + err = deeplyBadClientAssertion.Validate(minTTL, maxTTL) + must.ErrorIs(t, err, ErrAmbiguousClientAssertionKey) + }) } // Sanitize method should redact sensitive values @@ -1279,6 +1321,30 @@ func TestACLAuthMethod_Sanitize(t *testing.T) { must.Eq(t, "very private secret", dirty) must.Eq(t, "redacted", clean) }) + + t.Run("client assertion", func(t *testing.T) { + am := am.Copy() + am.Config.OIDCClientAssertion = &OIDCClientAssertion{} + am.Sanitize() // no nil panic + // client secret gets inherited + am.Config.OIDCClientSecret = "very private secret" + am.Canonicalize() + dirty := am.Config.OIDCClientAssertion.ClientSecret + clean := am.Sanitize().Config.OIDCClientAssertion.ClientSecret + must.Eq(t, "very private secret", dirty) + must.Eq(t, "redacted", clean) + // private key material + am.Config.OIDCClientAssertion.PrivateKey = &OIDCClientAssertionKey{ + PemKey: "very private key", + } + // dirty should remain dirty, because it only cleans a copy. + dirty = am.Config.OIDCClientAssertion.PrivateKey.PemKey + am.Sanitize() + clean = am.Sanitize().Config.OIDCClientAssertion.PrivateKey.PemKey + must.Eq(t, "very private key", dirty) + must.Eq(t, "redacted", clean) + }) + } func TestACLAuthMethod_Merge(t *testing.T) { @@ -1303,6 +1369,7 @@ func TestACLAuthMethod_Merge(t *testing.T) { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCClientAssertion: validClientAssertion(), OIDCDisableUserInfo: false, BoundAudiences: []string{"audience1", "audience2"}, AllowedRedirectURIs: []string{"foo", "bar"}, @@ -1322,6 +1389,7 @@ func TestACLAuthMethod_Merge(t *testing.T) { minTTL, _ := time.ParseDuration("10s") maxTTL, _ := time.ParseDuration("10h") must.NoError(t, am1.Validate(minTTL, maxTTL)) + must.Eq(t, am1.Config.OIDCClientAssertion.PrivateKey.KeyID, "test-key-id") } func TestACLAuthMethodConfig_Copy(t *testing.T) { @@ -1331,6 +1399,7 @@ func TestACLAuthMethodConfig_Copy(t *testing.T) { OIDCDiscoveryURL: "http://example.com", OIDCClientID: "mock", OIDCClientSecret: "very secret secret", + OIDCClientAssertion: validClientAssertion(), OIDCDisableUserInfo: false, OIDCScopes: []string{"groups"}, BoundAudiences: []string{"audience1", "audience2"}, @@ -1346,6 +1415,7 @@ func TestACLAuthMethodConfig_Copy(t *testing.T) { amc3 := amc1.Copy() amc3.AllowedRedirectURIs = []string{"new", "urls"} + amc3.OIDCClientAssertion.PrivateKey.KeyID = "new-key-id" must.NotEq(t, amc1, amc3) } @@ -1403,6 +1473,233 @@ func TestACLAuthMethod_TokenLocalityIsGlobal(t *testing.T) { must.False(t, localAuthMethod.TokenLocalityIsGlobal()) } +func TestOIDCClientAssertion_Copy(t *testing.T) { + ci.Parallel(t) + ca1 := &OIDCClientAssertion{ + KeySource: "keyy", // plain value + Audience: []string{"aud"}, // slice + ExtraHeaders: map[string]string{"foo": "bar"}, // map + PrivateKey: &OIDCClientAssertionKey{KeyID: "kid"}, // struct + } + ca2 := ca1.Copy() + must.Eq(t, ca1, ca2) + must.Eq(t, ca2.KeySource, "keyy") + must.Eq(t, ca2.Audience, []string{"aud"}) + must.Eq(t, ca2.PrivateKey.KeyID, "kid") + must.Eq(t, ca2.ExtraHeaders, map[string]string{"foo": "bar"}) + ca2.KeySource = "another" + must.NotEq(t, ca1, ca2) +} + +func TestOIDCClientAssertion_Canonicalize(t *testing.T) { + ci.Parallel(t) + cases := []struct { + keySource OIDCClientAssertionKeySource + expectAlgo string // varies based on he key source + }{ + {OIDCKeySourcePrivateKey, "RS256"}, + {OIDCKeySourceNomad, "RS256"}, + {OIDCKeySourceClientSecret, "HS256"}, + } + for _, tc := range cases { + t.Run(string(tc.keySource), func(t *testing.T) { + ca := &OIDCClientAssertion{ + KeyAlgorithm: "", // explicitly empty + KeySource: tc.keySource, + } + ca.Canonicalize() + must.Eq(t, tc.expectAlgo, ca.KeyAlgorithm) + }) + } +} + +func TestOIDCClientAssertion_Validate(t *testing.T) { + ci.Parallel(t) + ca := validClientAssertion() + must.NoError(t, ca.Validate()) + + cases := []struct { + err string + mod func(*OIDCClientAssertion) + }{ + { + err: "missing Audience", + mod: func(ca *OIDCClientAssertion) { + ca.Audience = nil + }, + }, + { + err: "PrivateKey is required", + mod: func(ca *OIDCClientAssertion) { + ca.KeySource = OIDCKeySourcePrivateKey + ca.PrivateKey = nil + }, + }, + { + err: "invalid PrivateKey", + mod: func(ca *OIDCClientAssertion) { + ca.KeySource = OIDCKeySourcePrivateKey + ca.PrivateKey.KeyID = "" + }, + }, + { + err: "OIDCClientSecret is required", + mod: func(ca *OIDCClientAssertion) { + ca.KeySource = OIDCKeySourceClientSecret + ca.ClientSecret = "" + }, + }, + { + err: "invalid KeySource", + mod: func(ca *OIDCClientAssertion) { + ca.KeySource = "bogus" + }, + }, + } + + for _, tc := range cases { + t.Run(tc.err, func(t *testing.T) { + ca := ca.Copy() + if tc.mod != nil { + tc.mod(ca) + } + err := ca.Validate() + must.ErrorContains(t, err, tc.err) + }) + } +} + +func TestOIDCClientAssertionKey_Copy(t *testing.T) { + ci.Parallel(t) + k1 := &OIDCClientAssertionKey{ + KeyID: "kid", + } + k2 := k1.Copy() + must.Eq(t, k1, k2) + k2.KeyID = "another" + must.NotEq(t, k1, k2) +} + +func TestOIDCClientAssertionKey_Validate(t *testing.T) { + ci.Parallel(t) + cases := []struct { + name string + key *OIDCClientAssertionKey + err error + }{ + { + name: "ok files", + // Validate only checks that they are set to something. + // their existence and contents are validated later. + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + PemCertFile: "/any.crt", + }, + }, + { + name: "ok pems", + key: &OIDCClientAssertionKey{ + PemKey: "anykey", + PemCert: "anycert", + }, + }, + { + name: "ok keyid", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + KeyID: "key-id", + }, + }, + { + name: "missing key", + key: &OIDCClientAssertionKey{ + KeyID: "key-id", + }, + err: ErrMissingClientAssertionKey, + }, + { + name: "missing kid or cert", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + }, + err: ErrMissingClientAssertionKeyID, + }, + { + name: "ambiguous key", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + PemKey: "anykey", + }, + err: ErrAmbiguousClientAssertionKey, + }, + { + name: "ambiguous keyid - cert file and pem", + key: &OIDCClientAssertionKey{ + PemKey: "anykey", // checked before cert + PemCertFile: "/any.cert", + PemCert: "anycert", + }, + err: ErrAmbiguousClientAssertionKeyID, + }, + { + name: "ambiguous keyid - cert file and keyid", + key: &OIDCClientAssertionKey{ + PemKey: "anykey", // checked before cert + PemCertFile: "/any.cert", + KeyID: "key-id", + }, + err: ErrAmbiguousClientAssertionKeyID, + }, + { + name: "ambiguous keyid - cert pem and keyid", + key: &OIDCClientAssertionKey{ + PemKey: "anykey", // checked before cert + PemCert: "anycert", + KeyID: "key-id", + }, + err: ErrAmbiguousClientAssertionKeyID, + }, + { + name: "bad key header for keyid", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + KeyID: "key-id", + KeyIDHeader: OIDCClientAssertionHeaderX5t, + }, + err: ErrInvalidKeyIDHeader, + }, + { + name: "bad key header for cert", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + PemCert: "anycert", + KeyIDHeader: OIDCClientAssertionHeaderKid, + }, + err: ErrInvalidKeyIDHeader, + }, + { + name: "bad key header for cert file", + key: &OIDCClientAssertionKey{ + PemKeyFile: "/any.key", + PemCertFile: "/any.crt", + KeyIDHeader: OIDCClientAssertionHeaderKid, + }, + err: ErrInvalidKeyIDHeader, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.key.Canonicalize() + err := tc.key.Validate() + if tc.err == nil { + must.NoError(t, err) + } else { + must.ErrorIs(t, err, tc.err) + } + }) + } +} + func TestACLBindingRule_Canonicalize(t *testing.T) { ci.Parallel(t) @@ -1786,3 +2083,19 @@ func TestACLOIDCCompleteAuthRequest_Validate(t *testing.T) { must.StrContains(t, err.Error(), "missing code") must.StrContains(t, err.Error(), "missing redirect URI") } + +func validClientAssertion() *OIDCClientAssertion { + return &OIDCClientAssertion{ + KeySource: OIDCKeySourcePrivateKey, + Audience: []string{"test-audience"}, + PrivateKey: &OIDCClientAssertionKey{ + PemKeyFile: "test-key-file", + KeyIDHeader: OIDCClientAssertionHeaderKid, + KeyID: "test-key-id", + }, + ExtraHeaders: map[string]string{"test-header": "test-value"}, + KeyAlgorithm: "test-key-algo", + // clientSecret is ordinarily inherited from parent ACLAuthMethodConfig + ClientSecret: "test-client-secret", + } +}