diff --git a/nomad/rpc_test.go b/nomad/rpc_test.go index 4ac6ffb4c..1a6e02160 100644 --- a/nomad/rpc_test.go +++ b/nomad/rpc_test.go @@ -15,10 +15,10 @@ import ( "os" "path" "path/filepath" + "regexp" "testing" "time" - "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-msgpack/codec" "github.com/hashicorp/go-sockaddr" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" @@ -1145,19 +1145,14 @@ func TestRPC_TLS_Enforcement_Raft(t *testing.T) { func TestRPC_TLS_Enforcement_RPC(t *testing.T) { ci.Parallel(t) - defer func() { - //TODO Avoid panics from logging during shutdown - time.Sleep(1 * time.Second) - }() - tlsHelper := newTLSTestHelper(t) - defer tlsHelper.cleanup() + t.Cleanup(tlsHelper.cleanup) - standardRPCs := map[string]interface{}{ + standardRPCs := map[string]any{ "Status.Ping": &structs.GenericRequest{}, } - localServersOnlyRPCs := map[string]interface{}{ + localServersOnlyRPCs := map[string]any{ "Eval.Update": &structs.EvalUpdateRequest{ WriteRequest: structs.WriteRequest{Region: "global"}, }, @@ -1187,7 +1182,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { }, } - localClientsOnlyRPCs := map[string]interface{}{ + localClientsOnlyRPCs := map[string]any{ "Alloc.GetAllocs": &structs.AllocsGetRequest{ QueryOptions: structs.QueryOptions{Region: "global"}, }, @@ -1210,7 +1205,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { cases := []struct { name string cn string - rpcs map[string]interface{} + rpcs map[string]any canRPC bool }{ // Local server. @@ -1325,11 +1320,20 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { if tc.canRPC { if err != nil { - require.NotContains(t, err, "certificate") + // note: lots of these RPCs will return + // validation errors after connection b/c we're + // focusing on testing TLS here + must.StrNotContains(t, err.Error(), "certificate") } } else { - require.Error(t, err) - require.Contains(t, err.Error(), "certificate") + // We expect "bad certificate" for these failures, + // but locally the error can return before the error + // message bytes have been received, in which case + // we immediately write on the pipe that was just + // closed by the client + must.Error(t, err) + must.RegexMatch(t, + regexp.MustCompile("(certificate|broken pipe)"), err.Error()) } }) } @@ -1337,7 +1341,7 @@ func TestRPC_TLS_Enforcement_RPC(t *testing.T) { t.Run(fmt.Sprintf("nomad RPC: rpc=%s verify_hostname=false", method), func(t *testing.T) { err := tlsHelper.nomadRPC(t, tlsHelper.nonVerifyServer, cfg, method, arg) if err != nil { - require.NotContains(t, err, "certificate") + must.StrNotContains(t, "certificate", err.Error()) } }) } @@ -1387,7 +1391,7 @@ func newTLSTestHelper(t *testing.T) tlsTestHelper { makeServer := func(bootstrapExpect int, verifyServerHostname bool) (*Server, func()) { return TestServer(t, func(c *Config) { - c.Logger.SetLevel(hclog.Off) + c.NumSchedulers = 0 c.BootstrapExpect = bootstrapExpect c.TLSConfig = &config.TLSConfig{ EnableRPC: true, @@ -1441,32 +1445,34 @@ func (h tlsTestHelper) newCert(t *testing.T, name string) string { } func (h tlsTestHelper) connect(t *testing.T, s *Server, c *config.TLSConfig) net.Conn { + t.Helper() conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) - require.NoError(t, err) + must.NoError(t, err) // configure TLS _, err = conn.Write([]byte{byte(pool.RpcTLS)}) - require.NoError(t, err) + must.NoError(t, err) // Client TLS verification isn't necessary for // our assertions tlsConf, err := tlsutil.NewTLSConfiguration(c, true, true) - require.NoError(t, err) + must.NoError(t, err) outTLSConf, err := tlsConf.OutgoingTLSConfig() - require.NoError(t, err) + must.NoError(t, err) outTLSConf.InsecureSkipVerify = true tlsConn := tls.Client(conn, outTLSConf) - require.NoError(t, tlsConn.Handshake()) + must.NoError(t, tlsConn.Handshake()) return tlsConn } func (h tlsTestHelper) nomadRPC(t *testing.T, s *Server, c *config.TLSConfig, method string, arg interface{}) error { + t.Helper() conn := h.connect(t, s, c) defer conn.Close() _, err := conn.Write([]byte{byte(pool.RpcNomad)}) - require.NoError(t, err) + must.NoError(t, err) codec := pool.NewClientCodec(conn)