diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index bb376f990..b34fed21a 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -4922,6 +4922,20 @@ type RescheduleTracker struct { Events []*RescheduleEvent } +func (rt *RescheduleTracker) Copy() *RescheduleTracker { + if rt == nil { + return nil + } + nt := &RescheduleTracker{} + *nt = *rt + rescheduleEvents := make([]*RescheduleEvent, 0, len(rt.Events)) + for _, tracker := range rt.Events { + rescheduleEvents = append(rescheduleEvents, tracker.Copy()) + } + nt.Events = rescheduleEvents + return nt +} + // RescheduleEvent is used to keep track of previous attempts at rescheduling an allocation type RescheduleEvent struct { // RescheduleTime is the timestamp of a reschedule attempt @@ -4934,12 +4948,12 @@ type RescheduleEvent struct { PrevNodeID string } -func (rt *RescheduleEvent) Copy() *RescheduleEvent { - if rt == nil { +func (re *RescheduleEvent) Copy() *RescheduleEvent { + if re == nil { return nil } copy := new(RescheduleEvent) - *copy = *rt + *copy = *re return copy } @@ -5102,12 +5116,7 @@ func (a *Allocation) copyImpl(job bool) *Allocation { na.TaskStates = ts } - if a.RescheduleTracker != nil { - var rescheduleTrackers []*RescheduleEvent - for _, tracker := range a.RescheduleTracker.Events { - rescheduleTrackers = append(rescheduleTrackers, tracker.Copy()) - } - } + na.RescheduleTracker = a.RescheduleTracker.Copy() return na } diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index a2612cbac..cdf4c091b 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -2800,6 +2800,28 @@ func TestAllocation_ShouldReschedule(t *testing.T) { } } +func TestRescheduleTracker_Copy(t *testing.T) { + type testCase struct { + original *RescheduleTracker + expected *RescheduleTracker + } + + cases := []testCase{ + {nil, nil}, + {&RescheduleTracker{Events: []*RescheduleEvent{ + {2, "12", "12"}, + }}, &RescheduleTracker{Events: []*RescheduleEvent{ + {2, "12", "12"}, + }}}, + } + + for _, tc := range cases { + if got := tc.original.Copy(); !reflect.DeepEqual(got, tc.expected) { + t.Fatalf("expected %v but got %v", *tc.expected, *got) + } + } +} + func TestVault_Validate(t *testing.T) { v := &Vault{ Env: true,