Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arbo: add CheckProofBatch and CalculateProofNodes #1398

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions api/censuses.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,15 +945,15 @@ func (a *API) censusVerifyHandler(msg *apirest.APIdata, ctx *httprouter.HTTPCont
}
}

valid, err := ref.Tree().VerifyProof(leafKey, cdata.Value, cdata.CensusProof, cdata.CensusRoot)
if err != nil {
if err := ref.Tree().VerifyProof(leafKey, cdata.Value, cdata.CensusProof, cdata.CensusRoot); err != nil {
if strings.Contains(err.Error(), "calculated vs expected root mismatch") {
return ctx.Send(nil, apirest.HTTPstatusBadRequest)
}
return ErrCensusProofVerificationFailed.WithErr(err)
}
if !valid {
return ctx.Send(nil, apirest.HTTPstatusBadRequest)
}

response := Census{
Valid: valid,
Valid: true,
}
var data []byte
if data, err = json.Marshal(&response); err != nil {
Expand Down
9 changes: 6 additions & 3 deletions censustree/censustree.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ func (t *Tree) Get(key []byte) ([]byte, error) {
// VerifyProof verifies a census proof.
// If the census is indexed key can be nil (value provides the key already).
// If root is nil the last merkle root is used for verify.
func (t *Tree) VerifyProof(key, value, proof, root []byte) (bool, error) {
func (t *Tree) VerifyProof(key, value, proof, root []byte) error {
var err error
if root == nil {
root, err = t.Root()
if err != nil {
return false, fmt.Errorf("cannot get tree root: %w", err)
return fmt.Errorf("cannot get tree root: %w", err)
}
}
// If the provided key is longer than the defined maximum length truncate it
Expand All @@ -176,7 +176,10 @@ func (t *Tree) VerifyProof(key, value, proof, root []byte) (bool, error) {
if len(leafKey) > DefaultMaxKeyLen {
leafKey = leafKey[:DefaultMaxKeyLen]
}
return t.tree.VerifyProof(leafKey, value, proof, root)
if err := t.tree.VerifyProof(leafKey, value, proof, root); err != nil {
return err
}
return nil
}

// GenProof generates a census proof for the provided key.
Expand Down
3 changes: 1 addition & 2 deletions censustree/censustree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ func TestWeightedProof(t *testing.T) {
root, err := censusTree.Root()
qt.Assert(t, err, qt.IsNil)

verified, err := censusTree.VerifyProof(userKey, value, siblings, root)
err = censusTree.VerifyProof(userKey, value, siblings, root)
qt.Assert(t, err, qt.IsNil)
qt.Assert(t, verified, qt.IsTrue)
}

func TestGetCensusWeight(t *testing.T) {
Expand Down
14 changes: 5 additions & 9 deletions tree/arbo/addbatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,27 +998,23 @@ func TestAddKeysWithEmptyValues(t *testing.T) {
// check with empty array
root, err := tree.Root()
c.Assert(err, qt.IsNil)
verif, err := CheckProof(tree.hashFunction, keys[9], []byte{}, root, siblings)
err = CheckProof(tree.hashFunction, keys[9], []byte{}, root, siblings)
c.Assert(err, qt.IsNil)
c.Check(verif, qt.IsTrue)

// check with array with only 1 zero
verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0}, root, siblings)
err = CheckProof(tree.hashFunction, keys[9], []byte{0}, root, siblings)
c.Assert(err, qt.IsNil)
c.Check(verif, qt.IsTrue)

// check with array with 32 zeroes
e32 := []byte{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
c.Assert(len(e32), qt.Equals, 32)
verif, err = CheckProof(tree.hashFunction, keys[9], e32, root, siblings)
err = CheckProof(tree.hashFunction, keys[9], e32, root, siblings)
c.Assert(err, qt.IsNil)
c.Check(verif, qt.IsTrue)

// check with array with value!=0 returns false at verification
verif, err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, root, siblings)
c.Assert(err, qt.IsNil)
c.Check(verif, qt.IsFalse)
err = CheckProof(tree.hashFunction, keys[9], []byte{0, 1}, root, siblings)
c.Assert(err, qt.ErrorMatches, "calculated vs expected root mismatch")
}
33 changes: 33 additions & 0 deletions tree/arbo/circomproofs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package arbo

import (
"bytes"
"encoding/json"
"fmt"
"slices"
)

// CircomVerifierProof contains the needed data to check a Circom Verifier Proof
Expand Down Expand Up @@ -89,3 +92,33 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro

return &cp, nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) {
paddedSiblings := slices.Clone(cvp.Siblings)
for k, v := range paddedSiblings {
if bytes.Equal(v, []byte{0}) {
paddedSiblings[k] = make([]byte, hashFunc.Len())
}
}
packedSiblings, err := PackSiblings(hashFunc, paddedSiblings)
if err != nil {
return nil, err
}
return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings, cvp.OldKey, (cvp.Fnc == 1))
}

// CheckProof verifies the given proof. The proof verification depends on the
// HashFunction passed as parameter.
// Returns nil if the proof is valid, or an error otherwise.
func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) error {
hashes, err := cvp.CalculateProofNodes(hashFunc)
if err != nil {
return err
}
if !bytes.Equal(hashes[0], cvp.Root) {
return fmt.Errorf("calculated vs expected root mismatch")
}
return nil
}
106 changes: 96 additions & 10 deletions tree/arbo/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package arbo
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"slices"
Expand Down Expand Up @@ -159,30 +160,115 @@ func bytesToBitmap(b []byte) []bool {

// CheckProof verifies the given proof. The proof verification depends on the
// HashFunction passed as parameter.
func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
// Returns nil if the proof is valid, or an error otherwise.
func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) error {
hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings, nil, false)
if err != nil {
return err
}
if !bytes.Equal(hashes[0], root) {
return fmt.Errorf("calculated vs expected root mismatch")
}
return nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the given proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings, oldKey []byte, exclusion bool) ([][]byte, error) {
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
if err != nil {
return false, err
return nil, err
}

keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8))))
copy(keyPath, k)
path := getPath(len(siblings), keyPath)

key, _, err := newLeafValue(hashFunc, k, v)
if err != nil {
return false, err
key := slices.Clone(k)

if exclusion {
if slices.Equal(k, oldKey) {
return nil, fmt.Errorf("exclusion proof invalid, key and oldKey are equal")
}
// we'll prove the path to the existing key (passed as oldKey)
key = slices.Clone(oldKey)
}

path := getPath(len(siblings), keyPath)
hash, _, err := newLeafValue(hashFunc, key, v)
if err != nil {
return nil, err
}
hashes := [][]byte{hash}
for i, sibling := range slices.Backward(siblings) {
if path[i] {
key, _, err = newIntermediate(hashFunc, sibling, key)
hash, _, err = newIntermediate(hashFunc, sibling, hash)
} else {
key, _, err = newIntermediate(hashFunc, key, sibling)
hash, _, err = newIntermediate(hashFunc, hash, sibling)
}
if err != nil {
return nil, err
}
hashes = append(hashes, hash)
}
slices.Reverse(hashes)
return hashes, nil
}

// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the
// HashFunction passed as parameter.
// Returns nil if the batch is valid, or an error otherwise.
//
// TODO: doesn't support removing leaves (newProofs can only update or add new leaves)
func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error {
newBranches := make(map[string]int)
newSiblings := make(map[string]int)

if len(oldProofs) != len(newProofs) {
return fmt.Errorf("batch of proofs incomplete")
}

if len(oldProofs) == 0 {
return fmt.Errorf("empty batch")
}

for i := range oldProofs {
// Map all old branches
oldNodes, err := oldProofs[i].CalculateProofNodes(hashFunc)
if err != nil {
return fmt.Errorf("old proof invalid: %w", err)
}
// and check they are valid
if !bytes.Equal(oldProofs[i].Root, oldNodes[0]) {
return fmt.Errorf("old proof invalid: root doesn't match")
}

// Map all new branches
newNodes, err := newProofs[i].CalculateProofNodes(hashFunc)
if err != nil {
return false, err
return fmt.Errorf("new proof invalid: %w", err)
}
// and check they are valid
if !bytes.Equal(newProofs[i].Root, newNodes[0]) {
return fmt.Errorf("new proof invalid: root doesn't match")
}

for level, hash := range newNodes {
newBranches[hex.EncodeToString(hash)] = level
}

for level := range newProofs[i].Siblings {
if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) {
// since in newBranch the root is level 0, we shift siblings to level + 1
newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1
}
}
}
return bytes.Equal(key, root), nil

for hash, level := range newSiblings {
if newBranches[hash] != newSiblings[hash] {
return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level)
}
}

return nil
}
Loading
Loading