Skip to content

Commit

Permalink
Fix: Trie iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowosie committed Sep 17, 2024
1 parent 75678cc commit 3564bf0
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 62 deletions.
4 changes: 2 additions & 2 deletions adapters/p2p2core/felt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/p2p/starknet/spec"
"github.com/ethereum/go-ethereum/common"
"reflect"
)

func AdaptHash(h *spec.Hash) *felt.Felt {
Expand All @@ -23,10 +24,9 @@ func AdaptFelt(f *spec.Felt252) *felt.Felt {
}

func adapt(v interface{ GetElements() []byte }) *felt.Felt {
if v == nil {
if v == nil || reflect.ValueOf(v).IsNil() {
return nil
}

return new(felt.Felt).SetBytes(v.GetElements())
}

Expand Down
43 changes: 43 additions & 0 deletions core/trie/snap_support.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package trie

import (
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/utils"
)

func (t *Trie) IterateAndGenerateProof(startValue *felt.Felt, consumer func(key, value *felt.Felt) (bool, error),
Expand Down Expand Up @@ -56,6 +57,48 @@ func (t *Trie) IterateAndGenerateProof(startValue *felt.Felt, consumer func(key,
return proofs, finished, nil
}

func (t *Trie) IterateWithLimit(
startAddr *felt.Felt,
limitAddr *felt.Felt,
maxNodes uint32,
// TODO: remove the logger - and move to the tree
logger utils.SimpleLogger,
consumer func(key, value *felt.Felt) error,
) ([]ProofNode, bool, error) {
pathes := make([]*felt.Felt, 0)
hashes := make([]*felt.Felt, 0)

count := uint32(0)
proof, finished, err := t.IterateAndGenerateProof(startAddr, func(key *felt.Felt, value *felt.Felt) (bool, error) {
// Need at least one.
if limitAddr != nil && key.Cmp(limitAddr) > 0 {
return true, nil
}

pathes = append(pathes, key)
hashes = append(hashes, value)

err := consumer(key, value)
if err != nil {
logger.Errorw("error from consumer function", "err", err)
return false, err
}

count++
if count >= maxNodes {
logger.Infow("Max nodes reached", "count", count)
return false, nil
}
return true, nil
})
if err != nil {
logger.Errorw("IterateAndGenerateProof", "err", err, "finished", finished)
return nil, finished, err
}

return proof, finished, err
}

func VerifyRange(root, startKey *felt.Felt, keys, values []*felt.Felt, proofs []ProofNode, hash hashFunc,
treeHeight uint8,
) (hasMore, valid bool, oerr error) {
Expand Down
220 changes: 220 additions & 0 deletions core/trie/snap_support_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package trie_test

import (
"fmt"
"github.com/NethermindEth/juno/db/pebble"
"github.com/NethermindEth/juno/utils"
"github.com/stretchr/testify/require"
"math"
"testing"

"github.com/NethermindEth/juno/core/crypto"
Expand Down Expand Up @@ -276,3 +281,218 @@ func TestRangeAndVerifyReject(t *testing.T) {
})
}
}

func TestIterateOverTrie(t *testing.T) {
memdb := pebble.NewMemTest(t)
txn, err := memdb.NewTransaction(true)
require.NoError(t, err)
logger := utils.NewNopZapLogger()

tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251)
require.NoError(t, err)

// key ranges
var (
bigPrefix = uint64(1000 * 1000 * 1000 * 1000)
count = 100
ranges = 5
fstInt, lstInt uint64
fstKey, lstKey *felt.Felt
)
for i := range ranges {
for j := range count {
lstInt = bigPrefix*uint64(i) + uint64(count+j)
lstKey = new(felt.Felt).SetUint64(lstInt)
value := new(felt.Felt).SetUint64(uint64(10*count + j + i))

if fstKey == nil {
fstKey = lstKey
fstInt = lstInt
}

_, err := tempTrie.Put(lstKey, value)
require.NoError(t, err)
}
}

maxNodes := uint32(ranges*count + 1)
startZero := felt.Zero.Clone()

visitor := func(start, limit *felt.Felt, max uint32) (int, bool, *felt.Felt, *felt.Felt) {
visited := 0
var fst, lst *felt.Felt
_, finish, err := tempTrie.IterateWithLimit(
start,
limit,
max,
logger,
func(key, value *felt.Felt) error {
if fst == nil {
fst = key
}
lst = key
visited++
return nil
})
require.NoError(t, err)
return visited, finish, fst, lst
}

t.Run("iterate without limit", func(t *testing.T) {
expectedLeaves := ranges * count
visited, finish, fst, lst := visitor(nil, nil, maxNodes)
require.Equal(t, expectedLeaves, visited)
require.True(t, finish)
require.Equal(t, fstKey, fst)
require.Equal(t, lstKey, lst)
fmt.Println("Visited:", visited, "\tFinish:", finish, "\tRange:", fst.Uint64(), "-", lst.Uint64())
})

t.Run("iterate over trie im chunks", func(t *testing.T) {
chunkSize := 77
lstChunkSize := int(math.Mod(float64(ranges*count), float64(chunkSize)))
startKey := startZero
for {
visited, finish, fst, lst := visitor(startKey, nil, uint32(chunkSize))
fmt.Println("Finish:", finish, "\tstart:", startKey.Uint64(), "\trange:", fst.Uint64(), "-", lst.Uint64())
if finish {
require.Equal(t, lstChunkSize, visited)
break
}
require.Equal(t, chunkSize, visited)
require.False(t, finish)
startKey = new(felt.Felt).SetUint64(lst.Uint64() + 1)
}
})

t.Run("iterate over trie im groups", func(t *testing.T) {
startKey := startZero
for {
visited, finish, fst, lst := visitor(startKey, nil, uint32(count))
if finish {
require.Equal(t, 0, visited)
fmt.Println("Finish:", finish, "\tstart:", startKey.Uint64(), "\trange: <empty>")
break
}
fmt.Println("Finish:", finish, "\tstart:", startKey.Uint64(), "\trange:", fst.Uint64(), "-", lst.Uint64())
require.Equal(t, count, visited)
require.False(t, finish)
if lst != nil {
startKey = new(felt.Felt).SetUint64(lst.Uint64() + 1)
}
}
})

t.Run("stop before first key", func(t *testing.T) {
lowerBound := new(felt.Felt).SetUint64(fstInt - 1)
visited, finish, _, _ := visitor(startZero, lowerBound, maxNodes)
require.True(t, finish)
require.Equal(t, 0, visited)
})

t.Run("first key is a limit", func(t *testing.T) {
visited, finish, fst, lst := visitor(startZero, fstKey, maxNodes)
require.Equal(t, 1, visited)
require.True(t, finish)
require.Equal(t, fstKey, fst)
require.Equal(t, fstKey, lst)
})

t.Run("start is the last key", func(t *testing.T) {
visited, finish, fst, lst := visitor(lstKey, nil, maxNodes)
require.Equal(t, 1, visited)
require.True(t, finish)
require.Equal(t, lstKey, fst)
require.Equal(t, lstKey, lst)
})

t.Run("start and limit are the last key", func(t *testing.T) {
visited, finish, fst, lst := visitor(lstKey, lstKey, maxNodes)
require.Equal(t, 1, visited)
require.True(t, finish)
require.Equal(t, lstKey, fst)
require.Equal(t, lstKey, lst)
})

t.Run("iterate after last key yields no key", func(t *testing.T) {
upperBound := new(felt.Felt).SetUint64(lstInt + 1)
visited, finish, fst, _ := visitor(upperBound, nil, maxNodes)
require.Equal(t, 0, visited)
require.True(t, finish)
require.Nil(t, fst)
})

t.Run("iterate with reversed bounds yields no key", func(t *testing.T) {
visited, finish, fst, _ := visitor(lstKey, fstKey, maxNodes)
require.Equal(t, 0, visited)
require.True(t, finish)
require.Nil(t, fst)
})

t.Run("iterate over the first group", func(t *testing.T) {
fstGrpBound := new(felt.Felt).SetUint64(fstInt + uint64(count-1))
visited, finish, fst, lst := visitor(fstKey, fstGrpBound, maxNodes)
require.Equal(t, count, visited)
require.True(t, finish)
require.Equal(t, fstKey, fst)
require.Equal(t, fstGrpBound, lst)
})

t.Run("iterate over the first group no lower bound", func(t *testing.T) {
fstGrpBound := new(felt.Felt).SetUint64(fstInt + uint64(count-1))
visited, finish, fst, lst := visitor(nil, fstGrpBound, maxNodes)
require.Equal(t, count, visited)
require.True(t, finish)
require.Equal(t, fstKey, fst)
require.Equal(t, fstGrpBound, lst)
})

t.Run("iterate over the first group by max nodes", func(t *testing.T) {
fstGrpBound := new(felt.Felt).SetUint64(fstInt + uint64(count-1))
visited, finish, fst, lst := visitor(fstKey, nil, uint32(count))
require.Equal(t, count, visited)
require.False(t, finish)
require.Equal(t, fstKey, fst)
require.Equal(t, fstGrpBound, lst)
})

t.Run("iterate over the last group, start before group bound", func(t *testing.T) {
lstGrpStartInt := lstInt - uint64(count-1)
lstGrpFstKey := new(felt.Felt).SetUint64(lstGrpStartInt)
startKey := new(felt.Felt).SetUint64(lstGrpStartInt - uint64(count))

visited, finish, fst, lst := visitor(startKey, nil, maxNodes)
require.Equal(t, count, visited)
require.True(t, finish)
require.Equal(t, lstGrpFstKey, fst)
require.Equal(t, lstKey, lst)
})

sndGrpFstKey := new(felt.Felt).SetUint64(bigPrefix + uint64(count))
sndGrpLstKey := new(felt.Felt).SetUint64(bigPrefix + uint64(2*count-1))
t.Run("second group key selection", func(t *testing.T) {
visited, _, _, lst := visitor(fstKey, nil, uint32(count+1))
require.Equal(t, count+1, visited)
require.Equal(t, sndGrpFstKey, lst)

visited, finish, fst, lst := visitor(sndGrpFstKey, sndGrpLstKey, maxNodes)
require.Equal(t, count, visited)
require.True(t, finish)
require.Equal(t, sndGrpFstKey, fst)
require.Equal(t, sndGrpLstKey, lst)
})

t.Run("second group key selection 2", func(t *testing.T) {
nodeAfterFstGrp := new(felt.Felt).SetUint64(fstInt + uint64(count+1))
visited, _, fst, lst := visitor(nodeAfterFstGrp, nil, 1)
require.Equal(t, 1, visited)
require.Equal(t, sndGrpFstKey, fst)
require.Equal(t, fst, lst)

visited, finish, fst, lst := visitor(sndGrpFstKey, nil, uint32(count))
require.Equal(t, count, visited)
require.False(t, finish)
require.Equal(t, sndGrpFstKey, fst)
require.Equal(t, sndGrpLstKey, lst)
})
}
Loading

0 comments on commit 3564bf0

Please sign in to comment.