diff --git a/rangeproof/innerproduct/innerproduct.go b/rangeproof/innerproduct/innerproduct.go index 74b211e..d6f9cc8 100644 --- a/rangeproof/innerproduct/innerproduct.go +++ b/rangeproof/innerproduct/innerproduct.go @@ -40,7 +40,7 @@ func Generate(GVec, HVec []ristretto.Point, aVec, bVec, HprimeFactors []ristrett H := make([]ristretto.Point, len(HVec)) copy(H, HVec) - hs := fiatshamir.HashCacher{[]byte{}} + hs := fiatshamir.HashCacher{Cache: []byte{}} lgN := bits.TrailingZeros(nextPow2(uint(n))) @@ -390,24 +390,25 @@ func (proof *Proof) Verify(G, H, L, R []ristretto.Point, HprimeFactor []ristrett return have.Equals(&P) } -func (p *Proof) Encode(w io.Writer) error { +// Encode a proof +func (proof *Proof) Encode(w io.Writer) error { - err := binary.Write(w, binary.BigEndian, p.A.Bytes()) + err := binary.Write(w, binary.BigEndian, proof.A.Bytes()) if err != nil { return err } - err = binary.Write(w, binary.BigEndian, p.B.Bytes()) + err = binary.Write(w, binary.BigEndian, proof.B.Bytes()) if err != nil { return err } - lenL := uint32(len(p.L)) + lenL := uint32(len(proof.L)) for i := uint32(0); i < lenL; i++ { - err = binary.Write(w, binary.BigEndian, p.L[i].Bytes()) + err = binary.Write(w, binary.BigEndian, proof.L[i].Bytes()) if err != nil { return err } - err = binary.Write(w, binary.BigEndian, p.R[i].Bytes()) + err = binary.Write(w, binary.BigEndian, proof.R[i].Bytes()) if err != nil { return err } @@ -415,8 +416,9 @@ func (p *Proof) Encode(w io.Writer) error { return nil } -func (p *Proof) Decode(r io.Reader) error { - if p == nil { +// Decode a proof +func (proof *Proof) Decode(r io.Reader) error { + if proof == nil { return errors.New("struct is nil") } @@ -429,8 +431,8 @@ func (p *Proof) Decode(r io.Reader) error { if err != nil { return err } - p.A.SetBytes(&ABytes) - p.B.SetBytes(&BBytes) + proof.A.SetBytes(&ABytes) + proof.B.SetBytes(&BBytes) buf := &bytes.Buffer{} _, err = buf.ReadFrom(r) @@ -443,8 +445,8 @@ func (p *Proof) Decode(r io.Reader) error { } lenL := uint32(numBytes / 64) - p.L = make([]ristretto.Point, lenL) - p.R = make([]ristretto.Point, lenL) + proof.L = make([]ristretto.Point, lenL) + proof.R = make([]ristretto.Point, lenL) for i := uint32(0); i < lenL; i++ { var LBytes, RBytes [32]byte @@ -456,31 +458,32 @@ func (p *Proof) Decode(r io.Reader) error { if err != nil { return err } - p.L[i].SetBytes(&LBytes) - p.R[i].SetBytes(&RBytes) + proof.L[i].SetBytes(&LBytes) + proof.R[i].SetBytes(&RBytes) } return nil } -func (p *Proof) Equals(other Proof) bool { - ok := p.A.Equals(&other.A) +// Equals tests for Equality two proofs +func (proof *Proof) Equals(other Proof) bool { + ok := proof.A.Equals(&other.A) if !ok { return ok } - ok = p.B.Equals(&other.B) + ok = proof.B.Equals(&other.B) if !ok { return ok } - for i := range p.L { - ok := p.L[i].Equals(&other.L[i]) + for i := range proof.L { + ok := proof.L[i].Equals(&other.L[i]) if !ok { return ok } - ok = p.R[i].Equals(&other.R[i]) + ok = proof.R[i].Equals(&other.R[i]) if !ok { return ok } @@ -503,6 +506,8 @@ func isPower2(n uint32) bool { return (n & (n - 1)) == 0 } +// DiffNextPow2 returns the difference between a given number and the next +// power of two func DiffNextPow2(n uint32) uint32 { pow2 := nextPow2(uint(n)) padAmount := uint32(pow2) - n + 1 diff --git a/rangeproof/rangeproof.go b/rangeproof/rangeproof.go index 98c5934..5de3984 100644 --- a/rangeproof/rangeproof.go +++ b/rangeproof/rangeproof.go @@ -69,7 +69,7 @@ func Prove(v []ristretto.Scalar, debug bool) (Proof, error) { ped.BaseVector.Compute(uint32((N * M))) // Hash for Fiat-Shamir - hs := fiatshamir.HashCacher{[]byte{}} + hs := fiatshamir.HashCacher{Cache: []byte{}} for _, amount := range v { // compute commmitment to v @@ -335,7 +335,7 @@ func Verify(p Proof) (bool, error) { H := ped2.BaseVector.Bases // Reconstruct the challenges - hs := fiatshamir.HashCacher{[]byte{}} + hs := fiatshamir.HashCacher{Cache: []byte{}} for _, V := range p.V { hs.Append(V.Value.Bytes()) } @@ -492,6 +492,7 @@ func megacheckWithC(ipproof *innerproduct.Proof, mu, x, y, z, t, taux, w ristret return true, nil } +// Encode a Proof func (p *Proof) Encode(w io.Writer, includeCommits bool) error { if includeCommits { @@ -532,6 +533,7 @@ func (p *Proof) Encode(w io.Writer, includeCommits bool) error { return p.IPProof.Encode(w) } +// Decode a Proof func (p *Proof) Decode(r io.Reader, includeCommits bool) error { if p == nil { @@ -578,6 +580,7 @@ func (p *Proof) Decode(r io.Reader, includeCommits bool) error { return p.IPProof.Decode(r) } +// Equals tests proof for equality func (p *Proof) Equals(other Proof, includeCommits bool) bool { if len(p.V) != len(other.V) && includeCommits { return false @@ -619,7 +622,7 @@ func (p *Proof) Equals(other Proof, includeCommits bool) bool { return ok } return true - return p.IPProof.Equals(*other.IPProof) + //return p.IPProof.Equals(*other.IPProof) } func readerToPoint(r io.Reader, p *ristretto.Point) error {