mirror of
https://github.com/kemko/nomad.git
synced 2026-01-05 18:05:42 +03:00
auth: oidc request lru cache (#25336)
use hashicorp/golang-lru instead of my hand-rolled cache
This commit is contained in:
2
go.mod
2
go.mod
@@ -75,7 +75,7 @@ require (
|
||||
github.com/hashicorp/go-syslog v1.0.0
|
||||
github.com/hashicorp/go-uuid v1.0.3
|
||||
github.com/hashicorp/go-version v1.7.0
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/hashicorp/hcl v1.0.1-vault-3
|
||||
github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d
|
||||
github.com/hashicorp/hil v0.0.0-20210521165536-27a72121fd40
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1269,8 +1269,8 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
|
||||
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c=
|
||||
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee h1:8B4HqvMUtYSjsGkYjiQGStc9pXffY2J+Z2SPQAj+wMY=
|
||||
github.com/hashicorp/hcl v1.0.1-0.20201016140508-a07e7d50bbee/go.mod h1:gwlu9+/P9MmKtYrMsHeFRZPXj2CTPm11TDnMeaRHS7g=
|
||||
github.com/hashicorp/hcl/v2 v2.20.2-0.20240517235513-55d9c02d147d h1:7abftkc86B+tlA/0cDy5f6C4LgWfFOCpsGg3RJZsfbw=
|
||||
|
||||
@@ -4,80 +4,64 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
//"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/hashicorp/cap/oidc"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
)
|
||||
|
||||
var ErrNonceReuse = errors.New("nonce reuse detected")
|
||||
var (
|
||||
ErrNonceReuse = errors.New("nonce reuse detected")
|
||||
ErrTooManyRequests = errors.New("too many auth requests")
|
||||
)
|
||||
|
||||
// expiringRequest ensures that OIDC requests that are only partially fulfilled
|
||||
// do not get stuck in memory forever.
|
||||
type expiringRequest struct {
|
||||
// req is what we actually care about
|
||||
req *oidc.Req
|
||||
// ctx lets us clean up stale requests automatically
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
// MaxRequests is how many requests are allowed to be stored at a time.
|
||||
// It needs to be large enough for legitimate user traffic, but small enough
|
||||
// to prevent a DOS from eating up server memory.
|
||||
const MaxRequests = 10000
|
||||
|
||||
// NewRequestCache creates a cache for OIDC requests.
|
||||
func NewRequestCache() *RequestCache {
|
||||
// The JWT expiration time in the cap library is 5 minutes,
|
||||
// so timeout should be around that long.
|
||||
func NewRequestCache(timeout time.Duration) *RequestCache {
|
||||
return &RequestCache{
|
||||
m: sync.Map{},
|
||||
// the JWT expiration time in cap library is 5 minutes,
|
||||
// so auto-delete from our request cache after 6.
|
||||
timeout: 6 * time.Minute,
|
||||
c: expirable.NewLRU[string, *oidc.Req](MaxRequests, nil, timeout),
|
||||
}
|
||||
}
|
||||
|
||||
type RequestCache struct {
|
||||
m sync.Map
|
||||
timeout time.Duration
|
||||
c *expirable.LRU[string, *oidc.Req]
|
||||
}
|
||||
|
||||
// Store saves the request, to be Loaded later with its Nonce.
|
||||
// If LoadAndDelete is not called, the stale request will be auto-deleted.
|
||||
func (rc *RequestCache) Store(ctx context.Context, req *oidc.Req) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, rc.timeout)
|
||||
er := &expiringRequest{
|
||||
req: req,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
func (rc *RequestCache) Store(req *oidc.Req) error {
|
||||
if rc.c.Len() >= MaxRequests {
|
||||
return ErrTooManyRequests
|
||||
}
|
||||
if _, loaded := rc.m.LoadOrStore(req.Nonce(), er); loaded {
|
||||
// we already had a request for this nonce, which should never happen,
|
||||
// so cancel the new request and error to notify caller of a bug.
|
||||
cancel()
|
||||
|
||||
if _, ok := rc.c.Get(req.Nonce()); ok {
|
||||
// we already had a request for this nonce (should never happen)
|
||||
return ErrNonceReuse
|
||||
}
|
||||
// auto-delete after timeout or context canceled
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
rc.m.Delete(req.Nonce())
|
||||
}()
|
||||
|
||||
rc.c.Add(req.Nonce(), req)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rc *RequestCache) Load(nonce string) *oidc.Req {
|
||||
if er, ok := rc.m.Load(nonce); ok {
|
||||
return er.(*expiringRequest).req
|
||||
if req, ok := rc.c.Get(nonce); ok {
|
||||
return req
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rc *RequestCache) LoadAndDelete(nonce string) *oidc.Req {
|
||||
if er, loaded := rc.m.LoadAndDelete(nonce); loaded {
|
||||
// there is a tiny race condition here. if by massive coincidence,
|
||||
// or a bug, the same nonce makes its way in here, this cancel()
|
||||
// triggers a map Delete() up in Store(), which could delete a request
|
||||
// out from under a subsequent Store()
|
||||
er.(*expiringRequest).cancel()
|
||||
return er.(*expiringRequest).req
|
||||
if req, ok := rc.c.Get(nonce); ok {
|
||||
rc.c.Remove(nonce)
|
||||
return req
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,41 +16,25 @@ import (
|
||||
func TestRequestCache(t *testing.T) {
|
||||
// using a top-level cache and running each sub-test in parallel exercises
|
||||
// a little bit of thread safety.
|
||||
rc := NewRequestCache()
|
||||
rc := NewRequestCache(time.Minute)
|
||||
|
||||
t.Run("reuse nonce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := getRequest(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
must.NoError(t, rc.Store(ctx, req))
|
||||
must.ErrorIs(t, rc.Store(ctx, req), ErrNonceReuse)
|
||||
})
|
||||
|
||||
t.Run("cancel parent ctx", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := getRequest(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
must.NoError(t, rc.Store(ctx, req))
|
||||
must.Eq(t, req, rc.Load(req.Nonce()))
|
||||
|
||||
cancel() // triggers delete
|
||||
waitUntilGone(t, rc, req.Nonce())
|
||||
must.NoError(t, rc.Store(req))
|
||||
must.ErrorIs(t, rc.Store(req), ErrNonceReuse)
|
||||
})
|
||||
|
||||
t.Run("load and delete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := getRequest(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
must.NoError(t, rc.Store(ctx, req))
|
||||
must.NoError(t, rc.Store(req))
|
||||
must.Eq(t, req, rc.Load(req.Nonce()))
|
||||
|
||||
must.Eq(t, req, rc.LoadAndDelete(req.Nonce())) // triggers delete
|
||||
must.Eq(t, req, rc.LoadAndDelete(req.Nonce()))
|
||||
|
||||
waitUntilGone(t, rc, req.Nonce())
|
||||
must.Nil(t, rc.LoadAndDelete(req.Nonce()))
|
||||
})
|
||||
@@ -58,18 +42,34 @@ func TestRequestCache(t *testing.T) {
|
||||
t.Run("timeout", func(t *testing.T) {
|
||||
// this test needs its own cache to reduce the timeout
|
||||
// without affecting any other tests.
|
||||
rc := NewRequestCache()
|
||||
rc.timeout = time.Millisecond
|
||||
rc := NewRequestCache(50 * time.Millisecond)
|
||||
|
||||
req := getRequest(t)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
must.NoError(t, rc.Store(ctx, req))
|
||||
must.NoError(t, rc.Store(req))
|
||||
|
||||
// timeout triggers delete behind the scenes
|
||||
waitUntilGone(t, rc, req.Nonce())
|
||||
})
|
||||
|
||||
t.Run("too many requests", func(t *testing.T) {
|
||||
// not Parallel, would make other tests flaky
|
||||
defer rc.c.Purge()
|
||||
|
||||
var gotErr error
|
||||
for i := range MaxRequests + 5 {
|
||||
req, err := oidc.NewRequest(time.Minute, "test-redirect-url",
|
||||
oidc.WithNonce(fmt.Sprintf("too-many-cooks-%d", i)))
|
||||
must.NoError(t, err)
|
||||
|
||||
if err := rc.Store(req); err != nil {
|
||||
gotErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
must.ErrorIs(t, gotErr, ErrTooManyRequests)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func getRequest(t *testing.T) *oidc.Req {
|
||||
|
||||
@@ -2625,7 +2625,7 @@ func (a *ACL) OIDCAuthURL(args *structs.ACLOIDCAuthURLRequest, reply *structs.AC
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = a.oidcRequestCache.Store(a.srv.shutdownCtx, oidcReq); err != nil {
|
||||
if err = a.oidcRequestCache.Store(oidcReq); err != nil {
|
||||
return fmt.Errorf("error storing OIDC request: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ package nomad
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
@@ -4218,9 +4217,7 @@ func cacheOIDCRequest(t *testing.T, cache *oidc.RequestCache, req structs.ACLOID
|
||||
opts...,
|
||||
)
|
||||
must.NoError(t, err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
t.Cleanup(cancel)
|
||||
// make sure the cache is clean first
|
||||
cache.LoadAndDelete(req.ClientNonce)
|
||||
must.NoError(t, cache.Store(ctx, oidcReq))
|
||||
must.NoError(t, cache.Store(oidcReq))
|
||||
}
|
||||
|
||||
@@ -440,7 +440,8 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigFunc
|
||||
// ACL.OIDCAuthURL and ACL.OIDCCompleteAuth.
|
||||
// It needs no special handling to handle agent shutdowns (its Store method
|
||||
// handles this lifecycle).
|
||||
s.oidcRequestCache = oidc.NewRequestCache()
|
||||
// 6 minutes is 1 minute longer than the JWT expiration time in the cap lib.
|
||||
s.oidcRequestCache = oidc.NewRequestCache(6 * time.Minute)
|
||||
|
||||
// Initialize the RPC layer
|
||||
if err := s.setupRPC(tlsWrap); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user