From d7454afb1111e531689ee4823a606eb8974cefbb Mon Sep 17 00:00:00 2001 From: Austin Abell Date: Wed, 18 Dec 2024 20:25:42 -0500 Subject: [PATCH] accelerate the to_mont multiply --- p256/src/arithmetic.rs | 5 +++++ p256/src/arithmetic/field.rs | 32 ++++++++++++++++++++++++++++---- primeorder/src/lib.rs | 2 ++ primeorder/src/risc0.rs | 30 ++++++++++++++---------------- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/p256/src/arithmetic.rs b/p256/src/arithmetic.rs index 9fcea6ac..1932efbb 100644 --- a/p256/src/arithmetic.rs +++ b/p256/src/arithmetic.rs @@ -73,4 +73,9 @@ impl PrimeCurveParams for NistP256 { #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] const EQUATION_B_LE: FieldElement256 = FieldElement256::new_unchecked(crate::risc0::SECP256R1_EQUATION_B_LE); + + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn from_u32_words_le(words: [u32; 8]) -> elliptic_curve::subtle::CtOption { + FieldElement::from_words_le(words) + } } diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index 8b9d6b45..b848c9bb 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -30,6 +30,14 @@ pub const MODULUS: U256 = U256::from_be_hex(MODULUS_HEX); const R_2: U256 = U256::from_be_hex("00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003"); +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use primeorder::risc0::FieldElement256; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +const R_2_LE: FieldElement256 = FieldElement256::new_unchecked([ + 0x00000001, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFE, 0x00000000, +]); + /// An element in the finite field modulo p = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1. /// /// The internal representation is in little-endian order. Elements are always in @@ -54,6 +62,25 @@ primeorder::impl_mont_field_element!( ); impl FieldElement { + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + #[inline(never)] + pub(crate) fn from_words_le(fe: [u32; 8]) -> CtOption { + // use elliptic_curve::bigint::Encoding; + // println!("r2: {:0X?}", fe_from_montgomery(R_2.as_words())); + + let fe = FieldElement256::new_unchecked(fe); + let mut mont = FieldElement256::default(); + fe.mul_unchecked(&R_2_LE, &mut mont); + + let buffer: [u32; 8] = mont.data; + + use crate::elliptic_curve::subtle::ConstantTimeLess as _; + let uint = U256::from_le_slice(bytemuck::cast_slice::(&buffer)); + let is_within_modulus = uint.ct_lt(&MODULUS); + + CtOption::new(Self(uint), is_within_modulus) + } + /// Returns the multiplicative inverse of self, if self is non-zero. #[inline(never)] pub fn invert(&self) -> CtOption { @@ -70,8 +97,7 @@ impl FieldElement { &crate::risc0::SECP256R1_PRIME, &mut output, ); - let bytes = bytemuck::cast_slice::(&output); - FieldElement::from_uint(U256::from_le_slice(bytes)) + FieldElement::from_words_le(output) } #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] @@ -114,8 +140,6 @@ impl FieldElement { // alpha = ± beta^((p + 1) / 4) mod p // // Thus sqrt can be implemented with a single exponentiation. - - // TODO apply acceleration let t11 = self.mul(&self.square()); let t1111 = t11.mul(&t11.sqn(2)); diff --git a/primeorder/src/lib.rs b/primeorder/src/lib.rs index 1121d792..8f6bbfe3 100644 --- a/primeorder/src/lib.rs +++ b/primeorder/src/lib.rs @@ -92,4 +92,6 @@ pub trait PrimeCurveParams: /// Coefficient `b` in the curve equation in little-endian words to be compatible with risc0 /// expected layout. const EQUATION_B_LE: risc0::FieldElement256; + + fn from_u32_words_le(words: [u32; 8]) -> elliptic_curve::subtle::CtOption; } diff --git a/primeorder/src/risc0.rs b/primeorder/src/risc0.rs index bb61b0d5..241f966e 100644 --- a/primeorder/src/risc0.rs +++ b/primeorder/src/risc0.rs @@ -12,7 +12,7 @@ use crate::PrimeCurveParams; /// Representation of a field element in raw bytes form. This is not in montgomery form. #[derive(Copy, Clone, Default, Debug, PartialEq, Eq)] pub struct FieldElement256 { - pub(crate) data: [u32; 8], + pub data: [u32; 8], _phantom: PhantomData, } @@ -73,7 +73,7 @@ impl FieldElement256 where C: PrimeCurveParams, { - pub(crate) fn mul_unchecked(&self, rhs: &Self, result: &mut Self) { + pub fn mul_unchecked(&self, rhs: &Self, result: &mut Self) { risc0_bigint2::field::modmul_256_unchecked( &self.data, &rhs.data, @@ -82,7 +82,7 @@ where ); } - pub(crate) fn add_unchecked(&self, rhs: &Self, result: &mut Self) { + pub fn add_unchecked(&self, rhs: &Self, result: &mut Self) { risc0_bigint2::field::modadd_256_unchecked( &self.data, &rhs.data, @@ -188,6 +188,7 @@ where let mut y_bytes_arr: [u8; 32] = y_bytes.as_slice().try_into().unwrap(); x_bytes_arr.reverse(); y_bytes_arr.reverse(); + // TODO make more alignment safe let x = bytemuck::cast::<_, [u32; 8]>(x_bytes_arr); let y = bytemuck::cast::<_, [u32; 8]>(y_bytes_arr); ec::AffinePoint::new_unchecked(x, y) @@ -200,19 +201,16 @@ where C: PrimeCurveParams, { if let Some(value) = affine.as_u32s() { - // TODO a lot of potentially unnecessary copying here. - let mut x = bytemuck::cast::<_, [u8; 32]>(value[0]); - let mut y = bytemuck::cast::<_, [u8; 32]>(value[1]); - x.reverse(); - y.reverse(); - let x_arr = GenericArray::from_slice(&x); - let y_arr = GenericArray::from_slice(&y); - let affine = AffinePoint { - x: C::FieldElement::from_repr(x_arr.clone()).unwrap(), - y: C::FieldElement::from_repr(y_arr.clone()).unwrap(), - infinity: 0, - }; - ProjectivePoint::from(affine) + let x = C::from_u32_words_le(value[0]); + let y = C::from_u32_words_le(value[1]); + + x.and_then(|x| { + y.map(|y| { + let affine = AffinePoint { x, y, infinity: 0 }; + ProjectivePoint::from(affine) + }) + }) + .unwrap_or(ProjectivePoint::IDENTITY) } else { ProjectivePoint::IDENTITY }