Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
altergui committed Nov 21, 2024
1 parent ec10056 commit 8562a79
Show file tree
Hide file tree
Showing 18 changed files with 680 additions and 219 deletions.
20 changes: 20 additions & 0 deletions smt/emulated/hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package emulated

import (
"github.com/consensys/gnark/std/math/emulated"

poseidon "github.com/mdehoog/poseidon/circuits/poseidon/emulated"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom

func Hash1[T emulated.FieldParams](field *emulated.Field[T], key, value *emulated.Element[T]) *emulated.Element[T] {
one := emulated.ValueOf[T](1)
inputs := []*emulated.Element[T]{key, value, &one}
return poseidon.Hash(field, inputs)
}

func Hash2[T emulated.FieldParams](field *emulated.Field[T], l, r *emulated.Element[T]) *emulated.Element[T] {
inputs := []*emulated.Element[T]{l, r}
return poseidon.Hash(field, inputs)
}
29 changes: 29 additions & 0 deletions smt/emulated/lev_ins.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom

func LevIns[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, siblings []*emulated.Element[T]) (levIns []frontend.Variable) {
levels := len(siblings)
levIns = make([]frontend.Variable, levels)
done := make([]frontend.Variable, levels-1)

isZero := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
isZero[i] = field.IsZero(siblings[i])
}
api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0)

levIns[levels-1] = api.Sub(1, isZero[levels-2])
done[levels-2] = levIns[levels-1]
for i := levels - 2; i > 0; i-- {
levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1]))
done[i-1] = api.Add(levIns[i], done[i])
}
levIns[0] = api.Sub(1, done[0])
return levIns
}
67 changes: 67 additions & 0 deletions smt/emulated/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"

"github.com/mdehoog/gnark-circom-smt/circuits/smt"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom

func Processor[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], oldRoot *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, newKey, newValue *emulated.Element[T], fnc0, fnc1 frontend.Variable) (newRoot *emulated.Element[T]) {
levels := len(siblings)
enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1))
hash1Old := Hash1(field, oldKey, oldValue)
hash1New := Hash1(field, newKey, newValue)
n2bOld := field.ToBits(oldKey)
n2bNew := field.ToBits(newKey)
smtLevIns := LevIns(api, field, enabled, siblings)

xors := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
xors[i] = api.Xor(n2bOld[i], n2bNew[i])
}

stTop := make([]frontend.Variable, levels)
stOld0 := make([]frontend.Variable, levels)
stBot := make([]frontend.Variable, levels)
stNew1 := make([]frontend.Variable, levels)
stNa := make([]frontend.Variable, levels)
stUpd := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
if i == 0 {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, enabled, 0, 0, 0, api.Sub(1, enabled), 0)
} else {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, stTop[i-1], stOld0[i-1], stBot[i-1], stNew1[i-1], stNa[i-1], stUpd[i-1])
}
}

api.AssertIsEqual(api.Add(api.Add(stNa[levels-1], stNew1[levels-1]), api.Add(stOld0[levels-1], stUpd[levels-1])), 1)

levelsOldRoot := make([]*emulated.Element[T], levels)
levelsNewRoot := make([]*emulated.Element[T], levels)
for i := levels - 1; i >= 0; i-- {
if i == levels-1 {
zero := emulated.ValueOf[T](0)
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero, &zero)
} else {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1])
}
}

topSwitcherL, topSwitcherR := Switcher(field, api.Mul(fnc0, fnc1), levelsOldRoot[0], levelsNewRoot[0])
ForceEqualIfEnabled(field, oldRoot, topSwitcherL, enabled)

newRoot = field.Select(enabled, topSwitcherR, oldRoot)

areKeyEquals := IsEqual(field, oldKey, newKey)
in := []frontend.Variable{
api.Sub(1, fnc0),
fnc1,
api.Sub(1, areKeyEquals),
}
keysOk := smt.MultiAnd(api, in)
api.AssertIsEqual(keysOk, 0)
return
}
27 changes: 27 additions & 0 deletions smt/emulated/processor_level.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom

func ProcessorLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stOld0, stBot, stNew1, stUpd frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], newlrbit frontend.Variable, oldChild, newChild *emulated.Element[T]) (oldRoot, newRoot *emulated.Element[T]) {
oldProofHashL, oldProofHashR := Switcher(field, newlrbit, oldChild, sibling)
oldProofHash := Hash2(field, oldProofHashL, oldProofHashR)

am := api.Add(api.Add(stBot, stNew1), stUpd)
oldRoot = mux2(api, field, am, stTop, old1leaf, oldProofHash)

am = api.Add(stTop, stBot)
a := mux2(api, field, am, stNew1, newChild, new1leaf)
b := mux2(api, field, stTop, stNew1, sibling, old1leaf)
newProofHashL, newProofHashR := Switcher(field, newlrbit, a, b)
newProofHash := Hash2(field, newProofHashL, newProofHashR)

am = api.Add(api.Add(stTop, stBot), stNew1)
bm := api.Add(stOld0, stUpd)
newRoot = mux2(api, field, am, bm, newProofHash, new1leaf)
return
}
34 changes: 34 additions & 0 deletions smt/emulated/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

func IsEqual[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T]) frontend.Variable {
return field.IsZero(field.Sub(a, b))
}

func ForceEqualIfEnabled[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T], enabled frontend.Variable) {
c := field.Select(enabled, a, b)
field.AssertIsEqual(c, b)
}

// Switcher is [out1, out2] = sel ? [r, l] : [l, r]
func Switcher[T emulated.FieldParams](field *emulated.Field[T], sel frontend.Variable, l, r *emulated.Element[T]) (*emulated.Element[T], *emulated.Element[T]) {
return field.Select(sel, r, l), field.Select(sel, l, r)
}

// mux2 is (out = as ? a : bs ? b : 0)
func mux2[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs frontend.Variable, a, b *emulated.Element[T]) *emulated.Element[T] {
sel := api.FromBinary(as, bs)
zero := emulated.ValueOf[T](0)
return field.Mux(sel, &zero, a, b, a)
}

// mux3 is (out = as ? a : bs ? b : cs ? c : 0)
func mux3[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs, cs frontend.Variable, a, b, c *emulated.Element[T]) *emulated.Element[T] {
sel := api.FromBinary(as, bs, cs)
zero := emulated.ValueOf[T](0)
return field.Mux(sel, &zero, a, b, a, c, a, b, a)
}
54 changes: 54 additions & 0 deletions smt/emulated/verifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"

"github.com/mdehoog/gnark-circom-smt/circuits/smt"
)

func InclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], key, value *emulated.Element[T]) {
Verifier[T](api, field, 1, root, siblings, key, value, 0, key, value, 0)
}

func ExclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key *emulated.Element[T]) {
zero := emulated.ValueOf[T](0)
Verifier[T](api, field, 1, root, siblings, oldKey, oldValue, isOld0, key, &zero, 1)
}

func Verifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key, value *emulated.Element[T], fnc frontend.Variable) {
nLevels := len(siblings)
hash1Old := Hash1(field, oldKey, oldValue)
hash1New := Hash1(field, key, value)
n2bNew := field.ToBits(key)
smtLevIns := LevIns(api, field, enabled, siblings)

stTop := make([]frontend.Variable, nLevels)
stI0 := make([]frontend.Variable, nLevels)
stIOld := make([]frontend.Variable, nLevels)
stINew := make([]frontend.Variable, nLevels)
stNa := make([]frontend.Variable, nLevels)
for i := 0; i < nLevels; i++ {
if i == 0 {
stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, enabled, 0, 0, 0, api.Sub(1, enabled))
} else {
stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, stTop[i-1], stI0[i-1], stIOld[i-1], stINew[i-1], stNa[i-1])
}
}
api.AssertIsEqual(api.Add(api.Add(api.Add(stNa[nLevels-1], stIOld[nLevels-1]), stINew[nLevels-1]), stI0[nLevels-1]), 1)

levels := make([]*emulated.Element[T], nLevels)
for i := nLevels - 1; i >= 0; i-- {
if i == nLevels-1 {
zero := emulated.ValueOf[T](0)
levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero)
} else {
levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1])
}
}

areKeyEquals := IsEqual(field, oldKey, key)
keysOk := smt.MultiAnd(api, []frontend.Variable{fnc, api.Sub(1, isOld0), areKeyEquals, enabled})
api.AssertIsEqual(keysOk, 0)
ForceEqualIfEnabled(field, levels[0], root, enabled)
}
13 changes: 13 additions & 0 deletions smt/emulated/verifier_level.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package emulated

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

func VerifierLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stIOld, stINew frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], lrbit frontend.Variable, child *emulated.Element[T]) (root *emulated.Element[T]) {
proofHashL, proofHashR := Switcher(field, lrbit, child, sibling)
proofHash := Hash2(field, proofHashL, proofHashR)
root = mux3(api, field, stTop, stIOld, stINew, proofHash, old1leaf, new1leaf)
return
}
19 changes: 19 additions & 0 deletions smt/hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package smt

import (
"github.com/consensys/gnark/frontend"

"github.com/mdehoog/poseidon/circuits/poseidon"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom

func Hash1(api frontend.API, key, value frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{key, value, 1}
return poseidon.Hash(api, inputs)
}

func Hash2(api frontend.API, l, r frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{l, r}
return poseidon.Hash(api, inputs)
}
26 changes: 26 additions & 0 deletions smt/lev_ins.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package smt

import "github.com/consensys/gnark/frontend"

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom

func LevIns(api frontend.API, enabled frontend.Variable, siblings []frontend.Variable) (levIns []frontend.Variable) {
levels := len(siblings)
levIns = make([]frontend.Variable, levels)
done := make([]frontend.Variable, levels-1)

isZero := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
isZero[i] = api.IsZero(siblings[i])
}
api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0)

levIns[levels-1] = api.Sub(1, isZero[levels-2])
done[levels-2] = levIns[levels-1]
for i := levels - 2; i > 0; i-- {
levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1]))
done[i-1] = api.Add(levIns[i], done[i])
}
levIns[0] = api.Sub(1, done[0])
return levIns
}
63 changes: 63 additions & 0 deletions smt/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package smt

import (
"github.com/consensys/gnark/frontend"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom

func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) {
levels := len(siblings)
enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1))
hash1Old := Hash1(api, oldKey, oldValue)
hash1New := Hash1(api, newKey, newValue)
n2bOld := api.ToBinary(oldKey, api.Compiler().FieldBitLen())
n2bNew := api.ToBinary(newKey, api.Compiler().FieldBitLen())
smtLevIns := LevIns(api, enabled, siblings)

xors := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
xors[i] = api.Xor(n2bOld[i], n2bNew[i])
}

stTop := make([]frontend.Variable, levels)
stOld0 := make([]frontend.Variable, levels)
stBot := make([]frontend.Variable, levels)
stNew1 := make([]frontend.Variable, levels)
stNa := make([]frontend.Variable, levels)
stUpd := make([]frontend.Variable, levels)
for i := 0; i < levels; i++ {
if i == 0 {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, enabled, 0, 0, 0, api.Sub(1, enabled), 0)
} else {
stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, stTop[i-1], stOld0[i-1], stBot[i-1], stNew1[i-1], stNa[i-1], stUpd[i-1])
}
}

api.AssertIsEqual(api.Add(api.Add(stNa[levels-1], stNew1[levels-1]), api.Add(stOld0[levels-1], stUpd[levels-1])), 1)

levelsOldRoot := make([]frontend.Variable, levels)
levelsNewRoot := make([]frontend.Variable, levels)
for i := levels - 1; i >= 0; i-- {
if i == levels-1 {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0)
} else {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1])
}
}

topSwitcherL, topSwitcherR := Switcher(api, api.Mul(fnc0, fnc1), levelsOldRoot[0], levelsNewRoot[0])
ForceEqualIfEnabled(api, oldRoot, topSwitcherL, enabled)

newRoot = api.Add(api.Mul(enabled, api.Sub(topSwitcherR, oldRoot)), oldRoot)

areKeyEquals := IsEqual(api, oldKey, newKey)
in := []frontend.Variable{
api.Sub(1, fnc0),
fnc1,
api.Sub(1, areKeyEquals),
}
keysOk := MultiAnd(api, in)
api.AssertIsEqual(keysOk, 0)
return
}
20 changes: 20 additions & 0 deletions smt/processor_level.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package smt

import (
"github.com/consensys/gnark/frontend"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom

func ProcessorLevel(api frontend.API, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) {
oldProofHashL, oldProofHashR := Switcher(api, newlrbit, oldChild, sibling)
oldProofHash := Hash2(api, oldProofHashL, oldProofHashR)

oldRoot = api.Add(api.Mul(old1leaf, api.Add(api.Add(stBot, stNew1), stUpd)), api.Mul(oldProofHash, stTop))

newProofHashL, newProofHashR := Switcher(api, newlrbit, api.Add(api.Mul(newChild, api.Add(stTop, stBot)), api.Mul(new1leaf, stNew1)), api.Add(api.Mul(sibling, stTop), api.Mul(old1leaf, stNew1)))
newProofHash := Hash2(api, newProofHashL, newProofHashR)

newRoot = api.Add(api.Mul(newProofHash, api.Add(api.Add(stTop, stBot), stNew1)), api.Mul(new1leaf, api.Add(stOld0, stUpd)))
return
}
Loading

0 comments on commit 8562a79

Please sign in to comment.