mirror of
https://github.com/kemko/nomad.git
synced 2026-01-04 01:15:43 +03:00
Introduce limits to prevent unauthorized users from exhausting all
ephemeral ports on agents:
* `{https,rpc}_handshake_timeout`
* `{http,rpc}_max_conns_per_client`
The handshake timeout closes connections that have not completed the TLS
handshake by the deadline (5s by default). For RPC connections this
timeout also separately applies to first byte being read so RPC
connections with TLS enabled have `rpc_handshake_time * 2` as their
deadline.
The connection limit per client prevents a single remote TCP peer from
exhausting all ephemeral ports. The default is 100, but can be lowered
to a minimum of 26. Since streaming RPC connections create a new TCP
connection (until MultiplexV2 is used), 20 connections are reserved for
Raft and non-streaming RPCs to prevent connection exhaustion due to
streaming RPCs.
All limits are configurable and may be disabled by setting them to `0`.
This also includes a fix that closes connections that attempt to create
TLS RPC connections recursively. While only users with valid mTLS
certificates could perform such an operation, it was added as a
safeguard to prevent programming errors before they could cause resource
exhaustion.
1097 lines
27 KiB
Go
1097 lines
27 KiB
Go
package agent
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/nomad/api"
|
|
"github.com/hashicorp/nomad/helper"
|
|
"github.com/hashicorp/nomad/helper/testlog"
|
|
"github.com/hashicorp/nomad/nomad/mock"
|
|
"github.com/hashicorp/nomad/nomad/structs"
|
|
"github.com/hashicorp/nomad/nomad/structs/config"
|
|
"github.com/hashicorp/nomad/testutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/ugorji/go/codec"
|
|
)
|
|
|
|
// makeHTTPServer returns a test server whose logs will be written to
|
|
// the passed writer. If the writer is nil, the logs are written to stderr.
|
|
func makeHTTPServer(t testing.TB, cb func(c *Config)) *TestAgent {
|
|
return NewTestAgent(t, t.Name(), cb)
|
|
}
|
|
|
|
func BenchmarkHTTPRequests(b *testing.B) {
|
|
s := makeHTTPServer(b, func(c *Config) {
|
|
c.Client.Enabled = false
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
job := mock.Job()
|
|
var allocs []*structs.Allocation
|
|
count := 1000
|
|
for i := 0; i < count; i++ {
|
|
alloc := mock.Alloc()
|
|
alloc.Job = job
|
|
alloc.JobID = job.ID
|
|
alloc.Name = fmt.Sprintf("my-job.web[%d]", i)
|
|
allocs = append(allocs, alloc)
|
|
}
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return allocs[:count], nil
|
|
}
|
|
b.ResetTimer()
|
|
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestRootFallthrough tests rootFallthrough handler to
|
|
// verify redirect and 404 behavior
|
|
func TestRootFallthrough(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
desc string
|
|
path string
|
|
expectedPath string
|
|
expectedCode int
|
|
}{
|
|
{
|
|
desc: "unknown endpoint 404s",
|
|
path: "/v1/unknown/endpoint",
|
|
expectedCode: 404,
|
|
},
|
|
{
|
|
desc: "root path redirects to ui",
|
|
path: "/",
|
|
expectedPath: "/ui/",
|
|
expectedCode: 307,
|
|
},
|
|
}
|
|
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
// setup a client that doesn't follow redirects
|
|
client := &http.Client{
|
|
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
|
|
reqURL := fmt.Sprintf("http://%s%s", s.Agent.config.AdvertiseAddrs.HTTP, tc.path)
|
|
|
|
resp, err := client.Get(reqURL)
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.expectedCode, resp.StatusCode)
|
|
|
|
if tc.expectedPath != "" {
|
|
loc, err := resp.Location()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tc.expectedPath, loc.Path)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetIndex(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
setIndex(resp, 1000)
|
|
header := resp.Header().Get("X-Nomad-Index")
|
|
if header != "1000" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
setIndex(resp, 2000)
|
|
if v := resp.Header()["X-Nomad-Index"]; len(v) != 1 {
|
|
t.Fatalf("bad: %#v", v)
|
|
}
|
|
}
|
|
|
|
func TestSetKnownLeader(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
setKnownLeader(resp, true)
|
|
header := resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "true" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
resp = httptest.NewRecorder()
|
|
setKnownLeader(resp, false)
|
|
header = resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "false" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetLastContact(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
setLastContact(resp, 123456*time.Microsecond)
|
|
header := resp.Header().Get("X-Nomad-LastContact")
|
|
if header != "123" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetMeta(t *testing.T) {
|
|
t.Parallel()
|
|
meta := structs.QueryMeta{
|
|
Index: 1000,
|
|
KnownLeader: true,
|
|
LastContact: 123456 * time.Microsecond,
|
|
}
|
|
resp := httptest.NewRecorder()
|
|
setMeta(resp, &meta)
|
|
header := resp.Header().Get("X-Nomad-Index")
|
|
if header != "1000" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
header = resp.Header().Get("X-Nomad-KnownLeader")
|
|
if header != "true" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
header = resp.Header().Get("X-Nomad-LastContact")
|
|
if header != "123" {
|
|
t.Fatalf("Bad: %v", header)
|
|
}
|
|
}
|
|
|
|
func TestSetHeaders(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
s.Agent.config.HTTPAPIResponseHeaders = map[string]string{"foo": "bar"}
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return &structs.Job{Name: "foo"}, nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
header := resp.Header().Get("foo")
|
|
|
|
if header != "bar" {
|
|
t.Fatalf("expected header: %v, actual: %v", "bar", header)
|
|
}
|
|
|
|
}
|
|
|
|
func TestContentTypeIsJSON(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return &structs.Job{Name: "foo"}, nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
|
|
contentType := resp.Header().Get("Content-Type")
|
|
|
|
if contentType != "application/json" {
|
|
t.Fatalf("Content-Type header was not 'application/json'")
|
|
}
|
|
}
|
|
|
|
func TestWrapNonJSON(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
|
handler := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return []byte("test response"), nil
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handler)(resp, req)
|
|
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, respBody, []byte("test response"))
|
|
|
|
}
|
|
|
|
func TestWrapNonJSON_Error(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
handlerRPCErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return nil, structs.NewErrRPCCoded(404, "not found")
|
|
}
|
|
|
|
handlerCodedErr := func(resp http.ResponseWriter, req *http.Request) ([]byte, error) {
|
|
return nil, CodedError(422, "unprocessable")
|
|
}
|
|
|
|
// RPC coded error
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handlerRPCErr)(resp, req)
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, []byte("not found"), respBody)
|
|
require.Equal(t, 404, resp.Code)
|
|
}
|
|
|
|
// CodedError
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
|
|
s.Server.wrapNonJSON(handlerCodedErr)(resp, req)
|
|
respBody, _ := ioutil.ReadAll(resp.Body)
|
|
require.Equal(t, []byte("unprocessable"), respBody)
|
|
require.Equal(t, 422, resp.Code)
|
|
}
|
|
|
|
}
|
|
|
|
func TestPrettyPrint(t *testing.T) {
|
|
t.Parallel()
|
|
testPrettyPrint("pretty=1", true, t)
|
|
}
|
|
|
|
func TestPrettyPrintOff(t *testing.T) {
|
|
t.Parallel()
|
|
testPrettyPrint("pretty=0", false, t)
|
|
}
|
|
|
|
func TestPrettyPrintBare(t *testing.T) {
|
|
t.Parallel()
|
|
testPrettyPrint("pretty", true, t)
|
|
}
|
|
|
|
func testPrettyPrint(pretty string, prettyFmt bool, t *testing.T) {
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
r := &structs.Job{Name: "foo"}
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return r, nil
|
|
}
|
|
|
|
urlStr := "/v1/job/foo?" + pretty
|
|
req, _ := http.NewRequest("GET", urlStr, nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
|
|
var expected bytes.Buffer
|
|
var err error
|
|
if prettyFmt {
|
|
enc := codec.NewEncoder(&expected, structs.JsonHandlePretty)
|
|
err = enc.Encode(r)
|
|
expected.WriteByte('\n')
|
|
} else {
|
|
enc := codec.NewEncoder(&expected, structs.JsonHandle)
|
|
err = enc.Encode(r)
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("failed to encode: %v", err)
|
|
}
|
|
actual, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("err: %s", err)
|
|
}
|
|
|
|
if !bytes.Equal(expected.Bytes(), actual) {
|
|
t.Fatalf("bad:\nexpected:\t%q\nactual:\t\t%q", expected.String(), string(actual))
|
|
}
|
|
}
|
|
|
|
func TestPermissionDenied(t *testing.T) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, structs.ErrPermissionDenied
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
|
|
// When remote RPC is used the errors have "rpc error: " prependend
|
|
{
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, fmt.Errorf("rpc error: %v", structs.ErrPermissionDenied)
|
|
}
|
|
|
|
req, _ := http.NewRequest("GET", "/v1/job/foo", nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
}
|
|
|
|
func TestTokenNotFound(t *testing.T) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
resp := httptest.NewRecorder()
|
|
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
|
return nil, structs.ErrTokenNotFound
|
|
}
|
|
|
|
urlStr := "/v1/job/foo"
|
|
req, _ := http.NewRequest("GET", urlStr, nil)
|
|
s.Server.wrap(handler)(resp, req)
|
|
assert.Equal(t, resp.Code, 403)
|
|
}
|
|
|
|
func TestParseWait(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60s&index=1000", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); d {
|
|
t.Fatalf("unexpected done")
|
|
}
|
|
|
|
if b.MinQueryIndex != 1000 {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
if b.MaxQueryTime != 60*time.Second {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
}
|
|
|
|
func TestParseWait_InvalidTime(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); !d {
|
|
t.Fatalf("expected done")
|
|
}
|
|
|
|
if resp.Code != 400 {
|
|
t.Fatalf("bad code: %v", resp.Code)
|
|
}
|
|
}
|
|
|
|
func TestParseWait_InvalidIndex(t *testing.T) {
|
|
t.Parallel()
|
|
resp := httptest.NewRecorder()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?wait=60s&index=foo", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
if d := parseWait(resp, req, &b); !d {
|
|
t.Fatalf("expected done")
|
|
}
|
|
|
|
if resp.Code != 400 {
|
|
t.Fatalf("bad code: %v", resp.Code)
|
|
}
|
|
}
|
|
|
|
func TestParseConsistency(t *testing.T) {
|
|
t.Parallel()
|
|
var b structs.QueryOptions
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/catalog/nodes?stale", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
parseConsistency(req, &b)
|
|
if !b.AllowStale {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
|
|
b = structs.QueryOptions{}
|
|
req, err = http.NewRequest("GET",
|
|
"/v1/catalog/nodes?consistent", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
parseConsistency(req, &b)
|
|
if b.AllowStale {
|
|
t.Fatalf("Bad: %v", b)
|
|
}
|
|
}
|
|
|
|
func TestParseRegion(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
req, err := http.NewRequest("GET",
|
|
"/v1/jobs?region=foo", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
var region string
|
|
s.Server.parseRegion(req, ®ion)
|
|
if region != "foo" {
|
|
t.Fatalf("bad %s", region)
|
|
}
|
|
|
|
region = ""
|
|
req, err = http.NewRequest("GET", "/v1/jobs", nil)
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
s.Server.parseRegion(req, ®ion)
|
|
if region != "global" {
|
|
t.Fatalf("bad %s", region)
|
|
}
|
|
}
|
|
|
|
func TestParseToken(t *testing.T) {
|
|
t.Parallel()
|
|
s := makeHTTPServer(t, nil)
|
|
defer s.Shutdown()
|
|
|
|
req, err := http.NewRequest("GET", "/v1/jobs", nil)
|
|
req.Header.Add("X-Nomad-Token", "foobar")
|
|
if err != nil {
|
|
t.Fatalf("err: %v", err)
|
|
}
|
|
|
|
var token string
|
|
s.Server.parseToken(req, &token)
|
|
if token != "foobar" {
|
|
t.Fatalf("bad %s", token)
|
|
}
|
|
}
|
|
|
|
// TestHTTP_VerifyHTTPSClient asserts that a client certificate signed by the
|
|
// appropriate CA is required when VerifyHTTPSClient=true.
|
|
func TestHTTP_VerifyHTTPSClient(t *testing.T) {
|
|
t.Parallel()
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
)
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.Region = "foo" // match the region on foocert
|
|
c.TLSConfig = &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
}
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
reqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
|
|
|
|
// FAIL: Requests that expect 127.0.0.1 as the name should fail
|
|
resp, err := http.Get(reqURL)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok := err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
hostErr, ok := urlErr.Err.(x509.HostnameError)
|
|
if !ok {
|
|
t.Fatalf("expected a x509.HostnameError but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
if expected := "127.0.0.1"; hostErr.Host != expected {
|
|
t.Fatalf("expected hostname on error to be %q but found %q", expected, hostErr.Host)
|
|
}
|
|
|
|
// FAIL: Requests that specify a valid hostname but not the CA should
|
|
// fail
|
|
tlsConf := &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
}
|
|
transport := &http.Transport{TLSClientConfig: tlsConf}
|
|
client := &http.Client{Transport: transport}
|
|
req, err := http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok = err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
_, ok = urlErr.Err.(x509.UnknownAuthorityError)
|
|
if !ok {
|
|
t.Fatalf("expected a x509.UnknownAuthorityError but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
|
|
// FAIL: Requests that specify a valid hostname and CA cert but lack a
|
|
// client certificate should fail
|
|
cacertBytes, err := ioutil.ReadFile(cafile)
|
|
if err != nil {
|
|
t.Fatalf("error reading cacert: %v", err)
|
|
}
|
|
tlsConf.RootCAs = x509.NewCertPool()
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
req, err = http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
t.Fatalf("expected non-nil error but received: %v", resp.StatusCode)
|
|
}
|
|
urlErr, ok = err.(*url.Error)
|
|
if !ok {
|
|
t.Fatalf("expected a *url.Error but received: %T -> %v", err, err)
|
|
}
|
|
opErr, ok := urlErr.Err.(*net.OpError)
|
|
if !ok {
|
|
t.Fatalf("expected a *net.OpErr but received: %T -> %v", urlErr.Err, urlErr.Err)
|
|
}
|
|
const badCertificate = "tls: bad certificate" // from crypto/tls/alert.go:52 and RFC 5246 § A.3
|
|
if opErr.Err.Error() != badCertificate {
|
|
t.Fatalf("expected tls.alert bad_certificate but received: %q", opErr.Err.Error())
|
|
}
|
|
|
|
// PASS: Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert, fookey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
}
|
|
transport = &http.Transport{TLSClientConfig: tlsConf}
|
|
client = &http.Client{Transport: transport}
|
|
req, err = http.NewRequest("GET", reqURL, nil)
|
|
if err != nil {
|
|
t.Fatalf("error creating request: %v", err)
|
|
}
|
|
resp, err = client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("expected 200 status code but got: %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTP_VerifyHTTPSClient_AfterConfigReload(t *testing.T) {
|
|
t.Parallel()
|
|
assert := assert.New(t)
|
|
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-bad.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-bad-key.pem"
|
|
foocert2 = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey2 = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
)
|
|
|
|
agentConfig := &Config{
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
},
|
|
}
|
|
|
|
newConfig := &Config{
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
VerifyHTTPSClient: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert2,
|
|
KeyFile: fookey2,
|
|
},
|
|
}
|
|
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.TLSConfig = agentConfig.TLSConfig
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
// Make an initial request that should fail.
|
|
// Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf := &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
RootCAs: x509.NewCertPool(),
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert, fookey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
},
|
|
}
|
|
|
|
// HTTPS request should succeed
|
|
httpsReqURL := fmt.Sprintf("https://%s/v1/agent/self", s.Agent.config.AdvertiseAddrs.HTTP)
|
|
|
|
cacertBytes, err := ioutil.ReadFile(cafile)
|
|
assert.Nil(err)
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
|
|
transport := &http.Transport{TLSClientConfig: tlsConf}
|
|
client := &http.Client{Transport: transport}
|
|
req, err := http.NewRequest("GET", httpsReqURL, nil)
|
|
assert.Nil(err)
|
|
|
|
// Check that we get an error that the certificate isn't valid for the
|
|
// region we are contacting.
|
|
_, err = client.Do(req)
|
|
assert.Contains(err.Error(), "certificate is valid for")
|
|
|
|
// Reload the TLS configuration==
|
|
assert.Nil(s.Agent.Reload(newConfig))
|
|
|
|
// Requests that specify a valid hostname, CA cert, and client
|
|
// certificate succeed.
|
|
tlsConf = &tls.Config{
|
|
ServerName: "client.regionFoo.nomad",
|
|
RootCAs: x509.NewCertPool(),
|
|
GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(foocert2, fookey2)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
},
|
|
}
|
|
|
|
cacertBytes, err = ioutil.ReadFile(cafile)
|
|
assert.Nil(err)
|
|
tlsConf.RootCAs.AppendCertsFromPEM(cacertBytes)
|
|
|
|
transport = &http.Transport{TLSClientConfig: tlsConf}
|
|
client = &http.Client{Transport: transport}
|
|
req, err = http.NewRequest("GET", httpsReqURL, nil)
|
|
assert.Nil(err)
|
|
|
|
resp, err := client.Do(req)
|
|
if assert.Nil(err) {
|
|
resp.Body.Close()
|
|
assert.Equal(resp.StatusCode, 200)
|
|
}
|
|
}
|
|
|
|
// TestHTTPServer_Limits_Error asserts invalid Limits cause errors. This is the
|
|
// HTTP counterpart to TestAgent_ServerConfig_Limits_Error.
|
|
func TestHTTPServer_Limits_Error(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
tls bool
|
|
timeout string
|
|
limit *int
|
|
expectedErr string
|
|
}{
|
|
{
|
|
tls: true,
|
|
timeout: "",
|
|
limit: nil,
|
|
expectedErr: "error parsing https_handshake_timeout: ",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "",
|
|
limit: nil,
|
|
expectedErr: "error parsing https_handshake_timeout: ",
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "-1s",
|
|
limit: nil,
|
|
expectedErr: "https_handshake_timeout must be >= 0",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "-1s",
|
|
limit: nil,
|
|
expectedErr: "https_handshake_timeout must be >= 0",
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: helper.IntToPtr(-1),
|
|
expectedErr: "http_max_conns_per_client must be >= 0",
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: helper.IntToPtr(-1),
|
|
expectedErr: "http_max_conns_per_client must be >= 0",
|
|
},
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Use a fake agent since the HTTP server should never start
|
|
agent := &Agent{
|
|
logger: testlog.HCLogger(t),
|
|
}
|
|
|
|
conf := &Config{
|
|
normalizedAddrs: &Addresses{
|
|
HTTP: "localhost:0", // port is never used
|
|
},
|
|
TLSConfig: &config.TLSConfig{
|
|
EnableHTTP: tc.tls,
|
|
},
|
|
Limits: config.Limits{
|
|
HTTPSHandshakeTimeout: tc.timeout,
|
|
HTTPMaxConnsPerClient: tc.limit,
|
|
},
|
|
}
|
|
|
|
srv, err := NewHTTPServer(agent, conf)
|
|
require.Error(t, err)
|
|
require.Nil(t, srv)
|
|
require.Contains(t, err.Error(), tc.expectedErr)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHTTPServer_Limits_OK asserts that all valid limits combinations
|
|
// (tls/timeout/conns) work.
|
|
func TestHTTPServer_Limits_OK(t *testing.T) {
|
|
t.Parallel()
|
|
const (
|
|
cafile = "../../helper/tlsutil/testdata/ca.pem"
|
|
foocert = "../../helper/tlsutil/testdata/nomad-foo.pem"
|
|
fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem"
|
|
maxConns = 10 // limit must be < this for testing
|
|
)
|
|
|
|
cases := []struct {
|
|
tls bool
|
|
timeout string
|
|
limit *int
|
|
assertTimeout bool
|
|
assertLimit bool
|
|
}{
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: nil,
|
|
assertTimeout: true,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "0",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "0",
|
|
limit: nil,
|
|
assertTimeout: false,
|
|
assertLimit: false,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "0",
|
|
limit: helper.IntToPtr(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "0",
|
|
limit: helper.IntToPtr(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: false,
|
|
timeout: "5s",
|
|
limit: helper.IntToPtr(2),
|
|
assertTimeout: false,
|
|
assertLimit: true,
|
|
},
|
|
{
|
|
tls: true,
|
|
timeout: "5s",
|
|
limit: helper.IntToPtr(2),
|
|
assertTimeout: true,
|
|
assertLimit: true,
|
|
},
|
|
}
|
|
|
|
assertTimeout := func(t *testing.T, a *TestAgent, assertTimeout bool, timeout string) {
|
|
timeoutDeadline, err := time.ParseDuration(timeout)
|
|
require.NoError(t, err)
|
|
|
|
// Increase deadline to detect timeouts
|
|
deadline := timeoutDeadline + time.Second
|
|
|
|
conn, err := net.DialTimeout("tcp", a.Server.Addr, deadline)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
buf := []byte{0}
|
|
readDeadline := time.Now().Add(deadline)
|
|
conn.SetReadDeadline(readDeadline)
|
|
n, err := conn.Read(buf)
|
|
require.Zero(t, n)
|
|
if assertTimeout {
|
|
// Server timeouts == EOF
|
|
require.Equal(t, io.EOF, err)
|
|
|
|
// Perform blocking query to assert timeout is not
|
|
// enabled post-TLS-handshake.
|
|
q := &api.QueryOptions{
|
|
WaitIndex: 10000, // wait a looong time
|
|
WaitTime: deadline,
|
|
}
|
|
|
|
// Assertions don't require certificate validation
|
|
conf := api.DefaultConfig()
|
|
conf.Address = a.HTTPAddr()
|
|
conf.TLSConfig.Insecure = true
|
|
client, err := api.NewClient(conf)
|
|
require.NoError(t, err)
|
|
|
|
// Assert a blocking query isn't timed out by the
|
|
// handshake timeout
|
|
jobs, meta, err := client.Jobs().List(q)
|
|
require.NoError(t, err)
|
|
require.Len(t, jobs, 0)
|
|
require.Truef(t, meta.RequestTime >= deadline,
|
|
"expected RequestTime (%s) >= Deadline (%s)",
|
|
meta.RequestTime, deadline)
|
|
|
|
return
|
|
}
|
|
|
|
// HTTP Server should *not* have timed out.
|
|
// Now() should always be after the read deadline, but
|
|
// isn't a sufficient assertion for correctness as slow
|
|
// tests may cause this to be true even if the server
|
|
// timed out.
|
|
require.True(t, time.Now().After(readDeadline))
|
|
|
|
testutil.RequireDeadlineErr(t, err)
|
|
}
|
|
|
|
assertNoLimit := func(t *testing.T, addr string) {
|
|
var err error
|
|
|
|
// Create max connections
|
|
conns := make([]net.Conn, maxConns)
|
|
errCh := make(chan error, maxConns)
|
|
for i := 0; i < maxConns; i++ {
|
|
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
|
|
require.NoError(t, err)
|
|
defer conns[i].Close()
|
|
|
|
go func(i int) {
|
|
buf := []byte{0}
|
|
readDeadline := time.Now().Add(1 * time.Second)
|
|
conns[i].SetReadDeadline(readDeadline)
|
|
n, err := conns[i].Read(buf)
|
|
if n > 0 {
|
|
errCh <- fmt.Errorf("n > 0: %d", n)
|
|
return
|
|
}
|
|
errCh <- err
|
|
}(i)
|
|
}
|
|
|
|
// Now assert each error is a clientside read deadline error
|
|
for i := 0; i < maxConns; i++ {
|
|
select {
|
|
case <-time.After(1 * time.Second):
|
|
t.Fatalf("timed out waiting for conn error %d", i)
|
|
case err := <-errCh:
|
|
testutil.RequireDeadlineErr(t, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
assertLimit := func(t *testing.T, addr string, limit int) {
|
|
var err error
|
|
|
|
// Create limit connections
|
|
conns := make([]net.Conn, limit)
|
|
errCh := make(chan error, limit)
|
|
for i := range conns {
|
|
conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
|
|
require.NoError(t, err)
|
|
defer conns[i].Close()
|
|
|
|
go func(i int) {
|
|
buf := []byte{0}
|
|
n, err := conns[i].Read(buf)
|
|
if n > 0 {
|
|
errCh <- fmt.Errorf("n > 0: %d", n)
|
|
return
|
|
}
|
|
errCh <- err
|
|
}(i)
|
|
}
|
|
|
|
// Assert a new connection is dropped
|
|
conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
buf := []byte{0}
|
|
deadline := time.Now().Add(10 * time.Second)
|
|
conn.SetReadDeadline(deadline)
|
|
n, err := conn.Read(buf)
|
|
require.Zero(t, n)
|
|
require.Equal(t, io.EOF, err)
|
|
|
|
// Assert existing connections are ok
|
|
require.Len(t, errCh, 0)
|
|
|
|
// Cleanup
|
|
for _, conn := range conns {
|
|
conn.Close()
|
|
}
|
|
for range conns {
|
|
err := <-errCh
|
|
require.Contains(t, err.Error(), "use of closed network connection")
|
|
}
|
|
}
|
|
|
|
for i := range cases {
|
|
tc := cases[i]
|
|
name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if tc.limit != nil && *tc.limit >= maxConns {
|
|
t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", *tc.limit, maxConns)
|
|
}
|
|
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
if tc.tls {
|
|
c.TLSConfig = &config.TLSConfig{
|
|
EnableHTTP: true,
|
|
CAFile: cafile,
|
|
CertFile: foocert,
|
|
KeyFile: fookey,
|
|
}
|
|
}
|
|
c.Limits.HTTPSHandshakeTimeout = tc.timeout
|
|
c.Limits.HTTPMaxConnsPerClient = tc.limit
|
|
})
|
|
defer s.Shutdown()
|
|
|
|
assertTimeout(t, s, tc.assertTimeout, tc.timeout)
|
|
if tc.assertLimit {
|
|
assertLimit(t, s.Server.Addr, *tc.limit)
|
|
} else {
|
|
assertNoLimit(t, s.Server.Addr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
|
|
s := makeHTTPServer(t, cb)
|
|
defer s.Shutdown()
|
|
testutil.WaitForLeader(t, s.Agent.RPC)
|
|
f(s)
|
|
}
|
|
|
|
func httpACLTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) {
|
|
s := makeHTTPServer(t, func(c *Config) {
|
|
c.ACL.Enabled = true
|
|
if cb != nil {
|
|
cb(c)
|
|
}
|
|
})
|
|
defer s.Shutdown()
|
|
testutil.WaitForLeader(t, s.Agent.RPC)
|
|
f(s)
|
|
}
|
|
|
|
func setToken(req *http.Request, token *structs.ACLToken) {
|
|
req.Header.Set("X-Nomad-Token", token.SecretID)
|
|
}
|
|
|
|
func encodeReq(obj interface{}) io.ReadCloser {
|
|
buf := bytes.NewBuffer(nil)
|
|
enc := json.NewEncoder(buf)
|
|
enc.Encode(obj)
|
|
return ioutil.NopCloser(buf)
|
|
}
|