diff --git a/nomad/leader.go b/nomad/leader.go index 9dd3792e3..7db6cc57e 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -1314,22 +1314,26 @@ func (s *Server) removeRaftPeer(m serf.Member, parts *serverParts) error { // Pick which remove API to use based on how the server was added. for _, server := range configFuture.Configuration().Servers { - // If we understand the new add/remove APIs and the server was added by ID, use the new remove API - if minRaftProtocol >= 2 && server.ID == raft.ServerID(parts.ID) { - s.logger.Info("removing server by ID", "id", server.ID) - future := s.raft.RemoveServer(raft.ServerID(parts.ID), 0, 0) - if err := future.Error(); err != nil { - s.logger.Error("failed to remove raft peer", "id", server.ID, "error", err) - return err - } - break - } else if server.Address == raft.ServerAddress(addr) { - // If not, use the old remove API - s.logger.Info("removing server by address", "address", server.Address) - future := s.raft.RemovePeer(raft.ServerAddress(addr)) - if err := future.Error(); err != nil { - s.logger.Error("failed to remove raft peer", "address", addr, "error", err) - return err + // Check if this is the server to remove based on how it was registered. + // Raft v2 servers are registered by address. + // Raft v3 servers are registered by ID. + if server.ID == raft.ServerID(parts.ID) || server.Address == raft.ServerAddress(addr) { + // Use the new add/remove APIs if we understand them. + if minRaftProtocol >= 2 { + s.logger.Info("removing server by ID", "id", server.ID) + future := s.raft.RemoveServer(server.ID, 0, 0) + if err := future.Error(); err != nil { + s.logger.Error("failed to remove raft peer", "id", server.ID, "error", err) + return err + } + } else { + // If not, use the old remove API + s.logger.Info("removing server by address", "address", server.Address) + future := s.raft.RemovePeer(raft.ServerAddress(addr)) + if err := future.Error(); err != nil { + s.logger.Error("failed to remove raft peer", "address", addr, "error", err) + return err + } } break } diff --git a/nomad/leader_test.go b/nomad/leader_test.go index b244273b0..7ee5e90fd 100644 --- a/nomad/leader_test.go +++ b/nomad/leader_test.go @@ -1216,7 +1216,9 @@ func TestLeader_RollRaftServer(t *testing.T) { // Kill the first v2 server s1.Shutdown() - for _, s := range []*Server{s1, s3} { + for _, s := range []*Server{s2, s3} { + s.RemoveFailedNode(s1.config.NodeID) + retry.Run(t, func(r *retry.R) { minVer, err := s.autopilot.MinRaftProtocol() if err != nil { @@ -1225,6 +1227,14 @@ func TestLeader_RollRaftServer(t *testing.T) { if got, want := minVer, 2; got != want { r.Fatalf("got min raft version %d want %d", got, want) } + + configFuture := s.raft.GetConfiguration() + if err != nil { + r.Fatal(err) + } + if len(configFuture.Configuration().Servers) != 2 { + r.Fatalf("expected 2 servers, got %d", len(configFuture.Configuration().Servers)) + } }) } @@ -1234,14 +1244,19 @@ func TestLeader_RollRaftServer(t *testing.T) { c.RaftConfig.ProtocolVersion = 3 }) defer cleanupS4() - TestJoin(t, s4, s2) + TestJoin(t, s2, s3, s4) servers[0] = s4 // Kill the second v2 server s2.Shutdown() for _, s := range []*Server{s3, s4} { - retry.Run(t, func(r *retry.R) { + s.RemoveFailedNode(s2.config.NodeID) + + retry.RunWith(&retry.Counter{ + Count: int(10 * testutil.TestMultiplier()), + Wait: time.Duration(testutil.TestMultiplier()) * time.Second, + }, t, func(r *retry.R) { minVer, err := s.autopilot.MinRaftProtocol() if err != nil { r.Fatal(err) @@ -1249,6 +1264,14 @@ func TestLeader_RollRaftServer(t *testing.T) { if got, want := minVer, 2; got != want { r.Fatalf("got min raft version %d want %d", got, want) } + + configFuture := s.raft.GetConfiguration() + if err != nil { + r.Fatal(err) + } + if len(configFuture.Configuration().Servers) != 2 { + r.Fatalf("expected 2 servers, got %d", len(configFuture.Configuration().Servers)) + } }) } // Replace another dead server with one running raft protocol v3 @@ -1257,14 +1280,19 @@ func TestLeader_RollRaftServer(t *testing.T) { c.RaftConfig.ProtocolVersion = 3 }) defer cleanupS5() - TestJoin(t, s5, s4) + TestJoin(t, s3, s4, s5) servers[1] = s5 // Kill the last v2 server, now minRaftProtocol should be 3 s3.Shutdown() for _, s := range []*Server{s4, s5} { - retry.Run(t, func(r *retry.R) { + s.RemoveFailedNode(s2.config.NodeID) + + retry.RunWith(&retry.Counter{ + Count: int(10 * testutil.TestMultiplier()), + Wait: time.Duration(testutil.TestMultiplier()) * time.Second, + }, t, func(r *retry.R) { minVer, err := s.autopilot.MinRaftProtocol() if err != nil { r.Fatal(err) @@ -1272,6 +1300,14 @@ func TestLeader_RollRaftServer(t *testing.T) { if got, want := minVer, 3; got != want { r.Fatalf("got min raft version %d want %d", got, want) } + + configFuture := s.raft.GetConfiguration() + if err != nil { + r.Fatal(err) + } + if len(configFuture.Configuration().Servers) != 2 { + r.Fatalf("expected 2 servers, got %d", len(configFuture.Configuration().Servers)) + } }) }