update github.com/didip/tollbooth from v6 to v7

This commit is contained in:
Dmitry Verkhoturov
2024-05-09 02:04:46 +02:00
committed by Umputun
parent d9085e7dcc
commit 0a9e489743
26 changed files with 249 additions and 292 deletions

View File

@@ -9,19 +9,19 @@ import (
// ConsulClientMock is a mock implementation of ConsulClient. // ConsulClientMock is a mock implementation of ConsulClient.
// //
// func TestSomethingThatUsesConsulClient(t *testing.T) { // func TestSomethingThatUsesConsulClient(t *testing.T) {
// //
// // make and configure a mocked ConsulClient // // make and configure a mocked ConsulClient
// mockedConsulClient := &ConsulClientMock{ // mockedConsulClient := &ConsulClientMock{
// GetFunc: func() ([]consulService, error) { // GetFunc: func() ([]consulService, error) {
// panic("mock out the Get method") // panic("mock out the Get method")
// }, // },
// } // }
// //
// // use mockedConsulClient in code that requires ConsulClient // // use mockedConsulClient in code that requires ConsulClient
// // and then make assertions. // // and then make assertions.
// //
// } // }
type ConsulClientMock struct { type ConsulClientMock struct {
// GetFunc mocks the Get method. // GetFunc mocks the Get method.
GetFunc func() ([]consulService, error) GetFunc func() ([]consulService, error)
@@ -50,7 +50,8 @@ func (mock *ConsulClientMock) Get() ([]consulService, error) {
// GetCalls gets all the calls that were made to Get. // GetCalls gets all the calls that were made to Get.
// Check the length with: // Check the length with:
// len(mockedConsulClient.GetCalls()) //
// len(mockedConsulClient.GetCalls())
func (mock *ConsulClientMock) GetCalls() []struct { func (mock *ConsulClientMock) GetCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -9,19 +9,19 @@ import (
// DockerClientMock is a mock implementation of DockerClient. // DockerClientMock is a mock implementation of DockerClient.
// //
// func TestSomethingThatUsesDockerClient(t *testing.T) { // func TestSomethingThatUsesDockerClient(t *testing.T) {
// //
// // make and configure a mocked DockerClient // // make and configure a mocked DockerClient
// mockedDockerClient := &DockerClientMock{ // mockedDockerClient := &DockerClientMock{
// ListContainersFunc: func() ([]containerInfo, error) { // ListContainersFunc: func() ([]containerInfo, error) {
// panic("mock out the ListContainers method") // panic("mock out the ListContainers method")
// }, // },
// } // }
// //
// // use mockedDockerClient in code that requires DockerClient // // use mockedDockerClient in code that requires DockerClient
// // and then make assertions. // // and then make assertions.
// //
// } // }
type DockerClientMock struct { type DockerClientMock struct {
// ListContainersFunc mocks the ListContainers method. // ListContainersFunc mocks the ListContainers method.
ListContainersFunc func() ([]containerInfo, error) ListContainersFunc func() ([]containerInfo, error)
@@ -50,7 +50,8 @@ func (mock *DockerClientMock) ListContainers() ([]containerInfo, error) {
// ListContainersCalls gets all the calls that were made to ListContainers. // ListContainersCalls gets all the calls that were made to ListContainers.
// Check the length with: // Check the length with:
// len(mockedDockerClient.ListContainersCalls()) //
// len(mockedDockerClient.ListContainersCalls())
func (mock *DockerClientMock) ListContainersCalls() []struct { func (mock *DockerClientMock) ListContainersCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -14,22 +14,22 @@ var _ Provider = &ProviderMock{}
// ProviderMock is a mock implementation of Provider. // ProviderMock is a mock implementation of Provider.
// //
// func TestSomethingThatUsesProvider(t *testing.T) { // func TestSomethingThatUsesProvider(t *testing.T) {
// //
// // make and configure a mocked Provider // // make and configure a mocked Provider
// mockedProvider := &ProviderMock{ // mockedProvider := &ProviderMock{
// EventsFunc: func(ctx context.Context) <-chan ProviderID { // EventsFunc: func(ctx context.Context) <-chan ProviderID {
// panic("mock out the Events method") // panic("mock out the Events method")
// }, // },
// ListFunc: func() ([]URLMapper, error) { // ListFunc: func() ([]URLMapper, error) {
// panic("mock out the List method") // panic("mock out the List method")
// }, // },
// } // }
// //
// // use mockedProvider in code that requires Provider // // use mockedProvider in code that requires Provider
// // and then make assertions. // // and then make assertions.
// //
// } // }
type ProviderMock struct { type ProviderMock struct {
// EventsFunc mocks the Events method. // EventsFunc mocks the Events method.
EventsFunc func(ctx context.Context) <-chan ProviderID EventsFunc func(ctx context.Context) <-chan ProviderID
@@ -70,7 +70,8 @@ func (mock *ProviderMock) Events(ctx context.Context) <-chan ProviderID {
// EventsCalls gets all the calls that were made to Events. // EventsCalls gets all the calls that were made to Events.
// Check the length with: // Check the length with:
// len(mockedProvider.EventsCalls()) //
// len(mockedProvider.EventsCalls())
func (mock *ProviderMock) EventsCalls() []struct { func (mock *ProviderMock) EventsCalls() []struct {
Ctx context.Context Ctx context.Context
} { } {
@@ -98,7 +99,8 @@ func (mock *ProviderMock) List() ([]URLMapper, error) {
// ListCalls gets all the calls that were made to List. // ListCalls gets all the calls that were made to List.
// Check the length with: // Check the length with:
// len(mockedProvider.ListCalls()) //
// len(mockedProvider.ListCalls())
func (mock *ProviderMock) ListCalls() []struct { func (mock *ProviderMock) ListCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -15,19 +15,19 @@ var _ Informer = &InformerMock{}
// InformerMock is a mock implementation of Informer. // InformerMock is a mock implementation of Informer.
// //
// func TestSomethingThatUsesInformer(t *testing.T) { // func TestSomethingThatUsesInformer(t *testing.T) {
// //
// // make and configure a mocked Informer // // make and configure a mocked Informer
// mockedInformer := &InformerMock{ // mockedInformer := &InformerMock{
// MappersFunc: func() []discovery.URLMapper { // MappersFunc: func() []discovery.URLMapper {
// panic("mock out the Mappers method") // panic("mock out the Mappers method")
// }, // },
// } // }
// //
// // use mockedInformer in code that requires Informer // // use mockedInformer in code that requires Informer
// // and then make assertions. // // and then make assertions.
// //
// } // }
type InformerMock struct { type InformerMock struct {
// MappersFunc mocks the Mappers method. // MappersFunc mocks the Mappers method.
MappersFunc func() []discovery.URLMapper MappersFunc func() []discovery.URLMapper
@@ -56,7 +56,8 @@ func (mock *InformerMock) Mappers() []discovery.URLMapper {
// MappersCalls gets all the calls that were made to Mappers. // MappersCalls gets all the calls that were made to Mappers.
// Check the length with: // Check the length with:
// len(mockedInformer.MappersCalls()) //
// len(mockedInformer.MappersCalls())
func (mock *InformerMock) MappersCalls() []struct { func (mock *InformerMock) MappersCalls() []struct {
} { } {
var calls []struct { var calls []struct {

View File

@@ -13,19 +13,19 @@ var _ RPCClient = &RPCClientMock{}
// RPCClientMock is a mock implementation of RPCClient. // RPCClientMock is a mock implementation of RPCClient.
// //
// func TestSomethingThatUsesRPCClient(t *testing.T) { // func TestSomethingThatUsesRPCClient(t *testing.T) {
// //
// // make and configure a mocked RPCClient // // make and configure a mocked RPCClient
// mockedRPCClient := &RPCClientMock{ // mockedRPCClient := &RPCClientMock{
// CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error { // CallFunc: func(serviceMethod string, args interface{}, reply interface{}) error {
// panic("mock out the Call method") // panic("mock out the Call method")
// }, // },
// } // }
// //
// // use mockedRPCClient in code that requires RPCClient // // use mockedRPCClient in code that requires RPCClient
// // and then make assertions. // // and then make assertions.
// //
// } // }
type RPCClientMock struct { type RPCClientMock struct {
// CallFunc mocks the Call method. // CallFunc mocks the Call method.
CallFunc func(serviceMethod string, args interface{}, reply interface{}) error CallFunc func(serviceMethod string, args interface{}, reply interface{}) error
@@ -67,7 +67,8 @@ func (mock *RPCClientMock) Call(serviceMethod string, args interface{}, reply in
// CallCalls gets all the calls that were made to Call. // CallCalls gets all the calls that were made to Call.
// Check the length with: // Check the length with:
// len(mockedRPCClient.CallCalls()) //
// len(mockedRPCClient.CallCalls())
func (mock *RPCClientMock) CallCalls() []struct { func (mock *RPCClientMock) CallCalls() []struct {
ServiceMethod string ServiceMethod string
Args interface{} Args interface{}

View File

@@ -13,19 +13,19 @@ var _ RPCDialer = &RPCDialerMock{}
// RPCDialerMock is a mock implementation of RPCDialer. // RPCDialerMock is a mock implementation of RPCDialer.
// //
// func TestSomethingThatUsesRPCDialer(t *testing.T) { // func TestSomethingThatUsesRPCDialer(t *testing.T) {
// //
// // make and configure a mocked RPCDialer // // make and configure a mocked RPCDialer
// mockedRPCDialer := &RPCDialerMock{ // mockedRPCDialer := &RPCDialerMock{
// DialFunc: func(network string, address string) (RPCClient, error) { // DialFunc: func(network string, address string) (RPCClient, error) {
// panic("mock out the Dial method") // panic("mock out the Dial method")
// }, // },
// } // }
// //
// // use mockedRPCDialer in code that requires RPCDialer // // use mockedRPCDialer in code that requires RPCDialer
// // and then make assertions. // // and then make assertions.
// //
// } // }
type RPCDialerMock struct { type RPCDialerMock struct {
// DialFunc mocks the Dial method. // DialFunc mocks the Dial method.
DialFunc func(network string, address string) (RPCClient, error) DialFunc func(network string, address string) (RPCClient, error)
@@ -63,7 +63,8 @@ func (mock *RPCDialerMock) Dial(network string, address string) (RPCClient, erro
// DialCalls gets all the calls that were made to Dial. // DialCalls gets all the calls that were made to Dial.
// Check the length with: // Check the length with:
// len(mockedRPCDialer.DialCalls()) //
// len(mockedRPCDialer.DialCalls())
func (mock *RPCDialerMock) DialCalls() []struct { func (mock *RPCDialerMock) DialCalls() []struct {
Network string Network string
Address string Address string

View File

@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/didip/tollbooth/v6" "github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v6/libstring" "github.com/didip/tollbooth/v7/libstring"
log "github.com/go-pkgz/lgr" log "github.com/go-pkgz/lgr"
R "github.com/go-pkgz/rest" R "github.com/go-pkgz/rest"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"

View File

@@ -15,28 +15,28 @@ var _ Matcher = &MatcherMock{}
// MatcherMock is a mock implementation of Matcher. // MatcherMock is a mock implementation of Matcher.
// //
// func TestSomethingThatUsesMatcher(t *testing.T) { // func TestSomethingThatUsesMatcher(t *testing.T) {
// //
// // make and configure a mocked Matcher // // make and configure a mocked Matcher
// mockedMatcher := &MatcherMock{ // mockedMatcher := &MatcherMock{
// CheckHealthFunc: func() map[string]error { // CheckHealthFunc: func() map[string]error {
// panic("mock out the CheckHealth method") // panic("mock out the CheckHealth method")
// }, // },
// MappersFunc: func() []discovery.URLMapper { // MappersFunc: func() []discovery.URLMapper {
// panic("mock out the Mappers method") // panic("mock out the Mappers method")
// }, // },
// MatchFunc: func(srv string, src string) discovery.Matches { // MatchFunc: func(srv string, src string) discovery.Matches {
// panic("mock out the Match method") // panic("mock out the Match method")
// }, // },
// ServersFunc: func() []string { // ServersFunc: func() []string {
// panic("mock out the Servers method") // panic("mock out the Servers method")
// }, // },
// } // }
// //
// // use mockedMatcher in code that requires Matcher // // use mockedMatcher in code that requires Matcher
// // and then make assertions. // // and then make assertions.
// //
// } // }
type MatcherMock struct { type MatcherMock struct {
// CheckHealthFunc mocks the CheckHealth method. // CheckHealthFunc mocks the CheckHealth method.
CheckHealthFunc func() map[string]error CheckHealthFunc func() map[string]error
@@ -90,7 +90,8 @@ func (mock *MatcherMock) CheckHealth() map[string]error {
// CheckHealthCalls gets all the calls that were made to CheckHealth. // CheckHealthCalls gets all the calls that were made to CheckHealth.
// Check the length with: // Check the length with:
// len(mockedMatcher.CheckHealthCalls()) //
// len(mockedMatcher.CheckHealthCalls())
func (mock *MatcherMock) CheckHealthCalls() []struct { func (mock *MatcherMock) CheckHealthCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -116,7 +117,8 @@ func (mock *MatcherMock) Mappers() []discovery.URLMapper {
// MappersCalls gets all the calls that were made to Mappers. // MappersCalls gets all the calls that were made to Mappers.
// Check the length with: // Check the length with:
// len(mockedMatcher.MappersCalls()) //
// len(mockedMatcher.MappersCalls())
func (mock *MatcherMock) MappersCalls() []struct { func (mock *MatcherMock) MappersCalls() []struct {
} { } {
var calls []struct { var calls []struct {
@@ -147,7 +149,8 @@ func (mock *MatcherMock) Match(srv string, src string) discovery.Matches {
// MatchCalls gets all the calls that were made to Match. // MatchCalls gets all the calls that were made to Match.
// Check the length with: // Check the length with:
// len(mockedMatcher.MatchCalls()) //
// len(mockedMatcher.MatchCalls())
func (mock *MatcherMock) MatchCalls() []struct { func (mock *MatcherMock) MatchCalls() []struct {
Srv string Srv string
Src string Src string
@@ -177,7 +180,8 @@ func (mock *MatcherMock) Servers() []string {
// ServersCalls gets all the calls that were made to Servers. // ServersCalls gets all the calls that were made to Servers.
// Check the length with: // Check the length with:
// len(mockedMatcher.ServersCalls()) //
// len(mockedMatcher.ServersCalls())
func (mock *MatcherMock) ServersCalls() []struct { func (mock *MatcherMock) ServersCalls() []struct {
} { } {
var calls []struct { var calls []struct {

4
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/umputun/reproxy
go 1.22 go 1.22
require ( require (
github.com/didip/tollbooth/v6 v6.1.2 github.com/didip/tollbooth/v7 v7.0.1
github.com/go-pkgz/lgr v0.11.1 github.com/go-pkgz/lgr v0.11.1
github.com/go-pkgz/repeater v1.1.3 github.com/go-pkgz/repeater v1.1.3
github.com/go-pkgz/rest v1.19.0 github.com/go-pkgz/rest v1.19.0
@@ -22,6 +22,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-pkgz/expirable-cache v1.0.0 // indirect github.com/go-pkgz/expirable-cache v1.0.0 // indirect
github.com/kr/text v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.0 // indirect github.com/prometheus/client_model v0.6.0 // indirect
github.com/prometheus/common v0.50.0 // indirect github.com/prometheus/common v0.50.0 // indirect
@@ -29,6 +30,5 @@ require (
golang.org/x/net v0.22.0 // indirect golang.org/x/net v0.22.0 // indirect
golang.org/x/sys v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect google.golang.org/protobuf v1.33.0 // indirect
) )

15
go.sum
View File

@@ -5,11 +5,12 @@ github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/didip/tollbooth/v6 v6.1.2 h1:Kdqxmqw9YTv0uKajBUiWQg+GURL/k4vy9gmLCL01PjQ= github.com/didip/tollbooth/v7 v7.0.1 h1:TkT4sBKoQoHQFPf7blQ54iHrZiTDnr8TceU+MulVAog=
github.com/didip/tollbooth/v6 v6.1.2/go.mod h1:xjcse6CTHCLuOkzsWrEgdy9WPJFv+p/x6v+MyfP+O9s= github.com/didip/tollbooth/v7 v7.0.1/go.mod h1:VZhDSGl5bDSPj4wPsih3PFa4Uh9Ghv8hgacaTm5PRT4=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ= github.com/go-pkgz/expirable-cache v0.1.0 h1:3bw0m8vlTK8qlwz5KXuygNBTkiKRTPrAGXU0Ej2AC1g=
github.com/go-pkgz/expirable-cache v0.1.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs=
github.com/go-pkgz/expirable-cache v1.0.0 h1:ns5+1hjY8hntGv8bPaQd9Gr7Jyo+Uw5SLyII40aQdtA= github.com/go-pkgz/expirable-cache v1.0.0 h1:ns5+1hjY8hntGv8bPaQd9Gr7Jyo+Uw5SLyII40aQdtA=
github.com/go-pkgz/expirable-cache v1.0.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs= github.com/go-pkgz/expirable-cache v1.0.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs=
github.com/go-pkgz/lgr v0.11.1 h1:hXFhZcznehI6imLhEa379oMOKFz7TQUmisAqb3oLOSM= github.com/go-pkgz/lgr v0.11.1 h1:hXFhZcznehI6imLhEa379oMOKFz7TQUmisAqb3oLOSM=
@@ -22,13 +23,11 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
@@ -43,7 +42,6 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
@@ -57,18 +55,13 @@ golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,7 +1,7 @@
linters: linters:
enable: enable:
- megacheck - megacheck
- golint - revive
- govet - govet
- unconvert - unconvert
- megacheck - megacheck
@@ -18,7 +18,7 @@ linters:
- varcheck - varcheck
- stylecheck - stylecheck
- gochecknoinits - gochecknoinits
- scopelint - exportloopref
- gocritic - gocritic
- nakedret - nakedret
- gosimple - gosimple

View File

@@ -23,6 +23,8 @@ This is a generic middleware to rate-limit HTTP requests.
**v6.x.x:** Replaced `go-cache` with `github.com/go-pkgz/expirable-cache` because `go-cache` leaks goroutines. **v6.x.x:** Replaced `go-cache` with `github.com/go-pkgz/expirable-cache` because `go-cache` leaks goroutines.
**v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers.
## Five Minute Tutorial ## Five Minute Tutorial
```go ```go
@@ -31,7 +33,7 @@ package main
import ( import (
"net/http" "net/http"
"github.com/didip/tollbooth/v6" "github.com/didip/tollbooth/v7"
) )
func HelloHandler(w http.ResponseWriter, req *http.Request) { func HelloHandler(w http.ResponseWriter, req *http.Request) {
@@ -52,8 +54,8 @@ func main() {
import ( import (
"time" "time"
"github.com/didip/tollbooth/v6" "github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v6/limiter" "github.com/didip/tollbooth/v7/limiter"
) )
lmt := tollbooth.NewLimiter(1, nil) lmt := tollbooth.NewLimiter(1, nil)
@@ -122,6 +124,13 @@ func main() {
* `X-Rate-Limit-Request-Remote-Addr` The rejected request `RemoteAddr`. * `X-Rate-Limit-Request-Remote-Addr` The rejected request `RemoteAddr`.
Upon both success and rejection [RateLimit](https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers) headers are sent:
* `RateLimit-Limit` The maximum request limit within the time window (1s).
* `RateLimit-Reset` The rate-limiter time window duration in seconds (always 1s).
* `RateLimit-Remaining` The remaining tokens.
5. Customize your own message or function when limit is reached. 5. Customize your own message or function when limit is reached.

View File

@@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.

View File

@@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.

View File

@@ -52,8 +52,6 @@ func Every(interval time.Duration) Limit {
// or its associated context.Context is canceled. // or its associated context.Context is canceled.
// //
// The methods AllowN, ReserveN, and WaitN consume n tokens. // The methods AllowN, ReserveN, and WaitN consume n tokens.
//
// Limiter is safe for simultaneous use by multiple goroutines.
type Limiter struct { type Limiter struct {
mu sync.Mutex mu sync.Mutex
limit Limit limit Limit
@@ -82,19 +80,6 @@ func (lim *Limiter) Burst() int {
return lim.burst return lim.burst
} }
// TokensAt returns the number of tokens available at time t.
func (lim *Limiter) TokensAt(t time.Time) float64 {
lim.mu.Lock()
_, tokens := lim.advance(t) // does not mutate lim
lim.mu.Unlock()
return tokens
}
// Tokens returns the number of tokens available now.
func (lim *Limiter) Tokens() float64 {
return lim.TokensAt(time.Now())
}
// NewLimiter returns a new Limiter that allows events up to rate r and permits // NewLimiter returns a new Limiter that allows events up to rate r and permits
// bursts of at most b tokens. // bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter { func NewLimiter(r Limit, b int) *Limiter {
@@ -104,16 +89,24 @@ func NewLimiter(r Limit, b int) *Limiter {
} }
} }
// Allow reports whether an event may happen now. // Allow is shorthand for AllowN(time.Now(), 1).
func (lim *Limiter) Allow() bool { func (lim *Limiter) Allow() bool {
return lim.AllowN(time.Now(), 1) return lim.AllowN(time.Now(), 1)
} }
// AllowN reports whether n events may happen at time t. // TokensAt returns the number of tokens available for the given time.
func (lim *Limiter) TokensAt(t time.Time) float64 {
lim.mu.Lock()
_, _, tokens := lim.advance(t) // does not mutate lim
lim.mu.Unlock()
return tokens
}
// AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate limit. // Use this method if you intend to drop / skip events that exceed the rate limit.
// Otherwise use Reserve or Wait. // Otherwise use Reserve or Wait.
func (lim *Limiter) AllowN(t time.Time, n int) bool { func (lim *Limiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(t, n, 0).ok return lim.reserveN(now, n, 0).ok
} }
// A Reservation holds information about events that are permitted by a Limiter to happen after a delay. // A Reservation holds information about events that are permitted by a Limiter to happen after a delay.
@@ -140,17 +133,17 @@ func (r *Reservation) Delay() time.Duration {
} }
// InfDuration is the duration returned by Delay when a Reservation is not OK. // InfDuration is the duration returned by Delay when a Reservation is not OK.
const InfDuration = time.Duration(math.MaxInt64) const InfDuration = time.Duration(1<<63 - 1)
// DelayFrom returns the duration for which the reservation holder must wait // DelayFrom returns the duration for which the reservation holder must wait
// before taking the reserved action. Zero duration means act immediately. // before taking the reserved action. Zero duration means act immediately.
// InfDuration means the limiter cannot grant the tokens requested in this // InfDuration means the limiter cannot grant the tokens requested in this
// Reservation within the maximum wait time. // Reservation within the maximum wait time.
func (r *Reservation) DelayFrom(t time.Time) time.Duration { func (r *Reservation) DelayFrom(now time.Time) time.Duration {
if !r.ok { if !r.ok {
return InfDuration return InfDuration
} }
delay := r.timeToAct.Sub(t) delay := r.timeToAct.Sub(now)
if delay < 0 { if delay < 0 {
return 0 return 0
} }
@@ -165,7 +158,7 @@ func (r *Reservation) Cancel() {
// CancelAt indicates that the reservation holder will not perform the reserved action // CancelAt indicates that the reservation holder will not perform the reserved action
// and reverses the effects of this Reservation on the rate limit as much as possible, // and reverses the effects of this Reservation on the rate limit as much as possible,
// considering that other reservations may have already been made. // considering that other reservations may have already been made.
func (r *Reservation) CancelAt(t time.Time) { func (r *Reservation) CancelAt(now time.Time) {
if !r.ok { if !r.ok {
return return
} }
@@ -173,7 +166,7 @@ func (r *Reservation) CancelAt(t time.Time) {
r.lim.mu.Lock() r.lim.mu.Lock()
defer r.lim.mu.Unlock() defer r.lim.mu.Unlock()
if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(t) { if r.lim.limit == Inf || r.tokens == 0 || r.timeToAct.Before(now) {
return return
} }
@@ -185,18 +178,18 @@ func (r *Reservation) CancelAt(t time.Time) {
return return
} }
// advance time to now // advance time to now
t, tokens := r.lim.advance(t) now, _, tokens := r.lim.advance(now)
// calculate new number of tokens // calculate new number of tokens
tokens += restoreTokens tokens += restoreTokens
if burst := float64(r.lim.burst); tokens > burst { if burst := float64(r.lim.burst); tokens > burst {
tokens = burst tokens = burst
} }
// update state // update state
r.lim.last = t r.lim.last = now
r.lim.tokens = tokens r.lim.tokens = tokens
if r.timeToAct == r.lim.lastEvent { if r.timeToAct == r.lim.lastEvent {
prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens))) prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens)))
if !prevEvent.Before(t) { if !prevEvent.Before(now) {
r.lim.lastEvent = prevEvent r.lim.lastEvent = prevEvent
} }
} }
@@ -211,20 +204,18 @@ func (lim *Limiter) Reserve() *Reservation {
// The Limiter takes this Reservation into account when allowing future events. // The Limiter takes this Reservation into account when allowing future events.
// The returned Reservations OK() method returns false if n exceeds the Limiter's burst size. // The returned Reservations OK() method returns false if n exceeds the Limiter's burst size.
// Usage example: // Usage example:
// // r := lim.ReserveN(time.Now(), 1)
// r := lim.ReserveN(time.Now(), 1) // if !r.OK() {
// if !r.OK() { // // Not allowed to act! Did you remember to set lim.burst to be > 0 ?
// // Not allowed to act! Did you remember to set lim.burst to be > 0 ? // return
// return // }
// } // time.Sleep(r.Delay())
// time.Sleep(r.Delay()) // Act()
// Act()
//
// Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events. // Use this method if you wish to wait and slow down in accordance with the rate limit without dropping events.
// If you need to respect a deadline or cancel the delay, use Wait instead. // If you need to respect a deadline or cancel the delay, use Wait instead.
// To drop or skip events exceeding rate limit, use Allow instead. // To drop or skip events exceeding rate limit, use Allow instead.
func (lim *Limiter) ReserveN(t time.Time, n int) *Reservation { func (lim *Limiter) ReserveN(now time.Time, n int) *Reservation {
r := lim.reserveN(t, n, InfDuration) r := lim.reserveN(now, n, InfDuration)
return &r return &r
} }
@@ -238,18 +229,6 @@ func (lim *Limiter) Wait(ctx context.Context) (err error) {
// canceled, or the expected wait time exceeds the Context's Deadline. // canceled, or the expected wait time exceeds the Context's Deadline.
// The burst limit is ignored if the rate limit is Inf. // The burst limit is ignored if the rate limit is Inf.
func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) { func (lim *Limiter) WaitN(ctx context.Context, n int) (err error) {
// The test code calls lim.wait with a fake timer generator.
// This is the real timer generator.
newTimer := func(d time.Duration) (<-chan time.Time, func() bool, func()) {
timer := time.NewTimer(d)
return timer.C, timer.Stop, func() {}
}
return lim.wait(ctx, n, time.Now(), newTimer)
}
// wait is the internal implementation of WaitN.
func (lim *Limiter) wait(ctx context.Context, n int, t time.Time, newTimer func(d time.Duration) (<-chan time.Time, func() bool, func())) error {
lim.mu.Lock() lim.mu.Lock()
burst := lim.burst burst := lim.burst
limit := lim.limit limit := lim.limit
@@ -265,25 +244,25 @@ func (lim *Limiter) wait(ctx context.Context, n int, t time.Time, newTimer func(
default: default:
} }
// Determine wait limit // Determine wait limit
now := time.Now()
waitLimit := InfDuration waitLimit := InfDuration
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
waitLimit = deadline.Sub(t) waitLimit = deadline.Sub(now)
} }
// Reserve // Reserve
r := lim.reserveN(t, n, waitLimit) r := lim.reserveN(now, n, waitLimit)
if !r.ok { if !r.ok {
return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n) return fmt.Errorf("rate: Wait(n=%d) would exceed context deadline", n)
} }
// Wait if necessary // Wait if necessary
delay := r.DelayFrom(t) delay := r.DelayFrom(now)
if delay == 0 { if delay == 0 {
return nil return nil
} }
ch, stop, advance := newTimer(delay) t := time.NewTimer(delay)
defer stop() defer t.Stop()
advance() // only has an effect when testing
select { select {
case <-ch: case <-t.C:
// We can proceed. // We can proceed.
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
@@ -302,13 +281,13 @@ func (lim *Limiter) SetLimit(newLimit Limit) {
// SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated // SetLimitAt sets a new Limit for the limiter. The new Limit, and Burst, may be violated
// or underutilized by those which reserved (using Reserve or Wait) but did not yet act // or underutilized by those which reserved (using Reserve or Wait) but did not yet act
// before SetLimitAt was called. // before SetLimitAt was called.
func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) { func (lim *Limiter) SetLimitAt(now time.Time, newLimit Limit) {
lim.mu.Lock() lim.mu.Lock()
defer lim.mu.Unlock() defer lim.mu.Unlock()
t, tokens := lim.advance(t) now, _, tokens := lim.advance(now)
lim.last = t lim.last = now
lim.tokens = tokens lim.tokens = tokens
lim.limit = newLimit lim.limit = newLimit
} }
@@ -319,13 +298,13 @@ func (lim *Limiter) SetBurst(newBurst int) {
} }
// SetBurstAt sets a new burst size for the limiter. // SetBurstAt sets a new burst size for the limiter.
func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) { func (lim *Limiter) SetBurstAt(now time.Time, newBurst int) {
lim.mu.Lock() lim.mu.Lock()
defer lim.mu.Unlock() defer lim.mu.Unlock()
t, tokens := lim.advance(t) now, _, tokens := lim.advance(now)
lim.last = t lim.last = now
lim.tokens = tokens lim.tokens = tokens
lim.burst = newBurst lim.burst = newBurst
} }
@@ -333,32 +312,20 @@ func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) {
// reserveN is a helper method for AllowN, ReserveN, and WaitN. // reserveN is a helper method for AllowN, ReserveN, and WaitN.
// maxFutureReserve specifies the maximum reservation wait duration allowed. // maxFutureReserve specifies the maximum reservation wait duration allowed.
// reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN. // reserveN returns Reservation, not *Reservation, to avoid allocation in AllowN and WaitN.
func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) Reservation { func (lim *Limiter) reserveN(now time.Time, n int, maxFutureReserve time.Duration) Reservation {
lim.mu.Lock() lim.mu.Lock()
defer lim.mu.Unlock()
if lim.limit == Inf { if lim.limit == Inf {
lim.mu.Unlock()
return Reservation{ return Reservation{
ok: true, ok: true,
lim: lim, lim: lim,
tokens: n, tokens: n,
timeToAct: t, timeToAct: now,
}
} else if lim.limit == 0 {
var ok bool
if lim.burst >= n {
ok = true
lim.burst -= n
}
return Reservation{
ok: ok,
lim: lim,
tokens: lim.burst,
timeToAct: t,
} }
} }
t, tokens := lim.advance(t) now, last, tokens := lim.advance(now)
// Calculate the remaining number of tokens resulting from the request. // Calculate the remaining number of tokens resulting from the request.
tokens -= float64(n) tokens -= float64(n)
@@ -380,42 +347,44 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
} }
if ok { if ok {
r.tokens = n r.tokens = n
r.timeToAct = t.Add(waitDuration) r.timeToAct = now.Add(waitDuration)
// Update state
lim.last = t
lim.tokens = tokens
lim.lastEvent = r.timeToAct
} }
// Update state
if ok {
lim.last = now
lim.tokens = tokens
lim.lastEvent = r.timeToAct
} else {
lim.last = last
}
lim.mu.Unlock()
return r return r
} }
// advance calculates and returns an updated state for lim resulting from the passage of time. // advance calculates and returns an updated state for lim resulting from the passage of time.
// lim is not changed. // lim is not changed.
// advance requires that lim.mu is held. // advance requires that lim.mu is held.
func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) { func (lim *Limiter) advance(now time.Time) (newNow time.Time, newLast time.Time, newTokens float64) {
last := lim.last last := lim.last
if t.Before(last) { if now.Before(last) {
last = t last = now
} }
// Calculate the new number of tokens, due to time that passed. // Calculate the new number of tokens, due to time that passed.
elapsed := t.Sub(last) elapsed := now.Sub(last)
delta := lim.limit.tokensFromDuration(elapsed) delta := lim.limit.tokensFromDuration(elapsed)
tokens := lim.tokens + delta tokens := lim.tokens + delta
if burst := float64(lim.burst); tokens > burst { if burst := float64(lim.burst); tokens > burst {
tokens = burst tokens = burst
} }
return t, tokens return now, last, tokens
} }
// durationFromTokens is a unit conversion function from the number of tokens to the duration // durationFromTokens is a unit conversion function from the number of tokens to the duration
// of time it takes to accumulate them at a rate of limit tokens per second. // of time it takes to accumulate them at a rate of limit tokens per second.
func (limit Limit) durationFromTokens(tokens float64) time.Duration { func (limit Limit) durationFromTokens(tokens float64) time.Duration {
if limit <= 0 {
return InfDuration
}
seconds := tokens / float64(limit) seconds := tokens / float64(limit)
return time.Duration(float64(time.Second) * seconds) return time.Duration(float64(time.Second) * seconds)
} }
@@ -423,8 +392,5 @@ func (limit Limit) durationFromTokens(tokens float64) time.Duration {
// tokensFromDuration is a unit conversion function from a time duration to the number of tokens // tokensFromDuration is a unit conversion function from a time duration to the number of tokens
// which could be accumulated during that duration at a rate of limit tokens per second. // which could be accumulated during that duration at a rate of limit tokens per second.
func (limit Limit) tokensFromDuration(d time.Duration) float64 { func (limit Limit) tokensFromDuration(d time.Duration) float64 {
if limit <= 0 {
return 0
}
return d.Seconds() * float64(limit) return d.Seconds() * float64(limit)
} }

View File

@@ -69,7 +69,6 @@ func CanonicalizeIP(ip string) string {
case ':': case ':':
// IPv6 // IPv6
isIPv6 = true isIPv6 = true
break
} }
} }
if !isIPv6 { if !isIPv6 {

View File

@@ -7,7 +7,8 @@ import (
"time" "time"
cache "github.com/go-pkgz/expirable-cache" cache "github.com/go-pkgz/expirable-cache"
"golang.org/x/time/rate"
"github.com/didip/tollbooth/v7/internal/time/rate"
) )
// New is a constructor for Limiter. // New is a constructor for Limiter.
@@ -597,3 +598,13 @@ func (l *Limiter) LimitReached(key string) bool {
return l.limitReachedWithTokenBucketTTL(key, ttl) return l.limitReachedWithTokenBucketTTL(key, ttl)
} }
// Tokens returns current amount of tokens left in the Bucket identified by key.
func (l *Limiter) Tokens(key string) int {
expiringMap, found := l.tokenBuckets.Get(key)
if !found {
return 0
}
return int(expiringMap.(*rate.Limiter).TokensAt(time.Now()))
}

View File

@@ -7,9 +7,9 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/didip/tollbooth/v6/errors" "github.com/didip/tollbooth/v7/errors"
"github.com/didip/tollbooth/v6/libstring" "github.com/didip/tollbooth/v7/libstring"
"github.com/didip/tollbooth/v6/limiter" "github.com/didip/tollbooth/v7/limiter"
) )
// setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration // setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
@@ -25,6 +25,14 @@ func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Req
w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr) w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr)
} }
// setRateLimitResponseHeaders configures RateLimit-Limit, RateLimit-Remaining and RateLimit-Reset
// as seen at https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers
func setRateLimitResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, tokensLeft int) {
w.Header().Add("RateLimit-Limit", fmt.Sprintf("%d", int(math.Round(lmt.GetMax()))))
w.Header().Add("RateLimit-Reset", "1")
w.Header().Add("RateLimit-Remaining", fmt.Sprintf("%d", tokensLeft))
}
// NewLimiter is a convenience function to limiter.New. // NewLimiter is a convenience function to limiter.New.
func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
return limiter.New(tbOptions). return limiter.New(tbOptions).
@@ -36,11 +44,18 @@ func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limit
// LimitByKeys keeps track number of request made by keys separated by pipe. // LimitByKeys keeps track number of request made by keys separated by pipe.
// It returns HTTPError when limit is exceeded. // It returns HTTPError when limit is exceeded.
func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError { func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError {
err, _ := LimitByKeysAndReturn(lmt, keys)
return err
}
// LimitByKeysAndReturn keeps track number of request made by keys separated by pipe.
// It returns HTTPError when limit is exceeded, and also returns the current limit value.
func LimitByKeysAndReturn(lmt *limiter.Limiter, keys []string) (*errors.HTTPError, int) {
if lmt.LimitReached(strings.Join(keys, "|")) { if lmt.LimitReached(strings.Join(keys, "|")) {
return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()} return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()}, 0
} }
return nil return nil, lmt.Tokens(strings.Join(keys, "|"))
} }
// ShouldSkipLimiter is a series of filter that decides if request should be limited or not. // ShouldSkipLimiter is a series of filter that decides if request should be limited or not.
@@ -97,6 +112,10 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool {
requestHeadersDefinedInLimiter = false requestHeadersDefinedInLimiter = false
for headerKey, headerValues := range lmtHeaders { for headerKey, headerValues := range lmtHeaders {
if len(headerValues) == 0 {
requestHeadersDefinedInLimiter = true
continue
}
for _, headerValue := range headerValues { for _, headerValue := range headerValues {
if r.Header.Get(headerKey) == headerValue { if r.Header.Get(headerKey) == headerValue {
requestHeadersDefinedInLimiter = true requestHeadersDefinedInLimiter = true
@@ -281,14 +300,24 @@ func LimitByRequest(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request
sliceKeys := BuildKeys(lmt, r) sliceKeys := BuildKeys(lmt, r)
// Get the lowest value over all keys to return in headers.
// Start with high arbitrary number so that any limit returned would be lower and would
// overwrite the value we start with.
var tokensLeft = math.MaxInt32
// Loop sliceKeys and check if one of them has error. // Loop sliceKeys and check if one of them has error.
for _, keys := range sliceKeys { for _, keys := range sliceKeys {
httpError := LimitByKeys(lmt, keys) httpError, keysLimit := LimitByKeysAndReturn(lmt, keys)
if tokensLeft > keysLimit {
tokensLeft = keysLimit
}
if httpError != nil { if httpError != nil {
setRateLimitResponseHeaders(lmt, w, tokensLeft)
return httpError return httpError
} }
} }
setRateLimitResponseHeaders(lmt, w, tokensLeft)
return nil return nil
} }

View File

@@ -1,67 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rate
import (
"sync"
"time"
)
// Sometimes will perform an action occasionally. The First, Every, and
// Interval fields govern the behavior of Do, which performs the action.
// A zero Sometimes value will perform an action exactly once.
//
// # Example: logging with rate limiting
//
// var sometimes = rate.Sometimes{First: 3, Interval: 10*time.Second}
// func Spammy() {
// sometimes.Do(func() { log.Info("here I am!") })
// }
type Sometimes struct {
First int // if non-zero, the first N calls to Do will run f.
Every int // if non-zero, every Nth call to Do will run f.
Interval time.Duration // if non-zero and Interval has elapsed since f's last run, Do will run f.
mu sync.Mutex
count int // number of Do calls
last time.Time // last time f was run
}
// Do runs the function f as allowed by First, Every, and Interval.
//
// The model is a union (not intersection) of filters. The first call to Do
// always runs f. Subsequent calls to Do run f if allowed by First or Every or
// Interval.
//
// A non-zero First:N causes the first N Do(f) calls to run f.
//
// A non-zero Every:M causes every Mth Do(f) call, starting with the first, to
// run f.
//
// A non-zero Interval causes Do(f) to run f if Interval has elapsed since
// Do last ran f.
//
// Specifying multiple filters produces the union of these execution streams.
// For example, specifying both First:N and Every:M causes the first N Do(f)
// calls and every Mth Do(f) call, starting with the first, to run f. See
// Examples for more.
//
// If Do is called multiple times simultaneously, the calls will block and run
// serially. Therefore, Do is intended for lightweight operations.
//
// Because a call to Do may block until f returns, if f causes Do to be called,
// it will deadlock.
func (s *Sometimes) Do(f func()) {
s.mu.Lock()
defer s.mu.Unlock()
if s.count == 0 ||
(s.First > 0 && s.count < s.First) ||
(s.Every > 0 && s.count%s.Every == 0) ||
(s.Interval > 0 && time.Since(s.last) >= s.Interval) {
f()
s.last = time.Now()
}
s.count++
}

16
vendor/modules.txt vendored
View File

@@ -7,12 +7,13 @@ github.com/cespare/xxhash/v2
# github.com/davecgh/go-spew v1.1.1 # github.com/davecgh/go-spew v1.1.1
## explicit ## explicit
github.com/davecgh/go-spew/spew github.com/davecgh/go-spew/spew
# github.com/didip/tollbooth/v6 v6.1.2 # github.com/didip/tollbooth/v7 v7.0.1
## explicit; go 1.12 ## explicit; go 1.12
github.com/didip/tollbooth/v6 github.com/didip/tollbooth/v7
github.com/didip/tollbooth/v6/errors github.com/didip/tollbooth/v7/errors
github.com/didip/tollbooth/v6/libstring github.com/didip/tollbooth/v7/internal/time/rate
github.com/didip/tollbooth/v6/limiter github.com/didip/tollbooth/v7/libstring
github.com/didip/tollbooth/v7/limiter
# github.com/felixge/httpsnoop v1.0.4 # github.com/felixge/httpsnoop v1.0.4
## explicit; go 1.13 ## explicit; go 1.13
github.com/felixge/httpsnoop github.com/felixge/httpsnoop
@@ -34,6 +35,8 @@ github.com/go-pkgz/rest/realip
# github.com/gorilla/handlers v1.5.2 # github.com/gorilla/handlers v1.5.2
## explicit; go 1.20 ## explicit; go 1.20
github.com/gorilla/handlers github.com/gorilla/handlers
# github.com/kr/text v0.1.0
## explicit
# github.com/pmezard/go-difflib v1.0.0 # github.com/pmezard/go-difflib v1.0.0
## explicit ## explicit
github.com/pmezard/go-difflib/difflib github.com/pmezard/go-difflib/difflib
@@ -81,9 +84,6 @@ golang.org/x/text/secure/bidirule
golang.org/x/text/transform golang.org/x/text/transform
golang.org/x/text/unicode/bidi golang.org/x/text/unicode/bidi
golang.org/x/text/unicode/norm golang.org/x/text/unicode/norm
# golang.org/x/time v0.5.0
## explicit; go 1.18
golang.org/x/time/rate
# google.golang.org/protobuf v1.33.0 # google.golang.org/protobuf v1.33.0
## explicit; go 1.17 ## explicit; go 1.17
google.golang.org/protobuf/encoding/protodelim google.golang.org/protobuf/encoding/protodelim