Skip to content

Commit

Permalink
minor bitset refactoring, more funcs inlineable
Browse files Browse the repository at this point in the history
  • Loading branch information
gaissmai committed Jan 4, 2025
1 parent 2d7613f commit 7dd230e
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 170 deletions.
26 changes: 12 additions & 14 deletions internal/bitset/bitset.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ func (b BitSet) Clear(i uint) BitSet {
}

// Test if bit i is set.
func (b BitSet) Test(i uint) bool {
func (b BitSet) Test(i uint) (ok bool) {
if x := int(i >> 6); x < len(b) {
return b[x]&(1<<(i&63)) != 0
}
return false
return
}

// Clone this BitSet, returning a new BitSet that has the same bits set.
Expand Down Expand Up @@ -226,9 +226,8 @@ func (b BitSet) Size() int {

// Rank returns the number of set bits up to and including the index
// that are set in the bitset.
func (b BitSet) Rank(i uint) int {
// inlined popcount to make Rank inlineable
var rnk int
func (b BitSet) Rank(i uint) (rnk int) {
// with inlined popcount to make Rank inlineable

i++ // Rank count is inclusive
wordIdx := i >> 6
Expand All @@ -239,7 +238,7 @@ func (b BitSet) Rank(i uint) int {
for _, x := range b {
rnk += bits.OnesCount64(x)
}
return rnk
return
}

// inlined popcount, partial slice
Expand All @@ -248,29 +247,28 @@ func (b BitSet) Rank(i uint) int {
}

if bitsIdx == 0 {
return rnk
return
}

// plus partial word
return rnk + bits.OnesCount64(b[wordIdx]<<(64-bitsIdx))
rnk += bits.OnesCount64(b[wordIdx] << (64 - bitsIdx))
return
}

// popcount
func popcount(s []uint64) int {
var cnt int
func popcount(s []uint64) (cnt int) {
for _, x := range s {
// count all the bits set in slice.
cnt += bits.OnesCount64(x)
}
return cnt
return
}

// popcountAnd
func popcountAnd(s, m []uint64) int {
var cnt int
func popcountAnd(s, m []uint64) (cnt int) {
for j := 0; j < len(s) && j < len(m); j++ {
// words are bitwise & followed by popcount.
cnt += bits.OnesCount64(s[j] & m[j])
}
return cnt
return
}
29 changes: 0 additions & 29 deletions internal/bitset/bitset_iter.go

This file was deleted.

100 changes: 0 additions & 100 deletions internal/bitset/bitset_iter_test.go

This file was deleted.

1 change: 0 additions & 1 deletion internal/sparse/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ func (s *Array[T]) Get(i uint) (val T, ok bool) {
if s.Test(i) {
return s.Items[s.Rank(i)-1], true
}

return
}

Expand Down
20 changes: 10 additions & 10 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,46 +228,46 @@ func (n *node[V]) purgeAndCompress(parentStack []*node[V], childPath []byte, is4
}
}

// lpm does a route lookup for idx in the 8-bit (stride) routing table
// lpmGet does a route lookup for idx in the 8-bit (stride) routing table
// at this depth and returns (baseIdx, value, true) if a matching
// longest prefix exists, or ok=false otherwise.
//
// backtracking is fast, it's just a bitset test and, if found, one popcount.
// max steps in backtracking is the stride length.
func (n *node[V]) lpm(idx uint) (baseIdx uint, val V, ok bool) {
// shortcut optimization
func (n *node[V]) lpmGet(idx uint) (baseIdx uint, val V, ok bool) {
// shortcut optimization, perhaps reduces the backtracking iterations
minIdx, ok := n.prefixes.FirstSet()
if !ok {
return 0, val, false
}

// backtracking the CBT
for baseIdx = idx; baseIdx >= minIdx; baseIdx >>= 1 {
for ; idx >= minIdx; idx >>= 1 {
// practically it's get, but get is not inlined
if n.prefixes.Test(baseIdx) {
return baseIdx, n.prefixes.MustGet(baseIdx), true
if n.prefixes.Test(idx) {
return idx, n.prefixes.MustGet(idx), true
}
}

// not found (on this level)
return 0, val, false
}

// lpmTest for faster lpm tests without value returns
// lpmTest for faster lpm tests without value returns.
func (n *node[V]) lpmTest(idx uint) bool {
// shortcut optimization
// shortcut optimization, perhaps reduces the backtracking iterations
minIdx, ok := n.prefixes.FirstSet()
if !ok {
return false
}

// backtracking the CBT
for idx := idx; idx >= minIdx; idx >>= 1 {
for ; idx >= minIdx; idx >>= 1 {
if n.prefixes.Test(idx) {
// no need for MustGet()
return true
}
}

return false
}

Expand Down
6 changes: 3 additions & 3 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPrefixInsert(t *testing.T) {
octet := byte(i)
addr := uint(i)
goldVal, goldOK := gold.lpm(octet)
_, fastVal, fastOK := fast.lpm(hostIndex(addr))
_, fastVal, fastOK := fast.lpmGet(hostIndex(addr))
if !getsEqual(fastVal, fastOK, goldVal, goldOK) {
t.Fatalf("get(%d) = (%v, %v), want (%v, %v)", octet, fastVal, fastOK, goldVal, goldOK)
}
Expand Down Expand Up @@ -92,7 +92,7 @@ func TestPrefixDelete(t *testing.T) {
octet := byte(i)
addr := uint(i)
goldVal, goldOK := gold.lpm(octet)
_, fastVal, fastOK := fast.lpm(hostIndex(addr))
_, fastVal, fastOK := fast.lpmGet(hostIndex(addr))
if !getsEqual(fastVal, fastOK, goldVal, goldOK) {
t.Fatalf("get(%d) = (%v, %v), want (%v, %v)", octet, fastVal, fastOK, goldVal, goldOK)
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func BenchmarkNodePrefixLPM(b *testing.B) {

b.ResetTimer()
for range b.N {
_, writeSink, _ = this.lpm(pfxToIdx(route.octet, route.bits))
_, writeSink, _ = this.lpmGet(pfxToIdx(route.octet, route.bits))
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions stringify.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (n *node[V]) getKidsRec(parentIdx uint, path [16]byte, depth int, is4 bool)
}

// check if lpmIdx for this idx' parent is equal to parentIdx
lpmIdx, _, _ := n.lpm(idx >> 1)
lpmIdx, _, _ := n.lpmGet(idx >> 1)

// if idx is directKid?
if lpmIdx == parentIdx {
Expand All @@ -202,7 +202,7 @@ func (n *node[V]) getKidsRec(parentIdx uint, path [16]byte, depth int, is4 bool)
allChildAddrs := n.children.AsSlice(make([]uint, 0, maxNodeChildren))
for i, addr := range allChildAddrs {
// do a longest-prefix-match
lpmIdx, _, _ := n.lpm(hostIndex(addr))
lpmIdx, _, _ := n.lpmGet(hostIndex(addr))
if lpmIdx == parentIdx {
switch k := n.children.Items[i].(type) {
case *node[V]:
Expand Down
4 changes: 2 additions & 2 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ LOOP:

// longest prefix match, skip if node has no prefixes
if n.prefixes.Len() != 0 {
if _, val, ok = n.lpm(hostIndex(uint(octets[depth]))); ok {
if _, val, ok = n.lpmGet(hostIndex(uint(octets[depth]))); ok {
return val, ok
}
}
Expand Down Expand Up @@ -500,7 +500,7 @@ LOOP:
idx = hostIndex(uint(octet))
}

if baseIdx, val, ok := n.lpm(idx); ok {
if baseIdx, val, ok := n.lpmGet(idx); ok {
// calculate the bits from depth and idx
bits := depth*strideLen + int(baseIdxLookupTbl[baseIdx].bits)

Expand Down
Loading

0 comments on commit 7dd230e

Please sign in to comment.