Skip to content

Commit

Permalink
Reconcile the two copies of scalar_vector.rs in monero-serai
Browse files Browse the repository at this point in the history
  • Loading branch information
kayabaNerve committed Mar 2, 2024
1 parent b427f4b commit 5629c94
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 226 deletions.
41 changes: 27 additions & 14 deletions coins/monero/src/ringct/bulletproofs/original.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,28 @@ use curve25519_dalek::{scalar::Scalar as DalekScalar, edwards::EdwardsPoint as D
use group::{ff::Field, Group};
use dalek_ff_group::{ED25519_BASEPOINT_POINT as G, Scalar, EdwardsPoint};

use multiexp::BatchVerifier;
use multiexp::{BatchVerifier, multiexp};

use crate::{Commitment, ringct::bulletproofs::core::*};

include!(concat!(env!("OUT_DIR"), "/generators.rs"));

static IP12_CELL: OnceLock<Scalar> = OnceLock::new();
pub(crate) fn IP12() -> Scalar {
*IP12_CELL.get_or_init(|| inner_product(&ScalarVector(vec![Scalar::ONE; N]), TWO_N()))
*IP12_CELL.get_or_init(|| ScalarVector(vec![Scalar::ONE; N]).inner_product(TWO_N()))
}

pub(crate) fn hadamard_fold(
l: &[EdwardsPoint],
r: &[EdwardsPoint],
a: Scalar,
b: Scalar,
) -> Vec<EdwardsPoint> {
let mut res = Vec::with_capacity(l.len() / 2);
for i in 0 .. l.len() {
res.push(multiexp(&[(a, l[i]), (b, r[i])]));
}
res
}

#[derive(Clone, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -57,7 +70,7 @@ impl OriginalStruct {
let mut cache = hash_to_scalar(&y.to_bytes());
let z = cache;

let l0 = &aL - z;
let l0 = aL - z;
let l1 = sL;

let mut zero_twos = Vec::with_capacity(MN);
Expand All @@ -69,12 +82,12 @@ impl OriginalStruct {
}

let yMN = ScalarVector::powers(y, MN);
let r0 = (&(aR + z) * &yMN) + ScalarVector(zero_twos);
let r1 = yMN * sR;
let r0 = ((aR + z) * &yMN) + &ScalarVector(zero_twos);
let r1 = yMN * &sR;

let (T1, T2, x, mut taux) = {
let t1 = inner_product(&l0, &r1) + inner_product(&l1, &r0);
let t2 = inner_product(&l1, &r1);
let t1 = l0.clone().inner_product(&r1) + r0.clone().inner_product(&l1);
let t2 = l1.clone().inner_product(&r1);

let mut tau1 = Scalar::random(&mut *rng);
let mut tau2 = Scalar::random(&mut *rng);
Expand All @@ -100,10 +113,10 @@ impl OriginalStruct {
taux += zpow[i + 2] * gamma;
}

let l = &l0 + &(l1 * x);
let r = &r0 + &(r1 * x);
let l = l0 + &(l1 * x);
let r = r0 + &(r1 * x);

let t = inner_product(&l, &r);
let t = l.clone().inner_product(&r);

let x_ip =
hash_cache(&mut cache, &[x.to_bytes(), taux.to_bytes(), mu.to_bytes(), t.to_bytes()]);
Expand All @@ -126,8 +139,8 @@ impl OriginalStruct {
let (aL, aR) = a.split();
let (bL, bR) = b.split();

let cL = inner_product(&aL, &bR);
let cR = inner_product(&aR, &bL);
let cL = aL.clone().inner_product(&bR);
let cR = aR.clone().inner_product(&bL);

let (G_L, G_R) = G_proof.split_at(aL.len());
let (H_L, H_R) = H_proof.split_at(aL.len());
Expand All @@ -140,8 +153,8 @@ impl OriginalStruct {
let w = hash_cache(&mut cache, &[L_i.compress().to_bytes(), R_i.compress().to_bytes()]);
let winv = w.invert().unwrap();

a = (aL * w) + (aR * winv);
b = (bL * winv) + (bR * w);
a = (aL * w) + &(aR * winv);
b = (bL * winv) + &(bR * w);

if a.len() != 1 {
G_proof = hadamard_fold(G_L, G_R, winv, w);
Expand Down
26 changes: 17 additions & 9 deletions coins/monero/src/ringct/bulletproofs/plus/aggregate_range_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl AggregateRangeStatement {
let mut d = ScalarVector::new(mn);
for j in 1 ..= V.len() {
z_pow.push(z.pow(Scalar::from(2 * u64::try_from(j).unwrap()))); // TODO: Optimize this
d = d.add_vec(&Self::d_j(j, V.len()).mul(z_pow[j - 1]));
d = d + &(Self::d_j(j, V.len()) * (z_pow[j - 1]));
}

let mut ascending_y = ScalarVector(vec![y]);
Expand All @@ -124,7 +124,8 @@ impl AggregateRangeStatement {
let mut descending_y = ascending_y.clone();
descending_y.0.reverse();

let d_descending_y = d.mul_vec(&descending_y);
let d_descending_y = d.clone() * &descending_y;
let d_descending_y_plus_z = d_descending_y + z;

let y_mn_plus_one = descending_y[0] * y;

Expand All @@ -135,17 +136,24 @@ impl AggregateRangeStatement {

let neg_z = -z;
let mut A_terms = Vec::with_capacity((generators.len() * 2) + 2);
for (i, d_y_z) in d_descending_y.add(z).0.drain(..).enumerate() {
for (i, d_y_z) in d_descending_y_plus_z.0.iter().enumerate() {
A_terms.push((neg_z, generators.generator(GeneratorsList::GBold1, i)));
A_terms.push((d_y_z, generators.generator(GeneratorsList::HBold1, i)));
A_terms.push((*d_y_z, generators.generator(GeneratorsList::HBold1, i)));
}
A_terms.push((y_mn_plus_one, commitment_accum));
A_terms.push((
((y_pows * z) - (d.sum() * y_mn_plus_one * z) - (y_pows * z.square())),
Generators::g(),
));

(y, d_descending_y, y_mn_plus_one, z, ScalarVector(z_pow), A + multiexp_vartime(&A_terms))
(
y,
d_descending_y_plus_z,
y_mn_plus_one,
z,
ScalarVector(z_pow),
A + multiexp_vartime(&A_terms),
)
}

pub(crate) fn prove<R: RngCore + CryptoRng>(
Expand Down Expand Up @@ -191,7 +199,7 @@ impl AggregateRangeStatement {
a_l.0.append(&mut u64_decompose(*witness.values.get(j - 1).unwrap_or(&0)).0);
}

let a_r = a_l.sub(Scalar::ONE);
let a_r = a_l.clone() - Scalar::ONE;

let alpha = Scalar::random(&mut *rng);

Expand All @@ -209,11 +217,11 @@ impl AggregateRangeStatement {
// Multiply by INV_EIGHT per earlier commentary
A.0 *= crate::INV_EIGHT();

let (y, d_descending_y, y_mn_plus_one, z, z_pow, A_hat) =
let (y, d_descending_y_plus_z, y_mn_plus_one, z, z_pow, A_hat) =
Self::compute_A_hat(PointVector(V), &generators, &mut transcript, A);

let a_l = a_l.sub(z);
let a_r = a_r.add_vec(&d_descending_y).add(z);
let a_l = a_l - z;
let a_r = a_r + &d_descending_y_plus_z;
let mut alpha = alpha;
for j in 1 ..= witness.gammas.len() {
alpha += z_pow[j - 1] * witness.gammas[j - 1] * y_mn_plus_one;
Expand Down
3 changes: 1 addition & 2 deletions coins/monero/src/ringct/bulletproofs/plus/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
use group::Group;
use dalek_ff_group::{Scalar, EdwardsPoint};

mod scalar_vector;
pub(crate) use scalar_vector::{ScalarVector, weighted_inner_product};
pub(crate) use crate::ringct::bulletproofs::scalar_vector::ScalarVector;
mod point_vector;
pub(crate) use point_vector::PointVector;

Expand Down
114 changes: 0 additions & 114 deletions coins/monero/src/ringct/bulletproofs/plus/scalar_vector.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ use rand_core::{RngCore, CryptoRng};

use zeroize::{Zeroize, ZeroizeOnDrop};

use multiexp::{multiexp, multiexp_vartime, BatchVerifier};
use multiexp::{BatchVerifier, multiexp, multiexp_vartime};
use group::{
ff::{Field, PrimeField},
GroupEncoding,
};
use dalek_ff_group::{Scalar, EdwardsPoint};

use crate::ringct::bulletproofs::plus::{
ScalarVector, PointVector, GeneratorsList, Generators, padded_pow_of_2, weighted_inner_product,
transcript::*,
ScalarVector, PointVector, GeneratorsList, Generators, padded_pow_of_2, transcript::*,
};

// Figure 1
Expand Down Expand Up @@ -219,7 +218,7 @@ impl WipStatement {
.zip(g_bold.0.iter().copied())
.chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
.collect::<Vec<_>>();
P_terms.push((weighted_inner_product(&witness.a, &witness.b, &y), g));
P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
P_terms.push((witness.alpha, h));
debug_assert_eq!(multiexp(&P_terms), P);
P_terms.zeroize();
Expand Down Expand Up @@ -258,14 +257,13 @@ impl WipStatement {
let d_l = Scalar::random(&mut *rng);
let d_r = Scalar::random(&mut *rng);

let c_l = weighted_inner_product(&a1, &b2, &y);
let c_r = weighted_inner_product(&(a2.mul(y_n_hat)), &b1, &y);
let c_l = a1.clone().weighted_inner_product(&b2, &y);
let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);

// TODO: Calculate these with a batch inversion
let y_inv_n_hat = y_n_hat.invert().unwrap();

let mut L_terms = a1
.mul(y_inv_n_hat)
let mut L_terms = (a1.clone() * y_inv_n_hat)
.0
.drain(..)
.zip(g_bold2.0.iter().copied())
Expand All @@ -277,8 +275,7 @@ impl WipStatement {
L_vec.push(L);
L_terms.zeroize();

let mut R_terms = a2
.mul(y_n_hat)
let mut R_terms = (a2.clone() * y_n_hat)
.0
.drain(..)
.zip(g_bold1.0.iter().copied())
Expand All @@ -294,8 +291,8 @@ impl WipStatement {
(e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);

a = a1.mul(e).add_vec(&a2.mul(y_n_hat * inv_e));
b = b1.mul(inv_e).add_vec(&b2.mul(e));
a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
b = (b1 * inv_e) + &(b2 * e);
alpha += (d_l * e_square) + (d_r * inv_e_square);

debug_assert_eq!(g_bold.len(), a.len());
Expand Down
Loading

0 comments on commit 5629c94

Please sign in to comment.