diff --git a/internal/servicecheck/neighbours.go b/internal/servicecheck/neighbours.go index 78e1f01..aca6032 100644 --- a/internal/servicecheck/neighbours.go +++ b/internal/servicecheck/neighbours.go @@ -1,11 +1,12 @@ package servicecheck import ( + "container/heap" "context" "crypto/sha256" + "encoding/binary" "fmt" "os" - "slices" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" @@ -25,7 +26,7 @@ type Neighbour struct { PodIP string HostIP string NodeName string - NodeHash string + NodeHash uint64 } // GetNeighbours returns a slice of neighbour kubenurses for the given namespace and labelSelector. @@ -74,7 +75,7 @@ func (c *Checker) GetNeighbours(ctx context.Context, namespace, labelSelector st PodIP: pod.Status.PodIP, HostIP: pod.Status.HostIP, NodeName: pod.Spec.NodeName, - NodeHash: sha256String(pod.Spec.NodeName), + NodeHash: sha256Uint64(pod.Spec.NodeName), } neighbours = append(neighbours, &n) } @@ -102,31 +103,55 @@ func (c *Checker) checkNeighbours(nh []*Neighbour) { } } +type Uint64Heap []uint64 + +func (h Uint64Heap) Len() int { return len(h) } +func (h Uint64Heap) Less(i, j int) bool { return h[i] > h[j] } // we want a max-heap, therefore the inversed condition +func (h Uint64Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *Uint64Heap) Push(x any) { + *h = append(*h, x.(uint64)) +} + +func (h *Uint64Heap) Pop() any { + n := len(*h) + x := (*h)[n-1] + *h = (*h)[0 : n-1] + + return x +} + func (c *Checker) filterNeighbours(nh []*Neighbour) []*Neighbour { - m := make(map[string]*Neighbour, len(nh)) - l := make([]string, 0, len(nh)) + m := make(map[uint64]*Neighbour, c.NeighbourLimit+1) + + sl := make(Uint64Heap, 0, c.NeighbourLimit+1) + h := &sl + currentNodeHash := sha256Uint64(currentNode) + + heap.Init(h) for _, n := range nh { - m[n.NodeHash] = n - l = append(l, n.NodeHash) - } + adjHash := n.NodeHash - currentNodeHash + m[adjHash] = n - slices.Sort(l) + heap.Push(h, adjHash) - currentNodeHash := sha256String(currentNode) - idx, _ := slices.BinarySearch(l, currentNodeHash) + if len(*h) > c.NeighbourLimit { + p := heap.Pop(h).(uint64) + delete(m, p) + } + } filteredNeighbours := make([]*Neighbour, 0, c.NeighbourLimit) - for i := 0; i < c.NeighbourLimit; i++ { - hash := l[(idx+i)%len(l)] - filteredNeighbours = append(filteredNeighbours, m[hash]) + for _, n := range m { + filteredNeighbours = append(filteredNeighbours, n) } return filteredNeighbours } -func sha256String(s string) string { +func sha256Uint64(s string) uint64 { h := sha256.Sum256([]byte(s)) - return string(h[:]) + return binary.BigEndian.Uint64(h[:8]) } diff --git a/internal/servicecheck/neighbours_test.go b/internal/servicecheck/neighbours_test.go index e2f1192..dd17007 100644 --- a/internal/servicecheck/neighbours_test.go +++ b/internal/servicecheck/neighbours_test.go @@ -15,7 +15,7 @@ func generateNeighbours(n int) (nh []*Neighbour) { nodeName := fmt.Sprintf("a1-k8s-abcd%03d.domain.tld", i) neigh := Neighbour{ NodeName: nodeName, - NodeHash: sha256String(nodeName), + NodeHash: sha256Uint64(nodeName), } nh = append(nh, &neigh) } @@ -23,6 +23,25 @@ func generateNeighbours(n int) (nh []*Neighbour) { return } +func BenchmarkNodeFiltering(b *testing.B) { + n := 10_000 + neighbourLimit := 10 + nh := generateNeighbours(n) + require.NotNil(b, nh) + + checker := Checker{ + NeighbourLimit: neighbourLimit, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + currentNode = nh[i%len(nh)].NodeName + b.StartTimer() + checker.filterNeighbours(nh) + b.StopTimer() + } +} + func TestNodeFiltering(t *testing.T) { n := 1_000