diff --git a/command/agent/agent.go b/command/agent/agent.go index 205d100c5..c927a7b90 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -340,6 +340,26 @@ func convertServerConfig(agentConfig *Config) (*nomad.Config, error) { conf.DisableDispatchedJobSummaryMetrics = agentConfig.Telemetry.DisableDispatchedJobSummaryMetrics conf.BackwardsCompatibleMetrics = agentConfig.Telemetry.BackwardsCompatibleMetrics + // Parse Limits timeout from a string into durations + if d, err := time.ParseDuration(agentConfig.Limits.RPCHandshakeTimeout); err != nil { + return nil, fmt.Errorf("error parsing rpc_handshake_timeout: %v", err) + } else if d < 0 { + return nil, fmt.Errorf("rpc_handshake_timeout must be >= 0") + } else { + conf.RPCHandshakeTimeout = d + } + + // Set max rpc conns; nil/0 == unlimited + // Leave a little room for streaming RPCs + minLimit := config.LimitsNonStreamingConnsPerClient + 5 + if agentConfig.Limits.RPCMaxConnsPerClient == nil || *agentConfig.Limits.RPCMaxConnsPerClient == 0 { + conf.RPCMaxConnsPerClient = 0 + } else if limit := *agentConfig.Limits.RPCMaxConnsPerClient; limit <= minLimit { + return nil, fmt.Errorf("rpc_max_conns_per_client must be > %d; found: %d", minLimit, limit) + } else { + conf.RPCMaxConnsPerClient = limit + } + return conf, nil } diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 8472782cc..09a5cde4e 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" sconfig "github.com/hashicorp/nomad/nomad/structs/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -53,34 +54,21 @@ func TestAgent_ServerConfig(t *testing.T) { t.Fatalf("error normalizing config: %v", err) } out, err := a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } + require.NoError(t, err) + serfAddr := out.SerfConfig.MemberlistConfig.AdvertiseAddr - if serfAddr != "127.0.0.1" { - t.Fatalf("expect 127.0.0.1, got: %s", serfAddr) - } + require.Equal(t, "127.0.0.1", serfAddr) + serfPort := out.SerfConfig.MemberlistConfig.AdvertisePort - if serfPort != 4000 { - t.Fatalf("expected 4000, got: %d", serfPort) - } - if out.AuthoritativeRegion != "global" { - t.Fatalf("bad: %#v", out.AuthoritativeRegion) - } - if !out.ACLEnabled { - t.Fatalf("ACL not enabled") - } + require.Equal(t, 4000, serfPort) + + require.Equal(t, "global", out.AuthoritativeRegion) + require.True(t, out.ACLEnabled) // Assert addresses weren't changed - if addr := conf.AdvertiseAddrs.RPC; addr != "127.0.0.1:4001" { - t.Fatalf("bad rpc advertise addr: %#v", addr) - } - if addr := conf.AdvertiseAddrs.HTTP; addr != "10.10.11.1:4005" { - t.Fatalf("expect 10.11.11.1:4005, got: %v", addr) - } - if addr := conf.Addresses.RPC; addr != "0.0.0.0" { - t.Fatalf("expect 0.0.0.0, got: %v", addr) - } + require.Equal(t, "127.0.0.1:4001", conf.AdvertiseAddrs.RPC) + require.Equal(t, "10.10.11.1:4005", conf.AdvertiseAddrs.HTTP) + require.Equal(t, "0.0.0.0", conf.Addresses.RPC) // Sets up the ports properly conf.Addresses.RPC = "" @@ -88,19 +76,12 @@ func TestAgent_ServerConfig(t *testing.T) { conf.Ports.RPC = 4003 conf.Ports.Serf = 4004 - if err := conf.normalizeAddrs(); err != nil { - t.Fatalf("error normalizing config: %v", err) - } + require.NoError(t, conf.normalizeAddrs()) + out, err = a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if addr := out.RPCAddr.Port; addr != 4003 { - t.Fatalf("expect 4003, got: %d", out.RPCAddr.Port) - } - if port := out.SerfConfig.MemberlistConfig.BindPort; port != 4004 { - t.Fatalf("expect 4004, got: %d", port) - } + require.NoError(t, err) + require.Equal(t, 4003, out.RPCAddr.Port) + require.Equal(t, 4004, out.SerfConfig.MemberlistConfig.BindPort) // Prefers advertise over bind addr conf.BindAddr = "127.0.0.3" @@ -111,100 +92,51 @@ func TestAgent_ServerConfig(t *testing.T) { conf.AdvertiseAddrs.RPC = "" conf.AdvertiseAddrs.Serf = "10.0.0.12:4004" - if err := conf.normalizeAddrs(); err != nil { - t.Fatalf("error normalizing config: %v", err) - } + require.NoError(t, conf.normalizeAddrs()) + out, err = a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if addr := out.RPCAddr.IP.String(); addr != "127.0.0.2" { - t.Fatalf("expect 127.0.0.2, got: %s", addr) - } - if port := out.RPCAddr.Port; port != 4003 { - t.Fatalf("expect 4647, got: %d", port) - } - if addr := out.SerfConfig.MemberlistConfig.BindAddr; addr != "127.0.0.2" { - t.Fatalf("expect 127.0.0.2, got: %s", addr) - } - if port := out.SerfConfig.MemberlistConfig.BindPort; port != 4004 { - t.Fatalf("expect 4648, got: %d", port) - } - if addr := conf.Addresses.HTTP; addr != "127.0.0.2" { - t.Fatalf("expect 127.0.0.2, got: %s", addr) - } - if addr := conf.Addresses.RPC; addr != "127.0.0.2" { - t.Fatalf("expect 127.0.0.2, got: %s", addr) - } - if addr := conf.Addresses.Serf; addr != "127.0.0.2" { - t.Fatalf("expect 10.0.0.12, got: %s", addr) - } - if addr := conf.normalizedAddrs.HTTP; addr != "127.0.0.2:4646" { - t.Fatalf("expect 127.0.0.2:4646, got: %s", addr) - } - if addr := conf.normalizedAddrs.RPC; addr != "127.0.0.2:4003" { - t.Fatalf("expect 127.0.0.2:4003, got: %s", addr) - } - if addr := conf.normalizedAddrs.Serf; addr != "127.0.0.2:4004" { - t.Fatalf("expect 10.0.0.12:4004, got: %s", addr) - } - if addr := conf.AdvertiseAddrs.HTTP; addr != "10.0.0.10:4646" { - t.Fatalf("expect 10.0.0.10:4646, got: %s", addr) - } - if addr := conf.AdvertiseAddrs.RPC; addr != "127.0.0.2:4003" { - t.Fatalf("expect 127.0.0.2:4003, got: %s", addr) - } - if addr := conf.AdvertiseAddrs.Serf; addr != "10.0.0.12:4004" { - t.Fatalf("expect 10.0.0.12:4004, got: %s", addr) - } + require.Equal(t, "127.0.0.2", out.RPCAddr.IP.String()) + require.Equal(t, 4003, out.RPCAddr.Port) + require.Equal(t, "127.0.0.2", out.SerfConfig.MemberlistConfig.BindAddr) + require.Equal(t, 4004, out.SerfConfig.MemberlistConfig.BindPort) + require.Equal(t, "127.0.0.2", conf.Addresses.HTTP) + require.Equal(t, "127.0.0.2", conf.Addresses.RPC) + require.Equal(t, "127.0.0.2", conf.Addresses.Serf) + require.Equal(t, "127.0.0.2:4646", conf.normalizedAddrs.HTTP) + require.Equal(t, "127.0.0.2:4003", conf.normalizedAddrs.RPC) + require.Equal(t, "127.0.0.2:4004", conf.normalizedAddrs.Serf) + require.Equal(t, "10.0.0.10:4646", conf.AdvertiseAddrs.HTTP) + require.Equal(t, "127.0.0.2:4003", conf.AdvertiseAddrs.RPC) + require.Equal(t, "10.0.0.12:4004", conf.AdvertiseAddrs.Serf) conf.Server.NodeGCThreshold = "42g" - if err := conf.normalizeAddrs(); err != nil { - t.Fatalf("error normalizing config: %v", err) - } + require.NoError(t, conf.normalizeAddrs()) + _, err = a.serverConfig() if err == nil || !strings.Contains(err.Error(), "unknown unit") { t.Fatalf("expected unknown unit error, got: %#v", err) } conf.Server.NodeGCThreshold = "10s" - if err := conf.normalizeAddrs(); err != nil { - t.Fatalf("error normalizing config: %v", err) - } + require.NoError(t, conf.normalizeAddrs()) out, err = a.serverConfig() - if err != nil { - t.Fatalf("error getting server config: %s", err) - } - if threshold := out.NodeGCThreshold; threshold != time.Second*10 { - t.Fatalf("expect 10s, got: %s", threshold) - } + require.NoError(t, err) + require.Equal(t, 10*time.Second, out.NodeGCThreshold) conf.Server.HeartbeatGrace = 37 * time.Second out, err = a.serverConfig() - if err != nil { - t.Fatalf("error getting server config: %s", err) - } - if threshold := out.HeartbeatGrace; threshold != time.Second*37 { - t.Fatalf("expect 37s, got: %s", threshold) - } + require.NoError(t, err) + require.Equal(t, 37*time.Second, out.HeartbeatGrace) conf.Server.MinHeartbeatTTL = 37 * time.Second out, err = a.serverConfig() - if err != nil { - t.Fatalf("error getting server config: %s", err) - } - if min := out.MinHeartbeatTTL; min != time.Second*37 { - t.Fatalf("expect 37s, got: %s", min) - } + require.NoError(t, err) + require.Equal(t, 37*time.Second, out.MinHeartbeatTTL) conf.Server.MaxHeartbeatsPerSecond = 11.0 out, err = a.serverConfig() - if err != nil { - t.Fatalf("error getting server config: %s", err) - } - if max := out.MaxHeartbeatsPerSecond; max != 11.0 { - t.Fatalf("expect 11, got: %v", max) - } + require.NoError(t, err) + require.Equal(t, float64(11.0), out.MaxHeartbeatsPerSecond) // Defaults to the global bind addr conf.Addresses.RPC = "" @@ -216,62 +148,32 @@ func TestAgent_ServerConfig(t *testing.T) { conf.Ports.HTTP = 4646 conf.Ports.RPC = 4647 conf.Ports.Serf = 4648 - if err := conf.normalizeAddrs(); err != nil { - t.Fatalf("error normalizing config: %v", err) - } + require.NoError(t, conf.normalizeAddrs()) + out, err = a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if addr := out.RPCAddr.IP.String(); addr != "127.0.0.3" { - t.Fatalf("expect 127.0.0.3, got: %s", addr) - } - if addr := out.SerfConfig.MemberlistConfig.BindAddr; addr != "127.0.0.3" { - t.Fatalf("expect 127.0.0.3, got: %s", addr) - } - if addr := conf.Addresses.HTTP; addr != "127.0.0.3" { - t.Fatalf("expect 127.0.0.3, got: %s", addr) - } - if addr := conf.Addresses.RPC; addr != "127.0.0.3" { - t.Fatalf("expect 127.0.0.3, got: %s", addr) - } - if addr := conf.Addresses.Serf; addr != "127.0.0.3" { - t.Fatalf("expect 127.0.0.3, got: %s", addr) - } - if addr := conf.normalizedAddrs.HTTP; addr != "127.0.0.3:4646" { - t.Fatalf("expect 127.0.0.3:4646, got: %s", addr) - } - if addr := conf.normalizedAddrs.RPC; addr != "127.0.0.3:4647" { - t.Fatalf("expect 127.0.0.3:4647, got: %s", addr) - } - if addr := conf.normalizedAddrs.Serf; addr != "127.0.0.3:4648" { - t.Fatalf("expect 127.0.0.3:4648, got: %s", addr) - } + require.NoError(t, err) + + require.Equal(t, "127.0.0.3", out.RPCAddr.IP.String()) + require.Equal(t, "127.0.0.3", out.SerfConfig.MemberlistConfig.BindAddr) + require.Equal(t, "127.0.0.3", conf.Addresses.HTTP) + require.Equal(t, "127.0.0.3", conf.Addresses.RPC) + require.Equal(t, "127.0.0.3", conf.Addresses.Serf) + require.Equal(t, "127.0.0.3:4646", conf.normalizedAddrs.HTTP) + require.Equal(t, "127.0.0.3:4647", conf.normalizedAddrs.RPC) + require.Equal(t, "127.0.0.3:4648", conf.normalizedAddrs.Serf) // Properly handles the bootstrap flags conf.Server.BootstrapExpect = 1 out, err = a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if !out.Bootstrap { - t.Fatalf("should have set bootstrap mode") - } - if out.BootstrapExpect != 0 { - t.Fatalf("bootstrap expect should be 0") - } + require.NoError(t, err) + require.True(t, out.Bootstrap) + require.Equal(t, int32(0), out.BootstrapExpect) conf.Server.BootstrapExpect = 3 out, err = a.serverConfig() - if err != nil { - t.Fatalf("err: %s", err) - } - if out.Bootstrap { - t.Fatalf("bootstrap mode should be disabled") - } - if out.BootstrapExpect != 3 { - t.Fatalf("should have bootstrap-expect = 3") - } + require.NoError(t, err) + require.False(t, out.Bootstrap) + require.Equal(t, int32(3), out.BootstrapExpect) } func TestAgent_ServerConfig_SchedulerFlags(t *testing.T) { @@ -335,6 +237,132 @@ func TestAgent_ServerConfig_SchedulerFlags(t *testing.T) { }) } } + +// TestAgent_ServerConfig_Limits_Errors asserts invalid Limits configurations +// cause errors. This is the server-only (RPC) counterpart to +// TestHTTPServer_Limits_Error. +func TestAgent_ServerConfig_Limits_Error(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + expectedErr string + limits sconfig.Limits + }{ + { + name: "Negative Timeout", + expectedErr: "rpc_handshake_timeout must be >= 0", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "-5s", + RPCMaxConnsPerClient: helper.IntToPtr(100), + }, + }, + { + name: "Invalid Timeout", + expectedErr: "error parsing rpc_handshake_timeout", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "s", + RPCMaxConnsPerClient: helper.IntToPtr(100), + }, + }, + { + name: "Missing Timeout", + expectedErr: "error parsing rpc_handshake_timeout", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "", + RPCMaxConnsPerClient: helper.IntToPtr(100), + }, + }, + { + name: "Negative Connection Limit", + expectedErr: "rpc_max_conns_per_client must be > 25; found: -100", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "5s", + RPCMaxConnsPerClient: helper.IntToPtr(-100), + }, + }, + { + name: "Low Connection Limit", + expectedErr: "rpc_max_conns_per_client must be > 25; found: 20", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "5s", + RPCMaxConnsPerClient: helper.IntToPtr(sconfig.LimitsNonStreamingConnsPerClient), + }, + }, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.name, func(t *testing.T) { + conf := DevConfig(nil) + require.NoError(t, conf.normalizeAddrs()) + + conf.Limits = tc.limits + serverConf, err := convertServerConfig(conf) + assert.Nil(t, serverConf) + require.Contains(t, err.Error(), tc.expectedErr) + }) + } +} + +// TestAgent_ServerConfig_Limits_OK asserts valid Limits configurations do not +// cause errors. This is the server-only (RPC) counterpart to +// TestHTTPServer_Limits_OK. +func TestAgent_ServerConfig_Limits_OK(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + limits sconfig.Limits + }{ + { + name: "Default", + limits: config.DefaultLimits(), + }, + { + name: "Zero+nil is valid to disable", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "0", + RPCMaxConnsPerClient: nil, + }, + }, + { + name: "Zeros are valid", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "0s", + RPCMaxConnsPerClient: helper.IntToPtr(0), + }, + }, + { + name: "Low limits are valid", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "1ms", + RPCMaxConnsPerClient: helper.IntToPtr(26), + }, + }, + { + name: "High limits are valid", + limits: sconfig.Limits{ + RPCHandshakeTimeout: "5h", + RPCMaxConnsPerClient: helper.IntToPtr(100000), + }, + }, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.name, func(t *testing.T) { + conf := DevConfig(nil) + require.NoError(t, conf.normalizeAddrs()) + + conf.Limits = tc.limits + serverConf, err := convertServerConfig(conf) + assert.NoError(t, err) + require.NotNil(t, serverConf) + }) + } +} + func TestAgent_ClientConfig(t *testing.T) { t.Parallel() conf := DefaultConfig() @@ -380,7 +408,7 @@ func TestAgent_ClientConfig(t *testing.T) { } // Clients should inherit telemetry configuration -func TestAget_Client_TelemetryConfiguration(t *testing.T) { +func TestAgent_Client_TelemetryConfiguration(t *testing.T) { assert := assert.New(t) conf := DefaultConfig() diff --git a/command/agent/config.go b/command/agent/config.go index 07e59703b..2e560ac42 100644 --- a/command/agent/config.go +++ b/command/agent/config.go @@ -166,6 +166,9 @@ type Config struct { // Plugins is the set of configured plugins Plugins []*config.PluginConfig `hcl:"plugin"` + // Limits contains the configuration for timeouts. + Limits config.Limits `hcl:"limits"` + // ExtraKeysHCL is used by hcl to surface unexpected keys ExtraKeysHCL []string `hcl:",unusedKeys" json:"-"` } @@ -862,6 +865,7 @@ func DefaultConfig() *Config { Version: version.GetVersion(), Autopilot: config.DefaultAutopilotConfig(), DisableUpdateCheck: helper.BoolToPtr(false), + Limits: config.DefaultLimits(), } } @@ -1066,6 +1070,8 @@ func (c *Config) Merge(b *Config) *Config { result.HTTPAPIResponseHeaders[k] = v } + result.Limits = c.Limits.Merge(b.Limits) + return &result } diff --git a/command/agent/http.go b/command/agent/http.go index 5cab3efea..910bb74ef 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -16,6 +16,7 @@ import ( "github.com/NYTimes/gziphandler" assetfs "github.com/elazarl/go-bindata-assetfs" "github.com/gorilla/websocket" + "github.com/hashicorp/go-connlimit" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/nomad/nomad/structs" @@ -112,14 +113,102 @@ func NewHTTPServer(agent *Agent, config *Config) (*HTTPServer, error) { return nil, err } + // Get connection handshake timeout limit + handshakeTimeout, err := time.ParseDuration(config.Limits.HTTPSHandshakeTimeout) + if err != nil { + return nil, fmt.Errorf("error parsing https_handshake_timeout: %v", err) + } else if handshakeTimeout < 0 { + return nil, fmt.Errorf("https_handshake_timeout must be >= 0") + } + + // Get max connection limit + maxConns := 0 + if mc := config.Limits.HTTPMaxConnsPerClient; mc != nil { + maxConns = *mc + } + if maxConns < 0 { + return nil, fmt.Errorf("http_max_conns_per_client must be >= 0") + } + + // Create HTTP server with timeouts + httpServer := http.Server{ + Addr: srv.Addr, + Handler: gzip(mux), + ConnState: makeConnState(config.TLSConfig.EnableHTTP, handshakeTimeout, maxConns), + } + go func() { defer close(srv.listenerCh) - http.Serve(ln, gzip(mux)) + httpServer.Serve(ln) }() return srv, nil } +// makeConnState returns a ConnState func for use in an http.Server. If +// isTLS=true and handshakeTimeout>0 then the handshakeTimeout will be applied +// as a connection deadline to new connections and removed when the connection +// is active (meaning it has successfully completed the TLS handshake). +// +// If limit > 0, a per-address connection limit will be enabled regardless of +// TLS. If connLimit == 0 there is no connection limit. +func makeConnState(isTLS bool, handshakeTimeout time.Duration, connLimit int) func(conn net.Conn, state http.ConnState) { + if !isTLS || handshakeTimeout == 0 { + if connLimit > 0 { + // Still return the connection limiter + return connlimit.NewLimiter(connlimit.Config{ + MaxConnsPerClientIP: connLimit, + }).HTTPConnStateFunc() + } + + return nil + } + + if connLimit > 0 { + // Return conn state callback with connection limiting and a + // handshake timeout. + + connLimiter := connlimit.NewLimiter(connlimit.Config{ + MaxConnsPerClientIP: connLimit, + }).HTTPConnStateFunc() + + return func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + // Set deadline to prevent slow send before TLS handshake or first + // byte of request. + conn.SetDeadline(time.Now().Add(handshakeTimeout)) + case http.StateActive: + // Clear read deadline. We should maybe set read timeouts more + // generally but that's a bigger task as some HTTP endpoints may + // stream large requests and responses (e.g. snapshot) so we can't + // set sensible blanket timeouts here. + conn.SetDeadline(time.Time{}) + } + + // Call connection limiter + connLimiter(conn, state) + } + } + + // Return conn state callback with just a handshake timeout + // (connection limiting disabled). + return func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + // Set deadline to prevent slow send before TLS handshake or first + // byte of request. + conn.SetDeadline(time.Now().Add(handshakeTimeout)) + case http.StateActive: + // Clear read deadline. We should maybe set read timeouts more + // generally but that's a bigger task as some HTTP endpoints may + // stream large requests and responses (e.g. snapshot) so we can't + // set sensible blanket timeouts here. + conn.SetDeadline(time.Time{}) + } + } +} + // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by NewHttpServer so // dead TCP connections eventually go away. diff --git a/command/agent/http_test.go b/command/agent/http_test.go index ab26ed4f9..e0cdc46ae 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -15,6 +15,9 @@ import ( "testing" "time" + "github.com/hashicorp/nomad/api" + "github.com/hashicorp/nomad/helper" + "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" @@ -733,6 +736,335 @@ func TestHTTP_VerifyHTTPSClient_AfterConfigReload(t *testing.T) { } } +// TestHTTPServer_Limits_Error asserts invalid Limits cause errors. This is the +// HTTP counterpart to TestAgent_ServerConfig_Limits_Error. +func TestHTTPServer_Limits_Error(t *testing.T) { + t.Parallel() + + cases := []struct { + tls bool + timeout string + limit *int + expectedErr string + }{ + { + tls: true, + timeout: "", + limit: nil, + expectedErr: "error parsing https_handshake_timeout: ", + }, + { + tls: false, + timeout: "", + limit: nil, + expectedErr: "error parsing https_handshake_timeout: ", + }, + { + tls: true, + timeout: "-1s", + limit: nil, + expectedErr: "https_handshake_timeout must be >= 0", + }, + { + tls: false, + timeout: "-1s", + limit: nil, + expectedErr: "https_handshake_timeout must be >= 0", + }, + { + tls: true, + timeout: "5s", + limit: helper.IntToPtr(-1), + expectedErr: "http_max_conns_per_client must be >= 0", + }, + { + tls: false, + timeout: "5s", + limit: helper.IntToPtr(-1), + expectedErr: "http_max_conns_per_client must be >= 0", + }, + } + + for i := range cases { + tc := cases[i] + name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit) + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Use a fake agent since the HTTP server should never start + agent := &Agent{ + logger: testlog.HCLogger(t), + } + + conf := &Config{ + normalizedAddrs: &Addresses{ + HTTP: "localhost:0", // port is never used + }, + TLSConfig: &config.TLSConfig{ + EnableHTTP: tc.tls, + }, + Limits: config.Limits{ + HTTPSHandshakeTimeout: tc.timeout, + HTTPMaxConnsPerClient: tc.limit, + }, + } + + srv, err := NewHTTPServer(agent, conf) + require.Error(t, err) + require.Nil(t, srv) + require.Contains(t, err.Error(), tc.expectedErr) + }) + } +} + +// TestHTTPServer_Limits_OK asserts that all valid limits combinations +// (tls/timeout/conns) work. +func TestHTTPServer_Limits_OK(t *testing.T) { + t.Parallel() + const ( + cafile = "../../helper/tlsutil/testdata/ca.pem" + foocert = "../../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../../helper/tlsutil/testdata/nomad-foo-key.pem" + maxConns = 10 // limit must be < this for testing + ) + + cases := []struct { + tls bool + timeout string + limit *int + assertTimeout bool + assertLimit bool + }{ + { + tls: false, + timeout: "5s", + limit: nil, + assertTimeout: false, + assertLimit: false, + }, + { + tls: true, + timeout: "5s", + limit: nil, + assertTimeout: true, + assertLimit: false, + }, + { + tls: false, + timeout: "0", + limit: nil, + assertTimeout: false, + assertLimit: false, + }, + { + tls: true, + timeout: "0", + limit: nil, + assertTimeout: false, + assertLimit: false, + }, + { + tls: false, + timeout: "0", + limit: helper.IntToPtr(2), + assertTimeout: false, + assertLimit: true, + }, + { + tls: true, + timeout: "0", + limit: helper.IntToPtr(2), + assertTimeout: false, + assertLimit: true, + }, + { + tls: false, + timeout: "5s", + limit: helper.IntToPtr(2), + assertTimeout: false, + assertLimit: true, + }, + { + tls: true, + timeout: "5s", + limit: helper.IntToPtr(2), + assertTimeout: true, + assertLimit: true, + }, + } + + assertTimeout := func(t *testing.T, a *TestAgent, assertTimeout bool, timeout string) { + timeoutDeadline, err := time.ParseDuration(timeout) + require.NoError(t, err) + + // Increase deadline to detect timeouts + deadline := timeoutDeadline + time.Second + + conn, err := net.DialTimeout("tcp", a.Server.Addr, deadline) + require.NoError(t, err) + defer conn.Close() + + buf := []byte{0} + readDeadline := time.Now().Add(deadline) + conn.SetReadDeadline(readDeadline) + n, err := conn.Read(buf) + require.Zero(t, n) + if assertTimeout { + // Server timeouts == EOF + require.Equal(t, io.EOF, err) + + // Perform blocking query to assert timeout is not + // enabled post-TLS-handshake. + q := &api.QueryOptions{ + WaitIndex: 10000, // wait a looong time + WaitTime: deadline, + } + + // Assertions don't require certificate validation + conf := api.DefaultConfig() + conf.Address = a.HTTPAddr() + conf.TLSConfig.Insecure = true + client, err := api.NewClient(conf) + require.NoError(t, err) + + // Assert a blocking query isn't timed out by the + // handshake timeout + jobs, meta, err := client.Jobs().List(q) + require.NoError(t, err) + require.Len(t, jobs, 0) + require.Truef(t, meta.RequestTime >= deadline, + "expected RequestTime (%s) >= Deadline (%s)", + meta.RequestTime, deadline) + + return + } + + // HTTP Server should *not* have timed out. + // Now() should always be after the read deadline, but + // isn't a sufficient assertion for correctness as slow + // tests may cause this to be true even if the server + // timed out. + require.True(t, time.Now().After(readDeadline)) + + testutil.RequireDeadlineErr(t, err) + } + + assertNoLimit := func(t *testing.T, addr string) { + var err error + + // Create max connections + conns := make([]net.Conn, maxConns) + errCh := make(chan error, maxConns) + for i := 0; i < maxConns; i++ { + conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conns[i].Close() + + go func(i int) { + buf := []byte{0} + readDeadline := time.Now().Add(1 * time.Second) + conns[i].SetReadDeadline(readDeadline) + n, err := conns[i].Read(buf) + if n > 0 { + errCh <- fmt.Errorf("n > 0: %d", n) + return + } + errCh <- err + }(i) + } + + // Now assert each error is a clientside read deadline error + for i := 0; i < maxConns; i++ { + select { + case <-time.After(1 * time.Second): + t.Fatalf("timed out waiting for conn error %d", i) + case err := <-errCh: + testutil.RequireDeadlineErr(t, err) + } + } + } + + assertLimit := func(t *testing.T, addr string, limit int) { + var err error + + // Create limit connections + conns := make([]net.Conn, limit) + errCh := make(chan error, limit) + for i := range conns { + conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conns[i].Close() + + go func(i int) { + buf := []byte{0} + n, err := conns[i].Read(buf) + if n > 0 { + errCh <- fmt.Errorf("n > 0: %d", n) + return + } + errCh <- err + }(i) + } + + // Assert a new connection is dropped + conn, err := net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conn.Close() + + buf := []byte{0} + deadline := time.Now().Add(10 * time.Second) + conn.SetReadDeadline(deadline) + n, err := conn.Read(buf) + require.Zero(t, n) + require.Equal(t, io.EOF, err) + + // Assert existing connections are ok + require.Len(t, errCh, 0) + + // Cleanup + for _, conn := range conns { + conn.Close() + } + for range conns { + err := <-errCh + require.Contains(t, err.Error(), "use of closed network connection") + } + } + + for i := range cases { + tc := cases[i] + name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit) + t.Run(name, func(t *testing.T) { + t.Parallel() + + if tc.limit != nil && *tc.limit >= maxConns { + t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", *tc.limit, maxConns) + } + + s := makeHTTPServer(t, func(c *Config) { + if tc.tls { + c.TLSConfig = &config.TLSConfig{ + EnableHTTP: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + } + c.Limits.HTTPSHandshakeTimeout = tc.timeout + c.Limits.HTTPMaxConnsPerClient = tc.limit + }) + defer s.Shutdown() + + assertTimeout(t, s, tc.assertTimeout, tc.timeout) + if tc.assertLimit { + assertLimit(t, s.Server.Addr, *tc.limit) + } else { + assertNoLimit(t, s.Server.Addr) + } + }) + } +} + func httpTest(t testing.TB, cb func(c *Config), f func(srv *TestAgent)) { s := makeHTTPServer(t, cb) defer s.Shutdown() diff --git a/command/agent/testagent.go b/command/agent/testagent.go index 4f5b9ffe9..71d40e9e2 100644 --- a/command/agent/testagent.go +++ b/command/agent/testagent.go @@ -9,7 +9,6 @@ import ( "net/http/httptest" "os" "path/filepath" - "runtime" "strings" "time" @@ -142,28 +141,27 @@ RETRY: } // we need the err var in the next exit condition - if agent, err := a.start(); err == nil { + agent, err := a.start() + if err == nil { a.Agent = agent break } else if i == 0 { - a.T.Logf("%s: Error starting agent: %v", a.Name, err) - runtime.Goexit() - } else { - if agent != nil { - agent.Shutdown() - } - wait := time.Duration(rand.Int31n(2000)) * time.Millisecond - a.T.Logf("%s: retrying in %v", a.Name, wait) - time.Sleep(wait) + a.T.Fatalf("%s: Error starting agent: %v", a.Name, err) } + if agent != nil { + agent.Shutdown() + } + wait := time.Duration(rand.Int31n(2000)) * time.Millisecond + a.T.Logf("%s: retrying in %v", a.Name, wait) + time.Sleep(wait) + // Clean out the data dir if we are responsible for it before we // try again, since the old ports may have gotten written to // the data dir, such as in the Raft configuration. if a.DataDir != "" { if err := os.RemoveAll(a.DataDir); err != nil { - a.T.Logf("%s: Error resetting data dir: %v", a.Name, err) - runtime.Goexit() + a.T.Fatalf("%s: Error resetting data dir: %v", a.Name, err) } } } @@ -273,7 +271,11 @@ func (a *TestAgent) HTTPAddr() string { if a.Server == nil { return "" } - return "http://" + a.Server.Addr + proto := "http://" + if a.Config.TLSConfig != nil && a.Config.TLSConfig.EnableHTTP { + proto = "https://" + } + return proto + a.Server.Addr } func (a *TestAgent) Client() *api.Client { diff --git a/nomad/config.go b/nomad/config.go index 1460b17fa..f706a3277 100644 --- a/nomad/config.go +++ b/nomad/config.go @@ -316,6 +316,20 @@ type Config struct { // PluginSingletonLoader is a plugin loader that will returns singleton // instances of the plugins. PluginSingletonLoader loader.PluginCatalog + + // RPCHandshakeTimeout is the deadline by which RPC handshakes must + // complete. The RPC handshake includes the first byte read as well as + // the TLS handshake and subsequent byte read if TLS is enabled. + // + // The deadline is reset after the first byte is read so when TLS is + // enabled RPC connections may take (timeout * 2) to complete. + // + // 0 means no timeout. + RPCHandshakeTimeout time.Duration + + // RPCMaxConnsPerClient is the maximum number of concurrent RPC + // connections from a single IP address. nil/0 means no limit. + RPCMaxConnsPerClient int } // CheckVersion is used to check if the ProtocolVersion is valid @@ -330,7 +344,8 @@ func (c *Config) CheckVersion() error { return nil } -// DefaultConfig returns the default configuration +// DefaultConfig returns the default configuration. Only used as the basis for +// merging agent or test parameters. func DefaultConfig() *Config { hostname, err := os.Hostname() if err != nil { diff --git a/nomad/rpc.go b/nomad/rpc.go index 228157d4c..f3986bc25 100644 --- a/nomad/rpc.go +++ b/nomad/rpc.go @@ -16,6 +16,7 @@ import ( golog "log" metrics "github.com/armon/go-metrics" + "github.com/hashicorp/go-connlimit" log "github.com/hashicorp/go-hclog" memdb "github.com/hashicorp/go-memdb" @@ -23,6 +24,7 @@ import ( "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/raft" "github.com/hashicorp/yamux" "github.com/ugorji/go/codec" @@ -49,17 +51,48 @@ const ( type rpcHandler struct { *Server + + // connLimiter is used to limit the number of RPC connections per + // remote address. It is distinct from the HTTP connection limit. + // + // nil if limiting is disabled + connLimiter *connlimit.Limiter + connLimit int + + // streamLimiter is used to limit the number of *streaming* RPC + // connections per remote address. It is lower than the overall + // connection limit to ensure their are free connections for Raft and + // other RPCs. + streamLimiter *connlimit.Limiter + streamLimit int + logger log.Logger gologger *golog.Logger } func newRpcHandler(s *Server) *rpcHandler { logger := s.logger.NamedIntercept("rpc") - return &rpcHandler{ - Server: s, - logger: logger, - gologger: logger.StandardLoggerIntercept(&log.StandardLoggerOptions{InferLevels: true}), + + r := rpcHandler{ + Server: s, + connLimit: s.config.RPCMaxConnsPerClient, + logger: logger, + gologger: logger.StandardLoggerIntercept(&log.StandardLoggerOptions{InferLevels: true}), } + + // Setup connection limits + if r.connLimit > 0 { + r.connLimiter = connlimit.NewLimiter(connlimit.Config{ + MaxConnsPerClientIP: r.connLimit, + }) + + r.streamLimit = r.connLimit - config.LimitsNonStreamingConnsPerClient + r.streamLimiter = connlimit.NewLimiter(connlimit.Config{ + MaxConnsPerClientIP: r.streamLimit, + }) + } + + return &r } // RPCContext provides metadata about the RPC connection. @@ -106,6 +139,24 @@ func (r *rpcHandler) listen(ctx context.Context) { // No error, reset loop delay acceptLoopDelay = 0 + // Apply per-connection limits (if enabled) *prior* to launching + // goroutine to block further Accept()s until limits are checked. + if r.connLimiter != nil { + free, err := r.connLimiter.Accept(conn) + if err != nil { + r.logger.Error("rejecting client for exceeding maximum RPC connections", + "remote_addr", conn.RemoteAddr(), "limit", r.connLimit) + conn.Close() + continue + } + + // Wrap the connection so that conn.Close calls free() as well. + // This is required for libraries like raft which handoff the + // net.Conn to another goroutine and therefore can't be tracked + // within this func. + conn = connlimit.Wrap(conn, free) + } + go r.handleConn(ctx, conn, &RPCContext{Conn: conn}) metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) } @@ -145,7 +196,16 @@ func (r *rpcHandler) handleAcceptErr(ctx context.Context, err error, loopDelay * // handleConn is used to determine if this is a Raft or // Nomad type RPC connection and invoke the correct handler +// +// **Cannot** use defer conn.Close in this method because the Raft handler uses +// the conn beyond the scope of this func. func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { + // Limit how long an unauthenticated client can hold the connection + // open before they send the magic byte. + if !rpcCtx.TLS && r.config.RPCHandshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(r.config.RPCHandshakeTimeout)) + } + // Read a single byte buf := make([]byte, 1) if _, err := conn.Read(buf); err != nil { @@ -156,6 +216,12 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC return } + // Reset the deadline as we aren't sure what is expected next - it depends on + // the protocol. + if !rpcCtx.TLS && r.config.RPCHandshakeTimeout > 0 { + conn.SetDeadline(time.Time{}) + } + // Enforce TLS if EnableRPC is set if r.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { if !r.config.TLSConfig.RPCUpgradeMode { @@ -190,6 +256,14 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC conn.Close() return } + + // Don't allow malicious client to create TLS-in-TLS forever. + if rpcCtx.TLS { + r.logger.Error("TLS connection attempting to establish inner TLS connection", "remote_addr", conn.RemoteAddr()) + conn.Close() + return + } + conn = tls.Server(conn, r.rpcTLS) // Force a handshake so we can get information about the TLS connection @@ -201,12 +275,24 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC return } + // Enforce handshake timeout during TLS handshake to prevent + // unauthenticated users from holding connections open + // indefinitely. + if r.config.RPCHandshakeTimeout > 0 { + tlsConn.SetDeadline(time.Now().Add(r.config.RPCHandshakeTimeout)) + } + if err := tlsConn.Handshake(); err != nil { r.logger.Warn("failed TLS handshake", "remote_addr", tlsConn.RemoteAddr(), "error", err) conn.Close() return } + // Reset the deadline as unauthenticated users have now been rejected. + if r.config.RPCHandshakeTimeout > 0 { + tlsConn.SetDeadline(time.Time{}) + } + // Update the connection context with the fact that the connection is // using TLS rpcCtx.TLS = true @@ -218,6 +304,20 @@ func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCC r.handleConn(ctx, conn, rpcCtx) case pool.RpcStreaming: + // Apply a lower limit to streaming RPCs to avoid denial of + // service by repeatedly starting streaming RPCs. + // + // TODO Remove once MultiplexV2 is used. + if r.streamLimiter != nil { + free, err := r.streamLimiter.Accept(conn) + if err != nil { + r.logger.Error("rejecting client for exceeding maximum streaming RPC connections", + "remote_addr", conn.RemoteAddr(), "stream_limit", r.streamLimit) + conn.Close() + return + } + defer free() + } r.handleStreamingConn(conn) case pool.RpcMultiplexV2: diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index e32fba14a..647ef53c7 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -2,6 +2,9 @@ package nomad import ( "context" + "crypto/tls" + "fmt" + "io" "net" "net/rpc" "os" @@ -13,6 +16,7 @@ import ( cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/pool" "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" @@ -532,3 +536,448 @@ func TestRPC_handleMultiplexV2(t *testing.T) { require.True(structs.IsErrUnknownMethod(err)) } + +// TestRPC_TLS_in_TLS asserts that trying to nest TLS connections fails. +func TestRPC_TLS_in_TLS(t *testing.T) { + t.Parallel() + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + ) + + s, cleanup := TestServer(t, func(c *Config) { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + }) + defer func() { + cleanup() + + //TODO Avoid panics from logging during shutdown + time.Sleep(1 * time.Second) + }() + + conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{byte(pool.RpcTLS)}) + require.NoError(t, err) + + // Client TLS verification isn't necessary for + // our assertions + tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true) + require.NoError(t, err) + outTLSConf, err := tlsConf.OutgoingTLSConfig() + require.NoError(t, err) + outTLSConf.InsecureSkipVerify = true + + // Do initial handshake + tlsConn := tls.Client(conn, outTLSConf) + require.NoError(t, tlsConn.Handshake()) + conn = tlsConn + + // Try to create a nested TLS connection + _, err = conn.Write([]byte{byte(pool.RpcTLS)}) + require.NoError(t, err) + + // Attempts at nested TLS connections should cause a disconnect + buf := []byte{0} + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, err := conn.Read(buf) + require.Zero(t, n) + require.Equal(t, io.EOF, err) +} + +// TestRPC_Limits_OK asserts that all valid limits combinations +// (tls/timeout/conns) work. +// +// Invalid limits are tested in command/agent/agent_test.go +func TestRPC_Limits_OK(t *testing.T) { + t.Parallel() + + const ( + cafile = "../helper/tlsutil/testdata/ca.pem" + foocert = "../helper/tlsutil/testdata/nomad-foo.pem" + fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" + maxConns = 10 // limit must be < this for testing + ) + + cases := []struct { + tls bool + timeout time.Duration + limit int + assertTimeout bool + assertLimit bool + }{ + { + tls: false, + timeout: 5 * time.Second, + limit: 0, + assertTimeout: true, + assertLimit: false, + }, + { + tls: true, + timeout: 5 * time.Second, + limit: 0, + assertTimeout: true, + assertLimit: false, + }, + { + tls: false, + timeout: 0, + limit: 0, + assertTimeout: false, + assertLimit: false, + }, + { + tls: true, + timeout: 0, + limit: 0, + assertTimeout: false, + assertLimit: false, + }, + { + tls: false, + timeout: 0, + limit: 2, + assertTimeout: false, + assertLimit: true, + }, + { + tls: true, + timeout: 0, + limit: 2, + assertTimeout: false, + assertLimit: true, + }, + { + tls: false, + timeout: 5 * time.Second, + limit: 2, + assertTimeout: true, + assertLimit: true, + }, + { + tls: true, + timeout: 5 * time.Second, + limit: 2, + assertTimeout: true, + assertLimit: true, + }, + } + + assertTimeout := func(t *testing.T, s *Server, useTLS bool, timeout time.Duration) { + // Increase timeout to detect timeouts + clientTimeout := timeout + time.Second + + conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) + require.NoError(t, err) + defer conn.Close() + + buf := []byte{0} + readDeadline := time.Now().Add(clientTimeout) + conn.SetReadDeadline(readDeadline) + n, err := conn.Read(buf) + require.Zero(t, n) + if timeout == 0 { + // Server should *not* have timed out. + // Now() should always be after the client read deadline, but + // isn't a sufficient assertion for correctness as slow tests + // may cause this to be true even if the server timed out. + now := time.Now() + require.Truef(t, now.After(readDeadline), + "Client read deadline (%s) should be in the past (before %s)", readDeadline, now) + + testutil.RequireDeadlineErr(t, err) + return + } + + // Server *should* have timed out (EOF) + require.Equal(t, io.EOF, err) + + // Create a new connection to assert timeout doesn't + // apply after first byte. + conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) + require.NoError(t, err) + defer conn.Close() + + if useTLS { + _, err := conn.Write([]byte{byte(pool.RpcTLS)}) + require.NoError(t, err) + + // Client TLS verification isn't necessary for + // our assertions + tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true) + require.NoError(t, err) + outTLSConf, err := tlsConf.OutgoingTLSConfig() + require.NoError(t, err) + outTLSConf.InsecureSkipVerify = true + + tlsConn := tls.Client(conn, outTLSConf) + require.NoError(t, tlsConn.Handshake()) + + conn = tlsConn + } + + // Writing the Nomad RPC byte should be sufficient to + // disable the handshake timeout + n, err = conn.Write([]byte{byte(pool.RpcNomad)}) + require.NoError(t, err) + require.Equal(t, 1, n) + + // Read should timeout due to client timeout, not + // server's timeout + readDeadline = time.Now().Add(clientTimeout) + conn.SetReadDeadline(readDeadline) + n, err = conn.Read(buf) + require.Zero(t, n) + testutil.RequireDeadlineErr(t, err) + } + + assertNoLimit := func(t *testing.T, addr string) { + var err error + + // Create max connections + conns := make([]net.Conn, maxConns) + errCh := make(chan error, maxConns) + for i := 0; i < maxConns; i++ { + conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conns[i].Close() + + go func(i int) { + buf := []byte{0} + readDeadline := time.Now().Add(1 * time.Second) + conns[i].SetReadDeadline(readDeadline) + n, err := conns[i].Read(buf) + if n > 0 { + errCh <- fmt.Errorf("n > 0: %d", n) + return + } + errCh <- err + }(i) + } + + // Now assert each error is a clientside read deadline error + deadline := time.After(10 * time.Second) + for i := 0; i < maxConns; i++ { + select { + case <-deadline: + t.Fatalf("timed out waiting for conn error %d/%d", i+1, maxConns) + case err := <-errCh: + testutil.RequireDeadlineErr(t, err) + } + } + } + + assertLimit := func(t *testing.T, addr string, limit int) { + var err error + + // Create limit connections + conns := make([]net.Conn, limit) + errCh := make(chan error, limit) + for i := range conns { + conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conns[i].Close() + + go func(i int) { + buf := []byte{0} + n, err := conns[i].Read(buf) + if n > 0 { + errCh <- fmt.Errorf("n > 0: %d", n) + return + } + errCh <- err + }(i) + } + + // Assert a new connection is dropped + conn, err := net.DialTimeout("tcp", addr, 1*time.Second) + require.NoError(t, err) + defer conn.Close() + + buf := []byte{0} + deadline := time.Now().Add(10 * time.Second) + conn.SetReadDeadline(deadline) + n, err := conn.Read(buf) + require.Zero(t, n) + require.Equal(t, io.EOF, err) + + // Assert existing connections are ok + ERRCHECK: + select { + case err := <-errCh: + t.Errorf("unexpected error from idle connection: (%T) %v", err, err) + goto ERRCHECK + default: + } + + // Cleanup + for _, conn := range conns { + conn.Close() + } + for range conns { + err := <-errCh + require.Contains(t, err.Error(), "use of closed network connection") + } + } + + for i := range cases { + tc := cases[i] + name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit) + t.Run(name, func(t *testing.T) { + t.Parallel() + + if tc.limit >= maxConns { + t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", tc.limit, maxConns) + } + if tc.assertTimeout && tc.timeout == 0 { + t.Fatalf("test fixture failure: cannot assert timeout when no timeout set (0)") + } + + s, cleanup := TestServer(t, func(c *Config) { + if tc.tls { + c.TLSConfig = &config.TLSConfig{ + EnableRPC: true, + CAFile: cafile, + CertFile: foocert, + KeyFile: fookey, + } + } + c.RPCHandshakeTimeout = tc.timeout + c.RPCMaxConnsPerClient = tc.limit + }) + defer func() { + cleanup() + + //TODO Avoid panics from logging during shutdown + time.Sleep(1 * time.Second) + }() + + assertTimeout(t, s, tc.tls, tc.timeout) + if tc.assertLimit { + assertLimit(t, s.config.RPCAddr.String(), tc.limit) + } else { + assertNoLimit(t, s.config.RPCAddr.String()) + } + }) + } +} + +// TestRPC_Limits_Streaming asserts that the streaming RPC limit is lower than +// the overall connection limit to prevent DOS via server-routed streaming API +// calls. +func TestRPC_Limits_Streaming(t *testing.T) { + t.Parallel() + + s, cleanup := TestServer(t, func(c *Config) { + limits := config.DefaultLimits() + c.RPCMaxConnsPerClient = *limits.RPCMaxConnsPerClient + }) + defer func() { + cleanup() + + //TODO Avoid panics from logging during shutdown + time.Sleep(1 * time.Second) + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error, 1) + + // Create a streaming connection + dialStreamer := func() net.Conn { + conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) + require.NoError(t, err) + + _, err = conn.Write([]byte{byte(pool.RpcStreaming)}) + require.NoError(t, err) + return conn + } + + // Create up to the limit streaming connections + streamers := make([]net.Conn, s.config.RPCMaxConnsPerClient-config.LimitsNonStreamingConnsPerClient) + for i := range streamers { + streamers[i] = dialStreamer() + + go func(i int) { + // Streamer should never die until test exits + buf := []byte{0} + _, err := streamers[i].Read(buf) + if ctx.Err() != nil { + // Error is expected when test finishes + return + } + + t.Logf("connection %d died with error: (%T) %v", i, err, err) + + // Send unexpected errors back + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + default: + // Only send first error + } + } + }(i) + } + + defer func() { + cancel() + for _, conn := range streamers { + conn.Close() + } + }() + + // Assert no streamer errors have occurred + select { + case err := <-errCh: + t.Fatalf("unexpected error from blocking streaming RPCs: (%T) %v", err, err) + case <-time.After(500 * time.Millisecond): + // Ok! No connections were rejected immediately. + } + + // Assert subsequent streaming RPC are rejected + conn := dialStreamer() + t.Logf("expect connection to be rejected due to limit") + buf := []byte{0} + conn.SetReadDeadline(time.Now().Add(3 * time.Second)) + _, err := conn.Read(buf) + require.Equalf(t, io.EOF, err, "expected io.EOF but found: (%T) %v", err, err) + + // Assert no streamer errors have occurred + select { + case err := <-errCh: + t.Fatalf("unexpected error from blocking streaming RPCs: %v", err) + default: + } + + // Subsequent non-streaming RPC should be OK + conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) + require.NoError(t, err) + _, err = conn.Write([]byte{byte(pool.RpcNomad)}) + require.NoError(t, err) + + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.Read(buf) + testutil.RequireDeadlineErr(t, err) + + // Close 1 streamer and assert another is allowed + t.Logf("expect streaming connection 0 to exit with error") + streamers[0].Close() + <-errCh + conn = dialStreamer() + + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + _, err = conn.Read(buf) + testutil.RequireDeadlineErr(t, err) +} diff --git a/nomad/structs/config/limits.go b/nomad/structs/config/limits.go new file mode 100644 index 000000000..5c17bc99e --- /dev/null +++ b/nomad/structs/config/limits.go @@ -0,0 +1,87 @@ +package config + +import "github.com/hashicorp/nomad/helper" + +const ( + // LimitsNonStreamingConnsPerClient is the number of connections per + // peer to reserve for non-streaming RPC connections. Since streaming + // RPCs require their own TCP connection, they have their own limit + // this amount lower than the overall limit. This reserves a number of + // connections for Raft and other RPCs. + // + // TODO Remove limit once MultiplexV2 is used. + LimitsNonStreamingConnsPerClient = 20 +) + +// Limits configures timeout limits similar to Consul's limits configuration +// parameters. Limits is the internal version with the fields parsed. +type Limits struct { + // HTTPSHandshakeTimeout is the deadline by which HTTPS TLS handshakes + // must complete. + // + // 0 means no timeout. + HTTPSHandshakeTimeout string `hcl:"https_handshake_timeout"` + + // HTTPMaxConnsPerClient is the maximum number of concurrent HTTP + // connections from a single IP address. nil/0 means no limit. + HTTPMaxConnsPerClient *int `hcl:"http_max_conns_per_client"` + + // RPCHandshakeTimeout is the deadline by which RPC handshakes must + // complete. The RPC handshake includes the first byte read as well as + // the TLS handshake and subsequent byte read if TLS is enabled. + // + // The deadline is reset after the first byte is read so when TLS is + // enabled RPC connections may take (timeout * 2) to complete. + // + // The RPC handshake timeout only applies to servers. 0 means no + // timeout. + RPCHandshakeTimeout string `hcl:"rpc_handshake_timeout"` + + // RPCMaxConnsPerClient is the maximum number of concurrent RPC + // connections from a single IP address. nil/0 means no limit. + RPCMaxConnsPerClient *int `hcl:"rpc_max_conns_per_client"` +} + +// DefaultLimits returns the default limits values. User settings should be +// merged into these defaults. +func DefaultLimits() Limits { + return Limits{ + HTTPSHandshakeTimeout: "5s", + HTTPMaxConnsPerClient: helper.IntToPtr(100), + RPCHandshakeTimeout: "5s", + RPCMaxConnsPerClient: helper.IntToPtr(100), + } +} + +// Merge returns a new Limits where non-empty/nil fields in the argument have +// precedence. +func (l *Limits) Merge(o Limits) Limits { + m := *l + + if o.HTTPSHandshakeTimeout != "" { + m.HTTPSHandshakeTimeout = o.HTTPSHandshakeTimeout + } + if o.HTTPMaxConnsPerClient != nil { + m.HTTPMaxConnsPerClient = helper.IntToPtr(*o.HTTPMaxConnsPerClient) + } + if o.RPCHandshakeTimeout != "" { + m.RPCHandshakeTimeout = o.RPCHandshakeTimeout + } + if o.RPCMaxConnsPerClient != nil { + m.RPCMaxConnsPerClient = helper.IntToPtr(*o.RPCMaxConnsPerClient) + } + + return m +} + +// Copy returns a new deep copy of a Limits struct. +func (l *Limits) Copy() Limits { + c := *l + if l.HTTPMaxConnsPerClient != nil { + c.HTTPMaxConnsPerClient = helper.IntToPtr(*l.HTTPMaxConnsPerClient) + } + if l.RPCMaxConnsPerClient != nil { + c.RPCMaxConnsPerClient = helper.IntToPtr(*l.RPCMaxConnsPerClient) + } + return c +} diff --git a/nomad/structs/config/limits_test.go b/nomad/structs/config/limits_test.go new file mode 100644 index 000000000..e4bd9d598 --- /dev/null +++ b/nomad/structs/config/limits_test.go @@ -0,0 +1,82 @@ +package config + +import ( + "testing" + "time" + + "github.com/hashicorp/nomad/helper" + "github.com/stretchr/testify/require" +) + +// TestLimits_Defaults asserts the default limits are valid. +func TestLimits_Defaults(t *testing.T) { + t.Parallel() + + l := DefaultLimits() + d, err := time.ParseDuration(l.HTTPSHandshakeTimeout) + require.NoError(t, err) + require.True(t, d > 0) + + d, err = time.ParseDuration(l.RPCHandshakeTimeout) + require.NoError(t, err) + require.True(t, d > 0) +} + +// TestLimits_Copy asserts Limits structs are deep copied. +func TestLimits_Copy(t *testing.T) { + t.Parallel() + + o := DefaultLimits() + c := o.Copy() + + // Assert changes to copy are not propagated to the original + c.HTTPSHandshakeTimeout = "1s" + c.HTTPMaxConnsPerClient = helper.IntToPtr(50) + c.RPCHandshakeTimeout = "1s" + c.RPCMaxConnsPerClient = helper.IntToPtr(50) + + require.NotEqual(t, c.HTTPSHandshakeTimeout, o.HTTPSHandshakeTimeout) + + // Pointers should be different + require.True(t, c.HTTPMaxConnsPerClient != o.HTTPMaxConnsPerClient) + + require.NotEqual(t, c.HTTPMaxConnsPerClient, o.HTTPMaxConnsPerClient) + require.NotEqual(t, c.RPCHandshakeTimeout, o.RPCHandshakeTimeout) + + // Pointers should be different + require.True(t, c.RPCMaxConnsPerClient != o.RPCMaxConnsPerClient) + + require.NotEqual(t, c.RPCMaxConnsPerClient, o.RPCMaxConnsPerClient) +} + +// TestLimits_Merge asserts non-zero fields from the method argument take +// precedence over the existing limits. +func TestLimits_Merge(t *testing.T) { + t.Parallel() + + l := Limits{} + o := DefaultLimits() + m := l.Merge(o) + + // Operands should not change + require.Equal(t, Limits{}, l) + require.Equal(t, DefaultLimits(), o) + + // m == o + require.Equal(t, m, DefaultLimits()) + + o.HTTPSHandshakeTimeout = "10s" + m2 := m.Merge(o) + + // Operands should not change + require.Equal(t, m, DefaultLimits()) + + // Use short struct initialization style so it fails to compile if + // fields are added + expected := Limits{"10s", helper.IntToPtr(100), "5s", helper.IntToPtr(100)} + require.Equal(t, expected, m2) + + // Mergin in 0 values should not change anything + m3 := m2.Merge(Limits{}) + require.Equal(t, m2, m3) +} diff --git a/testutil/net.go b/testutil/net.go new file mode 100644 index 000000000..bd8813305 --- /dev/null +++ b/testutil/net.go @@ -0,0 +1,22 @@ +package testutil + +import ( + "net" + + testing "github.com/mitchellh/go-testing-interface" + "github.com/stretchr/testify/require" +) + +// RequireDeadlineErr requires that an error be caused by a net.Conn's deadline +// being reached (after being set by conn.Set{Read,Write}Deadline or +// SetDeadline). +func RequireDeadlineErr(t testing.T, err error) { + t.Helper() + + require.NotNil(t, err) + netErr, ok := err.(net.Error) + require.Truef(t, ok, "error does not implement net.Error: (%T) %v", err, err) + require.Contains(t, netErr.Error(), ": i/o timeout") + require.True(t, netErr.Timeout()) + require.True(t, netErr.Temporary()) +} diff --git a/vendor/github.com/hashicorp/go-connlimit/README.md b/vendor/github.com/hashicorp/go-connlimit/README.md new file mode 100644 index 000000000..99d2d95b0 --- /dev/null +++ b/vendor/github.com/hashicorp/go-connlimit/README.md @@ -0,0 +1,77 @@ +# Go Server Client Connection Tracking + +This package provides a library for network servers to track how many +concurrent connections they have from a given client address. + +It's designed to be very simple and shared between several HashiCorp products +that provide network servers and need this kind of control to impose limits on +the resources that can be consumed by a single client. + +## Usage + +### TCP Server + +``` +// During server setup: +s.limiter = NewLimiter(Config{ + MaxConnsPerClientIP: 10, +}) + +``` + +``` +// handleConn is called in its own goroutine for each net.Conn accepted by +// a net.Listener. +func (s *Server) handleConn(conn net.Conn) { + defer conn.Close() + + // Track the connection + free, err := s.limiter.Accept(conn) + if err != nil { + // Not accepted as limit has been reached (or some other error), log error + // or warning and close. + + // The standard err.Error() message when limit is reached is generic so it + // doesn't leak information which may potentially be sensitive (e.g. current + // limits set or number of connections). This also allows comparison to + // ErrPerClientIPLimitReached if it's important to handle it differently + // from an internal library or io error (currently not possible but might be + // in the future if additional functionality is added). + + // If you would like to log more information about the current limit that + // can be obtained with s.limiter.Config(). + return + } + // Defer a call to free to decrement the counter for this client IP once we + // are done with this conn. + defer free() + + + // Handle the conn +} +``` + +### HTTP Server + +``` +lim := NewLimiter(Config{ + MaxConnsPerClientIP: 10, +}) +s := http.Server{ + // Other config here + ConnState: lim.HTTPConnStateFunc(), +} +``` + +### Dynamic Configuration + +The limiter supports dynamic reconfiguration. At any time, any goroutine may +call `limiter.SetConfig(c Config)` which will atomically update the config. All +subsequent calls to `Accept` will use the newly configured limits in their +decisions and calls to `limiter.Config()` will return the new config. + +Note that if the limits are reduced that will only prevent further connections +beyond the new limit - existing connections are not actively closed to meet the +limit. In cases where this is critical it's often preferable to mitigate in a +more focussed way e.g. by adding an iptables rule that blocks all connections +from one malicious client without affecting the whole server. diff --git a/vendor/github.com/hashicorp/go-connlimit/connlimit.go b/vendor/github.com/hashicorp/go-connlimit/connlimit.go new file mode 100644 index 000000000..cefcb8dad --- /dev/null +++ b/vendor/github.com/hashicorp/go-connlimit/connlimit.go @@ -0,0 +1,180 @@ +package connlimit + +import ( + "errors" + "net" + "net/http" + "sync" + "sync/atomic" +) + +var ( + // ErrPerClientIPLimitReached is returned if accepting a new conn would exceed + // the per-client-ip limit set. + ErrPerClientIPLimitReached = errors.New("client connection limit reached") +) + +// Limiter implements a simple limiter that tracks the number of connections +// from each client IP. It may be used in it's zero value although no limits +// will be configured initially - they can be set later with SetConfig. +type Limiter struct { + // cs stores the map of active connections by IP address. We store a set of + // conn pointers not just a counter because http.Server.ConnState hook only + // gives us a connection object between calls so we need to know if a closed + // conn is one that was previously accepted or one we've just closed in the + // ConnState hook because the client has hit its limit. + cs map[string]map[net.Conn]struct{} + + // l protects access to cs + l sync.Mutex + + // cfg is stored atomically to provide non-blocking reads via Config. This + // might be important if this is called regularly in a health or metrics + // endpoint and shouldn't block new connections being established. + cfg atomic.Value +} + +// Config is the configuration for the limiter. +type Config struct { + // MaxConnsPerClientIP limits how many concurrent connections are allowed from + // a given client IP. The IP is the one reported by the connection so cannot + // be relied upon if clients are connecting through multiple proxies or able + // to spoof their source IP address in some way. Similarly, multiple clients + // connected via a proxy or NAT gateway or similar will all be seen as coming + // from the same IP and so limited as one client. + MaxConnsPerClientIP int +} + +// NewLimiter returns a limiter with the specified config. +func NewLimiter(cfg Config) *Limiter { + l := &Limiter{} + l.SetConfig(cfg) + return l +} + +// Accept is called as early as possible when handling a new conn. If the +// connection should be accepted according to the Limiter's Config, it will +// return a free func and nil error. The free func must be called when the +// connection is no longer being handled - typically in a defer statement in the +// main connection handling goroutine, this will decrement the counter for that +// client IP. If the configured limit has been reached, a no-op func is returned +// (doesn't need to be called), and ErrPerClientIPLimitReached is returned. +// +// If any other error is returned it signifies something wrong with the config +// or transient failure to read or parse the remote IP. The free func will be a +// no-op in this case and need not be called. +func (l *Limiter) Accept(conn net.Conn) (func(), error) { + addrKey := addrKey(conn) + + // Load config outside locked section since it's not updated under lock anyway + // and the atomic Load might be slower/contented so better to do outside lock. + cfg := l.Config() + + l.l.Lock() + defer l.l.Unlock() + + if l.cs == nil { + l.cs = make(map[string]map[net.Conn]struct{}) + } + + cs := l.cs[addrKey] + if cs == nil { + cs = make(map[net.Conn]struct{}) + l.cs[addrKey] = cs + } + + n := len(cs) + + // Might be greater since config is dynamic. + if cfg.MaxConnsPerClientIP > 0 && n >= cfg.MaxConnsPerClientIP { + return func() {}, ErrPerClientIPLimitReached + } + + // Add the conn to the map + cs[conn] = struct{}{} + + // Create a free func over the address key we used + free := func() { + l.freeConn(conn) + } + + return free, nil +} + +func addrKey(conn net.Conn) string { + addr := conn.RemoteAddr() + switch a := addr.(type) { + case *net.TCPAddr: + return "ip:" + a.IP.String() + case *net.UDPAddr: + return "ip:" + a.IP.String() + case *net.IPAddr: + return "ip:" + a.IP.String() + default: + // not sure what to do with this, just assume whole Addr is relevant? + return addr.Network() + "/" + addr.String() + } +} + +// freeConn removes a connection from the map if it's present. It is a no-op if +// the conn was never accepted by Accept. +func (l *Limiter) freeConn(conn net.Conn) { + addrKey := addrKey(conn) + + l.l.Lock() + defer l.l.Unlock() + + cs, ok := l.cs[addrKey] + if !ok { + return + } + + delete(cs, conn) + if len(cs) == 0 { + delete(l.cs, addrKey) + } +} + +// Config returns the current limiter configuration. It is safe to call from any +// goroutine and does not block new connections being accepted. +func (l *Limiter) Config() Config { + cfgRaw := l.cfg.Load() + if cfg, ok := cfgRaw.(Config); ok { + return cfg + } + return Config{} +} + +// SetConfig dynamically updates the limiter configuration. It is safe to call +// from any goroutine. Note that if the limit is lowered, active conns will not +// be closed and may remain over the limit until they close naturally. +func (l *Limiter) SetConfig(c Config) { + l.cfg.Store(c) +} + +// HTTPConnStateFunc returns a func that can be passed as the ConnState field of +// an http.Server. This intercepts new HTTP connections to the server and +// applies the limiting to new connections. +// +// Note that if the conn is hijacked from the HTTP server then it will be freed +// in the limiter as if it was closed. Servers that use Hijacking must implement +// their own calls if they need to continue limiting the number of concurrent +// hijacked connections. +func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) { + return func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + _, err := l.Accept(conn) + if err != nil { + conn.Close() + } + case http.StateHijacked: + l.freeConn(conn) + case http.StateClosed: + // Maybe free the conn. This might be a conn we closed in the case above + // that was never counted as it was over limit but freeConn will be a + // no-op in that case. + l.freeConn(conn) + } + } +} diff --git a/vendor/github.com/hashicorp/go-connlimit/go.mod b/vendor/github.com/hashicorp/go-connlimit/go.mod new file mode 100644 index 000000000..d01785e84 --- /dev/null +++ b/vendor/github.com/hashicorp/go-connlimit/go.mod @@ -0,0 +1,8 @@ +module github.com/hashicorp/go-connlimit + +go 1.12 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/stretchr/testify v1.4.0 +) diff --git a/vendor/github.com/hashicorp/go-connlimit/go.sum b/vendor/github.com/hashicorp/go-connlimit/go.sum new file mode 100644 index 000000000..3216266c6 --- /dev/null +++ b/vendor/github.com/hashicorp/go-connlimit/go.sum @@ -0,0 +1,13 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/github.com/hashicorp/go-connlimit/wrap.go b/vendor/github.com/hashicorp/go-connlimit/wrap.go new file mode 100644 index 000000000..b9922609b --- /dev/null +++ b/vendor/github.com/hashicorp/go-connlimit/wrap.go @@ -0,0 +1,27 @@ +package connlimit + +import "net" + +// WrappedConn wraps a net.Conn and free() func returned by Limiter.Accept so +// that when the wrapped connections Close method is called, its free func is +// also called. +type WrappedConn struct { + net.Conn + free func() +} + +// Wrap wraps a net.Conn's Close method so free() is called when Close is +// called. Useful when handing off tracked connections to libraries that close +// them. +func Wrap(conn net.Conn, free func()) net.Conn { + return &WrappedConn{ + Conn: conn, + free: free, + } +} + +// Close frees the tracked connection and closes the underlying net.Conn. +func (w *WrappedConn) Close() error { + w.free() + return w.Conn.Close() +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 6712feea4..2ba65dc21 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -216,6 +216,7 @@ {"path":"github.com/hashicorp/errwrap","checksumSHA1":"cdOCt0Yb+hdErz8NAQqayxPmRsY=","revision":"7554cd9344cec97297fa6649b055a8c98c2a1e55"}, {"path":"github.com/hashicorp/go-checkpoint","checksumSHA1":"D267IUMW2rcb+vNe3QU+xhfSrgY=","revision":"1545e56e46dec3bba264e41fde2c1e2aa65b5dd4","revisionTime":"2017-10-09T17:35:28Z"}, {"path":"github.com/hashicorp/go-cleanhttp","checksumSHA1":"6ihdHMkDfFx/rJ1A36com2F6bQk=","revision":"a45970658e51fea2c41445ff0f7e07106d007617","revisionTime":"2017-02-11T00:33:01Z"}, + {"path":"github.com/hashicorp/go-connlimit","checksumSHA1":"hWFJgo9OJD+vwel31yaS1u7k3OU=","revision":"7b54d3380815c9b127c3d841df45951807b79ab8","revisionTime":"2020-01-28T15:55:23Z"}, {"path":"github.com/hashicorp/go-discover","checksumSHA1":"3m3SRZczpDY+fSN7oEUqoPJSZMg=","revision":"7698de1390a18e1d38f55ad02d4cab8917b9219d","revisionTime":"2020-01-08T19:47:35Z"}, {"path":"github.com/hashicorp/go-discover/provider/aliyun","checksumSHA1":"Jww5zrDwjMoFF31RqBapilTdi18=","revision":"7698de1390a18e1d38f55ad02d4cab8917b9219d","revisionTime":"2020-01-08T19:47:35Z","tree":true}, {"path":"github.com/hashicorp/go-discover/provider/aws","checksumSHA1":"mSoObM5f8c2FJW/09mNqDrMqvpw=","revision":"7698de1390a18e1d38f55ad02d4cab8917b9219d","revisionTime":"2020-01-08T19:47:35Z","tree":true}, diff --git a/website/source/docs/configuration/index.html.md b/website/source/docs/configuration/index.html.md index 51dca38bb..8176cf5f3 100644 --- a/website/source/docs/configuration/index.html.md +++ b/website/source/docs/configuration/index.html.md @@ -177,6 +177,57 @@ testing. server agents if it is expected that a terminated server instance will never join the cluster again. +- `limits` - Available in Nomad 0.10.3 and later, this is a nested object that + configures limits that are enforced by the agent. The following parameters + are available: + + - `https_handshake_timeout` `(string: "5s")` - Configures the limit for how + long the HTTPS server in both client and server agents will wait for a + client to complete a TLS handshake. This should be kept conservative as it + limits how many connections an unauthenticated attacker can open if + [`tls.http = true`][tls] is being used (strongly recommended in + production). Default value is `5s`. `0` disables HTTP handshake timeouts. + + - `http_max_conns_per_client` `(int: 100)` - Configures a limit of how many + concurrent TCP connections a single client IP address is allowed to open to + the agent's HTTP server. This affects the HTTP servers in both client and + server agents. Default value is `100`. `0` disables HTTP connection limits. + + - `rpc_handshake_timeout` `(string: "5s")` - Configures the limit for how + long servers will wait after a client TCP connection is established before + they complete the connection handshake. When TLS is used, the same timeout + applies to the TLS handshake separately from the initial protocol + negotiation. All Nomad clients should perform this immediately on + establishing a new connection. This should be kept conservative as it + limits how many connections an unauthenticated attacker can open if + TLS is being using to authenticate clients (strongly recommended in + production). When `tls.rpc` is true on servers, this limits how long the + connection and associated goroutines will be held open before the client + successfully authenticates. Default value is `5s`. `0` disables RPC handshake + timeouts. + + - `rpc_max_conns_per_client` `(int: 100)` - Configures a limit of how + many concurrent TCP connections a single source IP address is allowed + to open to a single server. Client agents do not accept RPC TCP connections + directly and therefore are not affected. It affects both clients connections + and other server connections. Nomad clients multiplex many RPC calls over a + single TCP connection, except for streaming endpoints such as [log + streaming][log-api] which require their own connection when routed through + servers. A server needs at least 2 TCP connections (1 Raft, 1 RPC) per peer + server locally and in any federated region. Servers also need a TCP connection + per routed streaming endpoint concurrently in use. Only operators use streaming + endpoints; as of 0.10.3 Nomad client code does not. A reasonably low limit + significantly reduces the ability of an unauthenticated attacker to consume + unbounded resources by holding open many connections. You may need to + increase this if WAN federated servers connect via proxies or NAT gateways + or similar causing many legitimate connections from a single source IP. + Default value is `100` which is designed to support the majority of users. + `0` disables RPC connection limits. `26` is the minimum as `20` connections + are always reserved for non-streaming connections (Raft and RPC) to ensure + streaming RPCs do not prevent normal server operation. This minimum may be + lowered in the future when streaming RPCs no longer require their own TCP + connection. + - `log_level` `(string: "INFO")` - Specifies the verbosity of logs the Nomad agent will output. Valid log levels include `WARN`, `INFO`, or `DEBUG` in increasing order of verbosity. @@ -250,7 +301,7 @@ testing. - `syslog_facility` `(string: "LOCAL0")` - Specifies the syslog facility to write to. This has no effect unless `enable_syslog` is true. -- `tls` `(`[`TLS`]`: nil)` - Specifies configuration for TLS. +- `tls` `(`[`TLS`][tls]`: nil)` - Specifies configuration for TLS. - `vault` `(`[`Vault`]`: nil)` - Specifies configuration for connecting to Vault. @@ -283,7 +334,8 @@ http_api_response_headers { [`Plugin`]: /docs/configuration/plugin.html "Nomad Agent Plugin Configuration" [`Sentinel`]: /docs/configuration/sentinel.html "Nomad Agent sentinel Configuration" [`Server`]: /docs/configuration/server.html "Nomad Agent server Configuration" -[`TLS`]: /docs/configuration/tls.html "Nomad Agent tls Configuration" +[tls]: /docs/configuration/tls.html "Nomad Agent tls Configuration" [`Vault`]: /docs/configuration/vault.html "Nomad Agent vault Configuration" [go-sockaddr/template]: https://godoc.org/github.com/hashicorp/go-sockaddr/template +[log-api]: /api/client.html#stream-logs [hcl]: https://github.com/hashicorp/hcl "HashiCorp Configuration Language" diff --git a/website/source/guides/upgrade/upgrade-specific.html.md b/website/source/guides/upgrade/upgrade-specific.html.md index cad2335d4..cff4ceb10 100644 --- a/website/source/guides/upgrade/upgrade-specific.html.md +++ b/website/source/guides/upgrade/upgrade-specific.html.md @@ -15,6 +15,30 @@ details provided for their upgrades as a result of new features or changed behavior. This page is used to document those details separately from the standard upgrade flow. +## Nomad 0.10.3 + +### Connection Limits Added + +Nomad 0.10.3 introduces the [limits][limits] agent configuration parameters for +mitigating denial of service attacks from users who are not authenticated via +mTLS. The default limits stanza is: + +```hcl +limits { + https_handshake_timeout = "5s" + http_max_conns_per_client = 100 + rpc_handshake_timeout = "5s" + rpc_max_conns_per_client = 100 +} +``` + +If your Nomad agent's endpoints are protected from unauthenticated users via +other mechanisms these limits may be safely disabled by setting them to `0`. + +However the defaults were chosen to be safe for a wide variety of Nomad +deployments and may protect against accidental abuses of the Nomad API that +could cause unintended resource usage. + ## Nomad 0.10.2 ### Preemption Panic Fixed @@ -385,6 +409,7 @@ deleted and then Nomad 0.3.0 can be launched. [dangling-containers]: /docs/drivers/docker.html#dangling-containers [gh-6787]: https://github.com/hashicorp/nomad/issues/6787 [hcl2]: https://github.com/hashicorp/hcl2 +[limits]: /docs/configuration/index.html#limits [lxc]: /docs/drivers/external/lxc.html [migrate]: /docs/job-specification/migrate.html [plugins]: /docs/drivers/external/index.html