From 7788f095674af72eed300bb78d1a27292cc53fc3 Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Wed, 24 Aug 2022 17:46:45 -0500 Subject: [PATCH] cleanup: create pointer.Compare helper function This PR creates a pointer.Compare helper for comparing equality of two pointers. Strictly only works with primitive types we know are safe to derefence and compare using '=='. --- helper/funcs.go | 8 ------ helper/funcs_test.go | 21 --------------- helper/pointer/pointer.go | 22 +++++++++++++++ helper/pointer/pointer_test.go | 49 ++++++++++++++++++++++++++++++++++ nomad/structs/services.go | 6 ++--- 5 files changed, 74 insertions(+), 32 deletions(-) diff --git a/helper/funcs.go b/helper/funcs.go index 93cb5b9b2..1ab2792d1 100644 --- a/helper/funcs.go +++ b/helper/funcs.go @@ -74,14 +74,6 @@ func HashUUID(input string) (output string, hashed bool) { return output, true } -// CompareTimePtrs return true if a is the same as b. -func CompareTimePtrs(a, b *time.Duration) bool { - if a == nil || b == nil { - return a == b - } - return *a == *b -} - // Min returns the minimum of a and b. func Min[T constraints.Ordered](a, b T) T { if a < b { diff --git a/helper/funcs_test.go b/helper/funcs_test.go index 5d882061f..792bd85ea 100644 --- a/helper/funcs_test.go +++ b/helper/funcs_test.go @@ -6,9 +6,7 @@ import ( "reflect" "sort" "testing" - "time" - "github.com/hashicorp/nomad/helper/pointer" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -132,25 +130,6 @@ func TestStringHasPrefixInSlice(t *testing.T) { } -func TestCompareTimePtrs(t *testing.T) { - t.Run("nil", func(t *testing.T) { - a := (*time.Duration)(nil) - b := (*time.Duration)(nil) - require.True(t, CompareTimePtrs(a, b)) - c := pointer.Of(3 * time.Second) - require.False(t, CompareTimePtrs(a, c)) - require.False(t, CompareTimePtrs(c, a)) - }) - - t.Run("not nil", func(t *testing.T) { - a := pointer.Of(1 * time.Second) - b := pointer.Of(1 * time.Second) - c := pointer.Of(2 * time.Second) - require.True(t, CompareTimePtrs(a, b)) - require.False(t, CompareTimePtrs(a, c)) - }) -} - func TestCompareSliceSetString(t *testing.T) { cases := []struct { A []string diff --git a/helper/pointer/pointer.go b/helper/pointer/pointer.go index 8fa960caf..0e806c0bb 100644 --- a/helper/pointer/pointer.go +++ b/helper/pointer/pointer.go @@ -1,6 +1,10 @@ // Package pointer provides helper functions related to Go pointers. package pointer +import ( + "golang.org/x/exp/constraints" +) + // Of returns a pointer to a. func Of[A any](a A) *A { return &a @@ -14,3 +18,21 @@ func Copy[A any](a *A) *A { na := *a return &na } + +// Primitive represents basic types that are safe to do basic comparisons by +// pointer dereference (checking nullity first). +type Primitive interface { + constraints.Ordered // just so happens to be the types we want +} + +// Eq returns whether a and b are equal in underlying value. +// +// May only be used on pointers to primitive types, where the comparison is +// guaranteed to be sensible. For complex types (i.e. structs) consider implementing +// an Equals method. +func Eq[P Primitive](a, b *P) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} diff --git a/helper/pointer/pointer_test.go b/helper/pointer/pointer_test.go index e656cce59..fca664834 100644 --- a/helper/pointer/pointer_test.go +++ b/helper/pointer/pointer_test.go @@ -2,6 +2,7 @@ package pointer import ( "testing" + "time" "github.com/shoenig/test/must" ) @@ -16,3 +17,51 @@ func Test_Of(t *testing.T) { sPtr = &b must.NotEq(t, s, *sPtr) } + +func Test_Copy(t *testing.T) { + orig := Of(1) + dup := Copy(orig) + orig = Of(7) + must.EqOp(t, 7, *orig) + must.EqOp(t, 1, *dup) +} + +func Test_Compare(t *testing.T) { + t.Run("int", func(t *testing.T) { + a := 1 + b := 2 + c := 1 + var n *int // nil + must.False(t, Eq(&a, &b)) + must.True(t, Eq(&a, &c)) + must.False(t, Eq(nil, &a)) + must.False(t, Eq(n, &a)) + must.True(t, Eq(n, nil)) + }) + + t.Run("string", func(t *testing.T) { + a := "cat" + b := "dog" + c := "cat" + var n *string + + must.False(t, Eq(&a, &b)) + must.True(t, Eq(&a, &c)) + must.False(t, Eq(nil, &a)) + must.False(t, Eq(n, &a)) + must.True(t, Eq(n, nil)) + }) + + t.Run("duration", func(t *testing.T) { + a := time.Duration(1) + b := time.Duration(2) + c := time.Duration(1) + var n *time.Duration + + must.False(t, Eq(&a, &b)) + must.True(t, Eq(&a, &c)) + must.False(t, Eq(nil, &a)) + must.False(t, Eq(n, &a)) + must.True(t, Eq(n, nil)) + }) +} diff --git a/nomad/structs/services.go b/nomad/structs/services.go index 38ddd5970..632391361 100644 --- a/nomad/structs/services.go +++ b/nomad/structs/services.go @@ -1210,7 +1210,7 @@ func (t *SidecarTask) Equals(o *SidecarTask) bool { return false } - if !helper.CompareTimePtrs(t.KillTimeout, o.KillTimeout) { + if !pointer.Eq(t.KillTimeout, o.KillTimeout) { return false } @@ -1218,7 +1218,7 @@ func (t *SidecarTask) Equals(o *SidecarTask) bool { return false } - if !helper.CompareTimePtrs(t.ShutdownDelay, o.ShutdownDelay) { + if !pointer.Eq(t.ShutdownDelay, o.ShutdownDelay) { return false } @@ -1811,7 +1811,7 @@ func (p *ConsulGatewayProxy) Equals(o *ConsulGatewayProxy) bool { return p == o } - if !helper.CompareTimePtrs(p.ConnectTimeout, o.ConnectTimeout) { + if !pointer.Eq(p.ConnectTimeout, o.ConnectTimeout) { return false }