Skip to content

Commit

Permalink
accelerate the to_mont multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
austinabell committed Dec 19, 2024
1 parent fb939a8 commit d7454af
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
5 changes: 5 additions & 0 deletions p256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,9 @@ impl PrimeCurveParams for NistP256 {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const EQUATION_B_LE: FieldElement256<NistP256> =
FieldElement256::new_unchecked(crate::risc0::SECP256R1_EQUATION_B_LE);

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
fn from_u32_words_le(words: [u32; 8]) -> elliptic_curve::subtle::CtOption<FieldElement> {
FieldElement::from_words_le(words)
}
}
32 changes: 28 additions & 4 deletions p256/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ pub const MODULUS: U256 = U256::from_be_hex(MODULUS_HEX);
const R_2: U256 =
U256::from_be_hex("00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003");

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
use primeorder::risc0::FieldElement256;

#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
const R_2_LE: FieldElement256<NistP256> = FieldElement256::new_unchecked([
0x00000001, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFE, 0x00000000,
]);

/// An element in the finite field modulo p = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1.
///
/// The internal representation is in little-endian order. Elements are always in
Expand All @@ -54,6 +62,25 @@ primeorder::impl_mont_field_element!(
);

impl FieldElement {
#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))]
#[inline(never)]
pub(crate) fn from_words_le(fe: [u32; 8]) -> CtOption<Self> {
// use elliptic_curve::bigint::Encoding;
// println!("r2: {:0X?}", fe_from_montgomery(R_2.as_words()));

let fe = FieldElement256::new_unchecked(fe);
let mut mont = FieldElement256::default();
fe.mul_unchecked(&R_2_LE, &mut mont);

let buffer: [u32; 8] = mont.data;

use crate::elliptic_curve::subtle::ConstantTimeLess as _;
let uint = U256::from_le_slice(bytemuck::cast_slice::<u32, u8>(&buffer));
let is_within_modulus = uint.ct_lt(&MODULUS);

CtOption::new(Self(uint), is_within_modulus)
}

/// Returns the multiplicative inverse of self, if self is non-zero.
#[inline(never)]
pub fn invert(&self) -> CtOption<Self> {
Expand All @@ -70,8 +97,7 @@ impl FieldElement {
&crate::risc0::SECP256R1_PRIME,
&mut output,
);
let bytes = bytemuck::cast_slice::<u32, u8>(&output);
FieldElement::from_uint(U256::from_le_slice(bytes))
FieldElement::from_words_le(output)
}

#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))]
Expand Down Expand Up @@ -114,8 +140,6 @@ impl FieldElement {
// alpha = ± beta^((p + 1) / 4) mod p
//
// Thus sqrt can be implemented with a single exponentiation.

// TODO apply acceleration

let t11 = self.mul(&self.square());
let t1111 = t11.mul(&t11.sqn(2));
Expand Down
2 changes: 2 additions & 0 deletions primeorder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,6 @@ pub trait PrimeCurveParams:
/// Coefficient `b` in the curve equation in little-endian words to be compatible with risc0
/// expected layout.
const EQUATION_B_LE: risc0::FieldElement256<Self>;

fn from_u32_words_le(words: [u32; 8]) -> elliptic_curve::subtle::CtOption<Self::FieldElement>;
}
30 changes: 14 additions & 16 deletions primeorder/src/risc0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::PrimeCurveParams;
/// Representation of a field element in raw bytes form. This is not in montgomery form.
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
pub struct FieldElement256<C> {
pub(crate) data: [u32; 8],
pub data: [u32; 8],
_phantom: PhantomData<C>,
}

Expand Down Expand Up @@ -73,7 +73,7 @@ impl<C> FieldElement256<C>
where
C: PrimeCurveParams,
{
pub(crate) fn mul_unchecked(&self, rhs: &Self, result: &mut Self) {
pub fn mul_unchecked(&self, rhs: &Self, result: &mut Self) {
risc0_bigint2::field::modmul_256_unchecked(
&self.data,
&rhs.data,
Expand All @@ -82,7 +82,7 @@ where
);
}

pub(crate) fn add_unchecked(&self, rhs: &Self, result: &mut Self) {
pub fn add_unchecked(&self, rhs: &Self, result: &mut Self) {
risc0_bigint2::field::modadd_256_unchecked(
&self.data,
&rhs.data,
Expand Down Expand Up @@ -188,6 +188,7 @@ where
let mut y_bytes_arr: [u8; 32] = y_bytes.as_slice().try_into().unwrap();
x_bytes_arr.reverse();
y_bytes_arr.reverse();
// TODO make more alignment safe
let x = bytemuck::cast::<_, [u32; 8]>(x_bytes_arr);
let y = bytemuck::cast::<_, [u32; 8]>(y_bytes_arr);
ec::AffinePoint::new_unchecked(x, y)
Expand All @@ -200,19 +201,16 @@ where
C: PrimeCurveParams,
{
if let Some(value) = affine.as_u32s() {
// TODO a lot of potentially unnecessary copying here.
let mut x = bytemuck::cast::<_, [u8; 32]>(value[0]);
let mut y = bytemuck::cast::<_, [u8; 32]>(value[1]);
x.reverse();
y.reverse();
let x_arr = GenericArray::from_slice(&x);
let y_arr = GenericArray::from_slice(&y);
let affine = AffinePoint {
x: C::FieldElement::from_repr(x_arr.clone()).unwrap(),
y: C::FieldElement::from_repr(y_arr.clone()).unwrap(),
infinity: 0,
};
ProjectivePoint::from(affine)
let x = C::from_u32_words_le(value[0]);
let y = C::from_u32_words_le(value[1]);

x.and_then(|x| {
y.map(|y| {
let affine = AffinePoint { x, y, infinity: 0 };
ProjectivePoint::from(affine)
})
})
.unwrap_or(ProjectivePoint::IDENTITY)
} else {
ProjectivePoint::IDENTITY
}
Expand Down

0 comments on commit d7454af

Please sign in to comment.