diff --git a/go.mod b/go.mod index bc10bb74b..aed39d2c1 100644 --- a/go.mod +++ b/go.mod @@ -75,7 +75,7 @@ require ( github.com/hashicorp/go-syslog v1.0.0 github.com/hashicorp/go-uuid v1.0.3 github.com/hashicorp/go-version v1.7.0 - github.com/hashicorp/golang-lru/v2 v2.0.1 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl v1.0.1-vault-3 github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40 diff --git a/go.sum b/go.sum index 5261310e6..782a7f1eb 100644 --- a/go.sum +++ b/go.sum @@ -1269,8 +1269,8 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= -github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4= -github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee h1:8B4HqvMUtYSjsGkYjiQGStc9pXffY2J+Z2SPQAj+wMY= github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee/go.mod h1:gwlu9+/P9MmKtYrMsHeFRZPXj2CTPm11TDnMeaRHS7g= github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d h1:7abftkc86B+tlA/0cDy5f6C4LgWfFOCpsGg3RJZsfbw= diff --git a/lib/auth/oidc/request.go b/lib/auth/oidc/request.go index c9d43bf1a..9bb43b619 100644 --- a/lib/auth/oidc/request.go +++ b/lib/auth/oidc/request.go @@ -4,80 +4,64 @@ package oidc import ( - "context" "errors" - "sync" "time" - //"github.com/coreos/go-oidc/v3/oidc" "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/golang-lru/v2/expirable" ) -var ErrNonceReuse = errors.New("nonce reuse detected") +var ( + ErrNonceReuse = errors.New("nonce reuse detected") + ErrTooManyRequests = errors.New("too many auth requests") +) -// expiringRequest ensures that OIDC requests that are only partially fulfilled -// do not get stuck in memory forever. -type expiringRequest struct { - // req is what we actually care about - req *oidc.Req - // ctx lets us clean up stale requests automatically - ctx context.Context - cancel context.CancelFunc -} +// MaxRequests is how many requests are allowed to be stored at a time. +// It needs to be large enough for legitimate user traffic, but small enough +// to prevent a DOS from eating up server memory. +const MaxRequests = 10000 // NewRequestCache creates a cache for OIDC requests. -func NewRequestCache() *RequestCache { +// The JWT expiration time in the cap library is 5 minutes, +// so timeout should be around that long. +func NewRequestCache(timeout time.Duration) *RequestCache { return &RequestCache{ - m: sync.Map{}, - // the JWT expiration time in cap library is 5 minutes, - // so auto-delete from our request cache after 6. - timeout: 6 * time.Minute, + c: expirable.NewLRU[string, *oidc.Req](MaxRequests, nil, timeout), } } type RequestCache struct { - m sync.Map - timeout time.Duration + c *expirable.LRU[string, *oidc.Req] } // Store saves the request, to be Loaded later with its Nonce. // If LoadAndDelete is not called, the stale request will be auto-deleted. -func (rc *RequestCache) Store(ctx context.Context, req *oidc.Req) error { - ctx, cancel := context.WithTimeout(ctx, rc.timeout) - er := &expiringRequest{ - req: req, - ctx: ctx, - cancel: cancel, +func (rc *RequestCache) Store(req *oidc.Req) error { + if rc.c.Len() >= MaxRequests { + return ErrTooManyRequests } - if _, loaded := rc.m.LoadOrStore(req.Nonce(), er); loaded { - // we already had a request for this nonce, which should never happen, - // so cancel the new request and error to notify caller of a bug. - cancel() + + if _, ok := rc.c.Get(req.Nonce()); ok { + // we already had a request for this nonce (should never happen) return ErrNonceReuse } - // auto-delete after timeout or context canceled - go func() { - <-ctx.Done() - rc.m.Delete(req.Nonce()) - }() + + rc.c.Add(req.Nonce(), req) + return nil } func (rc *RequestCache) Load(nonce string) *oidc.Req { - if er, ok := rc.m.Load(nonce); ok { - return er.(*expiringRequest).req + if req, ok := rc.c.Get(nonce); ok { + return req } return nil } func (rc *RequestCache) LoadAndDelete(nonce string) *oidc.Req { - if er, loaded := rc.m.LoadAndDelete(nonce); loaded { - // there is a tiny race condition here. if by massive coincidence, - // or a bug, the same nonce makes its way in here, this cancel() - // triggers a map Delete() up in Store(), which could delete a request - // out from under a subsequent Store() - er.(*expiringRequest).cancel() - return er.(*expiringRequest).req + if req, ok := rc.c.Get(nonce); ok { + rc.c.Remove(nonce) + return req } return nil } diff --git a/lib/auth/oidc/request_test.go b/lib/auth/oidc/request_test.go index 81780724c..683104354 100644 --- a/lib/auth/oidc/request_test.go +++ b/lib/auth/oidc/request_test.go @@ -4,7 +4,7 @@ package oidc import ( - "context" + "fmt" "testing" "time" @@ -16,41 +16,25 @@ import ( func TestRequestCache(t *testing.T) { // using a top-level cache and running each sub-test in parallel exercises // a little bit of thread safety. - rc := NewRequestCache() + rc := NewRequestCache(time.Minute) t.Run("reuse nonce", func(t *testing.T) { t.Parallel() req := getRequest(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - must.NoError(t, rc.Store(ctx, req)) - must.ErrorIs(t, rc.Store(ctx, req), ErrNonceReuse) - }) - - t.Run("cancel parent ctx", func(t *testing.T) { - t.Parallel() - req := getRequest(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - must.NoError(t, rc.Store(ctx, req)) - must.Eq(t, req, rc.Load(req.Nonce())) - - cancel() // triggers delete - waitUntilGone(t, rc, req.Nonce()) + must.NoError(t, rc.Store(req)) + must.ErrorIs(t, rc.Store(req), ErrNonceReuse) }) t.Run("load and delete", func(t *testing.T) { t.Parallel() req := getRequest(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - must.NoError(t, rc.Store(ctx, req)) + must.NoError(t, rc.Store(req)) must.Eq(t, req, rc.Load(req.Nonce())) - must.Eq(t, req, rc.LoadAndDelete(req.Nonce())) // triggers delete + must.Eq(t, req, rc.LoadAndDelete(req.Nonce())) + waitUntilGone(t, rc, req.Nonce()) must.Nil(t, rc.LoadAndDelete(req.Nonce())) }) @@ -58,18 +42,34 @@ func TestRequestCache(t *testing.T) { t.Run("timeout", func(t *testing.T) { // this test needs its own cache to reduce the timeout // without affecting any other tests. - rc := NewRequestCache() - rc.timeout = time.Millisecond + rc := NewRequestCache(50 * time.Millisecond) req := getRequest(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - must.NoError(t, rc.Store(ctx, req)) + must.NoError(t, rc.Store(req)) // timeout triggers delete behind the scenes waitUntilGone(t, rc, req.Nonce()) }) + + t.Run("too many requests", func(t *testing.T) { + // not Parallel, would make other tests flaky + defer rc.c.Purge() + + var gotErr error + for i := range MaxRequests + 5 { + req, err := oidc.NewRequest(time.Minute, "test-redirect-url", + oidc.WithNonce(fmt.Sprintf("too-many-cooks-%d", i))) + must.NoError(t, err) + + if err := rc.Store(req); err != nil { + gotErr = err + break + } + } + must.ErrorIs(t, gotErr, ErrTooManyRequests) + }) + } func getRequest(t *testing.T) *oidc.Req { diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index af46a50d6..955b9680a 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -2625,7 +2625,7 @@ func (a *ACL) OIDCAuthURL(args *structs.ACLOIDCAuthURLRequest, reply *structs.AC if err != nil { return err } - if err = a.oidcRequestCache.Store(a.srv.shutdownCtx, oidcReq); err != nil { + if err = a.oidcRequestCache.Store(oidcReq); err != nil { return fmt.Errorf("error storing OIDC request: %w", err) } } diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index ddfbffed4..a7d7e466d 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -5,7 +5,6 @@ package nomad import ( "bytes" - "context" "fmt" "io" "net/url" @@ -4218,9 +4217,7 @@ func cacheOIDCRequest(t *testing.T, cache *oidc.RequestCache, req structs.ACLOID opts..., ) must.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - t.Cleanup(cancel) // make sure the cache is clean first cache.LoadAndDelete(req.ClientNonce) - must.NoError(t, cache.Store(ctx, oidcReq)) + must.NoError(t, cache.Store(oidcReq)) } diff --git a/nomad/server.go b/nomad/server.go index d415effcd..6383938de 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -440,7 +440,8 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigFunc // ACL.OIDCAuthURL and ACL.OIDCCompleteAuth. // It needs no special handling to handle agent shutdowns (its Store method // handles this lifecycle). - s.oidcRequestCache = oidc.NewRequestCache() + // 6 minutes is 1 minute longer than the JWT expiration time in the cap lib. + s.oidcRequestCache = oidc.NewRequestCache(6 * time.Minute) // Initialize the RPC layer if err := s.setupRPC(tlsWrap); err != nil {