Files
reproxy/vendor/github.com/didip/tollbooth/v7/limiter/limiter.go
dependabot[bot] 1e81d22fef Bump the go-modules-updates group with 2 updates
Bumps the go-modules-updates group with 2 updates: [github.com/didip/tollbooth/v7](https://github.com/didip/tollbooth) and [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `github.com/didip/tollbooth/v7` from 7.0.1 to 7.0.2
- [Commits](https://github.com/didip/tollbooth/compare/v7.0.1...v7.0.2)

Updates `golang.org/x/crypto` from 0.23.0 to 0.24.0
- [Commits](https://github.com/golang/crypto/compare/v0.23.0...v0.24.0)

---
updated-dependencies:
- dependency-name: github.com/didip/tollbooth/v7
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: go-modules-updates
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: go-modules-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-07-28 02:22:59 -05:00

611 lines
15 KiB
Go

// Package limiter provides data structure to configure rate-limiter.
package limiter
import (
"net/http"
"sync"
"time"
cache "github.com/go-pkgz/expirable-cache/v3"
"github.com/didip/tollbooth/v7/internal/time/rate"
)
// New is a constructor for Limiter.
func New(generalExpirableOptions *ExpirableOptions) *Limiter {
lmt := &Limiter{}
lmt.SetMessageContentType("text/plain; charset=utf-8").
SetMessage("You have reached maximum request limit.").
SetStatusCode(429).
SetOnLimitReached(nil).
SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}).
SetForwardedForIndexFromBehind(0).
SetHeaders(make(map[string][]string)).
SetContextValues(make(map[string][]string)).
SetIgnoreURL(false)
if generalExpirableOptions != nil {
lmt.generalExpirableOptions = generalExpirableOptions
} else {
lmt.generalExpirableOptions = &ExpirableOptions{}
}
// Default for DefaultExpirationTTL is 10 years.
if lmt.generalExpirableOptions.DefaultExpirationTTL <= 0 {
lmt.generalExpirableOptions.DefaultExpirationTTL = 87600 * time.Hour
}
lmt.tokenBuckets = cache.NewCache[string, *rate.Limiter]().WithTTL(lmt.generalExpirableOptions.DefaultExpirationTTL)
lmt.basicAuthUsers = cache.NewCache[string, bool]().WithTTL(lmt.generalExpirableOptions.DefaultExpirationTTL)
return lmt
}
// Limiter is a config struct to limit a particular request handler.
type Limiter struct {
// Maximum number of requests to limit per second.
max float64
// Limiter burst size
burst int
// HTTP message when limit is reached.
message string
// Content-Type for Message
messageContentType string
// HTTP status code when limit is reached.
statusCode int
// A function to call when a request is rejected.
onLimitReached func(w http.ResponseWriter, r *http.Request)
// An option to write back what you want upon reaching a limit.
overrideDefaultResponseWriter bool
// List of places to look up IP address.
// Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP".
// You can rearrange the order as you like.
ipLookups []string
forwardedForIndex int
// List of HTTP Methods to limit (GET, POST, PUT, etc.).
// Empty means limit all methods.
methods []string
// Able to configure token bucket expirations.
generalExpirableOptions *ExpirableOptions
// List of basic auth usernames to limit.
basicAuthUsers cache.Cache[string, bool]
// Map of HTTP headers to limit.
// Empty means skip headers checking.
headers map[string]cache.Cache[string, bool]
// Map of Context values to limit.
contextValues map[string]cache.Cache[string, bool]
// Map of limiters with TTL
tokenBuckets cache.Cache[string, *rate.Limiter]
// Ignore URL on the rate limiter keys
ignoreURL bool
tokenBucketExpirationTTL time.Duration
basicAuthExpirationTTL time.Duration
headerEntryExpirationTTL time.Duration
contextEntryExpirationTTL time.Duration
sync.RWMutex
}
// SetTokenBucketExpirationTTL is thread-safe way of setting custom token bucket expiration TTL.
func (l *Limiter) SetTokenBucketExpirationTTL(ttl time.Duration) *Limiter {
l.Lock()
l.tokenBucketExpirationTTL = ttl
l.Unlock()
return l
}
// GetTokenBucketExpirationTTL is thread-safe way of getting custom token bucket expiration TTL.
func (l *Limiter) GetTokenBucketExpirationTTL() time.Duration {
l.RLock()
defer l.RUnlock()
return l.tokenBucketExpirationTTL
}
// SetBasicAuthExpirationTTL is thread-safe way of setting custom basic auth expiration TTL.
func (l *Limiter) SetBasicAuthExpirationTTL(ttl time.Duration) *Limiter {
l.Lock()
l.basicAuthExpirationTTL = ttl
l.Unlock()
return l
}
// GetBasicAuthExpirationTTL is thread-safe way of getting custom basic auth expiration TTL.
func (l *Limiter) GetBasicAuthExpirationTTL() time.Duration {
l.RLock()
defer l.RUnlock()
return l.basicAuthExpirationTTL
}
// SetHeaderEntryExpirationTTL is thread-safe way of setting custom basic auth expiration TTL.
func (l *Limiter) SetHeaderEntryExpirationTTL(ttl time.Duration) *Limiter {
l.Lock()
l.headerEntryExpirationTTL = ttl
l.Unlock()
return l
}
// GetHeaderEntryExpirationTTL is thread-safe way of getting custom basic auth expiration TTL.
func (l *Limiter) GetHeaderEntryExpirationTTL() time.Duration {
l.RLock()
defer l.RUnlock()
return l.headerEntryExpirationTTL
}
// SetContextValueEntryExpirationTTL is thread-safe way of setting custom Context value expiration TTL.
func (l *Limiter) SetContextValueEntryExpirationTTL(ttl time.Duration) *Limiter {
l.Lock()
l.contextEntryExpirationTTL = ttl
l.Unlock()
return l
}
// GetContextValueEntryExpirationTTL is thread-safe way of getting custom Context value expiration TTL.
func (l *Limiter) GetContextValueEntryExpirationTTL() time.Duration {
l.RLock()
defer l.RUnlock()
return l.contextEntryExpirationTTL
}
// SetMax is thread-safe way of setting maximum number of requests to limit per second.
func (l *Limiter) SetMax(max float64) *Limiter {
l.Lock()
l.max = max
l.Unlock()
return l
}
// GetMax is thread-safe way of getting maximum number of requests to limit per second.
func (l *Limiter) GetMax() float64 {
l.RLock()
defer l.RUnlock()
return l.max
}
// SetBurst is thread-safe way of setting maximum burst size.
func (l *Limiter) SetBurst(burst int) *Limiter {
l.Lock()
l.burst = burst
l.Unlock()
return l
}
// GetBurst is thread-safe way of setting maximum burst size.
func (l *Limiter) GetBurst() int {
l.RLock()
defer l.RUnlock()
return l.burst
}
// SetMessage is thread-safe way of setting HTTP message when limit is reached.
func (l *Limiter) SetMessage(msg string) *Limiter {
l.Lock()
l.message = msg
l.Unlock()
return l
}
// GetMessage is thread-safe way of getting HTTP message when limit is reached.
func (l *Limiter) GetMessage() string {
l.RLock()
defer l.RUnlock()
return l.message
}
// SetMessageContentType is thread-safe way of setting HTTP message Content-Type when limit is reached.
func (l *Limiter) SetMessageContentType(contentType string) *Limiter {
l.Lock()
l.messageContentType = contentType
l.Unlock()
return l
}
// GetMessageContentType is thread-safe way of getting HTTP message Content-Type when limit is reached.
func (l *Limiter) GetMessageContentType() string {
l.RLock()
defer l.RUnlock()
return l.messageContentType
}
// SetStatusCode is thread-safe way of setting HTTP status code when limit is reached.
func (l *Limiter) SetStatusCode(statusCode int) *Limiter {
l.Lock()
l.statusCode = statusCode
l.Unlock()
return l
}
// GetStatusCode is thread-safe way of getting HTTP status code when limit is reached.
func (l *Limiter) GetStatusCode() int {
l.RLock()
defer l.RUnlock()
return l.statusCode
}
// SetOnLimitReached is thread-safe way of setting after-rejection function when limit is reached.
func (l *Limiter) SetOnLimitReached(fn func(w http.ResponseWriter, r *http.Request)) *Limiter {
l.Lock()
l.onLimitReached = fn
l.Unlock()
return l
}
// ExecOnLimitReached is thread-safe way of executing after-rejection function when limit is reached.
func (l *Limiter) ExecOnLimitReached(w http.ResponseWriter, r *http.Request) {
l.RLock()
fn := l.onLimitReached
l.RUnlock()
if fn != nil {
fn(w, r)
}
}
// SetOverrideDefaultResponseWriter is a thread-safe way of setting the response writer override variable.
func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) {
l.Lock()
l.overrideDefaultResponseWriter = override
l.Unlock()
}
// GetOverrideDefaultResponseWriter is a thread-safe way of getting the response writer override variable.
func (l *Limiter) GetOverrideDefaultResponseWriter() bool {
l.RLock()
defer l.RUnlock()
return l.overrideDefaultResponseWriter
}
// SetIPLookups is thread-safe way of setting list of places to look up IP address.
func (l *Limiter) SetIPLookups(ipLookups []string) *Limiter {
l.Lock()
l.ipLookups = ipLookups
l.Unlock()
return l
}
// GetIPLookups is thread-safe way of getting list of places to look up IP address.
func (l *Limiter) GetIPLookups() []string {
l.RLock()
defer l.RUnlock()
return l.ipLookups
}
// SetIgnoreURL is thread-safe way of setting whenever ignore the URL on rate limit keys
func (l *Limiter) SetIgnoreURL(enabled bool) *Limiter {
l.Lock()
l.ignoreURL = enabled
l.Unlock()
return l
}
// GetIgnoreURL returns whether the URL is ignored in the rate limit key set
func (l *Limiter) GetIgnoreURL() bool {
l.RLock()
defer l.RUnlock()
return l.ignoreURL
}
// SetForwardedForIndexFromBehind is thread-safe way of setting which X-Forwarded-For index to choose.
func (l *Limiter) SetForwardedForIndexFromBehind(forwardedForIndex int) *Limiter {
l.Lock()
l.forwardedForIndex = forwardedForIndex
l.Unlock()
return l
}
// GetForwardedForIndexFromBehind is thread-safe way of getting which X-Forwarded-For index to choose.
func (l *Limiter) GetForwardedForIndexFromBehind() int {
l.RLock()
defer l.RUnlock()
return l.forwardedForIndex
}
// SetMethods is thread-safe way of setting list of HTTP Methods to limit (GET, POST, PUT, etc.).
func (l *Limiter) SetMethods(methods []string) *Limiter {
l.Lock()
l.methods = methods
l.Unlock()
return l
}
// GetMethods is thread-safe way of getting list of HTTP Methods to limit (GET, POST, PUT, etc.).
func (l *Limiter) GetMethods() []string {
l.RLock()
defer l.RUnlock()
return l.methods
}
// SetBasicAuthUsers is thread-safe way of setting list of basic auth usernames to limit.
func (l *Limiter) SetBasicAuthUsers(basicAuthUsers []string) *Limiter {
ttl := l.GetBasicAuthExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
for _, basicAuthUser := range basicAuthUsers {
l.basicAuthUsers.Set(basicAuthUser, true, ttl)
}
return l
}
// GetBasicAuthUsers is thread-safe way of getting list of basic auth usernames to limit.
func (l *Limiter) GetBasicAuthUsers() []string {
return l.basicAuthUsers.Keys()
}
// RemoveBasicAuthUsers is thread-safe way of removing basic auth usernames from existing list.
func (l *Limiter) RemoveBasicAuthUsers(basicAuthUsers []string) *Limiter {
for _, toBeRemoved := range basicAuthUsers {
l.basicAuthUsers.Invalidate(toBeRemoved)
}
return l
}
// DeleteExpiredTokenBuckets is thread-safe way of deleting expired token buckets
func (l *Limiter) DeleteExpiredTokenBuckets() {
l.tokenBuckets.DeleteExpired()
}
// SetHeaders is thread-safe way of setting map of HTTP headers to limit.
func (l *Limiter) SetHeaders(headers map[string][]string) *Limiter {
if l.headers == nil {
l.headers = make(map[string]cache.Cache[string, bool])
}
for header, entries := range headers {
l.SetHeader(header, entries)
}
return l
}
// GetHeaders is thread-safe way of getting map of HTTP headers to limit.
func (l *Limiter) GetHeaders() map[string][]string {
results := make(map[string][]string)
l.RLock()
defer l.RUnlock()
for header, entriesAsGoCache := range l.headers {
results[header] = entriesAsGoCache.Keys()
}
return results
}
// SetHeader is thread-safe way of setting entries of 1 HTTP header.
func (l *Limiter) SetHeader(header string, entries []string) *Limiter {
l.RLock()
existing, found := l.headers[header]
l.RUnlock()
ttl := l.GetHeaderEntryExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
if !found {
existing = cache.NewCache[string, bool]().WithTTL(ttl)
}
for _, entry := range entries {
existing.Set(entry, true, ttl)
}
l.Lock()
l.headers[header] = existing
l.Unlock()
return l
}
// GetHeader is thread-safe way of getting entries of 1 HTTP header.
func (l *Limiter) GetHeader(header string) []string {
l.RLock()
entriesAsGoCache := l.headers[header]
l.RUnlock()
return entriesAsGoCache.Keys()
}
// RemoveHeader is thread-safe way of removing entries of 1 HTTP header.
func (l *Limiter) RemoveHeader(header string) *Limiter {
ttl := l.GetHeaderEntryExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
l.Lock()
l.headers[header] = cache.NewCache[string, bool]().WithTTL(ttl)
l.Unlock()
return l
}
// RemoveHeaderEntries is thread-safe way of removing new entries to 1 HTTP header rule.
func (l *Limiter) RemoveHeaderEntries(header string, entriesForRemoval []string) *Limiter {
l.RLock()
entries, found := l.headers[header]
l.RUnlock()
if !found {
return l
}
for _, toBeRemoved := range entriesForRemoval {
entries.Invalidate(toBeRemoved)
}
return l
}
// SetContextValues is thread-safe way of setting map of HTTP headers to limit.
func (l *Limiter) SetContextValues(contextValues map[string][]string) *Limiter {
if l.contextValues == nil {
l.contextValues = make(map[string]cache.Cache[string, bool])
}
for contextValue, entries := range contextValues {
l.SetContextValue(contextValue, entries)
}
return l
}
// GetContextValues is thread-safe way of getting a map of Context values to limit.
func (l *Limiter) GetContextValues() map[string][]string {
results := make(map[string][]string)
l.RLock()
defer l.RUnlock()
for contextValue, entriesAsGoCache := range l.contextValues {
results[contextValue] = entriesAsGoCache.Keys()
}
return results
}
// SetContextValue is thread-safe way of setting entries of 1 Context value.
func (l *Limiter) SetContextValue(contextValue string, entries []string) *Limiter {
l.RLock()
existing, found := l.contextValues[contextValue]
l.RUnlock()
ttl := l.GetContextValueEntryExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
if !found {
existing = cache.NewCache[string, bool]().WithTTL(ttl)
}
for _, entry := range entries {
existing.Set(entry, true, ttl)
}
l.Lock()
l.contextValues[contextValue] = existing
l.Unlock()
return l
}
// GetContextValue is thread-safe way of getting 1 Context value entry.
func (l *Limiter) GetContextValue(contextValue string) []string {
l.RLock()
entriesAsGoCache := l.contextValues[contextValue]
l.RUnlock()
return entriesAsGoCache.Keys()
}
// RemoveContextValue is thread-safe way of removing entries of 1 Context value.
func (l *Limiter) RemoveContextValue(contextValue string) *Limiter {
ttl := l.GetContextValueEntryExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
l.Lock()
l.contextValues[contextValue] = cache.NewCache[string, bool]().WithTTL(ttl)
l.Unlock()
return l
}
// RemoveContextValuesEntries is thread-safe way of removing entries to a ContextValue.
func (l *Limiter) RemoveContextValuesEntries(contextValue string, entriesForRemoval []string) *Limiter {
l.RLock()
entries, found := l.contextValues[contextValue]
l.RUnlock()
if !found {
return l
}
for _, toBeRemoved := range entriesForRemoval {
entries.Invalidate(toBeRemoved)
}
return l
}
func (l *Limiter) limitReachedWithTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
lmtMax := l.GetMax()
lmtBurst := l.GetBurst()
l.Lock()
defer l.Unlock()
if _, found := l.tokenBuckets.Get(key); !found {
l.tokenBuckets.Set(
key,
rate.NewLimiter(rate.Limit(lmtMax), lmtBurst),
tokenBucketTTL,
)
}
expiringMap, found := l.tokenBuckets.Get(key)
if !found {
return false
}
return !expiringMap.Allow()
}
// LimitReached returns a bool indicating if the Bucket identified by key ran out of tokens.
func (l *Limiter) LimitReached(key string) bool {
ttl := l.GetTokenBucketExpirationTTL()
if ttl <= 0 {
ttl = l.generalExpirableOptions.DefaultExpirationTTL
}
return l.limitReachedWithTokenBucketTTL(key, ttl)
}
// Tokens returns current amount of tokens left in the Bucket identified by key.
func (l *Limiter) Tokens(key string) int {
expiringMap, found := l.tokenBuckets.Get(key)
if !found {
return 0
}
return int(expiringMap.TokensAt(time.Now()))
}