auth: oidc request lru cache (#25336)

use hashicorp/golang-lru instead of my hand-rolled cache
This commit is contained in:
Daniel Bennett
2025-03-11 09:46:23 -04:00
committed by GitHub
parent 61bbff9c24
commit 38f063a341
7 changed files with 63 additions and 81 deletions

2
go.mod
View File

@@ -75,7 +75,7 @@ require (
github.com/hashicorp/go-syslog v1.0.0
github.com/hashicorp/go-uuid v1.0.3
github.com/hashicorp/go-version v1.7.0
github.com/hashicorp/golang-lru/v2 v2.0.1
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/hashicorp/hcl v1.0.1-vault-3
github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d
github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40

4
go.sum
View File

@@ -1269,8 +1269,8 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c=
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4=
github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee h1:8B4HqvMUtYSjsGkYjiQGStc9pXffY2J+Z2SPQAj+wMY=
github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee/go.mod h1:gwlu9+/P9MmKtYrMsHeFRZPXj2CTPm11TDnMeaRHS7g=
github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d h1:7abftkc86B+tlA/0cDy5f6C4LgWfFOCpsGg3RJZsfbw=

View File

@@ -4,80 +4,64 @@
package oidc
import (
"context"
"errors"
"sync"
"time"
//"github.com/coreos/go-oidc/v3/oidc"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/golang-lru/v2/expirable"
)
var ErrNonceReuse = errors.New("nonce reuse detected")
var (
ErrNonceReuse = errors.New("nonce reuse detected")
ErrTooManyRequests = errors.New("too many auth requests")
)
// expiringRequest ensures that OIDC requests that are only partially fulfilled
// do not get stuck in memory forever.
type expiringRequest struct {
// req is what we actually care about
req *oidc.Req
// ctx lets us clean up stale requests automatically
ctx context.Context
cancel context.CancelFunc
}
// MaxRequests is how many requests are allowed to be stored at a time.
// It needs to be large enough for legitimate user traffic, but small enough
// to prevent a DOS from eating up server memory.
const MaxRequests = 10000
// NewRequestCache creates a cache for OIDC requests.
func NewRequestCache() *RequestCache {
// The JWT expiration time in the cap library is 5 minutes,
// so timeout should be around that long.
func NewRequestCache(timeout time.Duration) *RequestCache {
return &RequestCache{
m: sync.Map{},
// the JWT expiration time in cap library is 5 minutes,
// so auto-delete from our request cache after 6.
timeout: 6 * time.Minute,
c: expirable.NewLRU[string, *oidc.Req](MaxRequests, nil, timeout),
}
}
type RequestCache struct {
m sync.Map
timeout time.Duration
c *expirable.LRU[string, *oidc.Req]
}
// Store saves the request, to be Loaded later with its Nonce.
// If LoadAndDelete is not called, the stale request will be auto-deleted.
func (rc *RequestCache) Store(ctx context.Context, req *oidc.Req) error {
ctx, cancel := context.WithTimeout(ctx, rc.timeout)
er := &expiringRequest{
req: req,
ctx: ctx,
cancel: cancel,
func (rc *RequestCache) Store(req *oidc.Req) error {
if rc.c.Len() >= MaxRequests {
return ErrTooManyRequests
}
if _, loaded := rc.m.LoadOrStore(req.Nonce(), er); loaded {
// we already had a request for this nonce, which should never happen,
// so cancel the new request and error to notify caller of a bug.
cancel()
if _, ok := rc.c.Get(req.Nonce()); ok {
// we already had a request for this nonce (should never happen)
return ErrNonceReuse
}
// auto-delete after timeout or context canceled
go func() {
<-ctx.Done()
rc.m.Delete(req.Nonce())
}()
rc.c.Add(req.Nonce(), req)
return nil
}
func (rc *RequestCache) Load(nonce string) *oidc.Req {
if er, ok := rc.m.Load(nonce); ok {
return er.(*expiringRequest).req
if req, ok := rc.c.Get(nonce); ok {
return req
}
return nil
}
func (rc *RequestCache) LoadAndDelete(nonce string) *oidc.Req {
if er, loaded := rc.m.LoadAndDelete(nonce); loaded {
// there is a tiny race condition here. if by massive coincidence,
// or a bug, the same nonce makes its way in here, this cancel()
// triggers a map Delete() up in Store(), which could delete a request
// out from under a subsequent Store()
er.(*expiringRequest).cancel()
return er.(*expiringRequest).req
if req, ok := rc.c.Get(nonce); ok {
rc.c.Remove(nonce)
return req
}
return nil
}

View File

@@ -4,7 +4,7 @@
package oidc
import (
"context"
"fmt"
"testing"
"time"
@@ -16,41 +16,25 @@ import (
func TestRequestCache(t *testing.T) {
// using a top-level cache and running each sub-test in parallel exercises
// a little bit of thread safety.
rc := NewRequestCache()
rc := NewRequestCache(time.Minute)
t.Run("reuse nonce", func(t *testing.T) {
t.Parallel()
req := getRequest(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
must.NoError(t, rc.Store(ctx, req))
must.ErrorIs(t, rc.Store(ctx, req), ErrNonceReuse)
})
t.Run("cancel parent ctx", func(t *testing.T) {
t.Parallel()
req := getRequest(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
must.NoError(t, rc.Store(ctx, req))
must.Eq(t, req, rc.Load(req.Nonce()))
cancel() // triggers delete
waitUntilGone(t, rc, req.Nonce())
must.NoError(t, rc.Store(req))
must.ErrorIs(t, rc.Store(req), ErrNonceReuse)
})
t.Run("load and delete", func(t *testing.T) {
t.Parallel()
req := getRequest(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
must.NoError(t, rc.Store(ctx, req))
must.NoError(t, rc.Store(req))
must.Eq(t, req, rc.Load(req.Nonce()))
must.Eq(t, req, rc.LoadAndDelete(req.Nonce())) // triggers delete
must.Eq(t, req, rc.LoadAndDelete(req.Nonce()))
waitUntilGone(t, rc, req.Nonce())
must.Nil(t, rc.LoadAndDelete(req.Nonce()))
})
@@ -58,18 +42,34 @@ func TestRequestCache(t *testing.T) {
t.Run("timeout", func(t *testing.T) {
// this test needs its own cache to reduce the timeout
// without affecting any other tests.
rc := NewRequestCache()
rc.timeout = time.Millisecond
rc := NewRequestCache(50 * time.Millisecond)
req := getRequest(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
must.NoError(t, rc.Store(ctx, req))
must.NoError(t, rc.Store(req))
// timeout triggers delete behind the scenes
waitUntilGone(t, rc, req.Nonce())
})
t.Run("too many requests", func(t *testing.T) {
// not Parallel, would make other tests flaky
defer rc.c.Purge()
var gotErr error
for i := range MaxRequests + 5 {
req, err := oidc.NewRequest(time.Minute, "test-redirect-url",
oidc.WithNonce(fmt.Sprintf("too-many-cooks-%d", i)))
must.NoError(t, err)
if err := rc.Store(req); err != nil {
gotErr = err
break
}
}
must.ErrorIs(t, gotErr, ErrTooManyRequests)
})
}
func getRequest(t *testing.T) *oidc.Req {

View File

@@ -2625,7 +2625,7 @@ func (a *ACL) OIDCAuthURL(args *structs.ACLOIDCAuthURLRequest, reply *structs.AC
if err != nil {
return err
}
if err = a.oidcRequestCache.Store(a.srv.shutdownCtx, oidcReq); err != nil {
if err = a.oidcRequestCache.Store(oidcReq); err != nil {
return fmt.Errorf("error storing OIDC request: %w", err)
}
}

View File

@@ -5,7 +5,6 @@ package nomad
import (
"bytes"
"context"
"fmt"
"io"
"net/url"
@@ -4218,9 +4217,7 @@ func cacheOIDCRequest(t *testing.T, cache *oidc.RequestCache, req structs.ACLOID
opts...,
)
must.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
t.Cleanup(cancel)
// make sure the cache is clean first
cache.LoadAndDelete(req.ClientNonce)
must.NoError(t, cache.Store(ctx, oidcReq))
must.NoError(t, cache.Store(oidcReq))
}

View File

@@ -440,7 +440,8 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigFunc
// ACL.OIDCAuthURL and ACL.OIDCCompleteAuth.
// It needs no special handling to handle agent shutdowns (its Store method
// handles this lifecycle).
s.oidcRequestCache = oidc.NewRequestCache()
// 6 minutes is 1 minute longer than the JWT expiration time in the cap lib.
s.oidcRequestCache = oidc.NewRequestCache(6 * time.Minute)
// Initialize the RPC layer
if err := s.setupRPC(tlsWrap); err != nil {