Merge pull request #9608 from hashicorp/f-go-connlimit

Use go-connlimit to ratelimit with 429 responses
This commit is contained in:
Seth Hoenig
2020-12-10 11:05:07 -06:00
committed by GitHub
9 changed files with 486 additions and 33 deletions

View File

@@ -40,6 +40,10 @@ const (
// MissingRequestID is a placeholder if we cannot retrieve a request
// UUID from context
MissingRequestID = "<missing request id>"
// HTTPConnStateFuncWriteTimeout is how long to try to write conn state errors
// before closing the connection
HTTPConnStateFuncWriteTimeout = 10 * time.Millisecond
)
var (
@@ -171,7 +175,7 @@ func makeConnState(isTLS bool, handshakeTimeout time.Duration, connLimit int) fu
// Still return the connection limiter
return connlimit.NewLimiter(connlimit.Config{
MaxConnsPerClientIP: connLimit,
}).HTTPConnStateFunc()
}).HTTPConnStateFuncWithDefault429Handler(HTTPConnStateFuncWriteTimeout)
}
return nil
@@ -183,7 +187,7 @@ func makeConnState(isTLS bool, handshakeTimeout time.Duration, connLimit int) fu
connLimiter := connlimit.NewLimiter(connlimit.Config{
MaxConnsPerClientIP: connLimit,
}).HTTPConnStateFunc()
}).HTTPConnStateFuncWithDefault429Handler(HTTPConnStateFuncWriteTimeout)
return func(conn net.Conn, state http.ConnState) {
switch state {

View File

@@ -14,6 +14,7 @@ import (
"net/http/httptest"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"
@@ -869,15 +870,24 @@ func TestHTTPServer_Limits_Error(t *testing.T) {
}
}
func limitStr(limit *int) string {
if limit == nil {
return "none"
}
return strconv.Itoa(*limit)
}
// TestHTTPServer_Limits_OK asserts that all valid limits combinations
// (tls/timeout/conns) work.
func TestHTTPServer_Limits_OK(t *testing.T) {
t.Parallel()
const (
cafile = "../../helper/tlsutil/testdata/ca.pem"
foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
maxConns = 10 // limit must be < this for testing
maxConns = 10 // limit must be < this for testing
bufSize = 1 * 1024 // enough for 429 error message
)
cases := []struct {
@@ -954,11 +964,14 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
conn, err := net.DialTimeout("tcp", a.Server.Addr, deadline)
require.NoError(t, err)
defer conn.Close()
defer func() {
require.NoError(t, conn.Close())
}()
buf := []byte{0}
readDeadline := time.Now().Add(deadline)
conn.SetReadDeadline(readDeadline)
err = conn.SetReadDeadline(readDeadline)
require.NoError(t, err)
n, err := conn.Read(buf)
require.Zero(t, n)
if assertTimeout {
@@ -1011,12 +1024,12 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
for i := 0; i < maxConns; i++ {
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conns[i].Close()
go func(i int) {
buf := []byte{0}
readDeadline := time.Now().Add(1 * time.Second)
conns[i].SetReadDeadline(readDeadline)
err = conns[i].SetReadDeadline(readDeadline)
require.NoError(t, err)
n, err := conns[i].Read(buf)
if n > 0 {
errCh <- fmt.Errorf("n > 0: %d", n)
@@ -1036,18 +1049,37 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
"error does not wrap os.ErrDeadlineExceeded: (%T) %v", err, err)
}
}
for i := 0; i < maxConns; i++ {
require.NoError(t, conns[i].Close())
}
}
assertLimit := func(t *testing.T, addr string, limit int) {
dial := func(t *testing.T, addr string, useTLS bool) net.Conn {
if useTLS {
cert, err := tls.LoadX509KeyPair(foocert, fookey)
require.NoError(t, err)
conn, err := tls.Dial("tcp", addr, &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true, // good enough
})
require.NoError(t, err)
return conn
} else {
conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
return conn
}
}
assertLimit := func(t *testing.T, addr string, limit int, useTLS bool) {
var err error
// Create limit connections
conns := make([]net.Conn, limit)
errCh := make(chan error, limit)
for i := range conns {
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conns[i].Close()
conns[i] = dial(t, addr, useTLS)
go func(i int) {
buf := []byte{0}
@@ -1067,26 +1099,30 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
}
// Assert a new connection is dropped
conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
require.NoError(t, err)
defer conn.Close()
conn := dial(t, addr, useTLS)
defer func() {
require.NoError(t, conn.Close())
}()
buf := []byte{0}
deadline := time.Now().Add(10 * time.Second)
conn.SetReadDeadline(deadline)
n, err := conn.Read(buf)
require.Zero(t, n)
require.NoError(t, conn.SetReadDeadline(deadline))
// Soft-fail as following assertion helps with debugging
assert.Equal(t, io.EOF, err)
buf := make([]byte, bufSize)
n, err := conn.Read(buf)
require.NoError(t, err)
require.NotZero(t, n)
require.True(t, strings.HasPrefix(string(buf), "HTTP/1.1 429 Too Many Requests"))
// Assert existing connections are ok
require.Len(t, errCh, 0)
// Cleanup
for _, conn := range conns {
conn.Close()
require.NoError(t, conn.Close())
}
for range conns {
err := <-errCh
require.Contains(t, err.Error(), "use of closed network connection")
@@ -1095,7 +1131,7 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
for i := range cases {
tc := cases[i]
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, limitStr(tc.limit))
t.Run(name, func(t *testing.T) {
t.Parallel()
@@ -1114,21 +1150,24 @@ func TestHTTPServer_Limits_OK(t *testing.T) {
}
c.Limits.HTTPSHandshakeTimeout = tc.timeout
c.Limits.HTTPMaxConnsPerClient = tc.limit
c.LogLevel = "ERROR"
})
defer s.Shutdown()
defer func() {
require.NoError(t, s.Shutdown())
}()
assertTimeout(t, s, tc.assertTimeout, tc.timeout)
if tc.assertLimit {
// There's a race between assertTimeout(false) closing
// its connection and the HTTP server noticing and
// untracking it. Since there's no way to coordiante
// untracking it. Since there's no way to coordinate
// when this occurs, sleeping is the only way to avoid
// asserting limits before the timed out connection is
// untracked.
time.Sleep(1 * time.Second)
assertLimit(t, s.Server.Addr, *tc.limit)
assertLimit(t, s.Server.Addr, *tc.limit, tc.tls)
} else {
assertNoLimit(t, s.Server.Addr)
}