make LBSelector interface and implement all the current methods plus roundrobin

This commit is contained in:
Umputun
2023-11-27 12:05:17 -06:00
parent 8bde167226
commit fa23778d42
6 changed files with 184 additions and 13 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"math"
"math/rand"
"net/http"
"net/rpc"
"os"
@@ -36,7 +35,7 @@ var opts struct {
DropHeaders []string `long:"drop-header" env:"DROP_HEADERS" description:"incoming headers to drop" env-delim:","`
AuthBasicHtpasswd string `long:"basic-htpasswd" env:"BASIC_HTPASSWD" description:"htpasswd file for basic auth"`
RemoteLookupHeaders bool `long:"remote-lookup-headers" env:"REMOTE_LOOKUP_HEADERS" description:"enable remote lookup headers"`
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" default:"random"` // nolint
LBType string `long:"lb-type" env:"LB_TYPE" description:"load balancer type" choice:"random" choice:"failover" choice:"roundrobin" default:"random"` // nolint
SSL struct {
Type string `long:"type" env:"TYPE" description:"ssl (auto) support" choice:"none" choice:"static" choice:"auto" default:"none"` // nolint
@@ -414,14 +413,16 @@ func makeSSLConfig() (config proxy.SSLConfig, err error) {
return config, err
}
func makeLBSelector() func(len int) int {
func makeLBSelector() proxy.LBSelector {
switch opts.LBType {
case "random":
return rand.Intn
return &proxy.RandomSelector{}
case "failover":
return func(int) int { return 0 } // dead server won't be in the list, we can safely pick the first one
return &proxy.FailoverSelector{}
case "roundrobin":
return &proxy.RoundRobinSelector{}
default:
return func(int) int { return 0 }
return &proxy.FailoverSelector{}
}
}

45
app/proxy/lb_selector.go Normal file
View File

@@ -0,0 +1,45 @@
package proxy
import (
"math/rand"
"sync"
)
// RoundRobinSelector is a simple round-robin selector, thread-safe
type RoundRobinSelector struct {
lastSelected int
mu sync.Mutex
}
// Select returns next backend index
func (r *RoundRobinSelector) Select(n int) int {
r.mu.Lock()
defer r.mu.Unlock()
selected := r.lastSelected
r.lastSelected = (r.lastSelected + 1) % n
return selected
}
// RandomSelector is a random selector, thread-safe
type RandomSelector struct{}
// Select returns random backend index
func (r *RandomSelector) Select(n int) int {
return rand.Intn(n) //nolint:gosec // no need for crypto/rand here
}
// FailoverSelector is a selector with failover, thread-safe
type FailoverSelector struct{}
// Select returns next backend index
func (r *FailoverSelector) Select(_ int) int {
return 0 // dead server won't be in the list, we can safely pick the first one
}
// LBSelectorFunc is a functional adapted for LBSelector to select backend from the list
type LBSelectorFunc func(n int) int
// Select returns backend index
func (f LBSelectorFunc) Select(n int) int {
return f(n)
}

View File

@@ -0,0 +1,121 @@
package proxy
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRoundRobinSelector_Select(t *testing.T) {
selector := &RoundRobinSelector{}
testCases := []struct {
name string
len int
expected int
}{
{"First call", 3, 0},
{"Second call", 3, 1},
{"Third call", 3, 2},
{"Back to zero", 3, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}
func TestRoundRobinSelector_SelectConcurrent(t *testing.T) {
selector := &RoundRobinSelector{}
l := 3
numGoroutines := 1000
var wg sync.WaitGroup
wg.Add(numGoroutines)
results := &sync.Map{}
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
result := selector.Select(l)
results.Store(result, struct{}{})
}()
}
wg.Wait()
// check that all possible results are present in the map.
for i := 0; i < l; i++ {
_, ok := results.Load(i)
assert.True(t, ok, "expected to find %d in the results", i)
}
}
func TestRandomSelector_Select(t *testing.T) {
selector := &RandomSelector{}
testCases := []struct {
name string
len int
}{
{"First call", 5},
{"Second call", 5},
{"Third call", 5},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.True(t, result >= 0 && result < tc.len)
})
}
}
func TestFailoverSelector_Select(t *testing.T) {
selector := &FailoverSelector{}
testCases := []struct {
name string
len int
expected int
}{
{"First call", 5, 0},
{"Second call", 5, 0},
{"Third call", 5, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}
func TestLBSelectorFunc_Select(t *testing.T) {
selector := LBSelectorFunc(func(n int) int {
return n - 1 // simple selection logic for testing
})
testCases := []struct {
name string
len int
expected int
}{
{"First call", 5, 4},
{"Second call", 3, 2},
{"Third call", 1, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := selector.Select(tc.len)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
@@ -47,7 +46,7 @@ type Http struct { // nolint golint
Metrics MiddlewareProvider
PluginConductor MiddlewareProvider
Reporter Reporter
LBSelector func(len int) int
LBSelector LBSelector
OnlyFrom *OnlyFrom
BasicAuthEnabled bool
BasicAuthAllowed []string
@@ -75,6 +74,11 @@ type Reporter interface {
Report(w http.ResponseWriter, code int)
}
// LBSelector defines load balancer strategy
type LBSelector interface {
Select(len int) int // return index of picked server
}
// Timeouts consolidate timeouts for both server and transport
type Timeouts struct {
// server timeouts
@@ -101,7 +105,7 @@ func (h *Http) Run(ctx context.Context) error {
}
if h.LBSelector == nil {
h.LBSelector = rand.Intn
h.LBSelector = &RandomSelector{}
}
var httpServer, httpsServer *http.Server
@@ -277,7 +281,7 @@ func (h *Http) proxyHandler() http.HandlerFunc {
// and if match found sets it to the request context. Context used by proxy handler as well as by plugin conductor
func (h *Http) matchHandler(next http.Handler) http.Handler {
getMatch := func(mm discovery.Matches, picker func(len int) int) (m discovery.MatchedRoute, ok bool) {
getMatch := func(mm discovery.Matches, picker LBSelector) (m discovery.MatchedRoute, ok bool) {
if len(mm.Routes) == 0 {
return m, false
}
@@ -294,7 +298,7 @@ func (h *Http) matchHandler(next http.Handler) http.Handler {
case 1:
return matches[0], true
default:
return matches[picker(len(matches))], true
return matches[picker.Select(len(matches))], true
}
}

View File

@@ -874,7 +874,7 @@ func TestHttp_matchHandler(t *testing.T) {
client := http.Client{}
for _, tt := range tbl {
t.Run(tt.name, func(t *testing.T) {
h := Http{Matcher: matcherMock, LBSelector: func(len int) int { return 0 }}
h := Http{Matcher: matcherMock, LBSelector: &FailoverSelector{}}
handler := h.matchHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("req: %+v", r)
t.Logf("dst: %v", r.Context().Value(ctxURL))