From 8af4f1ebab640405c799e65d9873847a4acf04f8 Mon Sep 17 00:00:00 2001 From: kilic Date: Fri, 12 Apr 2024 14:40:41 +0300 Subject: [PATCH] MSM optimisations: CycloneMSM (#130) * impl msm with batch addition * bring back multiexp serial * parallelize coeffs to repr Co-authored-by: Han * parallelize bases to affine Co-authored-by: Han * add missing dependency * bring back old implementation postfix new one as `_independent_points` --------- Co-authored-by: Han --- src/msm.rs | 473 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 321 insertions(+), 152 deletions(-) diff --git a/src/msm.rs b/src/msm.rs index ae964cf7..25af9711 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -1,8 +1,14 @@ use std::ops::Neg; +use crate::CurveAffine; +use ff::Field; use ff::PrimeField; use group::Group; -use pasta_curves::arithmetic::CurveAffine; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; + +const BATCH_SIZE: usize = 64; fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { // Booth encoding: @@ -48,6 +54,238 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { } } +fn batch_add( + size: usize, + buckets: &mut [BucketAffine], + points: &[SchedulePoint], + bases: &[Affine], +) { + let mut t = vec![C::Base::ZERO; size]; + let mut z = vec![C::Base::ZERO; size]; + let mut acc = C::Base::ONE; + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut()) + { + *z = buckets[*buck_idx].x() - bases[*base_idx].x; + if *sign { + *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y); + } else { + *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y); + } + acc *= *z; + } + + acc = acc.invert().unwrap(); + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter()).zip(z.iter()).rev() + { + let lambda = acc * t; + acc *= z; + + let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); + if *sign { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y)); + } else { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y)); + } + buckets[*buck_idx].set_x(&x); + } +} + +#[derive(Debug, Clone, Copy)] +struct Affine { + x: C::Base, + y: C::Base, +} + +impl Affine { + fn from(point: &C) -> Self { + let coords = point.coordinates().unwrap(); + + Self { + x: *coords.x(), + y: *coords.y(), + } + } + + fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + fn eval(&self) -> C { + C::from_xy(self.x, self.y).unwrap() + } +} + +#[derive(Debug, Clone)] +enum BucketAffine { + None, + Point(Affine), +} + +#[derive(Debug, Clone)] +enum Bucket { + None, + Point(C::Curve), +} + +impl Bucket { + fn add_assign(&mut self, point: &C, sign: bool) { + *self = match *self { + Bucket::None => Bucket::Point({ + if sign { + point.to_curve() + } else { + point.to_curve().neg() + } + }), + Bucket::Point(a) => { + if sign { + Self::Point(a + point) + } else { + Self::Point(a - point) + } + } + } + } + + fn add(&self, other: &BucketAffine) -> C::Curve { + match (self, other) { + (Self::Point(this), BucketAffine::Point(other)) => *this + other.eval(), + (Self::Point(this), BucketAffine::None) => *this, + (Self::None, BucketAffine::Point(other)) => other.eval().to_curve(), + (Self::None, BucketAffine::None) => C::Curve::identity(), + } + } +} + +impl BucketAffine { + fn assign(&mut self, point: &Affine, sign: bool) -> bool { + match *self { + Self::None => { + *self = Self::Point(if sign { *point } else { point.neg() }); + true + } + Self::Point(_) => false, + } + } + + fn x(&self) -> C::Base { + match self { + Self::None => panic!("::x None"), + Self::Point(a) => a.x, + } + } + + fn y(&self) -> C::Base { + match self { + Self::None => panic!("::y None"), + Self::Point(a) => a.y, + } + } + + fn set_x(&mut self, x: &C::Base) { + match self { + Self::None => panic!("::set_x None"), + Self::Point(ref mut a) => a.x = *x, + } + } + + fn set_y(&mut self, y: &C::Base) { + match self { + Self::None => panic!("::set_y None"), + Self::Point(ref mut a) => a.y = *y, + } + } +} + +struct Schedule { + buckets: Vec>, + set: [SchedulePoint; BATCH_SIZE], + ptr: usize, +} + +#[derive(Debug, Clone, Default)] +struct SchedulePoint { + base_idx: usize, + buck_idx: usize, + sign: bool, +} + +impl SchedulePoint { + fn new(base_idx: usize, buck_idx: usize, sign: bool) -> Self { + Self { + base_idx, + buck_idx, + sign, + } + } +} + +impl Schedule { + fn new(c: usize) -> Self { + let set = (0..BATCH_SIZE) + .map(|_| SchedulePoint::default()) + .collect::>() + .try_into() + .unwrap(); + + Self { + buckets: vec![BucketAffine::None; 1 << (c - 1)], + set, + ptr: 0, + } + } + + fn contains(&self, buck_idx: usize) -> bool { + self.set.iter().any(|sch| sch.buck_idx == buck_idx) + } + + fn execute(&mut self, bases: &[Affine]) { + if self.ptr != 0 { + batch_add(self.ptr, &mut self.buckets, &self.set, bases); + self.ptr = 0; + self.set + .iter_mut() + .for_each(|sch| *sch = SchedulePoint::default()); + } + } + + fn add(&mut self, bases: &[Affine], base_idx: usize, buck_idx: usize, sign: bool) { + if !self.buckets[buck_idx].assign(&bases[base_idx], sign) { + self.set[self.ptr] = SchedulePoint::new(base_idx, buck_idx, sign); + self.ptr += 1; + } + + if self.ptr == self.set.len() { + self.execute(bases); + } + } +} + pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); @@ -121,30 +359,6 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & } } -/// Performs a small multi-exponentiation operation. -/// Uses the double-and-add algorithm with doublings shared across points. -pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); - let mut acc = C::Curve::identity(); - - // for byte idx - for byte_idx in (0..32).rev() { - // for bit idx - for bit_idx in (0..8).rev() { - acc = acc.double(); - // for each coeff - for coeff_idx in 0..coeffs.len() { - let byte = coeffs[coeff_idx].as_ref()[byte_idx]; - if ((byte >> bit_idx) & 1) != 0 { - acc += bases[coeff_idx]; - } - } - } - } - - acc -} - /// Performs a multi-exponentiation operation. /// /// This function will panic if coeffs and bases have a different length. @@ -178,139 +392,96 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu acc } } +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +pub fn best_multiexp_independent_points( + coeffs: &[C::Scalar], + bases: &[C], +) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); -#[cfg(test)] -mod test { - - use std::ops::Neg; - - use crate::bn256::{Fr, G1Affine, G1}; - use ark_std::{end_timer, start_timer}; - use ff::{Field, PrimeField}; - use group::{Curve, Group}; - use pasta_curves::arithmetic::CurveAffine; - use rand_core::OsRng; - - // keeping older implementation it here for baseline comparison, debugging & benchmarking - fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - assert_eq!(coeffs.len(), bases.len()); + // TODO: consider adjusting it with emprical data? + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; - let num_threads = rayon::current_num_threads(); - if coeffs.len() > num_threads { - let chunk = coeffs.len() / num_threads; - let num_chunks = coeffs.chunks(chunk).len(); - let mut results = vec![C::Curve::identity(); num_chunks]; - rayon::scope(|scope| { - let chunk = coeffs.len() / num_threads; - - for ((coeffs, bases), acc) in coeffs - .chunks(chunk) - .zip(bases.chunks(chunk)) - .zip(results.iter_mut()) - { - scope.spawn(move |_| { - multiexp_serial(coeffs, bases, acc); - }); - } - }); - results.iter().fold(C::Curve::identity(), |a, b| a + b) - } else { - let mut acc = C::Curve::identity(); - multiexp_serial(coeffs, bases, &mut acc); - acc - } + if c < 10 { + return best_multiexp(coeffs, bases); } - // keeping older implementation it here for baseline comparison, debugging & benchmarking - fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { - let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); - - let c = if bases.len() < 4 { - 1 - } else if bases.len() < 32 { - 3 - } else { - (f64::from(bases.len() as u32)).ln().ceil() as usize - }; - - fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { - let skip_bits = segment * c; - let skip_bytes = skip_bits / 8; + // coeffs to byte representation + let coeffs: Vec<_> = coeffs.par_iter().map(|a| a.to_repr()).collect(); + // copy bases into `Affine` to skip in on curve check for every access + let bases_local: Vec<_> = bases.par_iter().map(Affine::from).collect(); - if skip_bytes >= 32 { - return 0; - } - - let mut v = [0; 8]; - for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { - *v = *o; + // number of windows + let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1; + // accumumator for each window + let mut acc = vec![C::Curve::identity(); number_of_windows]; + acc.par_iter_mut().enumerate().rev().for_each(|(w, acc)| { + // jacobian buckets for already scheduled points + let mut j_bucks = vec![Bucket::::None; 1 << (c - 1)]; + + // schedular for affine addition + let mut sched = Schedule::new(c); + + for (base_idx, coeff) in coeffs.iter().enumerate() { + let buck_idx = get_booth_index(w, c, coeff.as_ref()); + + if buck_idx != 0 { + // parse bucket index + let sign = buck_idx.is_positive(); + let buck_idx = buck_idx.unsigned_abs() as usize - 1; + + if sched.contains(buck_idx) { + // greedy accumulation + // we use original bases here + j_bucks[buck_idx].add_assign(&bases[base_idx], sign); + } else { + // also flushes the schedule if full + sched.add(&bases_local, base_idx, buck_idx, sign); + } } - - let mut tmp = u64::from_le_bytes(v); - tmp >>= skip_bits - (skip_bytes * 8); - tmp %= 1 << c; - - tmp as usize } - let segments = (256 / c) + 1; - - for current_segment in (0..segments).rev() { - for _ in 0..c { - *acc = acc.double(); - } - - #[derive(Clone, Copy)] - enum Bucket { - None, - Affine(C), - Projective(C::Curve), - } + // flush the schedule + sched.execute(&bases_local); - impl Bucket { - fn add_assign(&mut self, other: &C) { - *self = match *self { - Bucket::None => Bucket::Affine(*other), - Bucket::Affine(a) => Bucket::Projective(a + *other), - Bucket::Projective(mut a) => { - a += *other; - Bucket::Projective(a) - } - } - } + // summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for (j_buck, a_buck) in j_bucks.iter().zip(sched.buckets.iter()).rev() { + running_sum += j_buck.add(a_buck); + *acc += running_sum; + } - fn add(self, mut other: C::Curve) -> C::Curve { - match self { - Bucket::None => other, - Bucket::Affine(a) => { - other += a; - other - } - Bucket::Projective(a) => other + a, - } - } - } + // shift accumulator to the window position + for _ in 0..c * w { + *acc = acc.double(); + } + }); + acc.into_iter().sum::<_>() +} - let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; +#[cfg(test)] +mod test { - for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_at::(current_segment, c, coeff); - if coeff != 0 { - buckets[coeff - 1].add_assign(base); - } - } + use std::ops::Neg; - // Summation by parts - // e.g. 3a + 2b + 1c = a + - // (a) + b + - // ((a) + b) + c - let mut running_sum = C::Curve::identity(); - for exp in buckets.into_iter().rev() { - running_sum = exp.add(running_sum); - *acc += &running_sum; - } - } - } + use crate::bn256::{Fr, G1Affine, G1}; + use ark_std::{end_timer, start_timer}; + use ff::{Field, PrimeField}; + use group::{Curve, Group}; + use pasta_curves::arithmetic::CurveAffine; + use rand_core::OsRng; #[test] fn test_booth_encoding() { @@ -374,21 +545,19 @@ mod test { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - let t0 = start_timer!(|| format!("w/ booth k={}", k)); - let e0 = super::best_multiexp(scalars, points); + let t0 = start_timer!(|| format!("cyclone k={}", k)); + let e0 = super::best_multiexp_independent_points(scalars, points); end_timer!(t0); - let t1 = start_timer!(|| format!("w/o booth k={}", k)); - let e1 = best_multiexp(scalars, points); + let t1 = start_timer!(|| format!("older k={}", k)); + let e1 = super::best_multiexp(scalars, points); end_timer!(t1); - assert_eq!(e0, e1); } } #[test] fn test_msm_cross() { - run_msm_cross::(10, 18); - // run_msm_cross::(19, 23); + run_msm_cross::(14, 22); } }