diff --git a/helper/lib/score_heap.go b/helper/lib/score_heap.go new file mode 100644 index 000000000..485d04eb6 --- /dev/null +++ b/helper/lib/score_heap.go @@ -0,0 +1,63 @@ +package lib + +import ( + "container/heap" +) + +// An HeapItem represents elements being managed in the Score heap +type HeapItem struct { + Value string // The Value of the item; arbitrary. + Score float64 // The Score of the item in the heap +} + +// A ScoreHeap implements heap.Interface and is a min heap +// that keeps the top K elements by Score. Push can be called +// with an arbitrary number of values but only the top K are stored +type ScoreHeap struct { + items []*HeapItem + capacity int +} + +func NewScoreHeap(capacity uint32) *ScoreHeap { + return &ScoreHeap{capacity: int(capacity)} +} + +func (pq ScoreHeap) Len() int { return len(pq.items) } + +func (pq ScoreHeap) Less(i, j int) bool { + return pq.items[i].Score < pq.items[j].Score +} + +func (pq ScoreHeap) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] +} + +// Push implements heap.Interface and only stores +// the top K elements by Score +func (pq *ScoreHeap) Push(x interface{}) { + item := x.(*HeapItem) + if len(pq.items) < pq.capacity { + pq.items = append(pq.items, item) + } else { + // Pop the lowest scoring element if this item's Score is + // greater than the min Score so far + minIndex := 0 + min := pq.items[minIndex] + if item.Score > min.Score { + // Replace min and heapify + pq.items[minIndex] = item + heap.Fix(pq, minIndex) + } + } +} + +// Push implements heap.Interface and returns the top K scoring +// elements in increasing order of Score. Callers must reverse the order +// of returned elements to get the top K scoring elements in descending order +func (pq *ScoreHeap) Pop() interface{} { + old := pq.items + n := len(old) + item := old[n-1] + pq.items = old[0 : n-1] + return item +} diff --git a/helper/lib/score_heap_test.go b/helper/lib/score_heap_test.go new file mode 100644 index 000000000..9061150c0 --- /dev/null +++ b/helper/lib/score_heap_test.go @@ -0,0 +1,79 @@ +package lib + +import ( + "container/heap" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestScoreHeap(t *testing.T) { + type testCase struct { + desc string + items map[string]float64 + expected []*HeapItem + } + + cases := []testCase{ + { + desc: "More than K elements", + items: map[string]float64{ + "banana": 3.0, + "apple": 2.25, + "pear": 2.32, + "watermelon": 5.45, + "orange": 0.20, + "strawberry": 9.03, + "blueberry": 0.44, + "lemon": 3.9, + "cherry": 0.03, + }, + expected: []*HeapItem{ + {Value: "pear", Score: 2.32}, + {Value: "banana", Score: 3.0}, + {Value: "lemon", Score: 3.9}, + {Value: "watermelon", Score: 5.45}, + {Value: "strawberry", Score: 9.03}, + }, + }, + { + desc: "Less than K elements", + items: map[string]float64{ + "eggplant": 9.0, + "okra": -1.0, + "corn": 0.25, + }, + expected: []*HeapItem{ + {Value: "okra", Score: -1.0}, + {Value: "corn", Score: 0.25}, + {Value: "eggplant", Score: 9.0}, + }, + }, + } + + for _, tc := range cases { + t.Run("", func(t *testing.T) { + // Create Score heap, push elements into it + pq := NewScoreHeap(5) + for value, score := range tc.items { + heapItem := &HeapItem{ + Value: value, + Score: score, + } + heap.Push(pq, heapItem) + } + + // Take the items out; they arrive in increasing Score order + require := require.New(t) + require.Equal(len(tc.expected), pq.Len()) + + i := 0 + for pq.Len() > 0 { + item := heap.Pop(pq).(*HeapItem) + require.Equal(tc.expected[i], item) + i++ + } + }) + } + +}