Skip to content

Commit

Permalink
implement shortest-prefix-match
Browse files Browse the repository at this point in the history
  • Loading branch information
gaissmai committed Feb 3, 2024
1 parent d6034bd commit eed783c
Show file tree
Hide file tree
Showing 8 changed files with 585 additions and 349 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,15 @@ the backtracking algorithm is as fast as possible.

func (t *Table[V]) Get(ip netip.Addr) (val V, ok bool)
func (t *Table[V]) Lookup(ip netip.Addr) (lpm netip.Prefix, val V, ok bool)
func (t *Table[V]) LookupShortest(ip netip.Addr) (spm netip.Prefix, val V, ok bool)

func (t *Table[V]) String() string
func (t *Table[V]) Fprint(w io.Writer) error

func (t *Table[V]) Dump(w io.Writer)
```

# TODO

- [x] try a simplest implementation with a hashmap for level compression (done, see branch hashmap)
- [ ] try multi level strides
- [ ] implement Overlaps ...

# CREDIT

Expand Down
17 changes: 12 additions & 5 deletions dumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@ import (
"strings"
)

// Dump the IPv4 and IPv6 tables to w.
// Needed during development and debugging, especially for the
// hierarchical tree in the String method, OMG.
// Will be removed when the API stabilizes.
// dumpString is just a wrapper for Dump.
func (t *Table[V]) dumpString() string {
w := new(strings.Builder)
if err := t.dump(w); err != nil {
panic(err)
}
return w.String()
}

// dump the IPv4 and IPv6 tables to w.
// Useful during development and debugging.
//
// Output:
//
Expand All @@ -40,7 +47,7 @@ import (
// ...prefxs(#1): 1/8
//
// ...
func (t *Table[V]) Dump(w io.Writer) error {
func (t *Table[V]) dump(w io.Writer) error {
t.init()

if err := t.dump4(w); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions dumper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestDumperPanic(t *testing.T) {
p := netip.MustParsePrefix
tbl := new(Table[any])
tbl.Insert(p("1.2.3.4/32"), nil)
_ = tbl.Dump(nil)
_ = tbl.dump(nil)
}

func TestDumperEmpty(t *testing.T) {
Expand Down Expand Up @@ -263,7 +263,7 @@ func checkDump(t *testing.T, tbl *Table[any], tt dumpTest) {
tbl.Insert(cidr, nil)
}
w := new(strings.Builder)
if err := tbl.Dump(w); err != nil {
if err := tbl.dump(w); err != nil {
t.Errorf("Dump() unexpected err: %v", err)
}
got := w.String()
Expand Down
104 changes: 36 additions & 68 deletions fulltable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ func TestFullNew(t *testing.T) {
runtime.ReadMemStats(&endMem)
bartBytes := endMem.TotalAlloc - startMem.TotalAlloc

// rtArt := art.Table[any]{}
// runtime.ReadMemStats(&startMem)
// for _, route := range nRoutes {
// rtArt.Insert(route.CIDR, nil)
// }
// runtime.ReadMemStats(&endMem)
// artBytes := endMem.TotalAlloc - startMem.TotalAlloc

t.Logf("BART: n: %d routes, raw: %d KBytes, bart: %6d KBytes, mult: %.2f (bart/raw)",
len(nRoutes), rawBytes/(2<<10), bartBytes/(2<<10), float32(bartBytes)/float32(rawBytes))

Expand All @@ -80,14 +72,6 @@ func TestFullNewV4(t *testing.T) {
runtime.ReadMemStats(&endMem)
bartBytes := endMem.TotalAlloc - startMem.TotalAlloc

// rtArt := art.Table[any]{}
// runtime.ReadMemStats(&startMem)
// for _, route := range nRoutes {
// rtArt.Insert(route.CIDR, nil)
// }
// runtime.ReadMemStats(&endMem)
// artBytes := endMem.TotalAlloc - startMem.TotalAlloc

t.Logf("BART: n: %d routes, raw: %d KBytes, bart: %6d KBytes, mult: %.2f (bart/raw)",
len(nRoutes), rawBytes/(2<<10), bartBytes/(2<<10), float32(bartBytes)/float32(rawBytes))

Expand All @@ -111,14 +95,6 @@ func TestFullNewV6(t *testing.T) {
runtime.ReadMemStats(&endMem)
bartBytes := endMem.TotalAlloc - startMem.TotalAlloc

// rtArt := art.Table[any]{}
// runtime.ReadMemStats(&startMem)
// for _, route := range nRoutes {
// rtArt.Insert(route.CIDR, nil)
// }
// runtime.ReadMemStats(&endMem)
// artBytes := endMem.TotalAlloc - startMem.TotalAlloc

t.Logf("BART: n: %d routes, raw: %d KBytes, bart: %6d KBytes, mult: %.2f (bart/raw)",
len(nRoutes), rawBytes/(2<<10), bartBytes/(2<<10), float32(bartBytes)/float32(rawBytes))

Expand All @@ -133,11 +109,9 @@ var (

func BenchmarkFullMatchV4(b *testing.B) {
var rtBart bart.Table[int]
// var rtArt art.Table[int]

for i, route := range routes {
rtBart.Insert(route.CIDR, i)
// rtArt.Insert(route.CIDR, i)
}

var ip netip.Addr
Expand All @@ -151,35 +125,33 @@ func BenchmarkFullMatchV4(b *testing.B) {
}
}

// b.Run("ARTGet", func(b *testing.B) {
// b.ResetTimer()
// for k := 0; k < b.N; k++ {
// intSink, okSink = rtArt.Get(ip)
// }
// })

b.Run("BARTGet", func(b *testing.B) {
b.Run("Get", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
intSink, okSink = rtBart.Get(ip)
}
})

b.Run("BARTLookup", func(b *testing.B) {
b.Run("Lookup", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, intSink, okSink = rtBart.Lookup(ip)
}
})

b.Run("LookupSCP", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, _, okSink = rtBart.LookupShortest(ip)
}
})
}

func BenchmarkFullMatchV6(b *testing.B) {
var rtBart bart.Table[int]
// var rtArt art.Table[int]

for i, route := range routes {
rtBart.Insert(route.CIDR, i)
// rtArt.Insert(route.CIDR, i)
}

var ip netip.Addr
Expand All @@ -193,35 +165,33 @@ func BenchmarkFullMatchV6(b *testing.B) {
}
}

// b.Run("ARTGet", func(b *testing.B) {
// b.ResetTimer()
// for k := 0; k < b.N; k++ {
// intSink, okSink = rtArt.Get(ip)
// }
// })

b.Run("BARTGet", func(b *testing.B) {
b.Run("Get", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
intSink, okSink = rtBart.Get(ip)
}
})

b.Run("BARTLookup", func(b *testing.B) {
b.Run("Lookup", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, intSink, okSink = rtBart.Lookup(ip)
}
})

b.Run("LookupSCP", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, _, okSink = rtBart.LookupShortest(ip)
}
})
}

func BenchmarkFullMissV4(b *testing.B) {
var rtBart bart.Table[int]
// var rtArt art.Table[int]

for i, route := range routes {
rtBart.Insert(route.CIDR, i)
// rtArt.Insert(route.CIDR, i)
}

var ip netip.Addr
Expand All @@ -234,35 +204,33 @@ func BenchmarkFullMissV4(b *testing.B) {
}
}

// b.Run("ARTGet", func(b *testing.B) {
// b.ResetTimer()
// for k := 0; k < b.N; k++ {
// intSink, okSink = rtArt.Get(ip)
// }
// })

b.Run("BARTGet", func(b *testing.B) {
b.Run("Get", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
intSink, okSink = rtBart.Get(ip)
}
})

b.Run("BARTLookup", func(b *testing.B) {
b.Run("Lookup", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, intSink, okSink = rtBart.Lookup(ip)
}
})

b.Run("LookupSCP", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, _, okSink = rtBart.LookupShortest(ip)
}
})
}

func BenchmarkFullMissV6(b *testing.B) {
var rtBart bart.Table[int]
// var rtArt art.Table[int]

for i, route := range routes {
rtBart.Insert(route.CIDR, i)
// rtArt.Insert(route.CIDR, i)
}

var ip netip.Addr
Expand All @@ -275,26 +243,26 @@ func BenchmarkFullMissV6(b *testing.B) {
}
}

// b.Run("ARTGet", func(b *testing.B) {
// b.ResetTimer()
// for k := 0; k < b.N; k++ {
// intSink, okSink = rtArt.Get(ip)
// }
// })

b.Run("BARTGet", func(b *testing.B) {
b.Run("Get", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
intSink, okSink = rtBart.Get(ip)
}
})

b.Run("BARTLookup", func(b *testing.B) {
b.Run("Lookup", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, intSink, okSink = rtBart.Lookup(ip)
}
})

b.Run("LookupSCP", func(b *testing.B) {
b.ResetTimer()
for k := 0; k < b.N; k++ {
_, _, okSink = rtBart.LookupShortest(ip)
}
})
}

func fillRouteTables() {
Expand Down
39 changes: 38 additions & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (p *prefixCBTree[V]) delete(addr uint, prefixLen int) (wasPresent bool) {
// at this depth and returns (baseIdx, value, true) if a matching
// longest prefix exists, or ok=false otherwise.
//
// backtracking is fast, it's just a bitset test and, if found, a popcount.
// backtracking is fast, it's just a bitset test and, if found, one popcount.
func (p *prefixCBTree[V]) lpmByIndex(idx uint) (baseIdx uint, val V, ok bool) {
// max steps in backtracking is the stride length.
for {
Expand Down Expand Up @@ -158,6 +158,43 @@ func (p *prefixCBTree[V]) lpmByPrefix(addr uint, prefixLen int) (baseIdx uint, v
return p.lpmByIndex(prefixToBaseIndex(addr, prefixLen))
}

// spmByIndex does a shortest-prefix-match for idx in the 8-bit (stride) routing table
// at this depth and returns (baseIdx, value, true) if a matching
// shortest prefix exists, or ok=false otherwise.
//
// backtracking is stride*bitset-test and, if found, one popcount.
func (p *prefixCBTree[V]) spmByIndex(idx uint) (baseIdx uint, val V, ok bool) {
var shortest uint
// steps in backtracking is always the stride length for spm,
for {
if p.indexes.Test(idx) {
shortest = idx
// no fast exit on match for shortest-prefix-match.
}

if idx == 0 {
break
}

// cache friendly backtracking to the next less specific route.
// thanks to the complete binary tree it's just a shift operation.
idx >>= 1
}

if shortest != 0 {
return shortest, *(p.values[p.rank(shortest)]), true
}

// not found (on this level)
return 0, val, false
}

// spmByAddr does a shortest-prefix-match for addr in the 8-bit (stride) routing table.
// It's an adapter to spmByIndex.
func (p *prefixCBTree[V]) spmByAddr(addr uint) (baseIdx uint, val V, ok bool) {
return p.spmByIndex(addrToBaseIndex(addr))
}

// getVal for baseIdx.
func (p *prefixCBTree[V]) getVal(baseIdx uint) *V {
if p.indexes.Test(baseIdx) {
Expand Down
Loading

0 comments on commit eed783c

Please sign in to comment.