From 8562a79f4ddba1de6c07b8f3a4307c70c98c66fa Mon Sep 17 00:00:00 2001 From: Gui Iribarren Date: Thu, 21 Nov 2024 11:29:32 +0100 Subject: [PATCH] replace package smt with https://github.com/mdehoog/gnark-circom-smt/ --- smt/emulated/hash.go | 20 +++ smt/emulated/lev_ins.go | 29 +++++ smt/emulated/processor.go | 67 ++++++++++ smt/emulated/processor_level.go | 27 ++++ smt/emulated/utils.go | 34 +++++ smt/emulated/verifier.go | 54 ++++++++ smt/emulated/verifier_level.go | 13 ++ smt/hash.go | 19 +++ smt/lev_ins.go | 26 ++++ smt/processor.go | 63 +++++++++ smt/processor_level.go | 20 +++ smt/processor_sm.go | 17 +++ smt/utils.go | 51 +++----- smt/verifier.go | 218 +++++--------------------------- smt/verifier_level.go | 12 ++ smt/verifier_sm.go | 14 ++ smt/wrapper.go | 31 +++++ smt/wrapper_arbo.go | 184 +++++++++++++++++++++++++++ 18 files changed, 680 insertions(+), 219 deletions(-) create mode 100644 smt/emulated/hash.go create mode 100644 smt/emulated/lev_ins.go create mode 100644 smt/emulated/processor.go create mode 100644 smt/emulated/processor_level.go create mode 100644 smt/emulated/utils.go create mode 100644 smt/emulated/verifier.go create mode 100644 smt/emulated/verifier_level.go create mode 100644 smt/hash.go create mode 100644 smt/lev_ins.go create mode 100644 smt/processor.go create mode 100644 smt/processor_level.go create mode 100644 smt/processor_sm.go create mode 100644 smt/verifier_level.go create mode 100644 smt/verifier_sm.go create mode 100644 smt/wrapper.go create mode 100644 smt/wrapper_arbo.go diff --git a/smt/emulated/hash.go b/smt/emulated/hash.go new file mode 100644 index 0000000..aa668fb --- /dev/null +++ b/smt/emulated/hash.go @@ -0,0 +1,20 @@ +package emulated + +import ( + "github.com/consensys/gnark/std/math/emulated" + + poseidon "github.com/mdehoog/poseidon/circuits/poseidon/emulated" +) + +// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom + +func Hash1[T emulated.FieldParams](field *emulated.Field[T], key, value *emulated.Element[T]) *emulated.Element[T] { + one := emulated.ValueOf[T](1) + inputs := []*emulated.Element[T]{key, value, &one} + return poseidon.Hash(field, inputs) +} + +func Hash2[T emulated.FieldParams](field *emulated.Field[T], l, r *emulated.Element[T]) *emulated.Element[T] { + inputs := []*emulated.Element[T]{l, r} + return poseidon.Hash(field, inputs) +} diff --git a/smt/emulated/lev_ins.go b/smt/emulated/lev_ins.go new file mode 100644 index 0000000..926c8d2 --- /dev/null +++ b/smt/emulated/lev_ins.go @@ -0,0 +1,29 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom + +func LevIns[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, siblings []*emulated.Element[T]) (levIns []frontend.Variable) { + levels := len(siblings) + levIns = make([]frontend.Variable, levels) + done := make([]frontend.Variable, levels-1) + + isZero := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + isZero[i] = field.IsZero(siblings[i]) + } + api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0) + + levIns[levels-1] = api.Sub(1, isZero[levels-2]) + done[levels-2] = levIns[levels-1] + for i := levels - 2; i > 0; i-- { + levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1])) + done[i-1] = api.Add(levIns[i], done[i]) + } + levIns[0] = api.Sub(1, done[0]) + return levIns +} diff --git a/smt/emulated/processor.go b/smt/emulated/processor.go new file mode 100644 index 0000000..ad75b6a --- /dev/null +++ b/smt/emulated/processor.go @@ -0,0 +1,67 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + + "github.com/mdehoog/gnark-circom-smt/circuits/smt" +) + +// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom + +func Processor[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], oldRoot *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, newKey, newValue *emulated.Element[T], fnc0, fnc1 frontend.Variable) (newRoot *emulated.Element[T]) { + levels := len(siblings) + enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1)) + hash1Old := Hash1(field, oldKey, oldValue) + hash1New := Hash1(field, newKey, newValue) + n2bOld := field.ToBits(oldKey) + n2bNew := field.ToBits(newKey) + smtLevIns := LevIns(api, field, enabled, siblings) + + xors := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + xors[i] = api.Xor(n2bOld[i], n2bNew[i]) + } + + stTop := make([]frontend.Variable, levels) + stOld0 := make([]frontend.Variable, levels) + stBot := make([]frontend.Variable, levels) + stNew1 := make([]frontend.Variable, levels) + stNa := make([]frontend.Variable, levels) + stUpd := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + if i == 0 { + stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, enabled, 0, 0, 0, api.Sub(1, enabled), 0) + } else { + stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = smt.ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, stTop[i-1], stOld0[i-1], stBot[i-1], stNew1[i-1], stNa[i-1], stUpd[i-1]) + } + } + + api.AssertIsEqual(api.Add(api.Add(stNa[levels-1], stNew1[levels-1]), api.Add(stOld0[levels-1], stUpd[levels-1])), 1) + + levelsOldRoot := make([]*emulated.Element[T], levels) + levelsNewRoot := make([]*emulated.Element[T], levels) + for i := levels - 1; i >= 0; i-- { + if i == levels-1 { + zero := emulated.ValueOf[T](0) + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero, &zero) + } else { + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, field, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1]) + } + } + + topSwitcherL, topSwitcherR := Switcher(field, api.Mul(fnc0, fnc1), levelsOldRoot[0], levelsNewRoot[0]) + ForceEqualIfEnabled(field, oldRoot, topSwitcherL, enabled) + + newRoot = field.Select(enabled, topSwitcherR, oldRoot) + + areKeyEquals := IsEqual(field, oldKey, newKey) + in := []frontend.Variable{ + api.Sub(1, fnc0), + fnc1, + api.Sub(1, areKeyEquals), + } + keysOk := smt.MultiAnd(api, in) + api.AssertIsEqual(keysOk, 0) + return +} diff --git a/smt/emulated/processor_level.go b/smt/emulated/processor_level.go new file mode 100644 index 0000000..d40b7bd --- /dev/null +++ b/smt/emulated/processor_level.go @@ -0,0 +1,27 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom + +func ProcessorLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stOld0, stBot, stNew1, stUpd frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], newlrbit frontend.Variable, oldChild, newChild *emulated.Element[T]) (oldRoot, newRoot *emulated.Element[T]) { + oldProofHashL, oldProofHashR := Switcher(field, newlrbit, oldChild, sibling) + oldProofHash := Hash2(field, oldProofHashL, oldProofHashR) + + am := api.Add(api.Add(stBot, stNew1), stUpd) + oldRoot = mux2(api, field, am, stTop, old1leaf, oldProofHash) + + am = api.Add(stTop, stBot) + a := mux2(api, field, am, stNew1, newChild, new1leaf) + b := mux2(api, field, stTop, stNew1, sibling, old1leaf) + newProofHashL, newProofHashR := Switcher(field, newlrbit, a, b) + newProofHash := Hash2(field, newProofHashL, newProofHashR) + + am = api.Add(api.Add(stTop, stBot), stNew1) + bm := api.Add(stOld0, stUpd) + newRoot = mux2(api, field, am, bm, newProofHash, new1leaf) + return +} diff --git a/smt/emulated/utils.go b/smt/emulated/utils.go new file mode 100644 index 0000000..fb2dd72 --- /dev/null +++ b/smt/emulated/utils.go @@ -0,0 +1,34 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +func IsEqual[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T]) frontend.Variable { + return field.IsZero(field.Sub(a, b)) +} + +func ForceEqualIfEnabled[T emulated.FieldParams](field *emulated.Field[T], a, b *emulated.Element[T], enabled frontend.Variable) { + c := field.Select(enabled, a, b) + field.AssertIsEqual(c, b) +} + +// Switcher is [out1, out2] = sel ? [r, l] : [l, r] +func Switcher[T emulated.FieldParams](field *emulated.Field[T], sel frontend.Variable, l, r *emulated.Element[T]) (*emulated.Element[T], *emulated.Element[T]) { + return field.Select(sel, r, l), field.Select(sel, l, r) +} + +// mux2 is (out = as ? a : bs ? b : 0) +func mux2[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs frontend.Variable, a, b *emulated.Element[T]) *emulated.Element[T] { + sel := api.FromBinary(as, bs) + zero := emulated.ValueOf[T](0) + return field.Mux(sel, &zero, a, b, a) +} + +// mux3 is (out = as ? a : bs ? b : cs ? c : 0) +func mux3[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], as, bs, cs frontend.Variable, a, b, c *emulated.Element[T]) *emulated.Element[T] { + sel := api.FromBinary(as, bs, cs) + zero := emulated.ValueOf[T](0) + return field.Mux(sel, &zero, a, b, a, c, a, b, a) +} diff --git a/smt/emulated/verifier.go b/smt/emulated/verifier.go new file mode 100644 index 0000000..bf8c241 --- /dev/null +++ b/smt/emulated/verifier.go @@ -0,0 +1,54 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + + "github.com/mdehoog/gnark-circom-smt/circuits/smt" +) + +func InclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], key, value *emulated.Element[T]) { + Verifier[T](api, field, 1, root, siblings, key, value, 0, key, value, 0) +} + +func ExclusionVerifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key *emulated.Element[T]) { + zero := emulated.ValueOf[T](0) + Verifier[T](api, field, 1, root, siblings, oldKey, oldValue, isOld0, key, &zero, 1) +} + +func Verifier[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], enabled frontend.Variable, root *emulated.Element[T], siblings []*emulated.Element[T], oldKey, oldValue *emulated.Element[T], isOld0 frontend.Variable, key, value *emulated.Element[T], fnc frontend.Variable) { + nLevels := len(siblings) + hash1Old := Hash1(field, oldKey, oldValue) + hash1New := Hash1(field, key, value) + n2bNew := field.ToBits(key) + smtLevIns := LevIns(api, field, enabled, siblings) + + stTop := make([]frontend.Variable, nLevels) + stI0 := make([]frontend.Variable, nLevels) + stIOld := make([]frontend.Variable, nLevels) + stINew := make([]frontend.Variable, nLevels) + stNa := make([]frontend.Variable, nLevels) + for i := 0; i < nLevels; i++ { + if i == 0 { + stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, enabled, 0, 0, 0, api.Sub(1, enabled)) + } else { + stTop[i], stI0[i], stIOld[i], stINew[i], stNa[i] = smt.VerifierSM(api, isOld0, smtLevIns[i], fnc, stTop[i-1], stI0[i-1], stIOld[i-1], stINew[i-1], stNa[i-1]) + } + } + api.AssertIsEqual(api.Add(api.Add(api.Add(stNa[nLevels-1], stIOld[nLevels-1]), stINew[nLevels-1]), stI0[nLevels-1]), 1) + + levels := make([]*emulated.Element[T], nLevels) + for i := nLevels - 1; i >= 0; i-- { + if i == nLevels-1 { + zero := emulated.ValueOf[T](0) + levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], &zero) + } else { + levels[i] = VerifierLevel(api, field, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1]) + } + } + + areKeyEquals := IsEqual(field, oldKey, key) + keysOk := smt.MultiAnd(api, []frontend.Variable{fnc, api.Sub(1, isOld0), areKeyEquals, enabled}) + api.AssertIsEqual(keysOk, 0) + ForceEqualIfEnabled(field, levels[0], root, enabled) +} diff --git a/smt/emulated/verifier_level.go b/smt/emulated/verifier_level.go new file mode 100644 index 0000000..d8899a4 --- /dev/null +++ b/smt/emulated/verifier_level.go @@ -0,0 +1,13 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +func VerifierLevel[T emulated.FieldParams](api frontend.API, field *emulated.Field[T], stTop, stIOld, stINew frontend.Variable, sibling, old1leaf, new1leaf *emulated.Element[T], lrbit frontend.Variable, child *emulated.Element[T]) (root *emulated.Element[T]) { + proofHashL, proofHashR := Switcher(field, lrbit, child, sibling) + proofHash := Hash2(field, proofHashL, proofHashR) + root = mux3(api, field, stTop, stIOld, stINew, proofHash, old1leaf, new1leaf) + return +} diff --git a/smt/hash.go b/smt/hash.go new file mode 100644 index 0000000..2fc896b --- /dev/null +++ b/smt/hash.go @@ -0,0 +1,19 @@ +package smt + +import ( + "github.com/consensys/gnark/frontend" + + "github.com/mdehoog/poseidon/circuits/poseidon" +) + +// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom + +func Hash1(api frontend.API, key, value frontend.Variable) frontend.Variable { + inputs := []frontend.Variable{key, value, 1} + return poseidon.Hash(api, inputs) +} + +func Hash2(api frontend.API, l, r frontend.Variable) frontend.Variable { + inputs := []frontend.Variable{l, r} + return poseidon.Hash(api, inputs) +} diff --git a/smt/lev_ins.go b/smt/lev_ins.go new file mode 100644 index 0000000..f34b460 --- /dev/null +++ b/smt/lev_ins.go @@ -0,0 +1,26 @@ +package smt + +import "github.com/consensys/gnark/frontend" + +// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smtlevins.circom + +func LevIns(api frontend.API, enabled frontend.Variable, siblings []frontend.Variable) (levIns []frontend.Variable) { + levels := len(siblings) + levIns = make([]frontend.Variable, levels) + done := make([]frontend.Variable, levels-1) + + isZero := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + isZero[i] = api.IsZero(siblings[i]) + } + api.AssertIsEqual(api.Mul(api.Sub(isZero[levels-1], 1), enabled), 0) + + levIns[levels-1] = api.Sub(1, isZero[levels-2]) + done[levels-2] = levIns[levels-1] + for i := levels - 2; i > 0; i-- { + levIns[i] = api.Mul(api.Sub(1, done[i]), api.Sub(1, isZero[i-1])) + done[i-1] = api.Add(levIns[i], done[i]) + } + levIns[0] = api.Sub(1, done[0]) + return levIns +} diff --git a/smt/processor.go b/smt/processor.go new file mode 100644 index 0000000..e27c864 --- /dev/null +++ b/smt/processor.go @@ -0,0 +1,63 @@ +package smt + +import ( + "github.com/consensys/gnark/frontend" +) + +// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom + +func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) { + levels := len(siblings) + enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1)) + hash1Old := Hash1(api, oldKey, oldValue) + hash1New := Hash1(api, newKey, newValue) + n2bOld := api.ToBinary(oldKey, api.Compiler().FieldBitLen()) + n2bNew := api.ToBinary(newKey, api.Compiler().FieldBitLen()) + smtLevIns := LevIns(api, enabled, siblings) + + xors := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + xors[i] = api.Xor(n2bOld[i], n2bNew[i]) + } + + stTop := make([]frontend.Variable, levels) + stOld0 := make([]frontend.Variable, levels) + stBot := make([]frontend.Variable, levels) + stNew1 := make([]frontend.Variable, levels) + stNa := make([]frontend.Variable, levels) + stUpd := make([]frontend.Variable, levels) + for i := 0; i < levels; i++ { + if i == 0 { + stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, enabled, 0, 0, 0, api.Sub(1, enabled), 0) + } else { + stTop[i], stOld0[i], stBot[i], stNew1[i], stNa[i], stUpd[i] = ProcessorSM(api, xors[i], isOld0, smtLevIns[i], fnc0, stTop[i-1], stOld0[i-1], stBot[i-1], stNew1[i-1], stNa[i-1], stUpd[i-1]) + } + } + + api.AssertIsEqual(api.Add(api.Add(stNa[levels-1], stNew1[levels-1]), api.Add(stOld0[levels-1], stUpd[levels-1])), 1) + + levelsOldRoot := make([]frontend.Variable, levels) + levelsNewRoot := make([]frontend.Variable, levels) + for i := levels - 1; i >= 0; i-- { + if i == levels-1 { + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0) + } else { + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1]) + } + } + + topSwitcherL, topSwitcherR := Switcher(api, api.Mul(fnc0, fnc1), levelsOldRoot[0], levelsNewRoot[0]) + ForceEqualIfEnabled(api, oldRoot, topSwitcherL, enabled) + + newRoot = api.Add(api.Mul(enabled, api.Sub(topSwitcherR, oldRoot)), oldRoot) + + areKeyEquals := IsEqual(api, oldKey, newKey) + in := []frontend.Variable{ + api.Sub(1, fnc0), + fnc1, + api.Sub(1, areKeyEquals), + } + keysOk := MultiAnd(api, in) + api.AssertIsEqual(keysOk, 0) + return +} diff --git a/smt/processor_level.go b/smt/processor_level.go new file mode 100644 index 0000000..46f8d0a --- /dev/null +++ b/smt/processor_level.go @@ -0,0 +1,20 @@ +package smt + +import ( + "github.com/consensys/gnark/frontend" +) + +// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom + +func ProcessorLevel(api frontend.API, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) { + oldProofHashL, oldProofHashR := Switcher(api, newlrbit, oldChild, sibling) + oldProofHash := Hash2(api, oldProofHashL, oldProofHashR) + + oldRoot = api.Add(api.Mul(old1leaf, api.Add(api.Add(stBot, stNew1), stUpd)), api.Mul(oldProofHash, stTop)) + + newProofHashL, newProofHashR := Switcher(api, newlrbit, api.Add(api.Mul(newChild, api.Add(stTop, stBot)), api.Mul(new1leaf, stNew1)), api.Add(api.Mul(sibling, stTop), api.Mul(old1leaf, stNew1))) + newProofHash := Hash2(api, newProofHashL, newProofHashR) + + newRoot = api.Add(api.Mul(newProofHash, api.Add(api.Add(stTop, stBot), stNew1)), api.Mul(new1leaf, api.Add(stOld0, stUpd))) + return +} diff --git a/smt/processor_sm.go b/smt/processor_sm.go new file mode 100644 index 0000000..af6c4fe --- /dev/null +++ b/smt/processor_sm.go @@ -0,0 +1,17 @@ +package smt + +import "github.com/consensys/gnark/frontend" + +// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorsm.circom + +func ProcessorSM(api frontend.API, xor, is0, levIns, fnc0, prevTop, prevOld0, prevBot, prevNew1, prevNa, prevUpd frontend.Variable) (stTop, stOld0, stBot, stNew1, stNa, stUpd frontend.Variable) { + aux1 := api.Mul(prevTop, levIns) + aux2 := api.Mul(aux1, fnc0) + stTop = api.Sub(prevTop, aux1) + stOld0 = api.Mul(aux2, is0) + stNew1 = api.Mul(api.Add(api.Sub(aux2, stOld0), prevBot), xor) + stBot = api.Mul(api.Sub(1, xor), api.Add(api.Sub(aux2, stOld0), prevBot)) + stUpd = api.Sub(aux1, aux2) + stNa = api.Add(api.Add(api.Add(prevNew1, prevOld0), prevNa), prevUpd) + return +} diff --git a/smt/utils.go b/smt/utils.go index 46e3054..1d1136e 100644 --- a/smt/utils.go +++ b/smt/utils.go @@ -1,48 +1,31 @@ package smt import ( - "github.com/vocdoni/gnark-crypto-primitives/poseidon" - "github.com/consensys/gnark/frontend" ) -// endLeafValue returns the encoded childless leaf value for the key-value pair -// provided, hashing it with the predefined hashing function 'H': -// -// newLeafValue = H(key | value | 1) -func endLeafValue(api frontend.API, key, value frontend.Variable) (frontend.Variable, error) { - return poseidon.Hash(api, key, value, 1) +func IsEqual(api frontend.API, a, b frontend.Variable) frontend.Variable { + return api.IsZero(api.Sub(a, b)) } -// intermediateLeafValue returns the encoded intermediate leaf value for the -// key-value pair provided, hashing it with the predefined hashing function 'H': -// -// intermediateLeafValue = H(l | r) -func intermediateLeafValue(api frontend.API, l, r frontend.Variable) (frontend.Variable, error) { - return poseidon.Hash(api, l, r) +func ForceEqualIfEnabled(api frontend.API, a, b, enabled frontend.Variable) { + c := api.IsZero(api.Sub(a, b)) + api.AssertIsEqual(api.Mul(api.Sub(1, c), enabled), 0) } -func switcher(api frontend.API, sel, l, r frontend.Variable) (outL, outR frontend.Variable) { - // aux <== (R-L)*sel; - aux := api.Mul(api.Sub(r, l), sel) - // outL <== aux + L; - outL = api.Add(aux, l) - // outR <== -aux + R; - outR = api.Sub(r, aux) - return +func MultiAnd(api frontend.API, in []frontend.Variable) frontend.Variable { + out := frontend.Variable(1) + for i := 0; i < len(in); i++ { + out = api.And(out, in[i]) + } + return out } -func multiAnd(api frontend.API, inputs ...frontend.Variable) frontend.Variable { - if len(inputs) == 0 { - return 0 - } - if len(inputs) == 1 { - return inputs[0] - } +func Switcher(api frontend.API, sel, l, r frontend.Variable) (frontend.Variable, frontend.Variable) { + aux := api.Mul(api.Sub(r, l), sel) - res := inputs[0] - for i := 1; i < len(inputs); i++ { - res = api.And(res, inputs[i]) - } - return res + outL := api.Add(aux, l) + outR := api.Sub(r, aux) + + return outL, outR } diff --git a/smt/verifier.go b/smt/verifier.go index 4ef173f..ce88232 100644 --- a/smt/verifier.go +++ b/smt/verifier.go @@ -1,201 +1,49 @@ -// smt is a port of Circom SMTVerifier. It attempts to check a proof -// of a Sparse Merkle Tree (compatible with Arbo Merkle Tree implementation). -// Check the original implementation from Iden3: -// - https://github.com/iden3/circomlib/tree/a8cdb6cd1ad652cca1a409da053ec98f19de6c9d/circuits/smt package smt -import "github.com/consensys/gnark/frontend" +import ( + "github.com/consensys/gnark/frontend" +) -func Verifier(api frontend.API, - root, key, value frontend.Variable, siblings []frontend.Variable) error { - return smtverifier(api, 1, root, 0, 0, 0, key, value, 0, siblings) +func InclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, key, value frontend.Variable) { + Verifier(api, 1, root, siblings, key, value, 0, key, value, 0) } -func smtverifier(api frontend.API, - enabled, root, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable, - siblings []frontend.Variable) error { - nLevels := len(siblings) - - // Steps: - // 1. Get the hash of both key-value pairs, old and new one. - // 2. Get the binary representation of the key new. - // 3. Get the path of the current key. - // 4. Calculate the root with the siblings provided. - // 5. Compare the calculated root with the provided one. - - // [STEP 1] - // hash1Old = H(oldKey | oldValue | 1) - hash1Old, err := endLeafValue(api, oldKey, oldValue) - if err != nil { - return err - } - // hash1New = H(key | value | 1) - hash1New, err := endLeafValue(api, key, value) - if err != nil { - return err - } +func ExclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key frontend.Variable) { + Verifier(api, 1, root, siblings, oldKey, oldValue, isOld0, key, 0, 1) +} - // [STEP 2] - // component n2bNew = Num2Bits_strict(); - // n2bNew.in <== key; +func Verifier(api frontend.API, enabled, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable) { + nLevels := len(siblings) + hash1Old := Hash1(api, oldKey, oldValue) + hash1New := Hash1(api, key, value) n2bNew := api.ToBinary(key, api.Compiler().FieldBitLen()) + smtLevIns := LevIns(api, enabled, siblings) - // [STEP 3] - // component smtLevIns = SMTLevIns(nLevels); - // for (i=0; i= 0; i-- { + if i == nLevels-1 { + levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0) + } else { + levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1]) } } - // component areKeyEquals = IsEqual(); - // areKeyEquals.in[0] <== oldKey; - // areKeyEquals.in[1] <== key; - keysEqual := frontend.Variable(0) - if api.Cmp(oldKey, key) == 0 { - keysEqual = 1 - } - - // component keysOk = MultiAND(4); - // keysOk.in[0] <== fnc; - // keysOk.in[1] <== 1-isOld0; - // keysOk.in[2] <== areKeyEquals.out; - // keysOk.in[3] <== enabled; - keysOk := multiAnd(api, fnc, api.Sub(1, isOld0), keysEqual, enabled) - // keysOk.out === 0; + areKeyEquals := IsEqual(api, oldKey, key) + keysOk := MultiAnd(api, []frontend.Variable{fnc, api.Sub(1, isOld0), areKeyEquals, enabled}) api.AssertIsEqual(keysOk, 0) - - // [STEP 5] - // component checkRoot = ForceEqualIfEnabled(); - // checkRoot.enabled <== enabled; - // checkRoot.in[0] <== levels[0].root; - // checkRoot.in[1] <== root; - api.AssertIsEqual(root, levels[0]) - return nil -} - -func smtLevIns(api frontend.API, siblings []frontend.Variable, enabled frontend.Variable) []frontend.Variable { - nLevels := len(siblings) - // The last level must always have a sibling of 0. If not, then it cannot be inserted. - // (isZero[nLevels-1].out - 1) * enabled === 0; - if api.IsZero(enabled) == 0 { - api.AssertIsEqual(siblings[nLevels-1], 0) - } - - // for (i=0; i 0; i-- { - // levIns[i] = (1-isDone[i])*(1-isZero[i-1]) - levIns[i] = api.Mul(api.Sub(1, isDone[i]), api.Sub(1, isZero[i-1])) - // done[i-1] = levIns[i] + done[i] - isDone[i-1] = api.Add(levIns[i], isDone[i]) - } - // levIns[0] <== (1-done[0]); - levIns[0] = api.Sub(1, isDone[0]) - - return levIns -} - -func smtVerifierSM(api frontend.API, - is0, levIns, fnc, prevTop, prevI0, prevIold, prevInew, prevNa frontend.Variable) ( - stTop, stIold, stI0, stInew, stNa frontend.Variable) { - // prev_top_lev_ins <== prev_top * levIns; - prevTopLevIns := api.Mul(prevTop, levIns) - // prev_top_lev_ins_fnc <== prev_top_lev_ins*fnc - prevTopLevInsFnc := api.Mul(prevTopLevIns, fnc) - - stTop = api.Sub(prevTop, prevTopLevIns) // st_top <== prev_top - prev_top_lev_ins - stIold = api.Mul(prevTopLevInsFnc, api.Sub(1, is0)) // st_iold <== prev_top_lev_ins_fnc * (1 - is0) - stI0 = api.Mul(prevTopLevIns, is0) // st_i0 <== prev_top_lev_ins * is0; - stInew = api.Sub(prevTopLevIns, prevTopLevInsFnc) // st_inew <== prev_top_lev_ins - prev_top_lev_ins_fnc - stNa = api.Add(prevNa, prevInew, prevIold, prevI0) // st_na <== prev_na + prev_inew + prev_iold + prev_i0 - return -} - -func smtVerifierLevel(api frontend.API, stTop, stIold, stInew, sibling, - old1leaf, new1leaf, lrbit, child frontend.Variable) (frontend.Variable, error) { - // component switcher = Switcher(); - // switcher.sel <== lrbit; - // switcher.L <== child; - // switcher.R <== sibling; - l, r := switcher(api, lrbit, child, sibling) - // component proofHash = SMTHash2(); - // proofHash.L <== switcher.outL; - // proofHash.R <== switcher.outR; - proofHash, err := intermediateLeafValue(api, l, r) - if err != nil { - return 0, err - } - // aux[0] <== proofHash.out * st_top; - aux0 := api.Mul(proofHash, stTop) - // aux[1] <== old1leaf * st_iold; - aux1 := api.Mul(old1leaf, stIold) - // root <== aux[0] + aux[1] + new1leaf * st_inew; - return api.Add(aux0, aux1, api.Mul(new1leaf, stInew)), nil + ForceEqualIfEnabled(api, levels[0], root, enabled) } diff --git a/smt/verifier_level.go b/smt/verifier_level.go new file mode 100644 index 0000000..b68cc8c --- /dev/null +++ b/smt/verifier_level.go @@ -0,0 +1,12 @@ +package smt + +import ( + "github.com/consensys/gnark/frontend" +) + +func VerifierLevel(api frontend.API, stTop, stIOld, stINew, sibling, old1leaf, new1leaf, lrbit, child frontend.Variable) (root frontend.Variable) { + proofHashL, proofHashR := Switcher(api, lrbit, child, sibling) + proofHash := Hash2(api, proofHashL, proofHashR) + root = api.Add(api.Add(api.Mul(proofHash, stTop), api.Mul(old1leaf, stIOld)), api.Mul(new1leaf, stINew)) + return +} diff --git a/smt/verifier_sm.go b/smt/verifier_sm.go new file mode 100644 index 0000000..e0dfd61 --- /dev/null +++ b/smt/verifier_sm.go @@ -0,0 +1,14 @@ +package smt + +import "github.com/consensys/gnark/frontend" + +func VerifierSM(api frontend.API, is0, levIns, fnc, prevTop, prevI0, prevIOld, prevINew, prevNa frontend.Variable) (stTop, stI0, stIOld, stINew, stNa frontend.Variable) { + aux1 := api.Mul(prevTop, levIns) + aux2 := api.Mul(aux1, fnc) + stTop = api.Sub(prevTop, aux1) + stINew = api.Sub(aux1, aux2) + stIOld = api.Mul(aux2, api.Sub(1, is0)) + stI0 = api.Mul(aux1, is0) + stNa = api.Add(api.Add(api.Add(prevNa, prevINew), prevIOld), prevI0) + return +} diff --git a/smt/wrapper.go b/smt/wrapper.go new file mode 100644 index 0000000..4ab27a3 --- /dev/null +++ b/smt/wrapper.go @@ -0,0 +1,31 @@ +package smt + +import ( + "math/big" + + "go.vocdoni.io/dvote/db" +) + +// Wrapper defines methods for wrapping existing SMT implementations, useful for +// generating circuit assignments for generating proof witnesses. See WrapperArbo +// for a concrete example that wrappers the arbo.Tree implementation. +type Wrapper interface { + Proof(key *big.Int) (Assignment, error) + ProofWithTx(tx db.Reader, key *big.Int) (Assignment, error) + SetProof(key, value *big.Int) (Assignment, error) + Set(key, value *big.Int) (Assignment, error) + SetWithTx(tx db.WriteTx, key, value *big.Int) (Assignment, error) +} + +type Assignment struct { + Fnc0 uint8 + Fnc1 uint8 + OldKey *big.Int + NewKey *big.Int + IsOld0 uint8 + OldValue *big.Int + NewValue *big.Int + OldRoot *big.Int + NewRoot *big.Int + Siblings []*big.Int +} diff --git a/smt/wrapper_arbo.go b/smt/wrapper_arbo.go new file mode 100644 index 0000000..2c20146 --- /dev/null +++ b/smt/wrapper_arbo.go @@ -0,0 +1,184 @@ +package smt + +import ( + "errors" + "math/big" + + "go.vocdoni.io/dvote/db" + "go.vocdoni.io/dvote/tree/arbo" +) + +// WrapperArbo wraps an arbo.Tree, generating circuit assignments for certain +// tree operations like Add and Update. +type WrapperArbo struct { + *arbo.Tree + database db.Database + levels uint8 +} + +func NewWrapperArbo(tree *arbo.Tree, database db.Database, levels uint8) Wrapper { + return &WrapperArbo{ + Tree: tree, + database: database, + levels: levels, + } +} + +func (t *WrapperArbo) Proof(key *big.Int) (Assignment, error) { + return t.ProofWithTx(t.database, key) +} + +func (t *WrapperArbo) ProofWithTx(tx db.Reader, key *big.Int) (Assignment, error) { + assignment := Assignment{ + NewKey: key, + } + + rootBytes, err := t.RootWithTx(tx) + if err != nil { + return assignment, err + } + assignment.OldRoot = arbo.BytesToBigInt(rootBytes) + assignment.NewRoot = arbo.BytesToBigInt(rootBytes) + + bLen := t.HashFunction().Len() + keyBytes := arbo.BigIntToBytes(bLen, key) + oldKeyBytes, oldValueBytes, siblingsPacked, exists, err := t.GenProofWithTx(tx, keyBytes) + if err != nil { + return assignment, err + } + + if exists { + assignment.Fnc0 = 0 + assignment.NewValue = arbo.BytesToBigInt(oldValueBytes) + } else { + assignment.Fnc0 = 1 + } + assignment.OldKey = arbo.BytesToBigInt(oldKeyBytes) + assignment.OldValue = arbo.BytesToBigInt(oldValueBytes) + if len(oldKeyBytes) > 0 { + assignment.IsOld0 = 0 + } else { + assignment.IsOld0 = 1 + } + + siblingsUnpacked, err := arbo.UnpackSiblings(t.HashFunction(), siblingsPacked) + if err != nil { + return assignment, err + } + + assignment.Siblings = make([]*big.Int, t.levels) + for i := 0; i < len(assignment.Siblings); i++ { + if i < len(siblingsUnpacked) { + assignment.Siblings[i] = arbo.BytesToBigInt(siblingsUnpacked[i]) + } else { + assignment.Siblings[i] = big.NewInt(0) + } + } + + return assignment, nil +} + +func (t *WrapperArbo) SetProof(key, value *big.Int) (Assignment, error) { + tx := t.database.WriteTx() + defer tx.Discard() + return t.SetWithTx(tx, key, value) +} + +func (t *WrapperArbo) Set(key, value *big.Int) (Assignment, error) { + tx := t.database.WriteTx() + defer tx.Discard() + assignment, err := t.SetWithTx(tx, key, value) + if err == nil { + err = tx.Commit() + } + return assignment, err +} + +func (t *WrapperArbo) SetWithTx(tx db.WriteTx, key, value *big.Int) (Assignment, error) { + return t.addOrUpdate(tx, key, value, func(k, v []byte, exists bool, assignment *Assignment) error { + if exists { + return t.update(tx, k, v, exists, assignment) + } else { + return t.add(tx, k, v, exists, assignment) + } + }) +} + +func (t *WrapperArbo) add(tx db.WriteTx, k, v []byte, _ bool, assignment *Assignment) error { + assignment.Fnc0 = 1 + assignment.Fnc1 = 0 + return t.Tree.AddWithTx(tx, k, v) +} + +func (t *WrapperArbo) update(tx db.WriteTx, k, v []byte, _ bool, assignment *Assignment) error { + assignment.Fnc0 = 0 + assignment.Fnc1 = 1 + return t.Tree.UpdateWithTx(tx, k, v) +} + +func (t *WrapperArbo) addOrUpdate(tx db.WriteTx, key, value *big.Int, action func(k, v []byte, exists bool, assignment *Assignment) error) (Assignment, error) { + assignment := Assignment{ + NewKey: key, + NewValue: value, + } + + oldRootBytes, err := t.RootWithTx(tx) + if err != nil { + return assignment, err + } + assignment.OldRoot = arbo.BytesToBigInt(oldRootBytes) + + bLen := t.HashFunction().Len() + keyBytes := arbo.BigIntToBytes(bLen, key) + valueBytes := arbo.BigIntToBytes(bLen, value) + + oldKeyBytes, oldValueBytes, err := t.Tree.GetWithTx(tx, keyBytes) + if err != nil && !errors.Is(err, arbo.ErrKeyNotFound) { + return assignment, err + } + err = action(keyBytes, valueBytes, err == nil, &assignment) + if err != nil { + return assignment, err + } + + assignment.OldKey = arbo.BytesToBigInt(oldKeyBytes) + assignment.OldValue = arbo.BytesToBigInt(oldValueBytes) + if len(oldKeyBytes) > 0 { + assignment.IsOld0 = 0 + } else { + assignment.IsOld0 = 1 + } + + newRootBytes, err := t.RootWithTx(tx) + if err != nil { + return assignment, err + } + assignment.NewRoot = arbo.BytesToBigInt(newRootBytes) + + _, _, siblingsPacked, exists, err := t.GenProofWithTx(tx, keyBytes) + if !exists { + return assignment, errors.New("key not found") + } + if err != nil { + return assignment, err + } + + siblingsUnpacked, err := arbo.UnpackSiblings(t.HashFunction(), siblingsPacked) + if err != nil { + return assignment, err + } + if assignment.IsOld0 == 0 && assignment.Fnc1 == 0 { + siblingsUnpacked = siblingsUnpacked[0 : len(siblingsUnpacked)-1] + } + + assignment.Siblings = make([]*big.Int, t.levels) + for i := 0; i < len(assignment.Siblings); i++ { + if i < len(siblingsUnpacked) { + assignment.Siblings[i] = arbo.BytesToBigInt(siblingsUnpacked[i]) + } else { + assignment.Siblings[i] = big.NewInt(0) + } + } + + return assignment, nil +}