Skip to content

Commit

Permalink
Add more checks in range proof
Browse files Browse the repository at this point in the history
  • Loading branch information
yycen committed Aug 17, 2023
1 parent f6d255b commit 91bcc64
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 32 deletions.
16 changes: 16 additions & 0 deletions crypto/mta/proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/bnb-chain/tss-lib/common"
"github.com/bnb-chain/tss-lib/crypto"
"github.com/bnb-chain/tss-lib/crypto/paillier"
"github.com/bnb-chain/tss-lib/tss"
)

const (
Expand Down Expand Up @@ -246,6 +247,18 @@ func (pf *ProofBobWC) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde,
if gcd.GCD(nil, nil, pf.V, pk.N).Cmp(one) != 0 {
return false
}
if pf.S1.Cmp(q) == -1 {
return false
}
if pf.S2.Cmp(q) == -1 {
return false
}
if pf.T1.Cmp(q) == -1 {
return false
}
if pf.T2.Cmp(q) == -1 {
return false
}

// 3.
if pf.S1.Cmp(q3) > 0 {
Expand All @@ -263,6 +276,9 @@ func (pf *ProofBobWC) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTilde,
if X == nil {
eHash = common.SHA512_256i(append(pk.AsInts(), c1, c2, pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...)
} else {
if !tss.SameCurve(ec, X.Curve()) {
return false
}
eHash = common.SHA512_256i(append(pk.AsInts(), X.X(), X.Y(), c1, c2, pf.U.X(), pf.U.Y(), pf.Z, pf.ZPrm, pf.T, pf.V, pf.W)...)
}
e = common.RejectionSample(q, eHash)
Expand Down
6 changes: 6 additions & 0 deletions crypto/mta/range_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ func (pf *RangeProofAlice) Verify(ec elliptic.Curve, pk *paillier.PublicKey, NTi
if new(big.Int).GCD(nil, nil, pf.W, NTilde).Cmp(one) != 0 {
return false
}
if pf.S1.Cmp(q) == -1 {
return false
}
if pf.S2.Cmp(q) == -1 {
return false
}

// 3.
if pf.S1.Cmp(q3) == 1 {
Expand Down
70 changes: 38 additions & 32 deletions crypto/mta/range_proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,63 +55,69 @@ func TestProveRangeAliceBypassed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()

sk_0, pk_0, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength)
sk0, pk0, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength)
assert.NoError(t, err)

m_0 := common.GetRandomPositiveInt(q)
c_0, r_0, err := sk_0.EncryptAndReturnRandomness(m_0)
m0 := common.GetRandomPositiveInt(q)
c0, r0, err := sk0.EncryptAndReturnRandomness(m0)
assert.NoError(t, err)

primes_0 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)}
NTildei_0, h1i_0, h2i_0, err := crypto.GenerateNTildei(primes_0)
primes0 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)}
Ntildei0, h1i0, h2i0, err := crypto.GenerateNTildei(primes0)
assert.NoError(t, err)
proof_0, err := ProveRangeAlice(tss.EC(), pk_0, c_0, NTildei_0, h1i_0, h2i_0, m_0, r_0)
proof0, err := ProveRangeAlice(tss.EC(), pk0, c0, Ntildei0, h1i0, h2i0, m0, r0)
assert.NoError(t, err)

ok_0 := proof_0.Verify(tss.EC(), pk_0, NTildei_0, h1i_0, h2i_0, c_0)
assert.True(t, ok_0, "proof must verify")
ok0 := proof0.Verify(tss.EC(), pk0, Ntildei0, h1i0, h2i0, c0)
assert.True(t, ok0, "proof must verify")

//proof 2
sk_1, pk_1, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength)
sk1, pk1, err := paillier.GenerateKeyPair(ctx, testPaillierKeyLength)
assert.NoError(t, err)

m_1 := common.GetRandomPositiveInt(q)
c_1, r_1, err := sk_1.EncryptAndReturnRandomness(m_1)
m1 := common.GetRandomPositiveInt(q)
c1, r1, err := sk1.EncryptAndReturnRandomness(m1)
assert.NoError(t, err)

primes_1 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)}
NTildei_1, h1i_1, h2i_1, err := crypto.GenerateNTildei(primes_1)
primes1 := [2]*big.Int{common.GetRandomPrimeInt(testSafePrimeBits), common.GetRandomPrimeInt(testSafePrimeBits)}
Ntildei1, h1i1, h2i1, err := crypto.GenerateNTildei(primes1)
assert.NoError(t, err)
proof_1, err := ProveRangeAlice(tss.EC(), pk_1, c_1, NTildei_1, h1i_1, h2i_1, m_1, r_1)
proof1, err := ProveRangeAlice(tss.EC(), pk1, c1, Ntildei1, h1i1, h2i1, m1, r1)
assert.NoError(t, err)

ok_1 := proof_1.Verify(tss.EC(), pk_1, NTildei_1, h1i_1, h2i_1, c_1)
assert.True(t, ok_1, "proof must verify")
ok1 := proof1.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, c1)
assert.True(t, ok1, "proof must verify")

cross_0 := proof_0.Verify(tss.EC(), pk_1, NTildei_1, h1i_1, h2i_1, c_1)
assert.False(t, cross_0, "proof must not verify")
cross0 := proof0.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, c1)
assert.False(t, cross0, "proof must not verify")

cross_1 := proof_1.Verify(tss.EC(), pk_0, NTildei_0, h1i_0, h2i_0, c_0)
assert.False(t, cross_1, "proof must not verify")
cross1 := proof1.Verify(tss.EC(), pk0, Ntildei0, h1i0, h2i0, c0)
assert.False(t, cross1, "proof must not verify")

fmt.Println("Did verify proof 0 with data from 0?", ok_0)
fmt.Println("Did verify proof 1 with data from 1?", ok_1)
fmt.Println("Did verify proof 0 with data from 0?", ok0)
fmt.Println("Did verify proof 1 with data from 1?", ok1)

fmt.Println("Did verify proof 0 with data from 1?", cross_0)
fmt.Println("Did verify proof 1 with data from 0?", cross_1)
fmt.Println("Did verify proof 0 with data from 1?", cross0)
fmt.Println("Did verify proof 1 with data from 0?", cross1)

//always passes
bypassedProof := &RangeProofAlice{
S: big.NewInt(0),
//new bypass
bypassedproofNew := &RangeProofAlice{
S: big.NewInt(1),
S1: big.NewInt(0),
S2: big.NewInt(0),
Z: big.NewInt(1),
U: big.NewInt(0),
U: big.NewInt(1),
W: big.NewInt(1),
}

bypassResult_1 := bypassedProof.Verify(tss.EC(), pk_0, NTildei_0, h1i_0, h2i_0, c_0)
fmt.Println("Did we bypass proof 1?", bypassResult_1)
bypassResult_2 := bypassedProof.Verify(tss.EC(), pk_1, NTildei_1, h1i_1, h2i_1, c_1)
fmt.Println("Did we bypass proof 2?", bypassResult_2)
cBogus := big.NewInt(1)
proofBogus, _ := ProveRangeAlice(tss.EC(), pk1, cBogus, Ntildei1, h1i1, h2i1, m1, r1)

ok2 := proofBogus.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, cBogus)
bypassresult3 := bypassedproofNew.Verify(tss.EC(), pk1, Ntildei1, h1i1, h2i1, cBogus)

//c = 1 is not valid, even though we can find a range proof for it that passes!
//this also means that the homo mul and add needs to be checked with this!
fmt.Println("Did verify proof bogus with data from bogus?", ok2)
fmt.Println("Did we bypass proof 3?", bypassresult3)
}
11 changes: 11 additions & 0 deletions tss/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ func GetCurveName(curve elliptic.Curve) (CurveName, bool) {
return "", false
}

// SameCurve returns true if both lhs and rhs are the same known curve
func SameCurve(lhs, rhs elliptic.Curve) bool {
lName, lOk := GetCurveName(lhs)
rName, rOk := GetCurveName(rhs)
if lOk && rOk {
return lName == rName
}
// if lhs/rhs not exist, return false
return false
}

// EC returns the current elliptic curve in use. The default is secp256k1
func EC() elliptic.Curve {
return ec
Expand Down

0 comments on commit 91bcc64

Please sign in to comment.