diff --git a/nomad/encrypter_test.go b/nomad/encrypter_test.go index a14513ea0..a126f4683 100644 --- a/nomad/encrypter_test.go +++ b/nomad/encrypter_test.go @@ -18,6 +18,7 @@ import ( msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/shoenig/test" "github.com/shoenig/test/must" + "github.com/shoenig/test/wait" "github.com/stretchr/testify/require" "github.com/hashicorp/nomad/ci" @@ -200,7 +201,7 @@ func TestEncrypter_KeyringReplication(t *testing.T) { c.BootstrapExpect = 3 c.NumSchedulers = 0 }) - defer cleanupSRV1() + t.Cleanup(cleanupSRV1) // add two more servers after we've bootstrapped @@ -208,21 +209,19 @@ func TestEncrypter_KeyringReplication(t *testing.T) { c.BootstrapExpect = 3 c.NumSchedulers = 0 }) - defer cleanupSRV2() + t.Cleanup(cleanupSRV2) srv3, cleanupSRV3 := TestServer(t, func(c *Config) { c.BootstrapExpect = 3 c.NumSchedulers = 0 }) - defer cleanupSRV3() - - TestJoin(t, srv1, srv2) - TestJoin(t, srv1, srv3) + t.Cleanup(cleanupSRV3) + servers := []*Server{srv1, srv2, srv3} + TestJoin(t, servers...) testutil.WaitForKeyring(t, srv1.RPC, "global") testutil.WaitForKeyring(t, srv2.RPC, "global") testutil.WaitForKeyring(t, srv3.RPC, "global") - servers := []*Server{srv1, srv2, srv3} var leader *Server for _, srv := range servers { @@ -230,7 +229,7 @@ func TestEncrypter_KeyringReplication(t *testing.T) { leader = srv } } - require.NotNil(t, leader, "expected there to be a leader") + must.NotNil(t, leader, must.Sprint("expected there to be a leader")) codec := rpcClient(t, leader) t.Logf("leader is %s", leader.config.NodeName) @@ -243,17 +242,20 @@ func TestEncrypter_KeyringReplication(t *testing.T) { } var listResp structs.KeyringListRootKeyMetaResponse - require.Eventually(t, func() bool { - msgpackrpc.CallWithCodec(codec, "Keyring.List", listReq, &listResp) - return len(listResp.Keys) == 1 - }, time.Second*5, time.Second, "expected keyring to be initialized") + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(func() bool { + msgpackrpc.CallWithCodec(codec, "Keyring.List", listReq, &listResp) + return len(listResp.Keys) == 1 + }), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected keyring to be initialized")) keyID1 := listResp.Keys[0].KeyID keyPath := filepath.Join(leader.GetConfig().DataDir, "keystore", keyID1+nomadKeystoreExtension) _, err := os.Stat(keyPath) - require.NoError(t, err, "expected key to be found in leader keystore") + must.NoError(t, err, must.Sprint("expected key to be found in leader keystore")) // Helper function for checking that a specific key has been // replicated to followers @@ -272,12 +274,12 @@ func TestEncrypter_KeyringReplication(t *testing.T) { } // Assert that the bootstrap key has been replicated to followers - require.Eventually(t, checkReplicationFn(keyID1), - time.Second*5, time.Second, - "expected keys to be replicated to followers after bootstrap") + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(checkReplicationFn(keyID1)), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected keys to be replicated to followers after bootstrap")) // Assert that key rotations are replicated to followers - rotateReq := &structs.KeyringRotateRootKeyRequest{ WriteRequest: structs.WriteRequest{ Region: "global", @@ -285,7 +287,7 @@ func TestEncrypter_KeyringReplication(t *testing.T) { } var rotateResp structs.KeyringRotateRootKeyResponse err = msgpackrpc.CallWithCodec(codec, "Keyring.Rotate", rotateReq, &rotateResp) - require.NoError(t, err) + must.NoError(t, err) keyID2 := rotateResp.Key.KeyID getReq := &structs.KeyringGetRootKeyRequest{ @@ -296,32 +298,32 @@ func TestEncrypter_KeyringReplication(t *testing.T) { } var getResp structs.KeyringGetRootKeyResponse err = msgpackrpc.CallWithCodec(codec, "Keyring.Get", getReq, &getResp) - require.NoError(t, err) - require.NotNil(t, getResp.Key, "expected key to be found on leader") + must.NoError(t, err) + must.NotNil(t, getResp.Key, must.Sprint("expected key to be found on leader")) keyPath = filepath.Join(leader.GetConfig().DataDir, "keystore", keyID2+nomadKeystoreExtension) _, err = os.Stat(keyPath) - require.NoError(t, err, "expected key to be found in leader keystore") + must.NoError(t, err, must.Sprint("expected key to be found in leader keystore")) - require.Eventually(t, checkReplicationFn(keyID2), - time.Second*5, time.Second, - "expected keys to be replicated to followers after rotation") + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(checkReplicationFn(keyID2)), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected keys to be replicated to followers after rotation")) // Scenario: simulate a key rotation that doesn't get replicated // before a leader election by stopping replication, rotating the // key, and triggering a leader election. - for _, srv := range servers { srv.keyringReplicator.stop() } err = msgpackrpc.CallWithCodec(codec, "Keyring.Rotate", rotateReq, &rotateResp) - require.NoError(t, err) + must.NoError(t, err) keyID3 := rotateResp.Key.KeyID err = leader.leadershipTransfer() - require.NoError(t, err) + must.NoError(t, err) testutil.WaitForLeader(t, leader.RPC) @@ -336,9 +338,10 @@ func TestEncrypter_KeyringReplication(t *testing.T) { go srv.keyringReplicator.run(ctx) } - require.Eventually(t, checkReplicationFn(keyID3), - time.Second*5, time.Second, - "expected keys to be replicated to followers after election") + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(checkReplicationFn(keyID3)), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected keys to be replicated to followers after election")) // Scenario: new members join the cluster @@ -353,16 +356,15 @@ func TestEncrypter_KeyringReplication(t *testing.T) { }) defer cleanupSRV5() - TestJoin(t, srv4, srv5) - TestJoin(t, srv5, srv1) servers = []*Server{srv1, srv2, srv3, srv4, srv5} - + TestJoin(t, servers...) testutil.WaitForLeader(t, srv4.RPC) testutil.WaitForLeader(t, srv5.RPC) - require.Eventually(t, checkReplicationFn(keyID3), - time.Second*5, time.Second, - "expected new servers to get replicated keys") + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(checkReplicationFn(keyID3)), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected new servers to get replicated key")) // Scenario: reload a snapshot @@ -377,19 +379,18 @@ func TestEncrypter_KeyringReplication(t *testing.T) { buf := bytes.NewBuffer(nil) sink := &MockSink{buf, false} must.NoError(t, snapshot.Persist(sink)) - must.NoError(t, srv5.fsm.Restore(sink)) // rotate the key err = msgpackrpc.CallWithCodec(codec, "Keyring.Rotate", rotateReq, &rotateResp) - require.NoError(t, err) + must.NoError(t, err) keyID4 := rotateResp.Key.KeyID - require.Eventually(t, checkReplicationFn(keyID4), - time.Second*5, time.Second, - "expected new servers to get replicated keys after snapshot restore") - + must.Wait(t, wait.InitialSuccess( + wait.BoolFunc(checkReplicationFn(keyID4)), + wait.Timeout(time.Second*5), wait.Gap(200*time.Millisecond)), + must.Sprint("expected new servers to get replicated keys after snapshot restore")) } func TestEncrypter_EncryptDecrypt(t *testing.T) { diff --git a/nomad/testing.go b/nomad/testing.go index 8ffd808f7..65f44b310 100644 --- a/nomad/testing.go +++ b/nomad/testing.go @@ -75,6 +75,7 @@ func TestConfigForServer(t testing.TB) *Config { config.SerfConfig.MemberlistConfig.ProbeTimeout = 50 * time.Millisecond config.SerfConfig.MemberlistConfig.ProbeInterval = 100 * time.Millisecond config.SerfConfig.MemberlistConfig.GossipInterval = 100 * time.Millisecond + config.SerfConfig.MemberlistConfig.PushPullInterval = 500 * time.Millisecond // Tighten the Raft timing config.RaftConfig.LeaderLeaseTimeout = 50 * time.Millisecond @@ -182,14 +183,16 @@ func TestServerErr(t testing.TB, cb func(*Config)) (*Server, func(), error) { } func TestJoin(t testing.TB, servers ...*Server) { - for i := 0; i < len(servers)-1; i++ { + addrs := make([]string, len(servers)) + for i := 0; i < len(servers); i++ { addr := fmt.Sprintf("127.0.0.1:%d", servers[i].config.SerfConfig.MemberlistConfig.BindPort) + addrs[i] = addr + } - for j := i + 1; j < len(servers); j++ { - num, err := servers[j].Join([]string{addr}) - must.NoError(t, err) - must.Eq(t, 1, num) - } + for i := 0; i < len(servers); i++ { + num, err := servers[i].Join(addrs) + must.NoError(t, err) + must.Eq(t, len(addrs), num) } }